From e87c79ca0cbab476a7d09853b5830b615a62f679 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 10 Nov 2022 03:04:57 +0000 Subject: [PATCH 001/453] [vision hash update] update the pinned vision hash (#88742) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88742 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index b985bb4d5e30..d8180093d885 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -bf58902b2fd881c760cd2eeacfae2d7c468ebf1f +ffd5a567eb90abf6b5555063da434d3c130d540f From dcefea2706fb35ece5e49fc138d952a2acd15824 Mon Sep 17 00:00:00 2001 From: efiks <5167930+efiks@users.noreply.github.com> Date: Thu, 10 Nov 2022 06:11:05 +0000 Subject: [PATCH 002/453] [caffe2][tourch] Optimize BatchBoxCox (#87585) Differential Revision: D40215424 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87585 Approved by: https://github.com/hyuen --- caffe2/perfkernels/batch_box_cox_avx2.cc | 117 ++++++++++++++++++++--- caffe2/perfkernels/lstm_unit_cpu-impl.h | 22 +---- caffe2/perfkernels/vectorizer.h | 28 ++++++ 3 files changed, 131 insertions(+), 36 deletions(-) create mode 100644 caffe2/perfkernels/vectorizer.h diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc index cf0801b4733e..8b93293646db 100644 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ b/caffe2/perfkernels/batch_box_cox_avx2.cc @@ -3,6 +3,35 @@ #include #include +#include "vectorizer.h" + +#ifndef VECTORIZED_KERNEL +#define CPU_CAPABILITY_AVX2 +#include + +namespace at::vec { + +template +Vectorized max(const Vectorized& a, const Vectorized& b); + +// Implements the vectorized version of std::max() operation, +// which DOESNOT propagates NaN for second argument +template <> +Vectorized max(const Vectorized& a, const Vectorized& b) { + // std::max(NaN, nonNan) -> NaN + return _mm256_max_pd(b, a); +} + + +template <> +Vectorized max(const Vectorized& a, const Vectorized& b) { + // std::max(NaN, nonNan) -> NaN + return _mm256_max_ps(b, a); +} + +} +#endif + #include #include #include @@ -65,6 +94,7 @@ DELEGATE_SIMPLE_UNARY_FUNCTION(float, Ln, vsLn) DELEGATE_SIMPLE_UNARY_FUNCTION(double, Ln, vdLn) #undef DELEGATE_SIMPLE_UNARY_FUNCTION +#ifndef VECTORIZED_KERNEL template void box_cox_zero_lambda( size_t D, @@ -72,36 +102,93 @@ void box_cox_zero_lambda( const T* const lambda2_data, T k_eps, T* const output_data) { - Add(D, self_data, lambda2_data, output_data); - for (const auto j : c10::irange(D)) { - output_data[j] = std::max(output_data[j], k_eps); + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(self_data + j); + auto lambda2 = Vec::loadu(lambda2_data + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto res = max.log(); + res.store(output_data + j); + } + for ( ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); } - - Ln(D, output_data, output_data); } template void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + int j = 0; + using Vec = at::vec::Vectorized; + constexpr int64_t VLEN = Vec::size(); + auto k_eps_vec = Vec(k_eps); + for(; j + VLEN < D; j += VLEN) { + auto data = Vec::loadu(data_ptr + j); + auto lambda2 = Vec::loadu(lambda2_ptr + j); + auto sum = data + lambda2; + auto max = at::vec::max(sum, k_eps_vec); + auto lambda1 = Vec::loadu(lambda1_ptr + j); + auto lambda_over_1 = lambda1.reciprocal(); + auto pow = max.pow(lambda1); + auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); + res.store(out + j); + } + for ( ;j < D; ++j) { + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lambda_over_1 = 1 / lambda1_ptr[j]; + auto pow = std::pow(max, lambda1_ptr[j]); + out[j] = pow * lambda_over_1 - lambda_over_1; + } +} +#else +template +void box_cox_zero_lambda( size_t D, const T* const self_data, - const T* const lambda1_data, const T* const lambda2_data, T k_eps, T* const output_data) { - Add(D, self_data, lambda2_data, output_data); - for (const auto j : c10::irange(D)) { - output_data[j] = std::max(output_data[j], k_eps); + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + auto sum = self_data[j] + lambda2_data[j]; + auto max = std::max(sum, k_eps); + output_data[j] = std::log(max); } +} - // output = output ^ lambda1 - Pow(D, output_data, lambda1_data, output_data); - // output = (output - 1)/ lambda1 - for (const auto j : c10::irange(D)) { - output_data[j] -= 1.0; +template +void box_cox_nonzero_lambda( + int64_t D, + const T* data_ptr, + const T* lambda1_ptr, + const T* lambda2_ptr, + T k_eps, + T* out) { + + VECTOR_LOOP for (auto j=0 ;j < D; ++j) { + FAST_MATH + auto sum = data_ptr[j] + lambda2_ptr[j]; + auto max = std::max(sum, k_eps); + auto lambda_over_1 = 1 / lambda1_ptr[j]; + auto pow = std::pow(max, lambda1_ptr[j]); + out[j] = pow * lambda_over_1 - lambda_over_1; } - Div(D, output_data, lambda1_data, output_data); } +#endif + template void box_cox_mixed_lambda( const T* const self_data, diff --git a/caffe2/perfkernels/lstm_unit_cpu-impl.h b/caffe2/perfkernels/lstm_unit_cpu-impl.h index 5e76e1aa39fe..239d2807f778 100644 --- a/caffe2/perfkernels/lstm_unit_cpu-impl.h +++ b/caffe2/perfkernels/lstm_unit_cpu-impl.h @@ -5,27 +5,7 @@ #include "c10/util/irange.h" #include "caffe2/utils/conversions.h" -#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) -#if defined(__clang__) && (__clang_major__ > 7) -#define IS_SANITIZER \ - ((__has_feature(address_sanitizer) == 1) || \ - (__has_feature(memory_sanitizer) == 1) || \ - (__has_feature(thread_sanitizer) == 1) || \ - (__has_feature(undefined_sanitizer) == 1)) - -#if IS_SANITIZER == 0 -#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") -#endif -#elif defined(_OPENMP) && (_OPENMP >= 201511) -// Support with OpenMP4.5 and above -#define VECTOR_LOOP _Pragma("omp for simd") -#endif -#endif - -#ifndef VECTOR_LOOP -// Not supported -#define VECTOR_LOOP -#endif +#include "vectorizer.h" namespace caffe2 { namespace perfkernels { diff --git a/caffe2/perfkernels/vectorizer.h b/caffe2/perfkernels/vectorizer.h new file mode 100644 index 000000000000..be4e6bbc280f --- /dev/null +++ b/caffe2/perfkernels/vectorizer.h @@ -0,0 +1,28 @@ +#pragma once + +#if (ENABLE_VECTORIZATION > 0) && !defined(_DEBUG) && !defined(DEBUG) +#if defined(__clang__) && (__clang_major__ > 7) +#define IS_SANITIZER \ + ((__has_feature(address_sanitizer) == 1) || \ + (__has_feature(memory_sanitizer) == 1) || \ + (__has_feature(thread_sanitizer) == 1) || \ + (__has_feature(undefined_sanitizer) == 1)) + +#if IS_SANITIZER == 0 +#define VECTOR_LOOP _Pragma("clang loop vectorize(enable)") +#define FAST_MATH _Pragma("clang fp contract(fast)") +#define VECTORIZED_KERNEL 1 +#endif +#elif defined(_OPENMP) && (_OPENMP >= 201511) +// Support with OpenMP4.5 and above +#define VECTOR_LOOP _Pragma("omp for simd") +#define VECTORIZED_KERNEL 1 +#define FAST_MATH +#endif +#endif + +#ifndef VECTOR_LOOP +// Not supported +#define VECTOR_LOOP +#define FAST_MATH +#endif From 7ad87f63e248b629d435a199cb61f4ed1f3dfcab Mon Sep 17 00:00:00 2001 From: Grigory Sizov Date: Thu, 10 Nov 2022 08:12:56 +0000 Subject: [PATCH 003/453] Support src_mask and src_key_padding_mask for Better Transformer (#88488) Fixes T135842750 (follow-up for #87377) ## Description At present, having both `src_key_padding_mask` and `src_mask` at the same time is not supported on the fastpath in Transformer and Multi-Head Attention. This PR enables using both masks on the fastpath on CPU and GPU: if both masks are passed, we merge them into a 4D mask in Python and change mask type to 2 before passing downstream. Downstream processing in native code is not changed, as it already supports 4D mask. Indeed, it is done depending on the device: - on CUDA, by `SoftMax.cu::masked_softmax_cuda`. When mask type is 2, it calls either `dispatch_softmax_forward` -> `softmax_warp_forward` or `at::softmax` (depending on the input size). In both cases 4D mask is supported. - on CPU, by `SoftMax.cpp::masked_softmax_cpp`. It calls `hosted_softmax` which supports 4D mask. ## Tests - Extended `test_mask_check_fastpath` to check that fast path is indeed taken in Transformer when two masks are passed - Added `test_multihead_self_attn_two_masks_fast_path_mock` to check that fast path is taken in MHA when two masks are passed - Added `test_multihead_self_attn_two_masks_fast_path` to check that fast and slow paths give the same result when two masks are passed in MHA - `test_masked_softmax_mask_types` now covers mask type 2 - `test_transformerencoderlayer_fast_path` (CPU smoke test) is expanded to the case of both masks provided simultaneously - `test_masked_softmax_devices_parity` checks that mask type 2 is accepted by CPU and CUDA paths Pull Request resolved: https://github.com/pytorch/pytorch/pull/88488 Approved by: https://github.com/mikekgfb --- test/test_nn.py | 132 ++++++++++++++++++++++++++++---- test/test_transformers.py | 31 ++++---- torch/nn/modules/activation.py | 48 ++++++++++-- torch/nn/modules/transformer.py | 12 +-- 4 files changed, 182 insertions(+), 41 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index d2eac6a277d7..b07793e79f48 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13122,10 +13122,10 @@ def _slow_masked_softmax(self, input, mask): s = exp.sum(dim=3, keepdim=True).expand(exp.size()) return exp / s - def test_masked_softmax_mask_types_0_1(self, device): - # Test that mask type 0 (LxL attention mask) and mask type 1 (BxL padding mask) - # are processed correctly on the fast path and the results match explicit slow - # calculation. + def test_masked_softmax_mask_types(self, device): + # Test that mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), + # and mask type 2 (generic BxHxLxL mask) are processed correctly on the + # fast path and the results match explicit slow calculation. sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] for (B, num_heads, L) in sizes: @@ -13138,7 +13138,12 @@ def test_masked_softmax_mask_types_0_1(self, device): src_key_padding_mask_orig = torch.randint(0, 2, (B, L)).bool() src_key_padding_mask = src_key_padding_mask_orig.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() - masks = [(src_mask_orig, src_mask, 0), (src_key_padding_mask_orig, src_key_padding_mask, 1)] + # mask_type == 2 => shape BxHxLxL + generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() + masks = [(src_mask_orig, src_mask, 0), + (src_key_padding_mask_orig, src_key_padding_mask, 1), + (generic_mask, generic_mask, 2) + ] for dim in [0, 3]: for mask_orig, mask, mask_type in masks: if (self.device_type == "cuda") and (num_heads % 2) and (mask_type == 1): @@ -13173,8 +13178,8 @@ def slow_masked_softmax(input, mask): @onlyCUDA def test_masked_softmax_devices_parity(self): - # Test that softmax with mask type 0 (LxL attention mask) and mask type 1 (BxL padding mask) - # gives the same result on CPU and on CUDA + # Test that softmax with mask type 0 (LxL attention mask), mask type 1 (BxL padding mask), + # and mask type 2 (BxHxLxL generic mask) gives the same result on CPU and on CUDA. sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] for (B, num_heads, L) in sizes: @@ -13182,7 +13187,9 @@ def test_masked_softmax_devices_parity(self): src_mask = torch.randint(0, 2, (L, L)).bool() # mask_type == 1 => padding mask of shape BxL src_key_padding_mask = torch.randint(0, 2, (B, L)).bool() - masks = [(src_mask, 0), (src_key_padding_mask, 1)] + # mask_type == 2 => generic mask of shape BxHxLxL + generic_mask = torch.randint(0, 2, (B, num_heads, L, L)).bool() + masks = [(src_mask, 0), (src_key_padding_mask, 1), (generic_mask, 2)] input = torch.randn((B, num_heads, L, L)) for dim in [0, 3]: for mask, mask_type in masks: @@ -13197,8 +13204,10 @@ def softmax_on_device(mask, input, device): softmax_res = torch._masked_softmax(input_device, mask_device, dim, mask_type) if mask_type == 0: mask_expanded = mask_device.reshape(1, 1, L, L).expand(B, num_heads, L, L).bool() - else: + elif mask_type == 1: mask_expanded = mask_device.reshape(B, 1, 1, L).expand(B, num_heads, L, L).bool() + else: + mask_expanded = mask_device # In result, should only fill the entirely masked out rows since those are non-deterministic (*may* be 0) # Fill rows with all True's with 0 mask_out = mask_expanded.all(dim, keepdim=True).expand(mask_expanded.shape) @@ -13209,6 +13218,93 @@ def softmax_on_device(mask, input, device): cuda_res = softmax_on_device(mask, input, "cuda") self.assertEqual(cpu_res, cuda_res, exact_dtype=True) + def test_multihead_self_attn_two_masks_fast_path(self, device): + """ + Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path + when both attention mask (mask type 0) and key padding mask (mask type 1) are provided + """ + with torch.no_grad(): + embed_dim = 14 + num_heads = 7 + batch_size = 8 + src_len = 5 + + query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) + # Create masks of two different types + attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) + key_padding_mask = torch.randint(0, 2, (batch_size, src_len)).bool().to(device) + + # We'll need expanded versions of the masks for masking out the outputs below + attn_mask_expanded = attn_mask.reshape(1, 1, src_len, src_len) \ + .expand(batch_size, num_heads, src_len, src_len) + key_padding_mask_expanded = key_padding_mask.reshape(batch_size, 1, 1, src_len) \ + .expand(batch_size, num_heads, src_len, src_len) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + + # Compute attention on the fast path + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, device=device) + mta_model.training = False + result_fast_path, _ = mta_model(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + + # Compute attention on the slow path + result_ref, _ = torch.nn.functional.multi_head_attention_forward(query.transpose(0, 1), + key.transpose(0, 1), + value.transpose(0, 1), + embed_dim, num_heads, + mta_model.in_proj_weight, + mta_model.in_proj_bias, + mta_model.bias_k, mta_model.bias_v, + mta_model.add_zero_attn, + mta_model.dropout, + mta_model.out_proj.weight, + mta_model.out_proj.bias, + training=mta_model.training, + key_padding_mask=key_padding_mask, + need_weights=False, + attn_mask=attn_mask, + use_separate_proj_weight=False, + q_proj_weight=mta_model.q_proj_weight, + k_proj_weight=mta_model.k_proj_weight, + v_proj_weight=mta_model.v_proj_weight, + average_attn_weights=False, + ) + result_ref = result_ref.transpose(0, 1) # Convert to batch-first + + # Rows which are completely masked out are nan, we need to exclude them from comparison + mask_out = merged_mask[:, 0, :, :].all(-1, keepdim=True).expand(batch_size, src_len, embed_dim) + result_fast_path_masked = result_fast_path.masked_fill(mask_out, 0) + result_ref_masked = result_ref.masked_fill(mask_out, 0) + + self.assertEqual(result_fast_path_masked, result_ref_masked) + + @torch.no_grad() + @unittest.skipIf(TEST_WITH_CROSSREF, 'CrossRef turns on TorchFunctionMode, and so disables fastpath.') + def test_multihead_self_attn_two_masks_fast_path_mock(self, device): + """ + Multihead self-attention should take fast path when both attention mask (mask type 0) + and key padding mask (mask type 1) are provided at the same time on CPU and CUDA + """ + if device not in ['cpu', 'cuda']: + self.skipTest("Fastpath only runs on CPU and CUDA.") + with torch.autocast(device_type=device, enabled=False): + embed_dim = 14 + num_heads = 7 + batch_size = 8 + src_len = 5 + + query = value = key = torch.rand(batch_size, src_len, embed_dim).to(device) + # Create masks of two different types + attn_mask = torch.randint(0, 2, (src_len, src_len)).bool().to(device) + key_padding_mask = torch.randint(0, 2, (batch_size, src_len)).bool().to(device) + + with mock.patch('torch._native_multi_head_attention') as fastpath_mock: + # Compute attention on the fast path + mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True, device=device).eval() + mta_model.training = False + mta_model(query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + # If mock was called, fastpath was taken + self.assertTrue(fastpath_mock.called) + def test_masked_softmax(self, device): sizes = [(1, 1, 32), (3, 16, 310), (12, 4, 1024), (4, 2, 1200)] for (B, num_heads, L) in sizes: @@ -15567,22 +15663,32 @@ def test_transformerencoderlayer_fast_path(self, device, dtype): """ Test transformer fast path on CPU with different valid mask types and shapes """ - model = torch.nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True, device=device, dtype=dtype) + d_model = 512 + nhead = 8 + batch_size = 32 + src_len = 10 + + model = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True, + device=device, dtype=dtype, dropout=0) model.eval() # Batched inputs - src = torch.rand(32, 10, 512) + src = torch.rand(batch_size, src_len, 512) # Attention mask of shape (src_len, src_len) - src_mask = torch.zeros(10, 10).to(torch.bool) + src_mask = torch.zeros(src_len, src_len).to(torch.bool) with torch.no_grad(): model(src, src_mask=src_mask) # Padding mask of shape (batch_size, src_len) - src_key_padding_mask = torch.zeros(32, 10).to(torch.bool) + src_key_padding_mask = torch.zeros(batch_size, src_len).to(torch.bool) with torch.no_grad(): model(src, src_key_padding_mask=src_key_padding_mask) + # Provide both masks + with torch.no_grad(): + model(src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask) + @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) diff --git a/test/test_transformers.py b/test/test_transformers.py index 656191c9ddda..a9d0d960fb9a 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -869,17 +869,18 @@ def rand_tensor(*shape): @torch.no_grad() def test_mask_check_fastpath(self): """ - Test that fastpath is executed independently of the mask that is passed. - If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath), - otherwise use fastpath with traditional tensors. + Test that fastpath is executed independently of the masks that are passed. + If the passed key padding mask is left aligned or mask_check=False, test that nested tensors are used + (sparsity fastpath), otherwise use fastpath with traditional tensors. + Also test that fast path is executed with both key padding mask and attention mask passed at the same time. """ x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float) - def _test_fastpath(model, mask, mock_return_value, nested_tensors=True): + def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, nested_tensors=True): with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock: fastpath_mock.return_value = mock_return_value - model(x, src_key_padding_mask=mask) + model(x, src_key_padding_mask=key_padding_mask, mask=attn_mask) # If mock was called, fastpath was taken self.assertTrue(fastpath_mock.called) @@ -893,31 +894,33 @@ def _test_fastpath(model, mask, mock_return_value, nested_tensors=True): model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True) model.eval() - aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) - not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) + aligned_key_padding_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool) + not_aligned_key_padding_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool) + attn_mask = torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]).to(torch.bool) nested_tensor_return_value = torch.nested.nested_tensor([torch.ones((2, 2), dtype=torch.float)]) tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float) # Left aligned mask results in sparsity fastpath - _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) # Not aligned mask results in fastpath - _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) + _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True) model.eval() # If nested tensor disabled, fastpath is always taken - _test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False) - _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False) - + _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, nested_tensors=False) + _test_fastpath(model, not_aligned_key_padding_mask, tensor_return_value, nested_tensors=False) + # Fast path is taken if both attention mask and key padding mask are present + _test_fastpath(model, aligned_key_padding_mask, tensor_return_value, attn_mask=attn_mask, nested_tensors=False) model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False) model.eval() # Mask check disabled results in sparisty fastpath, independently of the mask - _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True) - _test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) + _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("type", ["dense", "nested"]) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index d7a9f13809d6..5f5615b496d7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -901,6 +901,7 @@ class MultiheadAttention(Module): - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This restriction will be loosened in the future.) + - inputs are batched (3D) with ``batch_first==True`` - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - dropout is 0 @@ -908,9 +909,9 @@ class MultiheadAttention(Module): - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched - ``kdim`` and ``vdim`` are equal to ``embed_dim`` - - at most one of ``key_padding_mask`` or ``attn_mask`` is passed - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` nor ``attn_mask`` is passed + - autocast is disabled If the optimized implementation is in use, a `NestedTensor `_ can be passed for @@ -1094,9 +1095,8 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O elif not self._qkv_same_embed_dim: why_not_fast_path = "_qkv_same_embed_dim was not True" elif query.is_nested and (key_padding_mask is not None or attn_mask is not None): - why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input" - elif not query.is_nested and key_padding_mask is not None and attn_mask is not None: - why_not_fast_path = "key_padding_mask and attn_mask were both supplied" + why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \ + is not supported with NestedTensor input" elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" @@ -1120,6 +1120,8 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") if not why_not_fast_path: + merged_mask, mask_type = self.merge_masks(attn_mask, key_padding_mask, query) + return torch._native_multi_head_attention( query, key, @@ -1130,10 +1132,10 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O self.in_proj_bias, self.out_proj.weight, self.out_proj.bias, - key_padding_mask if key_padding_mask is not None else attn_mask, + merged_mask, need_weights, average_attn_weights, - 1 if key_padding_mask is not None else 0 if attn_mask is not None else None) + mask_type) any_nested = query.is_nested or key.is_nested or value.is_nested assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " + @@ -1175,6 +1177,40 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O else: return attn_output, attn_output_weights + def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], + query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]: + r""" + Determine mask type and combine masks if necessary. If only one mask is provided, that mask + and the corresponding mask type will be returned. If both masks are provided, they will be both + expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or`` + and mask type 2 will be returned + Args: + attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0 + key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1 + query: query embeddings of shape ``(batch_size, seq_len, embed_dim)`` + Returns: + merged_mask: merged mask + mask_type: merged mask type (0, 1, or 2) + """ + mask_type: Optional[int] = None + merged_mask: Optional[Tensor] = None + if attn_mask is not None: + mask_type = 0 + merged_mask = attn_mask + if key_padding_mask is not None: + mask_type = 1 + merged_mask = key_padding_mask + if (attn_mask is not None) and (key_padding_mask is not None): + # In this branch query can't be a nested tensor, so it has a shape + batch_size, seq_len, _ = query.shape + mask_type = 2 + key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len) \ + .expand(-1, self.num_heads, -1, -1) + attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1) + merged_mask = attn_mask_expanded.logical_or(key_padding_mask_expanded) + return merged_mask, mask_type + + class PReLU(Module): r"""Applies the element-wise function: diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 34dde6fc224f..37e8823edf2c 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -467,9 +467,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, elif not (self.norm1.eps == self.norm2.eps): why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None): - why_not_sparsity_fast_path = "src_key_padding_mask and src_mask are not supported with NestedTensor input" - elif (not src.is_nested) and (src_key_padding_mask is not None and src_mask is not None): - why_not_sparsity_fast_path = "src_key_padding_mask and src_mask were both supplied" + why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" elif self.self_attn.num_heads % 2 == 1: why_not_sparsity_fast_path = "num_head is odd" elif torch.is_autocast_enabled(): @@ -502,6 +500,7 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: + merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, @@ -521,11 +520,8 @@ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, self.linear1.bias, self.linear2.weight, self.linear2.bias, - # TODO: if src_mask and src_key_padding_mask merge to single 4-dim mask - src_mask if src_mask is not None else src_key_padding_mask, - 1 if src_key_padding_mask is not None else - 0 if src_mask is not None else - None, + merged_mask, + mask_type, ) From 7c353eb39559f2c8897a0580700dd0a6f943d34f Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Thu, 10 Nov 2022 09:40:05 +0000 Subject: [PATCH 004/453] [MPS] Fix softplus (#88555) 1. Fixes #87780 2. Fixes mps graph cache issue 3. Adds proper tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/88555 Approved by: https://github.com/kulinseth --- aten/src/ATen/native/mps/OperationUtils.h | 1 + aten/src/ATen/native/mps/OperationUtils.mm | 8 ++++- .../ATen/native/mps/operations/Activation.mm | 30 +++++++++++-------- test/test_mps.py | 11 ++++--- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 8d868386705a..93b014124339 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -109,6 +109,7 @@ void printTensorNDArray(const Tensor& t); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); string get_mem_format_string(c10::MemoryFormat memory_format); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 13a88efbfb5d..f41484b27b14 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -339,6 +339,12 @@ void resize_tensor(Tensor* output) { name:nil]; } +MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { + return [mpsGraph placeholderWithShape:@[@1] + dataType:dataType + name:nil]; +} + MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar) { return [mpsGraph placeholderWithShape:@[@1] dataType:getMPSScalarType(scalar.type()) @@ -382,4 +388,4 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override { } } // namespace mps } // namespace native -} // namespace at +} // namespace at \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index fca3f3f81b33..3837e407a76b 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -1464,13 +1464,15 @@ Tensor glu_backward_mps (const Tensor& grad_output, CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor *inputTensor_ = nil; MPSGraphTensor *betaTensor_ = nil; + MPSGraphTensor *thresholdTensor_ = nil; MPSGraphTensor *outputTensor_ = nil; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); MPSStream* stream = getCurrentMPSStream(); - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);; + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); @autoreleasepool { string key = "softplus_out_mps:" + getTensorsStringKey({self}); @@ -1486,7 +1488,9 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph = new CachedGraph(mpsGraph); MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, beta); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); + + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float)); MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; @@ -1499,9 +1503,6 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor secondaryTensor:betaTensor name:nil]; - MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() - shape:@[@1] - dataType:getMPSDataType(self.scalar_type())]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor secondaryTensor:thresholdTensor name:nil]; @@ -1524,6 +1525,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -1536,7 +1538,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, // Create dictionary of inputs and outputs NSDictionary* feeds = @{ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar) + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), }; NSDictionary* results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() @@ -1559,7 +1562,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, if(grad_input.numel() == 0) return; - MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);; + MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float); + MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float); struct CachedGraph : public MPSCachedGraph { @@ -1567,6 +1571,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor *gradOutputTensor_ = nil; MPSGraphTensor *inputTensor_ = nil; MPSGraphTensor *betaTensor_ = nil; + MPSGraphTensor *thresholdTensor_ = nil; MPSGraphTensor *outputTensor_ = nil; }; @@ -1590,7 +1595,9 @@ Tensor glu_backward_mps (const Tensor& grad_output, MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, beta); + MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); + + MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float)); MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[@1] @@ -1609,9 +1616,6 @@ Tensor glu_backward_mps (const Tensor& grad_output, rTensor = [mpsGraph divisionWithPrimaryTensor:rTensor secondaryTensor:unitExpBxTensor name:nil]; - MPSGraphTensor* thresholdTensor = [mpsGraph constantWithScalar:threshold.to() - shape:@[@1] - dataType:getMPSDataType(self.scalar_type())]; MPSGraphTensor* predicateTensor = [mpsGraph greaterThanWithPrimaryTensor:bxTensor secondaryTensor:thresholdTensor name:nil]; @@ -1623,6 +1627,7 @@ Tensor glu_backward_mps (const Tensor& grad_output, newCachedGraph->gradOutputTensor_ = gradOutputTensor; newCachedGraph->inputTensor_ = inputTensor; newCachedGraph->betaTensor_ = betaTensor; + newCachedGraph->thresholdTensor_ = thresholdTensor; newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -1637,7 +1642,8 @@ Tensor glu_backward_mps (const Tensor& grad_output, NSDictionary* feeds = @{ gradOutputPlaceholder.getMPSGraphTensor() : gradOutputPlaceholder.getMPSGraphTensorData(), selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(), - cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar) + cachedGraph->betaTensor_ : getMPSGraphTensorFromScalar(stream, beta_scalar), + cachedGraph->thresholdTensor_ : getMPSGraphTensorFromScalar(stream, threshold_scalar), }; NSDictionary* results = @{ gradInputPlaceholder.getMPSGraphTensor() : gradInputPlaceholder.getMPSGraphTensorData() diff --git a/test/test_mps.py b/test/test_mps.py index 6f6cf2d924f3..2ff5a9da71ef 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -3853,12 +3853,12 @@ def helper(shape, dim=0): # Test softplus def test_softplus(self): - def helper(shape): + def helper(shape, beta=0.5, threshold=0.5): cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True) x = cpu_x.detach().clone().to('mps').requires_grad_() - softplus_result = torch.nn.Softplus(beta=0.5, threshold=0.5)(x) - softplus_result_cpu = torch.nn.Softplus(beta=0.5, threshold=0.5)(cpu_x) + softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x) + softplus_result_cpu = torch.nn.Softplus(beta=beta, threshold=threshold)(cpu_x) cpu_grad = torch.randn(softplus_result.shape) grad = cpu_grad.to('mps') @@ -3872,6 +3872,8 @@ def helper(shape): # Test empty shape too for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]: helper(shape) + helper(shape, beta=0.6, threshold=0.6) # relu path + helper(shape, beta=1, threshold=20) # softplus path # Test silu @@ -7322,6 +7324,7 @@ class TestConsistency(TestCase): 'nn.functional.smooth_l1_loss': ['f16', 'f32'], 'nn.functional.soft_margin_loss': ['f32'], 'nn.functional.softmin': ['f32'], + 'nn.functional.softplus': ['f32'], 'nn.functional.softsign': ['f16', 'f32', 'i16', 'u8'], 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'u8'], 'nn.functional.threshold': ['f32', 'i16', 'i32', 'i64', 'u8'], @@ -7522,6 +7525,7 @@ class TestConsistency(TestCase): 'nn.functional.silu': ['f32'], 'nn.functional.soft_margin_loss': ['f32'], 'nn.functional.softmin': ['f32'], + 'nn.functional.softplus': ['f32'], 'nn.functional.softsign': ['f16', 'f32'], 'nn.functional.threshold': ['f32'], 'nn.functional.triplet_margin_loss': ['f32'], @@ -7614,7 +7618,6 @@ class TestConsistency(TestCase): 'nn.functional.huber_loss': [torch.float16], 'nn.functional.local_response_norm': [torch.int64], 'nn.functional.padcircular': [torch.uint8], - 'nn.functional.softplus': [torch.float32], 'pow': [torch.int64], 'select_scatter': [torch.uint8], 'sigmoid': [torch.int64], From e6561291b89ecfbe35990decfcf16db47419d429 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Nov 2022 13:44:45 +0000 Subject: [PATCH 005/453] add hack to allow hybrid compressed sparse comparison in assertEqual (#88749) Hybrid sparse CSR tensors can currently not be compared to strided ones since `.to_dense` does not work: ```py import torch from torch.testing._internal.common_utils import TestCase assertEqual = TestCase().assertEqual actual = torch.sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [[1, 11], [2, 12] ,[3, 13] ,[4, 14]]) expected = torch.stack([actual[0].to_dense(), actual[1].to_dense()]) assertEqual(actual, expected) ``` ``` main.py:4: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:54.) actual = torch.sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [[1, 11], [2, 12] ,[3, 13] ,[4, 14]]) Traceback (most recent call last): File "/home/philip/git/pytorch/torch/torch/testing/_comparison.py", line 1098, in assert_equal pair.compare() File "/home/philip/git/pytorch/torch/torch/testing/_comparison.py", line 619, in compare actual, expected = self._equalize_attributes(actual, expected) File "/home/philip/git/pytorch/torch/torch/testing/_comparison.py", line 706, in _equalize_attributes actual = actual.to_dense() if actual.layout != torch.strided else actual RuntimeError: sparse_compressed_to_dense: Hybrid tensors are not supported The above exception was the direct cause of the following exception: Traceback (most recent call last): File "main.py", line 10, in assertEqual(actual, expected) File "/home/philip/git/pytorch/torch/torch/testing/_internal/common_utils.py", line 2503, in assertEqual msg=(lambda generated_msg: f"{generated_msg}\n{msg}") if isinstance(msg, str) and self.longMessage else msg, File "/home/philip/git/pytorch/torch/torch/testing/_comparison.py", line 1112, in assert_equal ) from error RuntimeError: Comparing TensorOrArrayPair( id=(), actual=tensor(crow_indices=tensor([0, 2, 4]), col_indices=tensor([0, 1, 0, 1]), values=tensor([[ 1, 11], [ 2, 12], [ 3, 13], [ 4, 14]]), size=(2, 2, 2), nnz=4, layout=torch.sparse_csr), expected=tensor([[[ 1, 11], [ 2, 12]], [[ 3, 13], [ 4, 14]]]), rtol=0.0, atol=0.0, equal_nan=True, check_device=False, check_dtype=True, check_layout=False, check_stride=False, check_is_coalesced=False, ) resulted in the unexpected exception above. If you are a user and see this message during normal operation please file an issue at https://github.com/pytorch/pytorch/issues. If you are a developer and working on the comparison functions, please except the previous error and raise an expressive `ErrorMeta` instead. ``` This adds a temporary hack to `TestCase.assertEqual` to enable this. Basically, we are going through the individual CSR subtensors, call `.to_dense()` on them, and stack everything back together. I opted to not do this in the common machinery, since that way users are not affected by this (undocumented) hack. I also added an xfailed test that will trigger as soon as the behavior is supported natively so we don't forget to remove the hack when it is no longer needed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88749 Approved by: https://github.com/mruberry, https://github.com/pearu --- test/test_testing.py | 11 +++++++++++ torch/testing/_internal/common_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/test/test_testing.py b/test/test_testing.py index ccb2471e71e7..5ce07ce454dc 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1178,6 +1178,17 @@ def test_mismatching_values_msg(self): with self.assertRaisesRegex(AssertionError, re.escape("Sparse CSR values")): fn() + @unittest.expectedFailure + def test_hybrid_support(self): + # If you read this after the test unexpectedly succeeded, this is a good thing. It means that you added support + # for `.to_dense()` for hybrid sparse CSR tensors and in turn enabled support for them in + # `torch.testing.assert_close` if comparing to strided tensors. You can safely remove this test as well as the + # patch on `TensorOrArrayPair` in `torch.testing._internal.common_utils`. + actual = torch.sparse_csr_tensor([0, 2, 4], [0, 1, 0, 1], [[1, 11], [2, 12], [3, 13], [4, 14]]) + expected = torch.stack([actual[0].to_dense(), actual[1].to_dense()]) + + torch.testing.assert_close(actual, expected, check_layout=False) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Not all sandcastle jobs support CSC testing") class TestAssertCloseSparseCSC(TestCase): diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 6fd64187581f..8f497d515eb5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1820,6 +1820,31 @@ def __init__(self, actual, expected, *, rtol_override=0.0, atol_override=0.0, ** self.rtol = max(self.rtol, rtol_override) self.atol = max(self.atol, atol_override) + # This is a slow and ugly hack to allow the comparison of hybrid sparse CSR tensors with strided ones. If + # `check_layout=False` (default), the tensors will be converted to strided by calling `.to_dense()` on them. + # However, this is not yet supported for hybrid sparse CSR and thus we need to do it manually for now. + # FIXME: Remove this as soon as `.to_dense` is supported for hybrid sparse CSR tensors + if not self.check_layout: + self.actual, self.expected = self._handle_hybrid_sparse_csr(self.actual, self.expected) + + def _handle_hybrid_sparse_csr(self, actual, expected): + compressed_sparse_layouts = {torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc} + if not ((actual.layout in compressed_sparse_layouts) ^ (expected.layout in compressed_sparse_layouts)): + return actual, expected + + def to_dense(tensor): + if tensor.layout not in compressed_sparse_layouts: + return tensor + + def partial_to_dense(tensor): + if tensor.layout not in compressed_sparse_layouts or tensor.values().ndim == 1: + return tensor.to_dense() + return torch.stack([partial_to_dense(sub_tensor) for sub_tensor in tensor]) + + return partial_to_dense(tensor) + + return [to_dense(input) for input in [actual, expected]] + def _process_inputs(self, actual, expected, *, id, allow_subclasses): self._check_inputs_isinstance(actual, expected, cls=(torch.Tensor, np.ndarray)) From 3e43ff279428e5d07932968fbd7792200fa15a4d Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Thu, 10 Nov 2022 01:30:03 -0500 Subject: [PATCH 006/453] torchdynamo: add convolution add(relu) inplace fusion kernel (#88048) This PR is about add convolution add(relu) inplace fusion kernel which works for **other.add_(conv)**. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88048 Approved by: https://github.com/jgong5, https://github.com/jansel --- aten/src/ATen/native/mkldnn/Conv.cpp | 250 ++++++++++++++---- aten/src/ATen/native/mkldnn/Linear.cpp | 5 +- .../mkldnn/RegisterMkldnnOpContextClass.cpp | 4 +- aten/src/ATen/native/mkldnn/Utils.cpp | 9 +- aten/src/ATen/native/mkldnn/Utils.h | 4 +- .../check_forward_backward_compatibility.py | 1 + test/test_mkldnn_fusion.py | 22 +- torch/_inductor/ir.py | 14 +- torch/_inductor/lowering.py | 20 +- torch/_inductor/overrides.py | 18 +- 10 files changed, 286 insertions(+), 61 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 508aefe787ad..ec62715129f4 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -9,6 +9,7 @@ #include #else #include +#include #include #include #include @@ -175,51 +176,23 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo return memory_format; } -Tensor _mkldnn_convolution( +void _mkldnn_convolution_out ( const Tensor& input_t, const Tensor& weight_t, - const c10::optional& bias_opt, - IntArrayRef padding, + const Tensor& bias, + std::vector& output_sizes, + ideep::tensor& y, IntArrayRef stride, IntArrayRef dilation, + IntArrayRef padding, int64_t groups, - c10::string_view attr = "none", - torch::List> scalars = - torch::List>(), - c10::optional algorithm = c10::nullopt) { - ideep::attr_t op_attr = ideep::attr_t(); - if (attr != "none") { - auto it = fx_fusion_attr_map().find(attr); - TORCH_CHECK(it != fx_fusion_attr_map().end(), "Fusion behavior undefined."); - op_attr = it->second(scalars, algorithm); - } - // See [Note: hacky wrapper removal for optional tensor] - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - if (input_t.scalar_type() == ScalarType::BFloat16) { - TORCH_CHECK(mkldnn_bf16_device_check(), - "mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); - } - - check_shape_forward(input_t, weight_t, bias, padding, stride, dilation, groups); - - bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t); + bool is_channels_last, + const ideep::attr_t& op_attr) { auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last); - auto input = input_t.is_mkldnn() ? input_t : input_t.contiguous(memory_format); auto weight = weight_t.is_mkldnn() ? weight_t : weight_t.contiguous(memory_format); - auto output_sizes = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - auto output = at::empty({0}, input.options()); - const ideep::tensor x = itensor_from_tensor(input); const ideep::tensor w = itensor_from_tensor(weight); - - ideep::tensor y; - if (is_channels_last) { - output.resize_(output_sizes, memory_format); - y = itensor_from_tensor(output); - } if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_v3( @@ -249,11 +222,55 @@ Tensor _mkldnn_convolution( is_channels_last, op_attr); } +} + +Tensor _mkldnn_convolution( + const Tensor& input_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + c10::string_view attr = "none", + torch::List> scalars = + torch::List>(), + c10::optional algorithm = c10::nullopt) { + ideep::attr_t op_attr = ideep::attr_t(); + if (attr != "none") { + auto it = fusion_unary_attr_map().find(attr); + TORCH_CHECK( + it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); + op_attr = it->second(scalars, algorithm); + } + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + if (input_t.scalar_type() == ScalarType::BFloat16) { + TORCH_CHECK(mkldnn_bf16_device_check(), + "mkldnn_convolution: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq"); + } + + check_shape_forward(input_t, weight_t, bias, padding, stride, dilation, groups); + + bool is_channels_last = mkldnn_conv_use_channels_last(input_t, weight_t); + auto memory_format = mkldnn_convolution_memory_format(input_t.ndimension(), is_channels_last); + + auto output_sizes = conv_output_size(input_t.sizes(), weight_t.sizes(), padding, stride, dilation); + auto output = at::empty({0}, input_t.options()); + ideep::tensor y; + if (is_channels_last) { + output.resize_(output_sizes, memory_format); + y = itensor_from_tensor(output); + } + _mkldnn_convolution_out( + input_t, weight_t, bias, output_sizes, y, stride, dilation, padding, groups, is_channels_last, op_attr); - if (input.is_mkldnn()) { - return MKLDNNTensor(y, input.options()); + if (input_t.is_mkldnn()) { + return MKLDNNTensor(y, input_t.options()); } else if (!is_channels_last) { - return mkldnn_to_dense(MKLDNNTensor(y, input.options())); + return mkldnn_to_dense(MKLDNNTensor(y, input_t.options())); } else { TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc()); return output; @@ -297,6 +314,14 @@ Tensor mkldnn_convolution_pointwise( algorithm); } +// Fuse convolution+binary_op+unary_op for good performance, which doing such +// operation: output=unary_op(binary_op(conv(input_t, ...), other_t, alpha)). +// The binary_attr means which binary_op is, it can be "add", or +// other binary operation. the unary_attr means which unary_op is, +// it can be "relu" or other unary operation, if it is none, meaning that +// there doesn't have a unary post op. unary_scalars and unary_algorithm +// are the parameters of the unary op, such as "hardtanh" has scalar parameters, +// "gelu" has algorithm parameters. Tensor mkldnn_convolution_pointwise_binary( const Tensor& input_t, const Tensor& other_t, @@ -306,10 +331,17 @@ Tensor mkldnn_convolution_pointwise_binary( IntArrayRef stride, IntArrayRef dilation, int64_t groups, - c10::string_view attr) { + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm) { TORCH_CHECK( input_t.ndimension() == 4 || input_t.ndimension() == 5, "mkldnn_convolution_pointwise_binary: currently only support 2d and 3d") + TORCH_CHECK( + !alpha.has_value() || alpha.value().to() == 1.0, + "mkldnn_convolution_pointwise_binary: the alpha value should be none or 1.0"); c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); @@ -334,9 +366,22 @@ Tensor mkldnn_convolution_pointwise_binary( bool can_be_fused = groups == 1 && mkldnn_conv_use_channels_last(input_t, weight_t); - auto it_binary = fusion_binary_alg_map().find(attr); + c10::string_view unary_attr_value = "none"; + ideep::algorithm unary_alg; + if (unary_attr.has_value()) { + auto it_unary = fusion_unary_alg_map().find(unary_attr.value()); + // Now, we only support conv+binary+relu. + TORCH_CHECK( + it_unary != fusion_unary_alg_map().end(), + "Unary Fusion behavior undefined."); + unary_attr_value = unary_attr.value(); + unary_alg = it_unary->second; + } + auto it_binary = fusion_binary_alg_map().find(binary_attr); TORCH_CHECK( - it_binary != fusion_binary_alg_map().end(), "Fusion behavior undefined."); + it_binary != fusion_binary_alg_map().end(), + "Binary Fusion behavior undefined."); + if (can_be_fused) { c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); auto memory_format = @@ -356,7 +401,15 @@ Tensor mkldnn_convolution_pointwise_binary( } auto other_desc = ideep::tensor::desc( output_size, get_mkldnn_dtype(weight.scalar_type()), format_tag); - auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc); + + ideep::attr_t op_attr; + ideep::post_ops po; + po.append_binary(it_binary->second, other_desc); + if (unary_attr_value != "none") { + po.append_eltwise(1.0, unary_alg, 0.f, 0.f); + } + op_attr.set_post_ops(po); + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_binary( @@ -400,19 +453,123 @@ Tensor mkldnn_convolution_pointwise_binary( output = at::conv3d( input_t, weight_t, bias_opt, stride, padding, dilation, groups); } - if (attr == "add") { + if (binary_attr == "add" && unary_attr_value != "none") { + output = at::native::add_relu_(output, other_t); + return output; + } + if (binary_attr == "add") { output.add_(other_t); - } else if (attr == "sub") { + } else if (binary_attr == "sub") { output.sub_(other_t); - } else if (attr == "mul") { + } else if (binary_attr == "mul") { output.mul_(other_t); } else { output.div_(other_t); } + if (unary_attr_value != "none") { + output.relu_(); + } return output; } } +// Fuse convolution+binary_op+unary_op for good performance, which doing +// such operation: other_t=unary_op(binary_op(conv(input_t, ...), other_t, +// alpha)). The binary_attr means which binary_op is, it can be "add", or other +// binary operation. the unary_attr means which unary_op is, it can be "relu" or +// other unary operation, if it is none, meaning that there doesn't have a unary +// post op. unary_scalars and unary_algorithm are the parameters of the unary +// op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters. + +Tensor& mkldnn_convolution_pointwise_binary_( + const Tensor& input_t, + Tensor& other_t, + const Tensor& weight_t, + const c10::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm) { + // other_t += convolution(...), other_t = unary(other_t) + TORCH_CHECK( + input_t.ndimension() == 4 || input_t.ndimension() == 5, + "mkldnn_convolution_add_: currently only support 2d and 3d") + TORCH_CHECK( + binary_attr == "add", + "mkldnn_convolution_pointwise_binary_: only support binary op fusion") + TORCH_CHECK( + !alpha.has_value() || alpha.value().to() == 1.0, + "mkldnn_convolution_pointwise_binary: the alpha value for the binary op should be none(meaning 1.0) or 1.0"); + TORCH_CHECK( + !unary_attr.has_value() || unary_attr.value() == "relu", + "mkldnn_convolution_pointwise_binary: only support none or relu unary op fusion after binary op"); + + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + // Make sure inputs have same type(device, layout, dtype), device is cpu and + // dtype is float or bfloat16. + check_mkldnn_binary_fusion_inputs(input_t, other_t, weight_t, bias); + + check_shape_forward( + input_t, weight_t, bias, padding, stride, dilation, groups); + + auto output_sizes = conv_output_size( + input_t.sizes(), weight_t.sizes(), padding, stride, dilation); + TORCH_CHECK( + output_sizes == other_t.sizes(), + "Add Fusion's inputs should have same shape"); + // Only calling fusion path for channels_last path and the output is contiguous tensor(channels_last). + bool can_be_fused = mkldnn_conv_use_channels_last(input_t, weight_t) + && (other_t.is_contiguous(at::MemoryFormat::ChannelsLast) + || other_t.is_contiguous(at::MemoryFormat::ChannelsLast3d)); + if (can_be_fused) { + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + ideep::tensor y = itensor_from_tensor(other_t); + ideep::attr_t op_attr; + if (unary_attr.has_value()) { + op_attr = ideep::attr_t::residual(); + } else { + op_attr = ideep::attr_t::fuse_sum(); + } + _mkldnn_convolution_out( + input_t, + weight_t, + bias, + output_sizes, + y, + stride, + dilation, + padding, + groups, + true, + op_attr); + } else { + // Fallback case, if inputs are not channels last or have different dtype, + // OneDNN fusion may have performance regression. + Tensor output; + if (input_t.ndimension() == 4) { + output = at::conv2d( + input_t, weight_t, bias_opt, stride, padding, dilation, groups); + } else { + output = at::conv3d( + input_t, weight_t, bias_opt, stride, padding, dilation, groups); + } + if (unary_attr.has_value()) { + other_t = at::native::add_relu_(other_t, output); + } else { + other_t.add_(output); + } + } + return other_t; +} + Tensor mkldnn_convolution_backward_input( IntArrayRef input_size, const Tensor& grad_output, @@ -540,6 +697,9 @@ TORCH_LIBRARY_IMPL(mkldnn, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise.binary"), TORCH_FN(mkldnn_convolution_pointwise_binary)); + m.impl( + TORCH_SELECTIVE_NAME("mkldnn::_convolution_pointwise_.binary"), + TORCH_FN(mkldnn_convolution_pointwise_binary_)); } }} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index b57d8e56a16d..24bf1282bfd6 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -215,8 +215,9 @@ Tensor mkldnn_linear_pointwise( } const ideep::tensor w = itensor_from_tensor(weight_t); - auto it = fx_fusion_attr_map().find(attr); - TORCH_CHECK(it != fx_fusion_attr_map().end(), "Fusion behavior undefined."); + auto it = fusion_unary_attr_map().find(attr); + TORCH_CHECK( + it != fusion_unary_attr_map().end(), "Fusion behavior undefined."); ideep::attr_t op_attr = it->second(scalars, algorithm); if (mkldnn_bias.has_value()) { diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index 0be8d8a100cd..08230827b58e 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -42,7 +42,9 @@ TORCH_LIBRARY(mkldnn, m) { m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn::_convolution_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA( - "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str attr) -> Tensor Y")); + "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y")); + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y")); } TORCH_LIBRARY(mkldnn_prepacked, m) { diff --git a/aten/src/ATen/native/mkldnn/Utils.cpp b/aten/src/ATen/native/mkldnn/Utils.cpp index 42f855d75cbe..5db6e0b07ff1 100644 --- a/aten/src/ATen/native/mkldnn/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/Utils.cpp @@ -127,7 +127,7 @@ AttrFunction attr_func_gelu = [](torch::List> scalars, return ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type); }; -const std::map& fx_fusion_attr_map() { +const std::map& fusion_unary_attr_map() { static const std::map fusion_attr_map{ {"relu", ATTR_FUNC(relu)}, {"sigmoid", ATTR_FUNC(sigmoid)}, @@ -140,6 +140,13 @@ const std::map& fx_fusion_attr_map() { return fusion_attr_map; }; +const std::map& fusion_unary_alg_map() { + static const std::map fusion_attr_map{ + {"relu", {ideep::algorithm::eltwise_relu}}, + }; + return fusion_attr_map; +}; + const std::map& fusion_binary_alg_map() { static const std::map fusion_attr_map{ {"add", {ideep::algorithm::binary_add}}, diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h index 314a7efc950e..a25be13c46da 100644 --- a/aten/src/ATen/native/mkldnn/Utils.h +++ b/aten/src/ATen/native/mkldnn/Utils.h @@ -39,7 +39,9 @@ using AttrFunction = std::function>, c10::optional)>; -const std::map& fx_fusion_attr_map(); +const std::map& fusion_unary_attr_map(); + +const std::map& fusion_unary_alg_map(); const std::map& fusion_binary_alg_map(); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index e9d0834a812c..90080ab0934f 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -317,6 +317,7 @@ ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), + ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ] ALLOW_LIST_COMPILED = [ diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index cdef4bcfd6a5..9f264337d956 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -271,8 +271,8 @@ def forward(self, x, other): for pointwise_name, pointwise_fn in self._binary_list().items(): for dim in [2, 3]: channels_last = torch.channels_last if dim == 2 else torch.channels_last_3d - options = itertools.product([True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) - for bias, dilation, groups, memory_format in options: + options = itertools.product([False, True], [True, False], [1, 2], [1, 4], [torch.contiguous_format, channels_last]) + for fuse_relu, bias, dilation, groups, memory_format in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC) + input_shapes[dim] @@ -282,12 +282,26 @@ def forward(self, x, other): other = torch.randn_like(mod.conv(x)) with torch.no_grad(): ref = mod(x, other) + unary_attr = None + if fuse_relu: + ref.relu_() + unary_attr = "relu" attr = pointwise_name fused = torch.ops.mkldnn._convolution_pointwise( x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, - mod.conv.groups, attr + mod.conv.groups, attr, None, unary_attr, [], None ) - self.assertEqual(ref, fused) + # for binary add, we support inplace version. + if attr == "add": + fused_inplace = torch.ops.mkldnn._convolution_pointwise_( + x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, + mod.conv.groups, attr, None, unary_attr, [], None + ) + self.assertEqual(ref, other) + self.assertEqual(ref, fused_inplace) + + self.assertEqual(ref, fused) + def test_linear_binary_fusion_ops(self): class M(nn.Module): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 629a8e94534d..0353bcc8b0be 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3425,7 +3425,11 @@ def create( stride_: List[int], dilation_: List[int], groups: int, - attr, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List], + unary_algorithm: Optional[str], ): kernel = "torch.ops.mkldnn._convolution_pointwise.binary" (inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create( @@ -3433,7 +3437,13 @@ def create( ) other = cls.require_stride1(cls.realize_input(other)) inputs.insert(1, other) - constant_args = constant_args + [attr] + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ] return ConvolutionBinary( layout=kernel_layout, inputs=inputs, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0ede696828a3..dedd39cd91c4 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -936,11 +936,27 @@ def convolution_binary( stride, dilation, groups, - attr, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, ): return TensorBox.create( ir.ConvolutionBinary.create( - x, other, weight, bias, padding, stride, dilation, groups, attr + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, ) ) diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 1ab55142619c..581e1996a436 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -151,7 +151,11 @@ def __init__( def _update_module_params(self, conv, binary_op_name): self.__dict__ = copy.deepcopy(conv.__dict__) - self.attr = binary_op_name + self.binary_attr = binary_op_name + self.binary_alpha = None + self.unary_attr = None + self.unary_scalars = [] + self.unary_algorithm = None def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": @@ -166,7 +170,11 @@ def _conv_forward(self, input, other, weight, bias): self.stride, self.dilation, self.groups, - self.attr, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, ) return torch.ops.mkldnn._convolution_pointwise( input, @@ -177,7 +185,11 @@ def _conv_forward(self, input, other, weight, bias): self.stride, self.dilation, self.groups, - self.attr, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, ) def forward(self, input, other): From 7f28be10e5e71efda37800384fa897785499bed1 Mon Sep 17 00:00:00 2001 From: samdow Date: Tue, 1 Nov 2022 18:35:38 -0400 Subject: [PATCH 007/453] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) First half of #87990. This doesn't change any of the behavior and is just a rename Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218 Approved by: https://github.com/ezyang, https://github.com/zou3519 --- test/allowlist_for_publicAPI.json | 2 +- test/profiler/test_profiler_tree.py | 2 +- test/test_overrides.py | 4 +-- test/test_public_bindings.py | 2 +- torch/_C/__init__.pyi.in | 2 +- torch/__init__.py | 2 +- torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/misc.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/_tensor.py | 2 +- torch/csrc/Module.cpp | 4 +-- torch/csrc/autograd/init.cpp | 1 - torch/csrc/utils/disable_torch_function.cpp | 32 ++++++++++--------- torch/csrc/utils/disable_torch_function.h | 2 +- torch/distributed/_shard/common_op_utils.py | 4 +-- torch/distributed/_shard/partial_tensor.py | 2 +- torch/distributed/_shard/replicated_tensor.py | 4 +-- .../_shard/sharded_tensor/_ops/tensor_ops.py | 2 +- torch/masked/maskedtensor/core.py | 2 +- 20 files changed, 39 insertions(+), 38 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index ba4a2e96df21..8a66dc12d4b6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,7 +1128,7 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", - "DisableTorchFunction", + "DisableTorchFunctionSubclass", "Generator", "HalfStorage", "HalfTensor", diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index d4a31c645613..210530250f92 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -26,7 +26,7 @@ "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES, - "": PRUNE_ALL, + "": PRUNE_ALL, "cudaStreamIsCapturing": PRUNE_ALL, "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL, } diff --git a/test/test_overrides.py b/test/test_overrides.py index 7082f75a2141..01c763a548fc 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1448,7 +1448,7 @@ class B(torch.Tensor): x = B(torch.randn(5)) with A(): - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): self.assertNotIsInstance(torch.sum(x), B) self.assertTrue(called) @@ -1460,7 +1460,7 @@ class A(torch.Tensor): pass x = A(torch.randn(5)) - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): g = torch._C._EnableTorchFunction() try: self.assertIsInstance(torch.sum(x), A) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 4d2df6512698..6897c3102df6 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,7 +99,7 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", - "DisableTorchFunction", + "DisableTorchFunctionSubclass", "DispatchKey", "DispatchKeySet", "dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d20da2a04f3..79dd6386c378 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,7 +108,7 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunction(): ... +def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided : layout = ... diff --git a/torch/__init__.py b/torch/__init__.py index ae55f5975542..ef6138cb4866 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -288,7 +288,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunction', 'Generator']: + if name not in ['DisableTorchFunctionSubclass', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d3c5140fa4a9..9d8789746855 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -506,7 +506,7 @@ def wrap_tensor(self, value: torch.Tensor): ) # Disable __torch_function__ to prevent cloning of `value` to hit # us - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): if is_constant_source(self.get_source()): return self.tx.output.register_attr_or_module( value, diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index da327122a6a7..6e4325b6c0f4 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -538,7 +538,7 @@ def call_function( options = VariableTracker.propagate(self, new_args, new_kwargs.values()) # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): if isinstance(args[0], TorchVariable): return TensorVariable.create( tx=tx, diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 315c2b1a7e07..5a30f838e3f3 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -704,7 +704,7 @@ def inline_torch_function_unwrapped( # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): return tx.inline_user_function_return(tf_func_var, tf_args, {}) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 14f5cd2de0a7..79af51efc5b8 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1093,5 +1093,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None): memo[id(tensor)] = out return out else: - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 793034bb64ed..41b6569c06d8 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1297,7 +1297,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with _C.DisableTorchFunction(): + with _C.DisableTorchFunctionSubclass(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b8693a484ed9..efe6c18ea0cd 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1594,8 +1594,8 @@ Call this whenever a new thread is created in order to propagate values from (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); ASSERT_TRUE(set_module_attr( - "DisableTorchFunction", - (PyObject*)THPModule_DisableTorchFunctionType(), + "DisableTorchFunctionSubclass", + (PyObject*)THPModule_DisableTorchFunctionSubclassType(), /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index ee963232d316..d26db95f1295 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -343,7 +343,6 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { _C_m, "_RestorePythonTLSSnapshot") .def(py::init<>()); - // TODO: line up this binding with DisableTorchFunction py::class_(_C_m, "_DisableTorchDispatch") .def(py::init<>()); py::class_(_C_m, "_EnableTorchFunction") diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 682120d7e622..516e6b89d43a 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -35,18 +35,20 @@ typedef struct { PyObject_HEAD /* Type-specific fields go here. */ bool old_state; -} DisableTorchFunction; +} DisableTorchFunctionSubclass; -PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { - ((DisableTorchFunction*)self)->old_state = +PyObject* DisableTorchFunctionSubclass__enter( + PyObject* self, + PyObject* unused) { + ((DisableTorchFunctionSubclass*)self)->old_state = at::impl::PythonTorchFunctionTLS::is_disabled(); at::impl::PythonTorchFunctionTLS::set_disabled(true); Py_RETURN_NONE; } -PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { +PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { at::impl::PythonTorchFunctionTLS::set_disabled( - ((DisableTorchFunction*)self)->old_state); + ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } @@ -58,16 +60,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) { } } -static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT - {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, +static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT + {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; -PyTypeObject DisableTorchFunctionType = { +PyTypeObject DisableTorchFunctionSubclassType = { PyVarObject_HEAD_INIT( nullptr, - 0) "torch._C.DisableTorchFunction", /* tp_name */ - sizeof(DisableTorchFunction), /* tp_basicsize */ + 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */ + sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */ 0, /* tp_itemsize */ nullptr, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ @@ -92,7 +94,7 @@ PyTypeObject DisableTorchFunctionType = { 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ - DisableTorchFunction_methods, /* tp_methods */ + DisableTorchFunctionSubclass_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ @@ -105,12 +107,12 @@ PyTypeObject DisableTorchFunctionType = { PyType_GenericNew, /* tp_new */ }; -PyObject* THPModule_DisableTorchFunctionType() { - if (PyType_Ready(&DisableTorchFunctionType) < 0) { +PyObject* THPModule_DisableTorchFunctionSubclassType() { + if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) { return nullptr; } - return (PyObject*)(&DisableTorchFunctionType); + return (PyObject*)(&DisableTorchFunctionSubclassType); } PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 3cdc33e90681..881a7adb13eb 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,7 +29,7 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); -PyObject* THPModule_DisableTorchFunctionType(); +PyObject* THPModule_DisableTorchFunctionSubclassType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg); diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 08aa13282abc..42d65923a536 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -53,11 +53,11 @@ def tensor_default_op(types, args=(), kwargs=None, pg=None): Handles ``__torch_function__`` dispatch for the default tensor ops that behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or ``torch.Tensor.dtype``. We simply lower to the real op call with - DisableTorchFunction context like ``torch.Tensor.__torch_function__`` + DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` to avoid recursions. """ if kwargs is None: kwargs = {} - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): return op(*args, **kwargs) diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py index dc8d09bdd7f3..6a48163082c5 100644 --- a/torch/distributed/_shard/partial_tensor.py +++ b/torch/distributed/_shard/partial_tensor.py @@ -236,7 +236,7 @@ def find_process_group(e): # Need to disable all dispatch to print args and kwargs appropriately. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for PartialTensor!") diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py index 1327f89e00aa..e3db6b0fac66 100644 --- a/torch/distributed/_shard/replicated_tensor.py +++ b/torch/distributed/_shard/replicated_tensor.py @@ -109,7 +109,7 @@ def dispatch_arg(arg): # We cann't do super().__torch_function__() as it implicitly convert the result # back to tensor subclasses, where in our case, we need to control the output type # base on the inter-op rules we defined. - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs @@ -157,7 +157,7 @@ def validate(self) -> bool: return True def __setstate__(self, state): - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): self.data = state self.requires_grad = state.requires_grad from torch.distributed._shard.api import _get_current_process_group diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index e52c29238a62..9ed83ee33f61 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -203,7 +203,7 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): local_shard.tensor.requires_grad_(requires_grad) # update the wrapper class property - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): self_st.requires_grad_(requires_grad) # update the metadata in the meanwhile self_st._metadata.tensor_properties.requires_grad = requires_grad diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 3274ef2ef956..0459f24587bd 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -270,7 +270,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with torch._C.DisableTorchFunction(): + with torch._C.DisableTorchFunctionSubclass(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret From c0ecce15b5a54ff0185f9976e6bfb6f3a7de698d Mon Sep 17 00:00:00 2001 From: samdow Date: Mon, 7 Nov 2022 15:43:39 -0500 Subject: [PATCH 008/453] add DisableTorchFunction that matches DisableTorchDispatch (#88219) Closes #87990. This implements a new disable guard that matches DisableTorchDispatch (disables all subclasses and modes) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88219 Approved by: https://github.com/ezyang --- aten/src/ATen/PythonTorchFunctionTLS.cpp | 11 ++- aten/src/ATen/PythonTorchFunctionTLS.h | 12 ++- test/allowlist_for_publicAPI.json | 1 + test/test_overrides.py | 21 ++++ test/test_public_bindings.py | 1 + torch/_C/__init__.pyi.in | 1 + torch/__init__.py | 2 +- torch/csrc/Module.cpp | 4 + torch/csrc/autograd/init.cpp | 9 +- torch/csrc/utils/disable_torch_function.cpp | 100 ++++++++++++++++++-- torch/csrc/utils/disable_torch_function.h | 1 + 11 files changed, 139 insertions(+), 24 deletions(-) diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index c9487c6958cb..00f372f370e6 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -26,12 +26,12 @@ int64_t PythonTorchFunctionTLS::stack_len() { return pythonTorchFunctionState.stack_.size(); } -void PythonTorchFunctionTLS::set_disabled(bool disabled) { - pythonTorchFunctionState.disabled_ = disabled; +void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) { + pythonTorchFunctionState.disabled_state_ = disabled_state; } -bool PythonTorchFunctionTLS::is_disabled() { - return pythonTorchFunctionState.disabled_; +TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() { + return pythonTorchFunctionState.disabled_state_; } void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) { @@ -43,7 +43,8 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { } bool torch_function_mode_enabled() { - return PythonTorchFunctionTLS::stack_len() > 0; + return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED && + PythonTorchFunctionTLS::stack_len() > 0; } } // namespace impl diff --git a/aten/src/ATen/PythonTorchFunctionTLS.h b/aten/src/ATen/PythonTorchFunctionTLS.h index 5940fb6f2dee..a1e3a61ea202 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.h +++ b/aten/src/ATen/PythonTorchFunctionTLS.h @@ -6,9 +6,11 @@ namespace at { namespace impl { +enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; + struct TORCH_API PythonTorchFunctionTLS { - static void set_disabled(bool); - static bool is_disabled(); + static void set_disabled_state(TorchFunctionDisabledState disabled_state_); + static TorchFunctionDisabledState get_disabled_state(); static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); @@ -20,11 +22,11 @@ struct TORCH_API PythonTorchFunctionTLS { private: // The mode TLS is split into - // - disabled_, which says whether or not to disable all torch function - // modes + // - disabled_state, which says which part of torch function are disabled // - stack_, which is a vector of modes representing the stack of user // defined modes - bool disabled_; + TorchFunctionDisabledState disabled_state_ = + TorchFunctionDisabledState::ENABLED; std::vector> stack_; }; diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 8a66dc12d4b6..45ba9ae94676 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,6 +1128,7 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", + "DisableTorchFunction", "DisableTorchFunctionSubclass", "Generator", "HalfStorage", diff --git a/test/test_overrides.py b/test/test_overrides.py index 01c763a548fc..3b3a5ed063c7 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1453,6 +1453,27 @@ class B(torch.Tensor): self.assertTrue(called) + def test_disable_subclass_mode(self): + called = False + + class A(TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + nonlocal called + if kwargs is None: + kwargs = {} + called = True + return func(*args, **kwargs) + + class B(torch.Tensor): + pass + + x = B(torch.randn(5)) + with A(): + with torch._C.DisableTorchFunction(): + self.assertNotIsInstance(torch.sum(x), B) + + self.assertFalse(called) + def test_disable_enable_subclass(self): called = False diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 6897c3102df6..46c7396b9b07 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,6 +99,7 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", + "DisableTorchFunction", "DisableTorchFunctionSubclass", "DispatchKey", "DispatchKeySet", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 79dd6386c378..bc4bf03d8161 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,6 +108,7 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp +def DisableTorchFunction(): ... def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp diff --git a/torch/__init__.py b/torch/__init__.py index ef6138cb4866..2abf4ba4b07d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -288,7 +288,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'Generator']: + if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index efe6c18ea0cd..0a9aa53a0bbc 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1597,6 +1597,10 @@ Call this whenever a new thread is created in order to propagate values from "DisableTorchFunctionSubclass", (PyObject*)THPModule_DisableTorchFunctionSubclassType(), /* incref= */ false)); + ASSERT_TRUE(set_module_attr( + "DisableTorchFunction", + (PyObject*)THPModule_DisableTorchFunctionType(), + /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index d26db95f1295..6271cfd5cb99 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -60,13 +60,14 @@ struct DisableAutocast { struct EnableTorchFunction { EnableTorchFunction() - : old_(at::impl::PythonTorchFunctionTLS::is_disabled()) { - at::impl::PythonTorchFunctionTLS::set_disabled(false); + : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) { + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::ENABLED); } ~EnableTorchFunction() { - at::impl::PythonTorchFunctionTLS::set_disabled(old_); + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_); } - bool old_; + at::impl::TorchFunctionDisabledState old_; }; struct EnablePythonDispatcher { diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 516e6b89d43a..589b069250a3 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -11,7 +11,8 @@ PyObject* disabled_torch_function = nullptr; PyObject* disabled_torch_dispatch = nullptr; bool torch_function_enabled() { - return !at::impl::PythonTorchFunctionTLS::is_disabled(); + return at::impl::PythonTorchFunctionTLS::get_disabled_state() == + at::impl::TorchFunctionDisabledState::ENABLED; } PyObject* disabled_torch_function_impl() { @@ -34,20 +35,23 @@ void set_disabled_torch_dispatch_impl(PyObject* value) { typedef struct { PyObject_HEAD /* Type-specific fields go here. */ - bool old_state; + at::impl::TorchFunctionDisabledState old_state; } DisableTorchFunctionSubclass; PyObject* DisableTorchFunctionSubclass__enter( PyObject* self, PyObject* unused) { - ((DisableTorchFunctionSubclass*)self)->old_state = - at::impl::PythonTorchFunctionTLS::is_disabled(); - at::impl::PythonTorchFunctionTLS::set_disabled(true); + const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state(); + ((DisableTorchFunctionSubclass*)self)->old_state = old_state; + if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) { + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); + } Py_RETURN_NONE; } PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { - at::impl::PythonTorchFunctionTLS::set_disabled( + at::impl::PythonTorchFunctionTLS::set_disabled_state( ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } @@ -115,6 +119,81 @@ PyObject* THPModule_DisableTorchFunctionSubclassType() { return (PyObject*)(&DisableTorchFunctionSubclassType); } +typedef struct { + PyObject_HEAD + /* Type-specific fields go here. */ + at::impl::TorchFunctionDisabledState old_state; +} DisableTorchFunction; + +PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { + ((DisableTorchFunctionSubclass*)self)->old_state = + at::impl::PythonTorchFunctionTLS::get_disabled_state(); + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::ALL_DISABLED); + Py_RETURN_NONE; +} + +PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { + at::impl::PythonTorchFunctionTLS::set_disabled_state( + ((DisableTorchFunctionSubclass*)self)->old_state); + Py_RETURN_NONE; +} + +static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT + {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +PyTypeObject DisableTorchFunctionType = { + PyVarObject_HEAD_INIT( + nullptr, + 0) "torch._C.DisableTorchFunction", /* tp_name */ + sizeof(DisableTorchFunction), /* tp_basicsize */ + 0, /* tp_itemsize */ + nullptr, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + nullptr, /* tp_getattr */ + nullptr, /* tp_setattr */ + nullptr, /* tp_reserved */ + nullptr, /* tp_repr */ + nullptr, /* tp_as_number */ + nullptr, /* tp_as_sequence */ + nullptr, /* tp_as_mapping */ + nullptr, /* tp_hash */ + nullptr, /* tp_call */ + nullptr, /* tp_str */ + nullptr, /* tp_getattro */ + nullptr, /* tp_setattro */ + nullptr, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + nullptr, /* tp_doc */ + nullptr, /* tp_traverse */ + nullptr, /* tp_clear */ + nullptr, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + nullptr, /* tp_iter */ + nullptr, /* tp_iternext */ + DisableTorchFunction_methods, /* tp_methods */ + nullptr, /* tp_members */ + nullptr, /* tp_getset */ + nullptr, /* tp_base */ + nullptr, /* tp_dict */ + nullptr, /* tp_descr_get */ + nullptr, /* tp_descr_set */ + 0, /* tp_dictoffset */ + nullptr, /* tp_init */ + PyType_GenericAlloc, /* tp_alloc */ + PyType_GenericNew, /* tp_new */ +}; + +PyObject* THPModule_DisableTorchFunctionType() { + if (PyType_Ready(&DisableTorchFunctionType) < 0) { + return nullptr; + } + + return (PyObject*)(&DisableTorchFunctionType); +} + PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS PyObject *func = nullptr, *types = nullptr, *args = nullptr, @@ -137,11 +216,14 @@ PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { // These are all C-API calls so no exceptions will be raised // and therefore no need for RAII approach to storing // the old value. - bool old_value = at::impl::PythonTorchFunctionTLS::is_disabled(); - at::impl::PythonTorchFunctionTLS::set_disabled(true); + auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state(); + if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) { + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); + } // kwargs can safely be nullptr here. PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs); - at::impl::PythonTorchFunctionTLS::set_disabled(old_value); + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value); return result; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 881a7adb13eb..8fc5118830eb 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,6 +29,7 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); +PyObject* THPModule_DisableTorchFunctionType(); PyObject* THPModule_DisableTorchFunctionSubclassType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); From 3a09d9a129406a05ca7e82c1438f9aa83019f48d Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Thu, 10 Nov 2022 11:48:31 +0100 Subject: [PATCH 009/453] Symintify `broadcast_to` (#88776) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88776 Approved by: https://github.com/ezyang --- .../src/ATen/functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 4 +++- test/functorch/test_aotdispatch.py | 8 -------- test/test_proxy_tensor.py | 13 ------------- 5 files changed, 6 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 24a1c4ab507a..af58da07e048 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(bitwise_or, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); OP_DECOMPOSE(broadcast_tensors); - OP_DECOMPOSE(broadcast_to); + m.impl("broadcast_to", native::broadcast_to_symint); OP_DECOMPOSE(cartesian_prod); OP_DECOMPOSE(cdist); OP_DECOMPOSE(clip); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 31b4011c1281..deb9b949aa5d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced); } -Tensor broadcast_to(const Tensor& self, IntArrayRef size) { - return self.expand(size); +Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) { + return self.expand_symint(size); } std::vector broadcast_tensors(TensorList tensors) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 02b073a1ce78..94c56ce59fcd 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1195,8 +1195,10 @@ device_check: NoCheck device_guard: False -- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) +- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: broadcast_to_symint - func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 09b65a32bfee..22d013642379 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1093,20 +1093,12 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('masked_fill', ''), # could not find kernel - xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ... xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... - xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... - xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... xfail('median', ''), # could not find kernel diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 242be9c78939..8caf41a73906 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1114,23 +1114,10 @@ def f(a, b, c, d, e): xfail('linalg.eig'), xfail('linalg.eigvals'), skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel - xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... - xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition - xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition - xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... - xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition From 4b898a7304246275b250b159dd0ac8e68a6df95d Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Thu, 10 Nov 2022 01:07:50 +0100 Subject: [PATCH 010/453] Symintify `adaptive_avg_pool3d` (#88783) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88783 Approved by: https://github.com/ezyang --- aten/src/ATen/functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/AdaptiveAveragePooling3d.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 6 ++++-- test/test_proxy_tensor.py | 1 - 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index af58da07e048..e31b36d11241 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -45,7 +45,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(adaptive_max_pool1d); OP_DECOMPOSE(adaptive_avg_pool1d); m.impl("adaptive_avg_pool2d", native::adaptive_avg_pool2d_symint); - OP_DECOMPOSE(adaptive_avg_pool3d); + m.impl("adaptive_avg_pool3d", native::adaptive_avg_pool3d_symint); OP_DECOMPOSE(adjoint); OP_DECOMPOSE(arccos); OP_DECOMPOSE(arccosh); diff --git a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp index 427368e2c06a..a0a02ca53160 100644 --- a/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp +++ b/aten/src/ATen/native/AdaptiveAveragePooling3d.cpp @@ -313,7 +313,7 @@ Tensor adaptive_avg_pool3d_cpu(Tensor const& input, IntArrayRef output_size) { return output; } -Tensor adaptive_avg_pool3d(Tensor const& input, IntArrayRef output_size) { +Tensor adaptive_avg_pool3d_symint(Tensor const& input, SymIntArrayRef output_size) { TORCH_CHECK(output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3"); TORCH_CHECK( (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), @@ -326,7 +326,7 @@ Tensor adaptive_avg_pool3d(Tensor const& input, IntArrayRef output_size) { Tensor out = input.mean({-1, -2, -3}, /* keepdim = */ true); return out; } else { - return _adaptive_avg_pool3d(input, output_size); + return _adaptive_avg_pool3d_symint(input, output_size); } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 94c56ce59fcd..de087c0b8a89 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10595,15 +10595,17 @@ autogen: _adaptive_avg_pool2d_backward.out tags: canonical -- func: adaptive_avg_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out) -> Tensor(a!) +- func: adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) python_module: nn dispatch: CPU: adaptive_avg_pool3d_out_cpu CUDA: adaptive_avg_pool3d_out_cuda QuantizedCPU: adaptive_avg_pool3d_out_quantized_cpu -- func: adaptive_avg_pool3d(Tensor self, int[3] output_size) -> Tensor +- func: adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor python_module: nn + dispatch: + CompositeImplicitAutograd: adaptive_avg_pool3d_symint - func: _adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor dispatch: diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8caf41a73906..fcaefbed6635 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1233,7 +1233,6 @@ def f(a, b, c, d, e): xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d.default - couldn't find symbolic meta func... xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... From f98edfcc48c903d0d22a0105b0fafe4ca58121e6 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 10 Nov 2022 17:42:20 +0000 Subject: [PATCH 011/453] Make TorchElastic timer importable on Windows (#88522) Also, add `torch.distributed` to test imports, so that we would not regress in the future Fixes https://github.com/pytorch/pytorch/issues/85427 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88522 Approved by: https://github.com/d4l3k --- test/test_testing.py | 10 ++++++++-- .../elastic/timer/file_based_local_timer.py | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_testing.py b/test/test_testing.py index 5ce07ce454dc..8fe66043e5a1 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1794,8 +1794,14 @@ def test_circular_dependencies(self) -> None: if not sys.version_info >= (3, 9): ignored_modules.append("torch.utils.benchmark") if IS_WINDOWS or IS_MACOS: - # Distributed does not work on Windows or by default on Mac - ignored_modules.append("torch.distributed.") + # Distributed should be importable on Windows(except nn.api.), but not on Mac + if IS_MACOS: + ignored_modules.append("torch.distributed.") + else: + ignored_modules.append("torch.distributed.nn.api.") + ignored_modules.append("torch.distributed.optim.") + ignored_modules.append("torch.distributed.pipeline.") + ignored_modules.append("torch.distributed.rpc.") ignored_modules.append("torch.testing._internal.dist_utils") # And these both end up with transitive dependencies on distributed ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop") diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index 36ae944ec8e4..88fefe1dab81 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -10,6 +10,7 @@ import os import select import signal +import sys import threading import time from typing import Callable, Dict, List, Optional, Set, Tuple @@ -78,7 +79,8 @@ class FileTimerClient(TimerClient): signal: signal, the signal to use to kill the process. Using a negative or zero signal will not kill the process. """ - def __init__(self, file_path: str, signal=signal.SIGKILL) -> None: + def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else + signal.CTRL_C_EVENT)) -> None: # type: ignore[attr-defined] super().__init__() self._file_path = file_path self.signal = signal From 79b049af5ecbd8619acb4196f8c59228832ec99b Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 10 Nov 2022 17:48:16 +0000 Subject: [PATCH 012/453] Switch to setup-nvidia action (#88757) Use the new [setup-nvidia](https://github.com/pytorch/test-infra/blob/main/.github/actions/setup-nvidia/action.yml) action from test-infra. The new action is created so that it can be shared across different PyTorch repos. For examples: * [pytorch/pytorch](https://github.com/pytorch/pytorch/blob/master/.github/scripts/install_nvidia_utils_linux.sh) (fixed by this PR) * [pytorch/tau](https://github.com/pytorch/tau/blob/main/.github/workflows/install_nvidia_utils_linux.sh) (fixed by https://github.com/pytorch/tau/pull/595) * [pytorch/torchsnapshot](https://github.com/pytorch/torchsnapshot/blob/main/.github/scripts/install_nvidia_utils_linux.sh) (fixed by https://github.com/pytorch/torchsnapshot/pull/130) * [torch/multiply](https://github.com/pytorch/multipy/blob/main/.github/scripts/install_nvidia_utils_linux.sh) (fixed by https://github.com/pytorch/multipy/pull/264) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88757 Approved by: https://github.com/seemethere, https://github.com/atalman --- .github/scripts/install_nvidia_utils_linux.sh | 131 ------------------ .github/workflows/_binary-test-linux.yml | 11 +- .github/workflows/_linux-test.yml | 9 +- 3 files changed, 2 insertions(+), 149 deletions(-) delete mode 100755 .github/scripts/install_nvidia_utils_linux.sh diff --git a/.github/scripts/install_nvidia_utils_linux.sh b/.github/scripts/install_nvidia_utils_linux.sh deleted file mode 100755 index 37c6dccd4811..000000000000 --- a/.github/scripts/install_nvidia_utils_linux.sh +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env bash - -set -eou pipefail - - -DISTRIBUTION=$(. /etc/os-release;echo $ID$VERSION_ID) -DRIVER_VERSION="515.76" -DRIVER_FN="NVIDIA-Linux-x86_64-${DRIVER_VERSION}.run" -YUM_REPO_URL="https://nvidia.github.io/nvidia-docker/${DISTRIBUTION}/nvidia-docker.repo" - -install_nvidia_docker2_amzn2() { - ( - set -x - # Needed for yum-config-manager - sudo yum install -y yum-utils - sudo yum-config-manager --add-repo "${YUM_REPO_URL}" - sudo yum install -y nvidia-docker2 - sudo systemctl restart docker - ) -} - -install_nvidia_driver_amzn2() { - ( - set -x - - # Purge any nvidia driver installed from RHEL repo - sudo yum remove -y nvidia-driver-latest-dkms - - # Try to gather more information about the runner and its existing NVIDIA driver if any - echo "Before installing NVIDIA driver" - lspci - lsmod - modinfo nvidia || true - - HAS_NVIDIA_DRIVER=0 - # Check if NVIDIA driver has already been installed - if [ -x "$(command -v nvidia-smi)" ]; then - set +e - # The driver exists, check its version next. Also check only the first GPU if there are more than one of them - # so that the same driver version is not print over multiple lines - INSTALLED_DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader --id=0) - NVIDIA_SMI_STATUS=$? - - if [ "$NVIDIA_SMI_STATUS" -ne 0 ] && [ "$NVIDIA_SMI_STATUS" -ne 14 ]; then - echo "Failed to get NVIDIA driver version ($INSTALLED_DRIVER_VERSION). Continuing" - elif [ "$INSTALLED_DRIVER_VERSION" != "$DRIVER_VERSION" ]; then - echo "NVIDIA driver ($INSTALLED_DRIVER_VERSION) has been installed, but we expect to have $DRIVER_VERSION instead. Continuing" - else - HAS_NVIDIA_DRIVER=1 - echo "NVIDIA driver ($INSTALLED_DRIVER_VERSION) has already been installed. Skipping NVIDIA driver installation" - fi - set -e - fi - - if [ "$HAS_NVIDIA_DRIVER" -eq 0 ]; then - sudo yum groupinstall -y "Development Tools" - # ensure our kernel install is the same as our underlying kernel, - # groupinstall "Development Tools" has a habit of mismatching kernel headers - sudo yum install -y "kernel-devel-uname-r == $(uname -r)" - sudo modprobe backlight - sudo curl -fsL -o /tmp/nvidia_driver "https://s3.amazonaws.com/ossci-linux/nvidia_driver/$DRIVER_FN" - - set +e - sudo /bin/bash /tmp/nvidia_driver -s --no-drm - NVIDIA_INSTALLATION_STATUS=$? - - if [ "$NVIDIA_INSTALLATION_STATUS" -ne 0 ]; then - sudo cat /var/log/nvidia-installer.log - - NVIDIA_DEVICES=$(lspci -D | grep -i NVIDIA | cut -d' ' -f1) - # The GPU can get stuck in a failure state if somehow the test crashs the GPU microcode. When this - # happens, we'll try to reset all NVIDIA devices https://github.com/pytorch/pytorch/issues/88388 - for PCI_ID in "$NVIDIA_DEVICES"; do - DEVICE_ENABLED=$(cat /sys/bus/pci/devices/$PCI_ID/enable) - - echo "Reseting $PCI_ID (enabled state: $DEVICE_ENABLED)" - # This requires sudo permission of course - echo "1" | sudo tee /sys/bus/pci/devices/$PCI_ID/reset - sleep 1 - done - fi - - sudo rm -fv /tmp/nvidia_driver - set -e - fi - - sudo modprobe nvidia || true - echo "After installing NVIDIA driver" - lspci - lsmod - modinfo nvidia || true - - ( - set +e - nvidia-smi - NVIDIA_SMI_STATUS=$? - - # Allowable exit statuses for nvidia-smi, see: https://github.com/NVIDIA/gpu-operator/issues/285 - if [ "$NVIDIA_SMI_STATUS" -eq 0 ] || [ "$NVIDIA_SMI_STATUS" -eq 14 ]; then - echo "INFO: Ignoring allowed status ${NVIDIA_SMI_STATUS}" - else - echo "ERROR: nvidia-smi exited with unresolved status ${NVIDIA_SMI_STATUS}" - exit ${NVIDIA_SMI_STATUS} - fi - set -e - ) - ) -} - -echo "== Installing nvidia driver ${DRIVER_FN} ==" -case "${DISTRIBUTION}" in - amzn*) - install_nvidia_driver_amzn2 - ;; - *) - echo "ERROR: Unknown distribution ${DISTRIBUTION}" - exit 1 - ;; -esac - -# Install container toolkit based on distribution -echo "== Installing nvidia container toolkit for ${DISTRIBUTION} ==" -case "${DISTRIBUTION}" in - amzn*) - install_nvidia_docker2_amzn2 - ;; - *) - echo "ERROR: Unknown distribution ${DISTRIBUTION}" - exit 1 - ;; -esac diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index 12b3d4c64822..471a2af88b8f 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -171,17 +171,8 @@ jobs: path: "${{ runner.temp }}/artifacts/" - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: ${{ inputs.GPU_ARCH_TYPE == 'cuda' }} - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - set -ex - pushd pytorch - bash .github/scripts/install_nvidia_utils_linux.sh - echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" - popd - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index d2f48acca4e8..dc1346205e63 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -74,15 +74,8 @@ jobs: docker-image: ${{ inputs.docker-image }} - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 + uses: pytorch/test-infra/.github/actions/setup-nvidia@main if: contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') - with: - timeout_minutes: 10 - max_attempts: 3 - command: | - set -ex - bash .github/scripts/install_nvidia_utils_linux.sh - echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}" - name: Start monitoring script id: monitor-script From 656d0de6c50c373c7da2960ae6e9ca07b262384f Mon Sep 17 00:00:00 2001 From: Panagiotis Antoniadis Date: Thu, 10 Nov 2022 18:11:29 +0000 Subject: [PATCH 013/453] Change TORCH_INTERNAL_ASSERT to TORCH_CHECK and add a nice error message (#88804) Fixes #87672 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88804 Approved by: https://github.com/ezyang --- tools/autograd/templates/python_variable_methods.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 2cd847b73405..6ad042c0b903 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -1193,7 +1193,7 @@ static PyObject* THPVariable_set_( case 3: { // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) auto dispatch_set_ = [](const Tensor& self, const Tensor& source) -> Tensor { - TORCH_INTERNAL_ASSERT(source.dtype() == self.dtype()); + TORCH_CHECK(source.dtype() == self.dtype(), "Could not set tensor of type ", source.dtype(), " to a tensor of type ", self.dtype()); pybind11::gil_scoped_release no_gil; return self.set_(source); }; From 1e4079a4762f515406c7f4654e7a4340914898ef Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Thu, 10 Nov 2022 04:42:37 +0000 Subject: [PATCH 014/453] [nnc] Disable opaque pointers mode in LLVM backend to allow getPointerElementType (#88798) As of LLVM 15 typed pointers are going away: https://llvm.org/docs/OpaquePointers.html. Thus `getPointerElementType` is no longer legal, since pointers are all opaque. I don't totally remember why we use it so prolifically, or whether there's an easy change to get rid of it, or whether we'd need a significant refactor to carry around `Type`s alongside `Value`s. But in any case, NNC is deprecated (see: TorchInductor) and will hopefully be gone before LLVM 16 is a thing. For now, we can apply the hack of turning off opaque pointer mode on the LLVMContext. Differential Revision: [D41176215](https://our.internmc.facebook.com/intern/diff/D41176215) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88798 Approved by: https://github.com/desertfire --- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 78521efc240e..1ca5665b4432 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -446,6 +446,9 @@ LLVMCodeGenImpl::LLVMCodeGenImpl( irb_(getContext()), kernel_func_name_(std::move(kernel_func_name)), bufsExtAlloc_(ExternalAllocBufFinder::find(stmt)) { +#if LLVM_VERSION_MAJOR >= 15 + context_->setOpaquePointers(false); +#endif if (!triple) { triple = LLVMTargetTriple(); } From d3178465eed4895fa12430943db37d00dd2c483b Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Thu, 10 Nov 2022 18:17:20 +0000 Subject: [PATCH 015/453] [dynamo] `VariableTracker.call_method` requires a name (#88311) Summary: as title Test Plan: Before: N2743445, After: N2748186. Note there's a new error, but at least we got past the easy one. Differential Revision: D40938415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88311 Approved by: https://github.com/brad-mengchi --- test/test_datapipe.py | 3 ++- torch/_dynamo/variables/user_defined.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_datapipe.py b/test/test_datapipe.py index dbc5a5ae8071..b5de6a5f4006 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -33,7 +33,7 @@ import torch.utils.data.datapipes as dp import torch.utils.data.graph import torch.utils.data.graph_settings -from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings +from torch.testing._internal.common_utils import TestCase, run_tests, suppress_warnings, skipIfTorchDynamo from torch.utils.data import ( DataLoader, DataChunk, @@ -220,6 +220,7 @@ def test_dir(self): for api in ['open', 'read', 'close']: self.assertTrue(api in s) + @skipIfTorchDynamo def test_api(self): fd = TestStreamWrapper._FakeFD("") wrap_fd = StreamWrapper(fd) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2d33c8328268..09d7893bef66 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -68,7 +68,7 @@ def call_method( return variables.ListVariable(subs_as_vars, **options) - return super().call_method(tx, args, kwargs) + return super().call_method(tx, name, args, kwargs) def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" From 6bf2776ac1d16692778f052ba6796d3308ea97c6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 10 Nov 2022 15:17:51 +0000 Subject: [PATCH 016/453] [FSDP][Perf] Do not call `pad` in no-padding case (#88769) - Calling `F.pad()` issues a pad kernel from the CPU even if there is no padding needed, which can incur some non-negligible overhead. This PR removes that unnecessary call for the no-padding case. - This PR also does not zero the newly-allocated sharded gradient tensor before the reduce-scatter if `use_orig_params=True` because there is no need. The reduce-scatter will fill the tensor anyway, and we do not care about the values in the padding. For `use_orig_params=False`, the padding is exposed to the user, so we preserve the existing semantics of zeroing it. I left a to-do to follow-up since we may optimize that. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88769 Approved by: https://github.com/zhaojuanmao --- torch/distributed/fsdp/_runtime_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index e0fa12e19c2a..9aee15a016c4 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -537,8 +537,12 @@ def _post_backward_hook( numel_to_pad = ( state.world_size * chunks[0].numel() - unsharded_grad.numel() ) - padded_unsharded_grad = F.pad(unsharded_grad, [0, numel_to_pad]) - new_sharded_grad = torch.zeros_like(chunks[0]) # padded + padded_unsharded_grad = ( + F.pad(unsharded_grad, [0, numel_to_pad]) + if numel_to_pad > 0 + else unsharded_grad + ) + new_sharded_grad = torch.empty_like(chunks[0]) # padded state._communication_hook( state._communication_hook_state, padded_unsharded_grad, From d157fca59c3f28b532f5e845c48df0e2bedbfa39 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 10 Nov 2022 18:19:51 +0000 Subject: [PATCH 017/453] Revert "Symintify `broadcast_to` (#88776)" This reverts commit 3a09d9a129406a05ca7e82c1438f9aa83019f48d. Reverted https://github.com/pytorch/pytorch/pull/88776 on behalf of https://github.com/malfet due to Broke functorch/test_aotdispatch on M1, see https://hud.pytorch.org/pytorch/pytorch/commit/3a09d9a129406a05ca7e82c1438f9aa83019f48d --- .../src/ATen/functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 4 +--- test/functorch/test_aotdispatch.py | 8 ++++++++ test/test_proxy_tensor.py | 13 +++++++++++++ 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index e31b36d11241..66aaa53bfcc1 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(bitwise_or, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); OP_DECOMPOSE(broadcast_tensors); - m.impl("broadcast_to", native::broadcast_to_symint); + OP_DECOMPOSE(broadcast_to); OP_DECOMPOSE(cartesian_prod); OP_DECOMPOSE(cdist); OP_DECOMPOSE(clip); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index deb9b949aa5d..31b4011c1281 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced); } -Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) { - return self.expand_symint(size); +Tensor broadcast_to(const Tensor& self, IntArrayRef size) { + return self.expand(size); } std::vector broadcast_tensors(TensorList tensors) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index de087c0b8a89..0ea606f5e1fb 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1195,10 +1195,8 @@ device_check: NoCheck device_guard: False -- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) +- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function, method - dispatch: - CompositeImplicitAutograd: broadcast_to_symint - func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 22d013642379..09b65a32bfee 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1093,12 +1093,20 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('masked_fill', ''), # could not find kernel + xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ... xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... + xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc... + xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc... + xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... + xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... xfail('median', ''), # could not find kernel diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fcaefbed6635..fbeaa04aa65d 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1114,10 +1114,23 @@ def f(a, b, c, d, e): xfail('linalg.eig'), xfail('linalg.eigvals'), skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel + xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition + xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... + xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition + xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition + xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... + xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition + xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition From 48b58930cbfa725ac25a9303d496c76bf983574d Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Thu, 10 Nov 2022 18:32:25 +0000 Subject: [PATCH 018/453] [Inductor] Build FX Linear + Permute Vertical Fusion in Inductor (#88566) Summary: Build fx-based linear/matmul/bmm + permute/transpose vertical fusion in Inductor For an internal Ads model: 1.15x -> 1.36x speedup Differential Revision: D41071665 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88566 Approved by: https://github.com/jansel, https://github.com/jianyuh --- test/inductor/test_torchinductor.py | 106 +++++++++++++++ torch/_inductor/config.py | 3 + torch/_inductor/overrides.py | 199 ++++++++++++++++++++++++++++ 3 files changed, 308 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index db6c5dfc2bd1..064f04291a8e 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10,6 +10,7 @@ import typing import unittest import weakref +from typing import Any, Callable from unittest.mock import patch import torch @@ -18,6 +19,7 @@ from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.testing._internal.common_utils import ( IS_FBCODE, @@ -40,6 +42,14 @@ from torch._inductor import codecache, config, metrics from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing + from torch._inductor.overrides import ( + linear_permute_fusion, + linear_transpose, + permute_linear_fusion, + permute_matmul_fusion, + transpose_linear, + transpose_matmul, + ) from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed @@ -129,6 +139,29 @@ def maybe_test(*args, **kwargs): return wrap_test +PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] + + +def chain_passes(*passes: PassFunc) -> PassFunc: + def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module) + return module + + return parent_pass + + +def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: + return sum( + [ + 1 if (n.op == "call_function" and n.target == target_op) else 0 + for n in module.graph.nodes + ] + ) + + class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -1582,6 +1615,79 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) + def test_linear_permute_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + a0 = torch.nn.functional.linear(input, self.weight, self.bias) + b0 = a0.permute(0, 2, 1) + return b0 + + m, k, n = 16, 8, 4 + trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, m, k) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_linear_transpose = count_call_function(traced, linear_transpose) + self.assertEqual(num_linear, 0) + self.assertEqual(num_linear_transpose, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_linear_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.nn.functional.linear(input1, self.weight, self.bias) + return output + + m, k, n = 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, k, m) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_transpose_linear = count_call_function(traced, transpose_linear) + self.assertEqual(num_linear, 0) + self.assertEqual(num_transpose_linear, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_bmm_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, batch: int, k: int, n: int): + super().__init__() + self.other = torch.randn(batch, k, n) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.bmm(input1, self.other) + return output + + batch, m, k, n = 6, 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) + module = TestModule(batch, k, n).eval() + input = torch.randn(batch, k, m) + traced = trace_func(module, [input]) + num_bmm = count_call_function(traced, torch.bmm) + num_transpose_matmul = count_call_function(traced, transpose_matmul) + self.assertEqual(num_bmm, 0) + self.assertEqual(num_transpose_matmul, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + def test_slice1(self): def fn(a): return ( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 910e6d20b4d6..c9b7623cf528 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -67,6 +67,9 @@ # How to import torchdynamo, either torchdynamo or torch.dynamo dynamo_import = inductor_import.replace("inductor", "dynamo") +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + # config specific to codegen/cpp.pp class cpp: diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 581e1996a436..69a5bc6710f8 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -19,6 +19,8 @@ from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.overrides import TorchFunctionMode +from . import config + log = logging.getLogger(__name__) @@ -313,6 +315,14 @@ def check_node_is_binary(node): def fuse_fx(gm: torch.fx.GraphModule, example_inputs): + if config.permute_fusion: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm).propagate(*example_inputs) + gm = linear_permute_fusion(gm) + gm = permute_linear_fusion(gm) + gm = permute_matmul_fusion(gm) + # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm @@ -408,6 +418,195 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["weight"] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] + else: + return self.node.kwargs["bias"] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["other"] + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if ( + node.op == "call_method" + and node.target == "permute" + and check_permute(node) + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and node.target == torch.nn.functional.linear: + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.bmm or node.target == torch.matmul + ): + normalized = NormalizedMatmulNode(node) + A = normalized.get_input() + B = normalized.get_other() + Atrans = Btrans = False + if A.op == "call_method" and A.target == "permute" and check_permute(A): + Atrans = True + if len(A.args) > 0: + A = A.args[0] + else: + A = A.kwargs["input"] + + if B.op == "call_method" and B.target == "permute" and check_permute(B): + Btrans = True + if len(B.args) > 0: + B = B.args[0] + else: + B = B.kwargs["input"] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(A, B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) + + def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): From 3b8245ab12d54723b6e7bcceb176235f13f0348b Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Thu, 10 Nov 2022 18:34:19 +0000 Subject: [PATCH 019/453] [LTC] Make ComputePostOrder accept const T pointers (#88773) Summary: Since `c10::ArrayRef` now support `c10::ArrayRef`, let's restore `ComputePostOrder` to accept `const Node*` again, which is more suitable for the context of the given helpers. Test Plan: CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88773 Approved by: https://github.com/JackCaoG --- .github/ci_commit_pins/xla.txt | 2 +- test/cpp/lazy/test_ir_util.cpp | 2 +- torch/csrc/lazy/backend/backend_interface.cpp | 2 +- torch/csrc/lazy/backend/backend_interface.h | 4 +-- torch/csrc/lazy/backend/lowering_context.cpp | 2 +- torch/csrc/lazy/backend/lowering_context.h | 4 +-- torch/csrc/lazy/core/debug_util.cpp | 2 +- torch/csrc/lazy/core/ir_dump_util.cpp | 16 +++++----- torch/csrc/lazy/core/ir_dump_util.h | 12 ++++---- torch/csrc/lazy/core/ir_util.cpp | 30 +++++++++---------- torch/csrc/lazy/core/ir_util.h | 11 +++---- torch/csrc/lazy/core/lazy_graph_executor.cpp | 2 +- torch/csrc/lazy/core/lazy_graph_executor.h | 2 +- torch/csrc/lazy/python/init.cpp | 10 +++---- .../csrc/lazy/ts_backend/ts_backend_impl.cpp | 7 +++-- .../lazy/ts_backend/ts_lowering_context.cpp | 2 +- .../lazy/ts_backend/ts_lowering_context.h | 2 +- 17 files changed, 56 insertions(+), 56 deletions(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 7ec9661a1ce4..957272e8578b 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -7889d2d3be16675943d84e4a4133ed7c245a623f +08121e41079319cd369f82f523f5a714a0563f9d diff --git a/test/cpp/lazy/test_ir_util.cpp b/test/cpp/lazy/test_ir_util.cpp index 2befb04236ab..0b2bfc7614b1 100644 --- a/test/cpp/lazy/test_ir_util.cpp +++ b/test/cpp/lazy/test_ir_util.cpp @@ -52,7 +52,7 @@ TEST(IrUtilTest, BasicTest) { dynamic_cast(b.get())->AddOperand(Value(d, 0)); dynamic_cast(c.get())->AddOperand(Value(d, 0)); - std::vector postorder = Util::ComputePostOrder({a.get()}); + auto postorder = Util::ComputePostOrder({a.get()}); EXPECT_EQ(postorder.size(), 4); EXPECT_EQ(postorder.at(0), d.get()); EXPECT_EQ(postorder.at(1), c.get()); diff --git a/torch/csrc/lazy/backend/backend_interface.cpp b/torch/csrc/lazy/backend/backend_interface.cpp index 250a8847351c..0fb3257c90a9 100644 --- a/torch/csrc/lazy/backend/backend_interface.cpp +++ b/torch/csrc/lazy/backend/backend_interface.cpp @@ -38,7 +38,7 @@ at::Tensor MakeTensorFromComputationData( std::unique_ptr LoweringContext::Create( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) { return getBackend()->CreateLoweringContext( name, device, post_order, emit_status); diff --git a/torch/csrc/lazy/backend/backend_interface.h b/torch/csrc/lazy/backend/backend_interface.h index a70591c2a19c..f94d3b602e52 100644 --- a/torch/csrc/lazy/backend/backend_interface.h +++ b/torch/csrc/lazy/backend/backend_interface.h @@ -59,7 +59,7 @@ class TORCH_API BackendImplInterface { // Gets backend data if the node is a device data node. Otherwise returns // nullptr - virtual BackendDataPtr GetComputationDataFromNode(Node*) const = 0; + virtual BackendDataPtr GetComputationDataFromNode(const Node*) const = 0; virtual at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, @@ -72,7 +72,7 @@ class TORCH_API BackendImplInterface { virtual std::unique_ptr CreateLoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) const = 0; virtual std::unique_ptr CreateLoweringContext( diff --git a/torch/csrc/lazy/backend/lowering_context.cpp b/torch/csrc/lazy/backend/lowering_context.cpp index 64922a1b3e13..635ee4891cc7 100644 --- a/torch/csrc/lazy/backend/lowering_context.cpp +++ b/torch/csrc/lazy/backend/lowering_context.cpp @@ -9,7 +9,7 @@ LoweringContext::LoweringContext(const std::string& name, BackendDevice device) LoweringContext::LoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) : device_(std::move(device)), emit_status_(std::move(emit_status)) {} diff --git a/torch/csrc/lazy/backend/lowering_context.h b/torch/csrc/lazy/backend/lowering_context.h index 6f487aef7f74..49e7b8be58cb 100644 --- a/torch/csrc/lazy/backend/lowering_context.h +++ b/torch/csrc/lazy/backend/lowering_context.h @@ -42,7 +42,7 @@ class TORCH_API LoweringContext { LoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status); virtual ~LoweringContext() = default; @@ -50,7 +50,7 @@ class TORCH_API LoweringContext { static std::unique_ptr Create( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status); static std::unique_ptr Create( diff --git a/torch/csrc/lazy/core/debug_util.cpp b/torch/csrc/lazy/core/debug_util.cpp index 50f42b718128..50077d498a75 100644 --- a/torch/csrc/lazy/core/debug_util.cpp +++ b/torch/csrc/lazy/core/debug_util.cpp @@ -88,7 +88,7 @@ std::string DebugUtil::GetTensorsGraphInfo( c10::ArrayRef tensors, const std::vector* indices, GraphFormat format) { - std::vector root_nodes; + std::vector root_nodes; std::vector root_values; std::vector root_hashes; torch::lazy::Unique unique_device; diff --git a/torch/csrc/lazy/core/ir_dump_util.cpp b/torch/csrc/lazy/core/ir_dump_util.cpp index eff2873d668d..19cb2ae7b162 100644 --- a/torch/csrc/lazy/core/ir_dump_util.cpp +++ b/torch/csrc/lazy/core/ir_dump_util.cpp @@ -80,7 +80,7 @@ c10::optional ParseAttrTag( return tag; } -NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { +NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { NodeIdMap id_map; for (auto node : post_order) { TORCH_CHECK(id_map.emplace(node, id_map.size()).second, node->ToString()); @@ -89,7 +89,7 @@ NodeIdMap GenerateIdMap(c10::ArrayRef post_order) { } std::unordered_map GetRootsIds( - c10::ArrayRef roots) { + c10::ArrayRef roots) { std::unordered_map roots_ids; for (const auto i : c10::irange(roots.size())) { roots_ids[roots[i]] = i; @@ -178,14 +178,14 @@ std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) { } // namespace -std::string DumpUtil::ToDot(c10::ArrayRef nodes) { +std::string DumpUtil::ToDot(c10::ArrayRef nodes) { auto post_order = Util::ComputePostOrder(nodes); return PostOrderToDot(post_order, nodes); } std::string DumpUtil::PostOrderToDot( - c10::ArrayRef post_order, - c10::ArrayRef roots) { + c10::ArrayRef post_order, + c10::ArrayRef roots) { std::unordered_map roots_ids = GetRootsIds(roots); NodeIdMap id_map = GenerateIdMap(post_order); std::stringstream ss; @@ -218,14 +218,14 @@ std::string DumpUtil::PostOrderToDot( return ss.str(); } -std::string DumpUtil::ToText(c10::ArrayRef nodes) { +std::string DumpUtil::ToText(c10::ArrayRef nodes) { auto post_order = Util::ComputePostOrder(nodes); return PostOrderToText(post_order, nodes); } std::string DumpUtil::PostOrderToText( - c10::ArrayRef post_order, - c10::ArrayRef roots) { + c10::ArrayRef post_order, + c10::ArrayRef roots) { std::unordered_map roots_ids = GetRootsIds(roots); NodeIdMap id_map = GenerateIdMap(post_order); std::stringstream ss; diff --git a/torch/csrc/lazy/core/ir_dump_util.h b/torch/csrc/lazy/core/ir_dump_util.h index 22cf139bfbd6..4b4e1e0749b2 100644 --- a/torch/csrc/lazy/core/ir_dump_util.h +++ b/torch/csrc/lazy/core/ir_dump_util.h @@ -11,17 +11,17 @@ class BackendDevice; class TORCH_API DumpUtil { public: - static std::string ToDot(c10::ArrayRef nodes); + static std::string ToDot(c10::ArrayRef nodes); static std::string PostOrderToDot( - c10::ArrayRef post_order, - c10::ArrayRef roots); + c10::ArrayRef post_order, + c10::ArrayRef roots); - static std::string ToText(c10::ArrayRef nodes); + static std::string ToText(c10::ArrayRef nodes); static std::string PostOrderToText( - c10::ArrayRef post_order, - c10::ArrayRef roots); + c10::ArrayRef post_order, + c10::ArrayRef roots); static std::string ToBackend( c10::ArrayRef values, diff --git a/torch/csrc/lazy/core/ir_util.cpp b/torch/csrc/lazy/core/ir_util.cpp index 2d463bb99d5f..b2a2a8ecfa20 100644 --- a/torch/csrc/lazy/core/ir_util.cpp +++ b/torch/csrc/lazy/core/ir_util.cpp @@ -5,13 +5,12 @@ namespace torch { namespace lazy { -std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { - std::vector post_order; - std::vector queue; - // std::vector to c10::ArrayRef conversion is not supported, - // so we need to drop const in the return vector and use const_cast here. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - queue.push_back(const_cast(node)); +std::vector Util::ComputePostOrder( + const Node* node, + EmissionMap* emap) { + std::vector post_order; + std::vector queue; + queue.push_back(node); while (!queue.empty()) { node = queue.back(); auto it = emap->find(node); @@ -20,8 +19,7 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { for (auto& output : node->operands()) { auto oit = emap->find(output.node); if (oit == emap->end()) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - queue.push_back(const_cast(output.node)); + queue.push_back(output.node); } else { TORCH_CHECK( oit->second != kEmitting, @@ -38,8 +36,7 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { output.node->ToString()); } (*emap)[node] = kEmitted; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - post_order.push_back(const_cast(node)); + post_order.push_back(node); queue.pop_back(); } else { TORCH_CHECK(it->second == kEmitted); @@ -49,10 +46,10 @@ std::vector Util::ComputePostOrder(const Node* node, EmissionMap* emap) { return post_order; } -std::vector Util::ComputePostOrder( - c10::ArrayRef nodes, +std::vector Util::ComputePostOrder( + c10::ArrayRef nodes, EmissionMap* emap) { - std::vector post_order; + std::vector post_order; for (auto node : nodes) { auto node_post_order = ComputePostOrder(node, emap); post_order.insert( @@ -61,12 +58,13 @@ std::vector Util::ComputePostOrder( return post_order; } -std::vector Util::ComputePostOrder(c10::ArrayRef nodes) { +std::vector Util::ComputePostOrder( + c10::ArrayRef nodes) { EmissionMap emap; return ComputePostOrder(nodes, &emap); } -size_t Util::GetGraphSize(c10::ArrayRef nodes) { +size_t Util::GetGraphSize(c10::ArrayRef nodes) { return ComputePostOrder(nodes).size(); } diff --git a/torch/csrc/lazy/core/ir_util.h b/torch/csrc/lazy/core/ir_util.h index a95b1a523bfa..df3d0fd7ac40 100644 --- a/torch/csrc/lazy/core/ir_util.h +++ b/torch/csrc/lazy/core/ir_util.h @@ -25,21 +25,22 @@ class TORCH_API Util { // this API. The returned post-order can be empty if the node has already been // emitted inside the emission map. An error is generated if a loop is // detected. - static std::vector ComputePostOrder( + static std::vector ComputePostOrder( const Node* node, EmissionMap* emap); - static std::vector ComputePostOrder( - c10::ArrayRef nodes, + static std::vector ComputePostOrder( + c10::ArrayRef nodes, EmissionMap* emap); // Same as above, but computes the post order on the set of nodes specified as // argument. - static std::vector ComputePostOrder(c10::ArrayRef nodes); + static std::vector ComputePostOrder( + c10::ArrayRef nodes); // Retrieves the number of nodes within the graph whose sink are passed in the // nodes argument. - static size_t GetGraphSize(c10::ArrayRef nodes); + static size_t GetGraphSize(c10::ArrayRef nodes); }; } // namespace lazy diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index 4989ce24a0ef..1201971f3bc2 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -721,7 +721,7 @@ std::vector LazyGraphExecutor::FetchTensorData( LazyGraphExecutor::PostOrderData LazyGraphExecutor::RunPostOrder( const std::vector& tensors, SyncTensorCollection* coll) { - std::vector roots; + std::vector roots; roots.reserve(coll->indices.size()); for (auto index : coll->indices) { Value ir_value = tensors.at(index)->CurrentIrValue(); diff --git a/torch/csrc/lazy/core/lazy_graph_executor.h b/torch/csrc/lazy/core/lazy_graph_executor.h index b7e10374fbb7..9894295f3b32 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.h +++ b/torch/csrc/lazy/core/lazy_graph_executor.h @@ -158,7 +158,7 @@ class TORCH_API LazyGraphExecutor { }; struct PostOrderData { - std::vector post_order; + std::vector post_order; Util::EmissionMap emission_map; std::vector parameters_data; std::vector parameter_sequence; diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 2d421a3eb2ae..774df68e26de 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -42,9 +42,9 @@ std::ptrdiff_t GetTensorId(const at::Tensor& tensor) { std::string GetTensorsDump( const std::vector& tensors, - const std::function)>& + const std::function)>& coverter) { - std::vector nodes; + std::vector nodes; std::vector values; for (auto& tensor : tensors) { auto inner = at::functionalization::impl::from_functional_tensor(tensor); @@ -142,7 +142,7 @@ void initLazyBindings(PyObject* module) { lazy.def( "_get_tensors_text", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto coverter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToText(nodes); }; return GetTensorsDump(tensors, coverter); @@ -150,7 +150,7 @@ void initLazyBindings(PyObject* module) { lazy.def( "_get_tensors_dot", [](const std::vector& tensors) -> std::string { - auto coverter = [](c10::ArrayRef nodes) { + auto coverter = [](c10::ArrayRef nodes) { return torch::lazy::DumpUtil::ToDot(nodes); }; return GetTensorsDump(tensors, coverter); @@ -222,7 +222,7 @@ void initLazyBindings(PyObject* module) { [](const std::vector& tensors) -> std::pair, std::vector> { #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - std::vector roots; + std::vector roots; for (auto& tensor : tensors) { auto xtensor = TryGetLtcTensor(tensor); roots.push_back(xtensor->GetIrValue().node.get()); diff --git a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp index 4003a005fbfa..488dd9f24d9d 100644 --- a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp +++ b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp @@ -61,7 +61,7 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { std::unique_ptr CreateLoweringContext( const std::string& name, torch::lazy::BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, torch::lazy::Util::EmissionMap emit_status) const override { return std::make_unique( name, device, post_order, emit_status); @@ -113,8 +113,9 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { return std::make_shared(scalar, device); } - torch::lazy::BackendDataPtr GetComputationDataFromNode(Node* node) const { - auto* device_data_node = dynamic_cast(node); + torch::lazy::BackendDataPtr GetComputationDataFromNode( + const Node* node) const { + auto* device_data_node = DeviceData::Cast(node); if (!device_data_node) { return nullptr; } diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp index ff3d1aa07b78..ad1cac4870f5 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.cpp @@ -17,7 +17,7 @@ TSLoweringContext::TSLoweringContext( TSLoweringContext::TSLoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status) : torch::lazy::LoweringContext(name, device, post_order, emit_status), graph_(std::make_shared()), diff --git a/torch/csrc/lazy/ts_backend/ts_lowering_context.h b/torch/csrc/lazy/ts_backend/ts_lowering_context.h index 700f27d505fd..0ad2b669c0e6 100644 --- a/torch/csrc/lazy/ts_backend/ts_lowering_context.h +++ b/torch/csrc/lazy/ts_backend/ts_lowering_context.h @@ -71,7 +71,7 @@ class TORCH_API TSLoweringContext : public LoweringContext { TSLoweringContext( const std::string& name, BackendDevice device, - c10::ArrayRef post_order, + c10::ArrayRef post_order, Util::EmissionMap emit_status); size_t AddResult(const Output& output) override { From 4bcf2c53e521f5c61615b0adb84312513ad583f2 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 10 Nov 2022 19:22:09 +0000 Subject: [PATCH 020/453] Add warnings & regressions info text (#88837) Add text about what warnings and accuracy regressions dropdowns mean. Sample: https://github.com/pytorch/torchdynamo/issues/1831#issuecomment-1310770285 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88837 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 9c0538368b44..99c70426cd36 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -707,7 +707,12 @@ def get_metric_title(self, metric): def generate_warnings(self): title = "## Warnings ##" - body = "" + body = ( + "We flag models where:\n\n" + " - speedup < 0.95x\n" + " - compilation latency > 120 sec.\n" + " - compression ratio < 0.9\n\n" + ) for metric in [ "speedup", "compilation_latency", @@ -858,9 +863,14 @@ def find_last_2(self, suite, device, dtype, compiler): def generate_comment(self): title = "## Accuracy Regressions ##\n" - body = "" + body = ( + "For each relevant compiler, we compare the most recent 2 reports " + "(that run actually the compiler) to find models where previously " + "successful accuracy tests now fail.\n\n" + ) dtype = self.args.dtypes[0] device = self.args.devices[0] + regressions_present = False for suite in self.args.suites: dfs = [] for compiler in self.args.flag_compilers: @@ -893,6 +903,7 @@ def generate_comment(self): df = pd.concat(dfs, axis=0) if df.empty: continue + regressions_present = True tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") str_io = io.StringIO() str_io.write("\n") @@ -902,6 +913,9 @@ def generate_comment(self): str_io.write("~~~\n") body += str_io.getvalue() + if not regressions_present: + body += "No accuracy regressions found.\n" + comment = generate_dropdown_comment(title, body) with open(f"{self.args.output_dir}/gh_accuracy_regression.txt", "w") as gh_fh: From 1d54ce9d5d4e44416a55ad002b8dc9b984ecc906 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Thu, 10 Nov 2022 06:31:46 -0800 Subject: [PATCH 021/453] [14/N] Refactor _new_process_group_helper() to remove repeated code (#88351) Changes: - refactor parts of `_new_process_group_helper()` to remove repeated code Differential Revision: [D41188274](https://our.internmc.facebook.com/intern/diff/D41188274) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88351 Approved by: https://github.com/kwen2501 --- torch/distributed/distributed_c10d.py | 92 ++++++++------------------- 1 file changed, 25 insertions(+), 67 deletions(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 41d0ee21d3e3..4a132d141e00 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -925,8 +925,6 @@ def _new_process_group_helper( pg = ProcessGroupMPI.create(global_ranks_in_group) if not pg: return GroupMember.NON_GROUP_MEMBER - _world.pg_map[pg] = (Backend.MPI, None) - _world.pg_names[pg] = group_name else: # If this is a subgroup (which means group_ranks is specified), # we check if the current process is a member of the new group. @@ -943,27 +941,6 @@ def _new_process_group_helper( if pg_options is not None: raise RuntimeError("GLOO options not supported") pg = ProcessGroupGloo(prefix_store, group_rank, group_size, timeout=timeout) - # In debug mode and if GLOO is available, wrap in a wrapper PG that - # enables enhanced collective checking for debugability. - if get_debug_level() == DebugLevel.DETAIL: - if not _GLOO_AVAILABLE: - logger.info( - """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but - GLOO is not available. Build with Gloo to - create a wrapper process group in debug mode - to aid collective desynchronization debugging.""" - ) - else: - pg = _create_process_group_wrapper( - wrapped_pg=pg, - store_prefix=group_name, - store=store, - rank=group_rank, - world_size=group_size, - timeout=timeout, - ) - _world.pg_map[pg] = (Backend.GLOO, store) - _world.pg_names[pg] = group_name elif backend == Backend.NCCL: if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL " "built in") @@ -978,54 +955,12 @@ def _new_process_group_helper( pg_options._timeout = timeout pg = ProcessGroupNCCL(prefix_store, group_rank, group_size, pg_options) - # In debug mode and if GLOO is available, wrap in a wrapper PG that - # enables enhanced collective checking for debugability. - if get_debug_level() == DebugLevel.DETAIL: - if not _GLOO_AVAILABLE: - logger.info( - """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but - GLOO is not available. Build with Gloo to - create a wrapper process group in debug mode - to aid collective desynchronization debugging.""" - ) - else: - pg = _create_process_group_wrapper( - wrapped_pg=pg, - store_prefix=group_name, - store=store, - rank=group_rank, - world_size=group_size, - timeout=timeout, - ) - _world.pg_map[pg] = (Backend.NCCL, store) - _world.pg_names[pg] = group_name elif backend == Backend.UCC and is_ucc_available(): # TODO: once UCC plugin is fully deprecated, remove # is_ucc_available() from above elif-condition and raise # RuntimeError if is_ucc_available() returns false. pg = ProcessGroupUCC(prefix_store, group_rank, group_size, timeout=timeout) - # In debug mode and if GLOO is available, wrap in a wrapper PG that - # enables enhanced collective checking for debugability. - if get_debug_level() == DebugLevel.DETAIL: - if not _GLOO_AVAILABLE: - logger.info( - """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but - GLOO is not available. Build with Gloo to - create a wrapper process group in debug mode - to aid collective desynchronization debugging.""" - ) - else: - pg = _create_process_group_wrapper( - wrapped_pg=pg, - store_prefix=group_name, - store=store, - rank=group_rank, - world_size=group_size, - timeout=timeout, - ) - _world.pg_map[pg] = (Backend.UCC, store) - _world.pg_names[pg] = group_name else: assert backend.upper() in Backend._plugins, ( f"Unknown c10d backend type {backend.upper()}" @@ -1047,9 +982,32 @@ def _new_process_group_helper( dist_backend_opts.global_ranks_in_group = global_ranks_in_group pg = creator_fn(dist_backend_opts, pg_options) - _world.pg_map[pg] = (backend, store) - _world.pg_names[pg] = group_name + # Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set + if backend in [Backend.GLOO, Backend.NCCL, Backend.UCC]: + # In debug mode and if GLOO is available, wrap in a wrapper PG that + # enables enhanced collective checking for debuggability. + if get_debug_level() == DebugLevel.DETAIL: + if not _GLOO_AVAILABLE: + logger.info( + """TORCH_DISTRIBUTED_DEBUG was set to DETAIL, but + GLOO is not available. Build with Gloo to + create a wrapper process group in debug mode + to aid collective desynchronization debugging.""" + ) + else: + pg = _create_process_group_wrapper( + wrapped_pg=pg, + store_prefix=group_name, + store=store, + rank=group_rank, + world_size=group_size, + timeout=timeout, + ) + + # update global state + _world.pg_map[pg] = (backend, store) + _world.pg_names[pg] = group_name return pg From 98ecd06580b667441a45bfe7a67bc95ddb8a9353 Mon Sep 17 00:00:00 2001 From: Felix Divo <4403130+felixdivo@users.noreply.github.com> Date: Thu, 10 Nov 2022 19:29:29 +0000 Subject: [PATCH 022/453] Bring Unfold/Fold param doc order in line with code (#88819) Now the first parameter (if used as a positional argument) is the first that is listed in the docs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88819 Approved by: https://github.com/ngimel --- torch/nn/modules/fold.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/nn/modules/fold.py b/torch/nn/modules/fold.py index 5380cf155c90..a7b1f758dd5a 100644 --- a/torch/nn/modules/fold.py +++ b/torch/nn/modules/fold.py @@ -50,13 +50,13 @@ class Fold(Module): output_size (int or tuple): the shape of the spatial dimensions of the output (i.e., ``output.sizes()[2:]``) kernel_size (int or tuple): the size of the sliding blocks - stride (int or tuple): the stride of the sliding blocks in the input - spatial dimensions. Default: 1 - padding (int or tuple, optional): implicit zero padding to be added on - both sides of input. Default: 0 dilation (int or tuple, optional): a parameter that controls the stride of elements within the neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then @@ -192,13 +192,13 @@ class Unfold(Module): Args: kernel_size (int or tuple): the size of the sliding blocks - stride (int or tuple, optional): the stride of the sliding blocks in the input - spatial dimensions. Default: 1 - padding (int or tuple, optional): implicit zero padding to be added on - both sides of input. Default: 0 dilation (int or tuple, optional): a parameter that controls the stride of elements within the neighborhood. Default: 1 + padding (int or tuple, optional): implicit zero padding to be added on + both sides of input. Default: 0 + stride (int or tuple, optional): the stride of the sliding blocks in the input + spatial dimensions. Default: 1 * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or :attr:`stride` is an int or a tuple of length 1, their values will be From 90cf14ddf691bfae2d5c793376c68921b7111fde Mon Sep 17 00:00:00 2001 From: erjia Date: Thu, 10 Nov 2022 19:54:19 +0000 Subject: [PATCH 023/453] [DataPipe] Deprecating drop_empty_batches from Filter and other functional APIs (#88693) - Deprecating based on https://github.com/pytorch/data/issues/163 Corresponding PRs from TorchData: https://github.com/pytorch/data/pull/890 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88693 Approved by: https://github.com/NivekT --- torch/utils/data/datapipes/iter/selecting.py | 43 +++++--------------- torch/utils/data/datapipes/utils/common.py | 16 +------- 2 files changed, 12 insertions(+), 47 deletions(-) diff --git a/torch/utils/data/datapipes/iter/selecting.py b/torch/utils/data/datapipes/iter/selecting.py index 2ba91b36fffb..470d2952241f 100644 --- a/torch/utils/data/datapipes/iter/selecting.py +++ b/torch/utils/data/datapipes/iter/selecting.py @@ -1,11 +1,10 @@ -from typing import Callable, Iterator, Optional, TypeVar +from typing import Callable, Iterator, Tuple, TypeVar from torch.utils.data.datapipes._decorator import functional_datapipe from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper from torch.utils.data.datapipes.utils.common import ( _check_unpickable_fn, - _deprecation_warning, StreamWrapper, validate_input_col ) @@ -13,6 +12,7 @@ __all__ = ["FilterIterDataPipe", ] +T = TypeVar('T') T_co = TypeVar('T_co', covariant=True) @@ -24,7 +24,6 @@ class FilterIterDataPipe(IterDataPipe[T_co]): Args: datapipe: Iterable DataPipe being filtered filter_fn: Customized function mapping an element to a boolean. - drop_empty_batches (Deprecated): By default, drops a batch if it is empty after filtering instead of keeping an empty list input_col: Index or indices of data which ``filter_fn`` is applied, such as: - ``None`` as default to apply ``filter_fn`` to the data directly. @@ -41,15 +40,13 @@ class FilterIterDataPipe(IterDataPipe[T_co]): >>> list(filter_dp) [0, 2, 4] """ - datapipe: IterDataPipe + datapipe: IterDataPipe[T_co] filter_fn: Callable - drop_empty_batches: bool def __init__( self, - datapipe: IterDataPipe, + datapipe: IterDataPipe[T_co], filter_fn: Callable, - drop_empty_batches: Optional[bool] = None, input_col=None, ) -> None: super().__init__() @@ -58,17 +55,6 @@ def __init__( _check_unpickable_fn(filter_fn) self.filter_fn = filter_fn # type: ignore[assignment] - if drop_empty_batches is None: - drop_empty_batches = True - else: - _deprecation_warning( - type(self).__name__, - deprecation_version="1.12", - removal_version="1.14", - old_argument_name="drop_empty_batches", - ) - self.drop_empty_batches = drop_empty_batches - self.input_col = input_col validate_input_col(filter_fn, input_col) @@ -83,13 +69,13 @@ def _apply_filter_fn(self, data) -> bool: def __iter__(self) -> Iterator[T_co]: for data in self.datapipe: - filtered = self._returnIfTrue(data) - if self._isNonEmpty(filtered): + condition, filtered = self._returnIfTrue(data) + if condition: yield filtered else: StreamWrapper.close_streams(data) - def _returnIfTrue(self, data): + def _returnIfTrue(self, data: T) -> Tuple[bool, T]: condition = self._apply_filter_fn(data) if df_wrapper.is_column(condition): @@ -99,18 +85,11 @@ def _returnIfTrue(self, data): if mask: result.append(df_wrapper.get_item(data, idx)) if len(result): - return df_wrapper.concat(result) + return True, df_wrapper.concat(result) else: - return None + return False, None # type: ignore[return-value] if not isinstance(condition, bool): raise ValueError("Boolean output is required for `filter_fn` of FilterIterDataPipe, got", type(condition)) - if condition: - return data - - def _isNonEmpty(self, data): - if df_wrapper.is_dataframe(data): - return True - r = data is not None and \ - not (isinstance(data, list) and len(data) == 0 and self.drop_empty_batches) - return r + + return condition, data diff --git a/torch/utils/data/datapipes/utils/common.py b/torch/utils/data/datapipes/utils/common.py index 20c61c0ead11..75d9a5cf173c 100644 --- a/torch/utils/data/datapipes/utils/common.py +++ b/torch/utils/data/datapipes/utils/common.py @@ -227,21 +227,7 @@ def validate_pathname_binary_tuple(data: Tuple[str, IOBase]): # Deprecated function names and its corresponding DataPipe type and kwargs for the `_deprecation_warning` function -_iter_deprecated_functional_names: Dict[str, Dict] = {"open_file_by_fsspec": - {"old_class_name": "FSSpecFileOpener", - "deprecation_version": "0.4.0", - "removal_version": "0.6.0", - "old_functional_name": "open_file_by_fsspec", - "new_functional_name": "open_files_by_fsspec", - "deprecate_functional_name_only": True}, - "open_file_by_iopath": - {"old_class_name": "IoPathFileOpener", - "deprecation_version": "0.4.0", - "removal_version": "0.6.0", - "old_functional_name": "open_file_by_iopath", - "new_functional_name": "open_files_by_iopath", - "deprecate_functional_name_only": True}} - +_iter_deprecated_functional_names: Dict[str, Dict] = {} _map_deprecated_functional_names: Dict[str, Dict] = {} From 29550e2c1df4cf3ef949e8f1ef973fd5e103a2d3 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 10 Nov 2022 20:56:30 +0000 Subject: [PATCH 024/453] Revert "[Inductor] Build FX Linear + Permute Vertical Fusion in Inductor (#88566)" This reverts commit 48b58930cbfa725ac25a9303d496c76bf983574d. Reverted https://github.com/pytorch/pytorch/pull/88566 on behalf of https://github.com/huydhn due to This change breaks trunk https://hud.pytorch.org/pytorch/pytorch/commit/48b58930cbfa725ac25a9303d496c76bf983574d --- test/inductor/test_torchinductor.py | 106 --------------- torch/_inductor/config.py | 3 - torch/_inductor/overrides.py | 199 ---------------------------- 3 files changed, 308 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 064f04291a8e..db6c5dfc2bd1 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10,7 +10,6 @@ import typing import unittest import weakref -from typing import Any, Callable from unittest.mock import patch import torch @@ -19,7 +18,6 @@ from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.testing._internal.common_utils import ( IS_FBCODE, @@ -42,14 +40,6 @@ from torch._inductor import codecache, config, metrics from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing - from torch._inductor.overrides import ( - linear_permute_fusion, - linear_transpose, - permute_linear_fusion, - permute_matmul_fusion, - transpose_linear, - transpose_matmul, - ) from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed @@ -139,29 +129,6 @@ def maybe_test(*args, **kwargs): return wrap_test -PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] - - -def chain_passes(*passes: PassFunc) -> PassFunc: - def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: - for pass_ in passes: - if isinstance(module, torch.fx.GraphModule): - ShapeProp(module).propagate(*input) - module = pass_(module) - return module - - return parent_pass - - -def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: - return sum( - [ - 1 if (n.op == "call_function" and n.target == target_op) else 0 - for n in module.graph.nodes - ] - ) - - class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -1615,79 +1582,6 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) - def test_linear_permute_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - a0 = torch.nn.functional.linear(input, self.weight, self.bias) - b0 = a0.permute(0, 2, 1) - return b0 - - m, k, n = 16, 8, 4 - trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, m, k) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_linear_transpose = count_call_function(traced, linear_transpose) - self.assertEqual(num_linear, 0) - self.assertEqual(num_linear_transpose, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - def test_permute_linear_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.nn.functional.linear(input1, self.weight, self.bias) - return output - - m, k, n = 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, k, m) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_transpose_linear = count_call_function(traced, transpose_linear) - self.assertEqual(num_linear, 0) - self.assertEqual(num_transpose_linear, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - def test_permute_bmm_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, batch: int, k: int, n: int): - super().__init__() - self.other = torch.randn(batch, k, n) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.bmm(input1, self.other) - return output - - batch, m, k, n = 6, 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) - module = TestModule(batch, k, n).eval() - input = torch.randn(batch, k, m) - traced = trace_func(module, [input]) - num_bmm = count_call_function(traced, torch.bmm) - num_transpose_matmul = count_call_function(traced, transpose_matmul) - self.assertEqual(num_bmm, 0) - self.assertEqual(num_transpose_matmul, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - def test_slice1(self): def fn(a): return ( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c9b7623cf528..910e6d20b4d6 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -67,9 +67,6 @@ # How to import torchdynamo, either torchdynamo or torch.dynamo dynamo_import = inductor_import.replace("inductor", "dynamo") -# Fx-based linear/matmul/bmm + permute/transpose vertical fusion -permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" - # config specific to codegen/cpp.pp class cpp: diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 69a5bc6710f8..581e1996a436 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -19,8 +19,6 @@ from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.overrides import TorchFunctionMode -from . import config - log = logging.getLogger(__name__) @@ -315,14 +313,6 @@ def check_node_is_binary(node): def fuse_fx(gm: torch.fx.GraphModule, example_inputs): - if config.permute_fusion: - # For linear permute fusion, we need to check input info to identify - # and perform proper permutation/transpose - ShapeProp(gm).propagate(*example_inputs) - gm = linear_permute_fusion(gm) - gm = permute_linear_fusion(gm) - gm = permute_matmul_fusion(gm) - # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm @@ -418,195 +408,6 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) -class NormalizedLinearNode: - def __init__(self, node: torch.fx.Node) -> None: - assert node.op == "call_function" - assert node.target in [torch.nn.functional.linear] - self.node: torch.fx.Node = node - - def get_input(self) -> torch.fx.Node: - if len(self.node.args) > 0: - return self.node.args[0] - else: - return self.node.kwargs["input"] - - def get_weight(self) -> torch.fx.Node: - if len(self.node.args) > 1: - return self.node.args[1] - else: - return self.node.kwargs["weight"] - - def get_bias(self) -> torch.fx.Node: - if len(self.node.args) > 2: - return self.node.args[2] - else: - return self.node.kwargs["bias"] - - -class NormalizedMatmulNode: - def __init__(self, node: torch.fx.Node) -> None: - assert node.op == "call_function" - assert node.target in [torch.bmm, torch.matmul] - self.node: torch.fx.Node = node - - def get_input(self) -> torch.fx.Node: - if len(self.node.args) > 0: - return self.node.args[0] - else: - return self.node.kwargs["input"] - - def get_other(self) -> torch.fx.Node: - if len(self.node.args) > 1: - return self.node.args[1] - else: - return self.node.kwargs["other"] - - -def check_permute(node: torch.fx.Node): - ranks = len(node.meta["tensor_meta"].shape) - if len(node.args) > 3: - permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] - elif ( - "permutation" in node.kwargs - and node.kwargs["permutation"] is not None - and len(node.kwargs["permutation"]) > 2 - ): - permutation = [i % ranks for i in node.kwargs["permutation"]] - else: - return False - allowed_permutation = list(range(ranks)) - allowed_permutation[-1] = ranks - 2 - allowed_permutation[-2] = ranks - 1 - return permutation == allowed_permutation - - -def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if ( - node.op == "call_method" - and node.target == "permute" - and check_permute(node) - ): - if len(node.args) > 0: - input_node = node.args[0] - else: - input_node = node.kwargs["input"] - if ( - input_node.op == "call_function" - and input_node.target == torch.nn.functional.linear - ): - normalized = NormalizedLinearNode(input_node) - input = normalized.get_input() - weight = normalized.get_weight() - bias = normalized.get_bias() - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - linear_transpose, args=(input, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -# Y1 = X * W^T + bias -# Y2 = Y1.permute(0, 2, 1) -# ----> -# Y2 = (W * X^T + bias.unsqueeze(-1))^T -def linear_transpose( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) - - -def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if node.op == "call_function" and node.target == torch.nn.functional.linear: - if len(node.args) > 0: - input_node = node.args[0] - else: - input_node = node.kwargs["input"] - if ( - input_node.op == "call_method" - and input_node.target == "permute" - and check_permute(input_node) - ): - normalized = NormalizedLinearNode(node) - if len(input_node.args) > 0: - input = input_node.args[0] - else: - input = input_node.kwargs["input"] - weight = normalized.get_weight() - bias = normalized.get_bias() - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - transpose_linear, args=(input, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if node.op == "call_function" and ( - node.target == torch.bmm or node.target == torch.matmul - ): - normalized = NormalizedMatmulNode(node) - A = normalized.get_input() - B = normalized.get_other() - Atrans = Btrans = False - if A.op == "call_method" and A.target == "permute" and check_permute(A): - Atrans = True - if len(A.args) > 0: - A = A.args[0] - else: - A = A.kwargs["input"] - - if B.op == "call_method" and B.target == "permute" and check_permute(B): - Btrans = True - if len(B.args) > 0: - B = B.args[0] - else: - B = B.kwargs["input"] - - if Atrans or Btrans: - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - transpose_matmul, - args=(A, B, Atrans, Btrans), - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -# X1 = X.permute(0, 2, 1) -# Y1 = X1 * W1^T + bias1 -# ----> -# Y2 = X1.transpose(-1, -2) * W1^T + bias1 -def transpose_linear( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return torch.matmul(input.transpose(-1, -2), weight.t()) + bias - - -def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): - if Atrans: - A = A.transpose(-1, -2) - if Btrans: - B = B.transpose(-1, -2) - return torch.matmul(A, B) - - def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): From d5e1e2f0fcd4e0602295bfaf80b8aeb80c86a70d Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 9 Nov 2022 15:31:44 -0800 Subject: [PATCH 025/453] [xnnpack][on-device] executor class (#88778) # Executor Class Executor object used to wrap our xnn_runtime object. The ideal flow of this object looks as such: ``` executor.set_inputs(vector inputs, vector outputs) executor.forward() ``` This will likely be returned by our delegate compile and given over to execute in order to run inference using the xnn runtime ##### Executorch Considerations ``` #include #include ``` These Aten functions are included in order to use at::Tensor when setting the inputs, this will change when used for Executorch because we will be switching from at::Tensor to whatever tensor abstraction is used for ET. Seems like they have the same call for `.data_ptr()`, so realistically all logic here will be the same. ATen/Utils is used for TORCH_CHECK. We will switch to ET_CHECK_MESSAGE for executorch. Differential Revision: [D40733121](https://our.internmc.facebook.com/intern/diff/D40733121/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88778 Approved by: https://github.com/digantdesai --- .../backends/xnnpack/executor/xnn_executor.h | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h diff --git a/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h new file mode 100644 index 000000000000..f82bde231c90 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h @@ -0,0 +1,69 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNExecutor { + private: + std::unique_ptr runtime_; + std::vector input_ids_; + std::vector output_ids_; + std::vector externals_; + + public: + XNNExecutor(xnn_runtime_t runtime_ptr) + : runtime_(runtime_ptr, xnn_delete_runtime){}; + + template + bool set_inputs(std::vector& inputs, std::vector& outputs) { + externals_.clear(); + + if (inputs.size() != input_ids_.size()) { + return false; + } + + for (int i = 0; i < inputs.size(); i++) { + externals_.emplace_back(xnn_external_value{input_ids_[i], inputs[i]}); + } + + if (outputs.size() != output_ids_.size()) { + return false; + } + + for (int i = 0; i < outputs.size(); i++) { + externals_.emplace_back(xnn_external_value{output_ids_[i], outputs[i]}); + } + + return true; + }; + + bool forward() { + xnn_status status = + xnn_setup_runtime(runtime_.get(), externals_.size(), externals_.data()); + + if (status != xnn_status_success) { + return false; + } + + status = xnn_invoke_runtime(runtime_.get()); + + if (status != xnn_status_success) { + return false; + } + + return true; + }; + + friend class XNNCompiler; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch From 394b998de2228a4b4730c52b50975a2ecf756049 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 10 Nov 2022 21:04:35 +0000 Subject: [PATCH 026/453] sub setup.py install -> develop (#88507) If someone is building the project from source they're likely a contributor for which develop will be much more useful. For people that want to try the latest and greatest they can leverage the nightlies Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/88507 Approved by: https://github.com/malfet --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3a80c8083a49..bcce2997b25b 100644 --- a/README.md +++ b/README.md @@ -234,7 +234,7 @@ python tools/amd_build/build_amd.py Install PyTorch ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -python setup.py install +python setup.py develop ``` Note that if you are using [Anaconda](https://www.anaconda.com/distribution/#download-section), you may experience an error caused by the linker: @@ -251,7 +251,7 @@ This is caused by `ld` from the Conda environment shadowing the system `ld`. You ```bash export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} -MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install +MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py develop ``` **On Windows** @@ -274,7 +274,7 @@ In this mode PyTorch computations will run on your CPU, not your GPU ```cmd conda activate -python setup.py install +python setup.py develop ``` Note on OpenMP: The desired OpenMP implementation is Intel OpenMP (iomp). In order to link against iomp, you'll need to manually download the library and set up the building environment by tweaking `CMAKE_INCLUDE_PATH` and `LIB`. The instruction [here](https://github.com/pytorch/pytorch/blob/master/docs/source/notes/windows.rst#building-from-source) is an example for setting up both MKL and Intel OpenMP. Without these configurations for CMake, Microsoft Visual C OpenMP runtime (vcomp) will be used. @@ -315,7 +315,7 @@ for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\ :: [Optional] If you want to override the CUDA host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe -python setup.py install +python setup.py develop ``` From 3a4e8736ad66db2089cbcb3a24cf779aab3a7564 Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 9 Nov 2022 15:33:00 -0800 Subject: [PATCH 027/453] [xnnpack][on-device] compiler --> executor object (#88779) #### XNN Compiler Object This is purely to abstract away the subgraph rebuild from the flatbuffer object. CompileModel return an executor object which we can use to setup inputs and run forward with. #### Executorch Considerations We Include ATen/utils for torch_check, this will be changed when moving to executorch Differential Revision: [D40733163](https://our.internmc.facebook.com/intern/diff/D40733163/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88779 Approved by: https://github.com/digantdesai --- .../xnnpack/compiler/xnn_compiler.cpp | 128 ++++++++++++++++++ .../backends/xnnpack/compiler/xnn_compiler.h | 25 ++++ 2 files changed, 153 insertions(+) create mode 100644 torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp create mode 100644 torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp new file mode 100644 index 000000000000..395d59a1cf21 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -0,0 +1,128 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include + +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +XNNExecutor XNNCompiler::compileModel(std::string ser_model) { + const char* buffer_pointer = ser_model.data(); + + auto output_min = -std::numeric_limits::infinity(); + auto output_max = std::numeric_limits::infinity(); + + auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(buffer_pointer); + // initialize xnnpack + xnn_status status = xnn_initialize(/*allocator =*/nullptr); + TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack"); + + // create xnnpack subgraph + xnn_subgraph_t subgraph_ptr = nullptr; + + // TODO: @maxren serialize extern_ids in flatbuffer schema + std::unordered_set extern_ids; + for (auto input_id : *flatbuffer_graph->input_ids()) { + extern_ids.insert(input_id); + } + for (auto output_id : *flatbuffer_graph->output_ids()) { + extern_ids.insert(output_id); + } + status = xnn_create_subgraph( + /*external_value_ids=*/extern_ids.size(), + /*flags=*/0, + &subgraph_ptr); + TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph"); + + // mapping from old ids to new created value ids + // The old ids that were serialied were generated AoT, since + // we are re-defining tensor values, the defined IDs could be + // different from the ones generated AoT, as a result, we need + // a new mapping from the old ids to the newly created ones + std::unordered_map remapped_ids; + + for (auto value : *flatbuffer_graph->values()) { + switch (value->value_type()) { + case fb_xnnpack::ValueUnion::XNNTensorValue: { + auto tensor_value = value->value_as_XNNTensorValue(); + + const void* data_ptr = nullptr; + auto buffer_idx = tensor_value->constant_buffer_idx(); + if (buffer_idx != 0) { + // TODO: @maxren implement data handling + TORCH_CHECK(false, "Cosntant data handling not yet implemented") + } + std::vector dims_data; + for (auto dim : *tensor_value->dims()) { + dims_data.push_back(static_cast(dim)); + } + + uint32_t id = XNN_INVALID_VALUE_ID; + status = xnn_define_tensor_value( + /*subgraph=*/subgraph_ptr, + /*datatype=*/xnn_datatype_fp32, + /*num_dims=*/tensor_value->num_dims(), + /*dims=*/dims_data.data(), + /*data=*/data_ptr, + /*external_id=*/tensor_value->external_id(), + /*flags=*/tensor_value->flags(), + /*id_out=*/&id); + TORCH_CHECK( + status == xnn_status_success, + "Failed to define tensor values in graph") + // map serialized id to newly generated id + remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id)); + break; + } + default: { + TORCH_CHECK(false, "Unhandled value type found in deserialization"); + } + } + } + + for (auto node : *flatbuffer_graph->nodes()) { + switch (node->node_type()) { + case fb_xnnpack::NodeUnion::XNNAdd: { + auto graph_node = node->node_as_XNNAdd(); + status = xnn_define_add2( + subgraph_ptr, + output_min, + output_max, + remapped_ids.at(graph_node->input1_id()), + remapped_ids.at(graph_node->input2_id()), + remapped_ids.at(graph_node->output_id()), + graph_node->flags()); + TORCH_CHECK(status == xnn_status_success, "Failed to create add node") + break; + } + default: + TORCH_CHECK(false, "Unhandled node type found in deserialization"); + } + } + + xnn_runtime_t runtime_ptr = nullptr; + status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr); + TORCH_CHECK(xnn_status_success == status); + + XNNExecutor executor(runtime_ptr); + + for (auto old_id : *flatbuffer_graph->input_ids()) { + executor.input_ids_.push_back(remapped_ids.at(old_id)); + } + + for (auto old_id : *flatbuffer_graph->output_ids()) { + executor.output_ids_.push_back(remapped_ids.at(old_id)); + } + + return executor; +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h new file mode 100644 index 000000000000..99eecfdcaa45 --- /dev/null +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -0,0 +1,25 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace xnnpack { +namespace delegate { + +class XNNCompiler { + public: + // Takes Flatbuffer Serialized XNNPack Model and rebuilds the xnn-subgraph + // returns an executor object that holds the xnn runtime object which we + // can then use to set inputs and run inference using the xnn graph. + static XNNExecutor compileModel(std::string ser_model); +}; + +} // namespace delegate +} // namespace xnnpack +} // namespace jit +} // namespace torch From 1ae772a663f772171f0c5d6d7d311792f331206a Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Thu, 10 Nov 2022 06:56:26 -0800 Subject: [PATCH 028/453] [inductor] Remove import check for fast_flush (#88812) https://github.com/pytorch/pytorch/pull/88557/ has a guard to make sure that triton's `do_bench` includes the `fast_flush` argument. Since we've updated Triton to a sufficiently recent revision, we can remove that guard. Differential Revision: [D41185280](https://our.internmc.facebook.com/intern/diff/D41185280/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88812 Approved by: https://github.com/soumith --- torch/_inductor/triton_ops/autotune.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torch/_inductor/triton_ops/autotune.py b/torch/_inductor/triton_ops/autotune.py index 0fbdd2d4591b..808241cd02a2 100644 --- a/torch/_inductor/triton_ops/autotune.py +++ b/torch/_inductor/triton_ops/autotune.py @@ -132,14 +132,9 @@ def kernel_call(): stream=stream, ) - import inspect - from triton.testing import do_bench - if "fast_flush" in inspect.signature(do_bench).parameters.keys(): - return do_bench(kernel_call, rep=40, fast_flush=True) - else: - return do_bench(kernel_call, rep=40) + return do_bench(kernel_call, rep=40, fast_flush=True) @dynamo_utils.dynamo_timed def autotune_to_one_config(self, *args, **kwargs): From de38c8769835ab0efa055baaf7605be37e410417 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Thu, 10 Nov 2022 21:32:41 +0000 Subject: [PATCH 029/453] Use run_test in MPS (#88829) Run mps through run_test to get disable test infra, create xml files (which can then be used for flakiness detection), and reruns Also added the workflow steps for uploading the xml files Pull Request resolved: https://github.com/pytorch/pytorch/pull/88829 Approved by: https://github.com/malfet, https://github.com/huydhn --- .github/workflows/_mac-test-mps.yml | 18 ++++++++++++++++-- test/run_test.py | 17 ++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_mac-test-mps.yml b/.github/workflows/_mac-test-mps.yml index 3f7ba04f3e84..24203e005153 100644 --- a/.github/workflows/_mac-test-mps.yml +++ b/.github/workflows/_mac-test-mps.yml @@ -66,6 +66,7 @@ jobs: ${CONDA_RUN} python3 -mpip install --no-index --no-deps dist/*.whl - name: Run MPS tests + id: test env: ENV_NAME: conda-test-env-${{ github.run_id }} shell: arch -arch arm64 bash {0} @@ -74,5 +75,18 @@ jobs: set -ex # TODO(https://github.com/pytorch/pytorch/issues/79293) - ${CONDA_RUN} --cwd test python3 test_mps.py -v - ${CONDA_RUN} --cwd test python3 test_metal.py -v + ${CONDA_RUN} python3 test/run_test.py --mps --verbose + + - name: Get workflow job id + id: get-job-id + uses: ./.github/actions/get-workflow-job-id + if: always() + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload test artifacts + uses: ./.github/actions/upload-test-artifacts + if: always() && (steps.test.conclusion == 'success' || steps.test.conclusion == 'failure') + with: + use-gha: true + file-suffix: ${{ github.job }}-mps-1-1-macos-m1-12_${{ steps.get-job-id.outputs.job-id }} diff --git a/test/run_test.py b/test/run_test.py index 307b83dfdcd7..59454c6aaa3f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -101,9 +101,6 @@ def skip_test_p(name: str) -> bool: 'test_jit_simple', 'test_jit_string', 'test_kernel_launch_checks', - 'test_metal', - # Right now we have a separate CI job for running MPS - 'test_mps', 'test_nnapi', 'test_segment_reductions', 'test_static_runtime', @@ -843,6 +840,14 @@ def parse_args(): "This requires functorch to already be installed." ) ) + parser.add_argument( + "--mps", + "--mps", + action="store_true", + help=( + "If this flag is present, we will only run test_mps and test_metal" + ) + ) parser.add_argument( "-core", "--core", @@ -1052,6 +1057,12 @@ def get_selected_tests(options): # Exclude all functorch tests otherwise options.exclude.extend(FUNCTORCH_TESTS) + if options.mps: + selected_tests = ['test_mps', 'test_metal'] + else: + # Exclude all mps tests otherwise + options.exclude.extend(['test_mps', 'test_metal']) + # process reordering if options.bring_to_front: to_front = set(options.bring_to_front) From 37b468ac777ba548a2808010fd2f1b146b779fe0 Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 9 Nov 2022 15:33:57 -0800 Subject: [PATCH 030/453] [xnnpack][lite-int][on-device] rebuild serialized modules at runtime (#88780) This is the on-device runtime work. We modify the compile and execute from our hacky solution from before to what will actually be running at runtime. First we rebuild our graph from the serialized flatbuffer string. We also introduce a runtime wrapper that inherits CustomClassHolder that allows us to forward along the built xnngraph runtime to our execute function Once the subgraph object has been rebuilt by our we pass it along to the runtime wrapper for us to forward along to execute At execute we prep the input/outputs and invoke the runtime using our runtime wrapper. Finally we forward those results to our execution Differential Revision: [D39413031](https://our.internmc.facebook.com/intern/diff/D39413031/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D39413031/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/88780 Approved by: https://github.com/digantdesai --- test/jit/xnnpack/test_xnnpack_delegate.py | 2 +- .../backends/xnnpack/xnnpack_backend_lib.cpp | 68 +++++++++++++++++-- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/test/jit/xnnpack/test_xnnpack_delegate.py b/test/jit/xnnpack/test_xnnpack_delegate.py index 167a049ec0cc..997cc757e629 100644 --- a/test/jit/xnnpack/test_xnnpack_delegate.py +++ b/test/jit/xnnpack/test_xnnpack_delegate.py @@ -91,7 +91,7 @@ def forward(self, x, y): add_module, { "forward": { - "inputs" : [sample_inputs[0], sample_inputs[1]], + "inputs" : [sample_inputs[0].clone(), sample_inputs[1].clone()], "outputs": [sample_output] } } diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index d55e89ed216f..a5718820fc19 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -1,15 +1,27 @@ +#include #include #include #include #include -#include +#include +#include namespace torch { namespace jit { namespace xnnpack { namespace delegate { +class XNNModelWrapper : public CustomClassHolder { + public: + XNNExecutor executor_; + XNNModelWrapper(XNNExecutor executor) : executor_(std::move(executor)){}; + + XNNModelWrapper() = delete; + + XNNModelWrapper(const XNNModelWrapper& oldObject) = delete; +}; + class XNNPackBackend : public PyTorchBackendInterface { public: // Constructor. @@ -26,9 +38,27 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::IValue processed, c10::impl::GenericDict method_compile_spec) override { auto dict = processed.toGenericDict(); + + // Compiling and wrapping exeuction object + std::string ser_model = dict.at("ser_model").toStringRef(); + XNNExecutor executor = XNNCompiler::compileModel(ser_model); + + auto model_ptr = c10::make_intrusive(std::move(executor)); + auto runtime_handle = IValue::make_capsule(model_ptr); + auto wrapper = c10::static_intrusive_pointer_cast( + runtime_handle.toCapsule()); + + // Packing outputs into generic dict c10::Dict handles( c10::StringType::get(), c10::AnyType::get()); - handles.insert("forward", dict); + + c10::Dict ret( + c10::StringType::get(), c10::AnyType::get()); + + ret.insert("runtime", runtime_handle); + ret.insert("output_shapes", dict.at("outputs")); + + handles.insert("forward", ret); return handles; } @@ -41,9 +71,39 @@ class XNNPackBackend : public PyTorchBackendInterface { c10::impl::GenericList execute( c10::IValue handle, c10::impl::GenericList inputs) override { - auto answer = handle.toGenericDict().at("Answer"); + auto dict = handle.toGenericDict(); + auto output_shapes = dict.at("output_shapes").toList(); + + auto capsule = dict.at("runtime").toCapsule(); + auto model_wrapper = + c10::static_intrusive_pointer_cast(capsule); + + XNNExecutor& executor = model_wrapper->executor_; + + std::vector input_pointers; + for (int i = 0; i < inputs.size(); ++i) { + at::IValue val = inputs.get(i); + TORCH_CHECK(val.isTensor(), "Non-tensor inputs not supported"); + input_pointers.push_back(val.toTensor().data_ptr()); + } + + std::vector output_tensors; + std::vector output_pointers; + output_tensors.reserve(output_shapes.size()); + for (int i = 0; i < output_shapes.size(); i++) { + auto o_shape = output_shapes.get(i).toIntVector(); + auto output = at::empty(o_shape, c10::ScalarType::Float); + output_tensors.push_back(output); + output_pointers.push_back(output.data_ptr()); + } + + TORCH_CHECK( + executor.set_inputs(input_pointers, output_pointers), + "Number of inputs/outputs does not match expected number of inputs/outputs"); + TORCH_CHECK(executor.forward(), "Failed to invoke XNNPack runtime"); - return answer.toList(); + c10::List output_list(output_tensors); + return c10::impl::toList(output_list); } }; From ad2eba802c04394875af0f00b985f7f338423f1e Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 8 Nov 2022 07:59:11 -0800 Subject: [PATCH 031/453] [ao] fuser_method_mappings.py fixing public v private (#87516) Summary: made _get_valid_patterns, _DEFAULT_PATTERN_TO_FUSER_METHOD, _reverse3, _reverse2, _reverse_sequential_wrapper2, _DEFAULT_OP_LIST_TO_FUSER_METHOD, _sequential_wrapper2 private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709281](https://our.internmc.facebook.com/intern/diff/D40709281) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87516 Approved by: https://github.com/jcaip --- .../ao_migration/test_quantization.py | 2 +- test/quantization/core/test_backend_config.py | 4 +- torch/ao/quantization/__init__.py | 5 -- .../_common_operator_config_utils.py | 28 ++++---- .../backend_config/backend_config.py | 4 +- .../quantization/backend_config/executorch.py | 6 +- .../ao/quantization/fuser_method_mappings.py | 72 +++++++++---------- torch/ao/quantization/fx/README.md | 2 +- torch/quantization/fuser_method_mappings.py | 2 +- 9 files changed, 57 insertions(+), 68 deletions(-) diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 52b8f631711f..2617e7a1187d 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -225,7 +225,7 @@ def test_function_import_fuser_method_mappings(self): "get_fuser_method", ] dict_list = [ - "DEFAULT_OP_LIST_TO_FUSER_METHOD" + "_DEFAULT_OP_LIST_TO_FUSER_METHOD" ] self._test_function_import('fuser_method_mappings', function_list) self._test_dict_import('fuser_method_mappings', dict_list) diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py index e1e7067d4135..aa9de64824bc 100644 --- a/test/quantization/core/test_backend_config.py +++ b/test/quantization/core/test_backend_config.py @@ -14,7 +14,7 @@ ObservationType, ) from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 +from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2 from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer @@ -106,7 +106,7 @@ def test_dtype_config_to_dict(self): # BackendPatternConfig # ====================== - _fuser_method = reverse_sequential_wrapper2(nni.LinearReLU) + _fuser_method = _reverse_sequential_wrapper2(nni.LinearReLU) _num_tensor_args_to_observation_type = { 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index abc0bd24d97b..2e8390c1acc7 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -114,7 +114,6 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", - "get_valid_patterns", "is_activation_post_process", "load_observer_state_dict", "no_observer_set", @@ -132,12 +131,8 @@ "quantize_jit", "quantize_qat", "register_activation_post_process_hook", - "reverse2", - "reverse3", - "reverse_sequential_wrapper2", "script_qconfig", "script_qconfig_dict", - "sequential_wrapper2", "swap_module", "weight_observer_range_neg_127_to_127", ] diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index bc6f678485fb..c2f0f7227b10 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -15,9 +15,9 @@ ) from ..fake_quantize import FixedQParamsFakeQuantize from ..fuser_method_mappings import ( - reverse_sequential_wrapper2, - reverse2, - reverse3, + _reverse_sequential_wrapper2, + _reverse2, + _reverse3, fuse_conv_bn, fuse_conv_bn_relu, fuse_linear_bn, @@ -115,13 +115,13 @@ def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPattern linear_configs.append( BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU)) # linear relu, linear module + functional relu linear_configs.append( BackendPatternConfig((torch.nn.functional.relu, torch.nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU)) # 2.2 linear module + relu, fused module configs @@ -158,7 +158,7 @@ def _get_linear_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPattern linear_configs.append( BackendPatternConfig((nn.BatchNorm1d, nn.Linear)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_linear_bn)) + .set_fuser_method(_reverse2(fuse_linear_bn)) .set_fused_module(nni.LinearBn1d)) # 3.2 linear bn fused @@ -218,13 +218,13 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((torch.nn.ReLU, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # conv relu fusion, conv module + functional relu conv_configs.append( BackendPatternConfig((F.relu, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # 2.2 conv module + relu fused module configs # conv relu, fused module @@ -273,20 +273,20 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((convs.bn, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_conv_bn)) + .set_fuser_method(_reverse2(fuse_conv_bn)) .set_fused_module(convs.fused_conv_bn)) # conv + bn + relu module fusion conv_configs.append( BackendPatternConfig((nn.ReLU, (convs.bn, convs.root))) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse3(fuse_conv_bn_relu)) + .set_fuser_method(_reverse3(fuse_conv_bn_relu)) .set_fused_module(convs.fused_conv_bn_relu)) # conv + bn + relu functional fusion conv_configs.append( BackendPatternConfig((F.relu, (convs.bn, convs.root))) .set_dtype_configs(dtype_configs) # noqa: E131 .set_root_module(convs.root) - .set_fuser_method(reverse3(fuse_conv_bn_relu)) + .set_fuser_method(_reverse3(fuse_conv_bn_relu)) .set_fused_module(convs.fused_conv_bn_relu)) # TODO: we can add fusion for torch.relu as well @@ -330,7 +330,7 @@ def _get_conv_configs(dtype_configs): conv_configs.append( BackendPatternConfig((convs.bn, convs.transpose)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse2(fuse_convtranspose_bn)) + .set_fuser_method(_reverse2(fuse_convtranspose_bn)) .set_root_module(convs.transpose) .set_reference_quantized_module(convs.transpose_reference)) @@ -497,13 +497,13 @@ def _get_bn_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConf bn_configs.append( BackendPatternConfig((torch.nn.ReLU, bn)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(fused_bn)) + .set_fuser_method(_reverse_sequential_wrapper2(fused_bn)) .set_fused_module(fused_bn)) # bn module + F.relu fusion config bn_configs.append( BackendPatternConfig((torch.nn.functional.relu, bn)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(bn_to_fused_bn[bn])) + .set_fuser_method(_reverse_sequential_wrapper2(bn_to_fused_bn[bn])) .set_fused_module(fused_bn)) bn_configs.append( BackendPatternConfig(bn) diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 2f491b162404..1305c32a4ea8 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -229,7 +229,7 @@ class BackendConfig: import torch from torch.ao.quantization.backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType - from torch.ao.quantization.fuser_method_mappings import reverse_sequential_wrapper2 + from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2 weighted_int8_dtype_config = DTypeConfig( input_dtype=torch.quint8, @@ -248,7 +248,7 @@ class BackendConfig: .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ .add_dtype_config(weighted_int8_dtype_config) \ .set_fused_module(torch.nn.intrinsic.ConvReLU2d) \ - .set_fuser_method(reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) + .set_fuser_method(_reverse_sequential_wrapper2(torch.nn.intrinsic.ConvReLU2d)) backend_config = BackendConfig("my_backend") \ .set_backend_pattern_config(linear_config) \ diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py index 4c0f2a48b552..3c729327de76 100644 --- a/torch/ao/quantization/backend_config/executorch.py +++ b/torch/ao/quantization/backend_config/executorch.py @@ -7,7 +7,7 @@ import torch.nn.quantized._reference as nnqr from .backend_config import BackendConfig, BackendPatternConfig, DTypeConfig, ObservationType from ._common_operator_config_utils import _Conv2dMetadata -from ..fuser_method_mappings import reverse_sequential_wrapper2 +from ..fuser_method_mappings import _reverse_sequential_wrapper2 __all__ = [ @@ -105,13 +105,13 @@ def _get_conv_configs() -> List[BackendPatternConfig]: conv_configs.append( BackendPatternConfig((torch.nn.ReLU, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # conv module + functional relu conv_configs.append( BackendPatternConfig((F.relu, convs.root)) .set_dtype_configs(dtype_configs) # noqa: E131 - .set_fuser_method(reverse_sequential_wrapper2(convs.fused_conv_relu)) + .set_fuser_method(_reverse_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu)) # fused conv relu module conv_configs.append( diff --git a/torch/ao/quantization/fuser_method_mappings.py b/torch/ao/quantization/fuser_method_mappings.py index 2e39f87321d4..db4cc9a04d76 100644 --- a/torch/ao/quantization/fuser_method_mappings.py +++ b/torch/ao/quantization/fuser_method_mappings.py @@ -10,13 +10,7 @@ "fuse_conv_bn_relu", "fuse_linear_bn", "fuse_convtranspose_bn", - "sequential_wrapper2", "get_fuser_method", - "reverse_sequential_wrapper2", - "reverse2", - "reverse3", - "DEFAULT_PATTERN_TO_FUSER_METHOD", - "get_valid_patterns", "get_fuser_method_new", ] @@ -156,7 +150,7 @@ def fuse_convtranspose_bn(is_qat, convt, bn): else: return nn.utils.fusion.fuse_conv_bn_eval(convt, bn, transpose=True) -def sequential_wrapper2(sequential): +def _sequential_wrapper2(sequential): """ Given a sequential class for two modules, return a function that takes is_qat, and then two modules as argument, that ignores the is_qat flag and always returns the sequential that combines the two input modules @@ -165,20 +159,20 @@ def fuser_method(is_qat, m1, m2): return sequential(m1, m2) return fuser_method -DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { +_DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, (nn.Conv2d, nn.BatchNorm2d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv3d, nn.BatchNorm3d): fuse_conv_bn, (nn.Conv3d, nn.BatchNorm3d, nn.ReLU): fuse_conv_bn_relu, - (nn.Conv1d, nn.ReLU): sequential_wrapper2(nni.ConvReLU1d), - (nn.Conv2d, nn.ReLU): sequential_wrapper2(nni.ConvReLU2d), - (nn.Conv3d, nn.ReLU): sequential_wrapper2(nni.ConvReLU3d), + (nn.Conv1d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU1d), + (nn.Conv2d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU2d), + (nn.Conv3d, nn.ReLU): _sequential_wrapper2(nni.ConvReLU3d), (nn.Linear, nn.BatchNorm1d): fuse_linear_bn, - (nn.Linear, nn.ReLU): sequential_wrapper2(nni.LinearReLU), - (nn.BatchNorm2d, nn.ReLU): sequential_wrapper2(nni.BNReLU2d), - (nn.BatchNorm3d, nn.ReLU): sequential_wrapper2(nni.BNReLU3d), + (nn.Linear, nn.ReLU): _sequential_wrapper2(nni.LinearReLU), + (nn.BatchNorm2d, nn.ReLU): _sequential_wrapper2(nni.BNReLU2d), + (nn.BatchNorm3d, nn.ReLU): _sequential_wrapper2(nni.BNReLU3d), (nn.ConvTranspose1d, nn.BatchNorm1d): fuse_convtranspose_bn, (nn.ConvTranspose2d, nn.BatchNorm2d): fuse_convtranspose_bn, (nn.ConvTranspose3d, nn.BatchNorm3d): fuse_convtranspose_bn, @@ -190,13 +184,13 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None): ''' if additional_fuser_method_mapping is None: additional_fuser_method_mapping = {} - all_mappings = get_combined_dict(DEFAULT_OP_LIST_TO_FUSER_METHOD, + all_mappings = get_combined_dict(_DEFAULT_OP_LIST_TO_FUSER_METHOD, additional_fuser_method_mapping) fuser_method = all_mappings.get(op_list, None) assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list) return fuser_method -def reverse_sequential_wrapper2(sequential): +def _reverse_sequential_wrapper2(sequential): """ Given a sequential class for two modules, return a function that takes is_qat, and then two modules as argument, that ignores the is_qat flag and always returns the sequential that combines the two input modules, with @@ -206,37 +200,37 @@ def fuser_method(is_qat, m1, m2): return sequential(m2, m1) return fuser_method -def reverse2(f): +def _reverse2(f): def reversed(is_qat, x, y): return f(is_qat, y, x) return reversed -def reverse3(f): +def _reverse3(f): def reversed(is_qat, x, w): y, z = w return f(is_qat, z, y, x) return reversed -DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = { - (nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu), - (nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu), - (nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn), - (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu), - (nn.ReLU, nn.Conv1d): reverse_sequential_wrapper2(nni.ConvReLU1d), - (nn.ReLU, nn.Conv2d): reverse_sequential_wrapper2(nni.ConvReLU2d), - (nn.ReLU, nn.Conv3d): reverse_sequential_wrapper2(nni.ConvReLU3d), - (nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn), - (nn.ReLU, nn.Linear): reverse_sequential_wrapper2(nni.LinearReLU), - (nn.ReLU, nn.BatchNorm2d): reverse_sequential_wrapper2(nni.BNReLU2d), - (nn.ReLU, nn.BatchNorm3d): reverse_sequential_wrapper2(nni.BNReLU3d), - (nn.BatchNorm1d, nn.ConvTranspose1d): reverse2(fuse_convtranspose_bn), - (nn.BatchNorm2d, nn.ConvTranspose2d): reverse2(fuse_convtranspose_bn), - (nn.BatchNorm3d, nn.ConvTranspose3d): reverse2(fuse_convtranspose_bn), +_DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = { + (nn.BatchNorm1d, nn.Conv1d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): _reverse3(fuse_conv_bn_relu), + (nn.BatchNorm2d, nn.Conv2d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): _reverse3(fuse_conv_bn_relu), + (nn.BatchNorm3d, nn.Conv3d): _reverse2(fuse_conv_bn), + (nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): _reverse3(fuse_conv_bn_relu), + (nn.ReLU, nn.Conv1d): _reverse_sequential_wrapper2(nni.ConvReLU1d), + (nn.ReLU, nn.Conv2d): _reverse_sequential_wrapper2(nni.ConvReLU2d), + (nn.ReLU, nn.Conv3d): _reverse_sequential_wrapper2(nni.ConvReLU3d), + (nn.BatchNorm1d, nn.Linear): _reverse2(fuse_linear_bn), + (nn.ReLU, nn.Linear): _reverse_sequential_wrapper2(nni.LinearReLU), + (nn.ReLU, nn.BatchNorm2d): _reverse_sequential_wrapper2(nni.BNReLU2d), + (nn.ReLU, nn.BatchNorm3d): _reverse_sequential_wrapper2(nni.BNReLU3d), + (nn.BatchNorm1d, nn.ConvTranspose1d): _reverse2(fuse_convtranspose_bn), + (nn.BatchNorm2d, nn.ConvTranspose2d): _reverse2(fuse_convtranspose_bn), + (nn.BatchNorm3d, nn.ConvTranspose3d): _reverse2(fuse_convtranspose_bn), } -def get_valid_patterns(op_pattern): +def _get_valid_patterns(op_pattern): """ Returns a list of valid patterns generated from the op_pattern, since MatchAllNode can match all types of nodes, @@ -261,7 +255,7 @@ def get_valid_patterns(op_pattern): if isinstance(op_pattern, (tuple, list)): sub_combs = [] for sub_pattern in op_pattern: - sub_combs.append(get_valid_patterns(sub_pattern)) + sub_combs.append(_get_valid_patterns(sub_pattern)) result = list(itertools.product(*sub_combs)) else: result = [op_pattern, MatchAllNode] @@ -274,9 +268,9 @@ def get_fuser_method_new( Would like to implement this first and have a separate PR for deprecation """ if fuser_method_mapping is None: - fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD + fuser_method_mapping = _DEFAULT_PATTERN_TO_FUSER_METHOD - op_patterns = get_valid_patterns(op_pattern) + op_patterns = _get_valid_patterns(op_pattern) fuser_method = None for op_pattern in op_patterns: fuser_method = fuser_method_mapping.get(op_pattern, None) diff --git a/torch/ao/quantization/fx/README.md b/torch/ao/quantization/fx/README.md index cba11e9d3641..622acd30956c 100644 --- a/torch/ao/quantization/fx/README.md +++ b/torch/ao/quantization/fx/README.md @@ -81,7 +81,7 @@ What we did in this example are: ``` BackendPatternConfig((torch.nn.ReLU, torch.nn.Linear)) - .set_fuser_method(reverse_sequential_wrapper2(nni.LinearReLU)) + .set_fuser_method(_reverse_sequential_wrapper2(nni.LinearReLU)) ._set_root_node_getter(my_root_node_getter) ._set_extra_inputs_getter(my_extra_inputs_getter) ``` diff --git a/torch/quantization/fuser_method_mappings.py b/torch/quantization/fuser_method_mappings.py index 50520b3f7967..22f4e638ea69 100644 --- a/torch/quantization/fuser_method_mappings.py +++ b/torch/quantization/fuser_method_mappings.py @@ -10,6 +10,6 @@ fuse_conv_bn, fuse_conv_bn_relu, fuse_linear_bn, - DEFAULT_OP_LIST_TO_FUSER_METHOD, + _DEFAULT_OP_LIST_TO_FUSER_METHOD, get_fuser_method, ) From c1553880de95845c5a194247c683872949d66cd6 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Nov 2022 21:38:04 +0000 Subject: [PATCH 032/453] Have kernel names include fused ops (#88624) - Propagates origin fx nodes through inlining during lowering - Concatenates op names into kernel name - Adds config to cap the number of ops in the kernel name so they don't get too long Caveats: - The ordering in the name may not match the order that the ops are executed in the kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/88624 Approved by: https://github.com/anijain2305, https://github.com/jansel --- test/inductor/test_torchinductor.py | 15 +++--- torch/_inductor/codegen/cpp.py | 2 +- torch/_inductor/codegen/triton.py | 7 ++- torch/_inductor/codegen/triton_template.py | 2 +- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/config.py | 4 ++ torch/_inductor/graph.py | 59 +++++++++++----------- torch/_inductor/ir.py | 1 + torch/_inductor/utils.py | 27 ++++++++++ 9 files changed, 79 insertions(+), 42 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index db6c5dfc2bd1..229f0fa83dd4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4893,14 +4893,13 @@ def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: with V.set_graph_handler(graph), V.set_debug_handler(DebugContext()): graph.run(*(cxt.example_args)) mod = graph.compile_to_module() - i = 0 - while True: - attribute = f"kernel{i}" - if not hasattr(mod, attribute): - break - else: - kernels.append(getattr(mod, attribute)) - i = i + 1 + + for val in mod.__dict__.values(): + if isinstance( + val, torch._inductor.triton_ops.autotune.CachingAutotuner + ): + kernels.append(val) + return kernels def test_divisibile_by_16_covers_numel_args(self): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 90ae4b44d579..65a9335d6cbf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1291,7 +1291,7 @@ def codegen_define_and_call(self, wrapper): codecache_def.splice(code) codecache_def.writeline("''')") - kernel_name = wrapper.next_kernel_name() + kernel_name = "kernel_cpp_" + wrapper.next_kernel_suffix() codecache_str = codecache_def.getvalue() # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does # not use BracesBuffer, so we have no good indicator of a C++ buffer atm. diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3471d23a7213..88a0ad4977be 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -16,6 +16,7 @@ from ..ir import ReductionHint from ..utils import ( free_symbol_startswith, + get_fused_kernel_name, instance_descriptor, sympy_product, sympy_subs, @@ -1281,7 +1282,11 @@ def codegen_node_schedule(self, node_schedule, numel, reduction_numel): if src_code in wrapper.kernels: kernel_name = wrapper.kernels[src_code] else: - kernel_name = wrapper.next_kernel_name() + kernel_name = ( + "triton_" + + get_fused_kernel_name(node_schedule) + + wrapper.next_kernel_suffix() + ) wrapper.kernels[src_code] = kernel_name subs_name = kernel_name if config.triton.ordered_kernel_names else "kernel" src_code = src_code.replace("KERNEL_NAME", subs_name) diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py index 4d86feeccec8..0de771ff6574 100644 --- a/torch/_inductor/codegen/triton_template.py +++ b/torch/_inductor/codegen/triton_template.py @@ -335,7 +335,7 @@ def template_codegen(scheduler, scheduler_node, epilogue): break assert kernel_buf_replace_name is not None - kernel_name = wrapper.next_kernel_name() + kernel_name = "triton_template_" + wrapper.next_kernel_suffix() # code gen kernel wrapper.header.splice( kernel.codegen_kernel( diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 7efc1cf1aa8c..cf8fb46c84bd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -255,8 +255,8 @@ def write_get_cuda_stream(self, index): self.writeline(f"{name} = get_cuda_stream({index})") return name - def next_kernel_name(self): - return f"kernel{next(self._names_iter)}" + def next_kernel_suffix(self): + return f"{next(self._names_iter)}" def codegen_allocation(self, buffer): name = buffer.get_name() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 910e6d20b4d6..87e2793782be 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -61,6 +61,10 @@ compile_threads = min(32, os.cpu_count()) if sys.platform != "win32" else 1 +# If kernel is fused, the name is generated from the origin node op names +# for larger kernels limit this +kernel_name_max_ops = 10 + # How to import torchinductor, either torchinductor or torch.inductor inductor_import = __name__.replace(".config", "") diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index adf8ed961421..f69a891fca7b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -21,7 +21,7 @@ from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox from .lowering import lowerings, make_fallback, needs_realized_inputs from .sizevars import SizeVarAllocator -from .utils import dynamo_utils +from .utils import dynamo_utils, gather_origins from .virtualized import V log = logging.getLogger(__name__) @@ -212,34 +212,35 @@ def placeholder(self, target, args, kwargs): return tensor def call_function(self, target, args, kwargs): - if target is operator.getitem and isinstance(args[0], (list, tuple)): - return super().call_function(target, args, kwargs) - - if target not in lowerings: - if config.implicit_fallbacks: - error = ( - MissingOperatorWithDecomp - if get_decompositions([target]) - else MissingOperatorWithoutDecomp - ) - log.warning( - "Creating implicit fallback for:\n%s", - error.operator_str(target, args, kwargs), - ) - make_fallback(target) - elif get_decompositions([target]): - # There isn't a good way to dynamically patch this in - # since AOT Autograd already ran. The error message tells - # the user how to fix it. - raise MissingOperatorWithDecomp(target, args, kwargs) - else: - raise MissingOperatorWithoutDecomp(target, args, kwargs) - - try: - out = lowerings[target](*args, **kwargs) - return out - except Exception as e: - raise LoweringException(e, target, args, kwargs) from e + with ir.IRNode.current_origins(gather_origins(args, kwargs)): + if target is operator.getitem and isinstance(args[0], (list, tuple)): + return super().call_function(target, args, kwargs) + + if target not in lowerings: + if config.implicit_fallbacks: + error = ( + MissingOperatorWithDecomp + if get_decompositions([target]) + else MissingOperatorWithoutDecomp + ) + log.warning( + "Creating implicit fallback for:\n%s", + error.operator_str(target, args, kwargs), + ) + make_fallback(target) + elif get_decompositions([target]): + # There isn't a good way to dynamically patch this in + # since AOT Autograd already ran. The error message tells + # the user how to fix it. + raise MissingOperatorWithDecomp(target, args, kwargs) + else: + raise MissingOperatorWithoutDecomp(target, args, kwargs) + + try: + out = lowerings[target](*args, **kwargs) + return out + except Exception as e: + raise LoweringException(e, target, args, kwargs) from e def get_attr(self, target, args, kwargs): # this is a constant diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0353bcc8b0be..924ec7aaa7b2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3626,6 +3626,7 @@ def realize(self): data=self.data, ) self.data.name = V.graph.register_buffer(self.data) + self.data.origins = self.origins return self.data.name def realize_hint(self): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 829fbd2897d5..5bfda50dd6f7 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -178,6 +178,33 @@ def wrapper(self): return wrapper +def get_fused_kernel_name(node_schedule): + return "_".join( + ["fused"] + + [ + str(origin.name) + for origin in functools.reduce( + operator.or_, + [node.node.origins for node in node_schedule if hasattr(node, "node")], + ) + if origin.op == "call_function" + ][0 : config.kernel_name_max_ops] + ) + + +def gather_origins(args, kwargs): + import itertools + + from .ir import ComputedBuffer, IRNode + + def is_unrealized_node(n): + return isinstance(n, IRNode) and not isinstance(n, ComputedBuffer) + + kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] + arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] + return set(itertools.chain(*arg_origins, *kwarg_origins)) + + def sympy_str(expr: sympy.Expr): """ Normal sympy str is very slow, this is a lot faster. The result are From a6610faa93ac008c088bcbe26bdbb56de8275cf1 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 8 Nov 2022 07:59:11 -0800 Subject: [PATCH 033/453] [ao] qconfig_mapping_utils.py fixing public v private (#87517) Summary: made _get_object_type_qconfig, _get_module_name_regex_qconfig, _get_module_name_qconfig, _maybe_adjust_qconfig_for_module_type_or_name, _get_flattened_qconfig_dict _update_qconfig_for_qat private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709279](https://our.internmc.facebook.com/intern/diff/D40709279) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87517 Approved by: https://github.com/jcaip --- test/quantization/fx/test_quantize_fx.py | 28 ++++++++--------- torch/ao/quantization/fx/convert.py | 4 +-- torch/ao/quantization/fx/prepare.py | 8 ++--- .../quantization/fx/qconfig_mapping_utils.py | 16 +++++----- .../ao/quantization/qconfig_mapping_utils.py | 31 +++++++------------ 5 files changed, 40 insertions(+), 47 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 236a5587d859..8c75658a04e1 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -99,9 +99,9 @@ ) from torch.ao.quantization.qconfig_mapping_utils import ( - get_object_type_qconfig, - get_module_name_qconfig, - get_module_name_regex_qconfig, + _get_object_type_qconfig, + _get_module_name_qconfig, + _get_module_name_regex_qconfig, ) from torch.ao.quantization.fx.pattern_utils import ( @@ -1876,9 +1876,9 @@ def test_qconfig_mapping_set_object_type(self): qconfig_mapping.set_object_type(torch.nn.Linear, qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.Linear], qconfig3) self.assertEqual(qconfig_mapping.object_type_qconfigs[torch.nn.ReLU], qconfig2) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) - self.assertEqual(get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.Linear, None), qconfig3) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, torch.nn.ReLU, None), qconfig2) + self.assertEqual(_get_object_type_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_regex(self): qconfig1 = get_default_qconfig() @@ -1898,11 +1898,11 @@ def test_qconfig_mapping_set_module_name_regex(self): qconfig_mapping.set_module_name_regex("foo.*bar", qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*bar"], qconfig3) self.assertEqual(qconfig_mapping.module_name_regex_qconfigs["foo.*"], qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) - self.assertEqual(get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo123bar", None), qconfig3) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobar", None), qconfig3) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foobaz", None), qconfig2) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "foo", None), qconfig2) + self.assertEqual(_get_module_name_regex_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name(self): qconfig1 = get_default_qconfig() @@ -1922,9 +1922,9 @@ def test_qconfig_mapping_set_module_name(self): qconfig_mapping.set_module_name("mod1", qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod1"], qconfig3) self.assertEqual(qconfig_mapping.module_name_qconfigs["mod2"], qconfig2) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) - self.assertEqual(get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod1", None), qconfig3) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "mod2", None), qconfig2) + self.assertEqual(_get_module_name_qconfig(qconfig_mapping, "nomatch", None), None) def test_qconfig_mapping_set_module_name_object_type_order(self): qconfig1 = get_default_qconfig() diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 74eb8f1ca542..b5e9cf3bbcb3 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -24,7 +24,7 @@ ) from ..qconfig_mapping import QConfigMapping from ..qconfig_mapping_utils import ( - update_qconfig_for_qat, + _update_qconfig_for_qat, ) from .qconfig_mapping_utils import ( generate_node_name_to_qconfig, @@ -563,7 +563,7 @@ def convert( modules_copy = copy.deepcopy(modules) if model._is_qat: - update_qconfig_for_qat(qconfig_mapping, {}) + _update_qconfig_for_qat(qconfig_mapping, {}) update_qconfig_for_fusion(model, qconfig_mapping) compare_prepare_convert_qconfig_mappings(prepare_qconfig_mapping, qconfig_mapping) # type: ignore[arg-type] diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 160b80a8807f..281bd960ed7b 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -28,8 +28,8 @@ QConfigMapping, ) from ..qconfig_mapping_utils import ( - get_flattened_qconfig_dict, - update_qconfig_for_qat, + _get_flattened_qconfig_dict, + _update_qconfig_for_qat, ) from .qconfig_mapping_utils import ( generate_node_name_to_qconfig, @@ -1587,14 +1587,14 @@ def prepare( update_qconfig_for_fusion(model, qconfig_mapping) update_qconfig_for_fusion(model, _equalization_config) - flattened_qconfig_dict = get_flattened_qconfig_dict(qconfig_mapping) + flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) # TODO: support regex as well propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) if is_qat: module_to_qat_module = get_module_to_qat_module(backend_config) qat_swap_modules(model, module_to_qat_module) - update_qconfig_for_qat(qconfig_mapping, {}) + _update_qconfig_for_qat(qconfig_mapping, {}) # mapping from fully qualified module name to module instance # for example, diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 2abfaf826c42..66dffd50cd00 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -29,8 +29,8 @@ QConfigMapping, ) from ..qconfig_mapping_utils import ( - get_object_type_qconfig, - maybe_adjust_qconfig_for_module_type_or_name, + _get_object_type_qconfig, + _maybe_adjust_qconfig_for_module_type_or_name, ) @@ -121,17 +121,17 @@ def generate_node_name_to_qconfig( qconfig = None if node.op == "get_attr": module_name, _ = _parent_name(node.target) - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[module_name]), module_name, global_qconfig) qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == "call_function": # precedence: module_name_qconfig # > function_qconfig > global_qconfig # module_name takes precedence over function qconfig - function_qconfig = get_object_type_qconfig( + function_qconfig = _get_object_type_qconfig( qconfig_mapping, node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, function_qconfig) cur_object_type_idx = \ @@ -146,11 +146,11 @@ def generate_node_name_to_qconfig( # first use node.target (string) to get the qconfig # this is to support configs like # "object_type": [("reshpe", qconfig)] - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, node.target, module_path, global_qconfig) # if there is no special config for the method, we'll fall back to the # config for the module that contains the call_method node - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, qconfig) # currently call_method does not support modifying qconfig # by order, we can add this later if it is needed. @@ -160,7 +160,7 @@ def generate_node_name_to_qconfig( # if the node is an observer, just continue - don't add it to the qconfig_map if is_activation_post_process(modules[node.target]): continue - qconfig = maybe_adjust_qconfig_for_module_type_or_name( + qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] diff --git a/torch/ao/quantization/qconfig_mapping_utils.py b/torch/ao/quantization/qconfig_mapping_utils.py index 09bce4fbebb0..0109729e580c 100644 --- a/torch/ao/quantization/qconfig_mapping_utils.py +++ b/torch/ao/quantization/qconfig_mapping_utils.py @@ -1,5 +1,5 @@ import re -from typing import Dict, Callable, Union +from typing import Dict, Callable, Union, List from .utils import ( get_combined_dict, @@ -12,25 +12,18 @@ from .qconfig_mapping import QConfigMapping -# TODO: revisit this list. Many helper methods shouldn't be public -__all__ = [ - "get_flattened_qconfig_dict", - "get_object_type_qconfig", - "get_module_name_qconfig", - "get_module_name_regex_qconfig", - "maybe_adjust_qconfig_for_module_type_or_name", - "update_qconfig_for_qat", +__all__: List[str] = [ ] -def get_object_type_qconfig( +def _get_object_type_qconfig( qconfig_mapping: QConfigMapping, object_type: Union[Callable, str], fallback_qconfig: QConfigAny) -> QConfigAny: return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) -def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): +def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): if re.match(regex_pattern, module_name): # first match wins @@ -38,7 +31,7 @@ def get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig return fallback_qconfig -def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): +def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): if module_name == '': # module name qconfig not found return fallback_qconfig @@ -46,23 +39,23 @@ def get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): return qconfig_mapping.module_name_qconfigs[module_name] else: parent, _ = _parent_name(module_name) - return get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) + return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) -def maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): +def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary - module_type_qconfig = get_object_type_qconfig( + module_type_qconfig = _get_object_type_qconfig( qconfig_mapping, module_type, global_qconfig) - module_name_regex_qconfig = get_module_name_regex_qconfig( + module_name_regex_qconfig = _get_module_name_regex_qconfig( qconfig_mapping, module_name, module_type_qconfig) - module_name_qconfig = get_module_name_qconfig( + module_name_qconfig = _get_module_name_qconfig( qconfig_mapping, module_name, module_name_regex_qconfig) return module_name_qconfig -def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: +def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: """ flatten the global, object_type and module_name qconfig to the same qconfig_dict so that it can be used by propagate_qconfig_ function. @@ -94,7 +87,7 @@ def get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Ca return flattened -def update_qconfig_for_qat( +def _update_qconfig_for_qat( qconfig_mapping: QConfigMapping, additional_qat_module_mapping: Dict[Callable, Callable]): """ From 20ae19aa1dd307f9bdde0754c327ffb69eef13c0 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:31 -0800 Subject: [PATCH 034/453] [ONNX] Improve diagnostic message formatting (#87830) * Reflect required arguments in method signature for each diagnostic rule. Previous design accepts arbitrary sized tuple which is hard to use and prone to error. ![image](https://user-images.githubusercontent.com/9376104/200381982-d1e905f0-a159-4ef5-8d2e-070524e8f5bf.png) * Removed `DiagnosticTool` to keep things compact. * Removed specifying supported rule set for tool(context) and checking if rule of reported diagnostic falls inside the set, to keep things compact. * Initial overview markdown file. * Change `full_description` definition. Now `text` field should not be empty. And its markdown should be stored in `markdown` field. * Change `message_default_template` to allow only named fields (excluding numeric fields). `field_name` provides clarity on what argument is expected. * Added `diagnose` api to `torch.onnx._internal.diagnostics`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87830 Approved by: https://github.com/abock --- test/onnx/internal/test_diagnostics.py | 38 +--- tools/onnx/gen_diagnostics.py | 83 +++++++-- tools/onnx/templates/rules.py.in | 6 +- .../jit/passes/onnx/shape_type_inference.cpp | 2 +- torch/csrc/onnx/diagnostics/diagnostics.h | 18 +- torch/onnx/_internal/diagnostics/OVERVIEW.md | 83 +++++++++ torch/onnx/_internal/diagnostics/__init__.py | 4 +- .../onnx/_internal/diagnostics/_diagnostic.py | 53 +++--- torch/onnx/_internal/diagnostics/_rules.py | 88 +++++++-- .../_internal/diagnostics/infra/__init__.py | 2 - .../_internal/diagnostics/infra/_infra.py | 174 ++++++++++-------- .../_internal/diagnostics/infra/engine.py | 51 ++--- torch/onnx/_internal/diagnostics/rules.yaml | 21 ++- torch/onnx/errors.py | 44 ++--- 14 files changed, 414 insertions(+), 253 deletions(-) create mode 100644 torch/onnx/_internal/diagnostics/OVERVIEW.md diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index fbe79216d087..fbd888329a50 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -176,7 +176,7 @@ def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( sample_rule, sample_level, ): - diagnostics.context.diagnose(sample_rule, sample_level, ("foo",)) + diagnostics.context.diagnose(sample_rule, sample_level) @dataclasses.dataclass @@ -196,31 +196,17 @@ class TestDiagnosticsInfra(common_utils.TestCase): def setUp(self): self.engine = infra.DiagnosticEngine() self.rules = _RuleCollectionForTest() - self.diagnostic_tool = infra.DiagnosticTool("test_tool", "1.0.0", self.rules) with contextlib.ExitStack() as stack: self.context = stack.enter_context( - self.engine.create_diagnostic_context(self.diagnostic_tool) + self.engine.create_diagnostic_context("test", "1.0.0") ) self.addCleanup(stack.pop_all().close) return super().setUp() - def test_diagnose_raises_value_error_when_rule_not_supported(self): - rule_id = "0" - rule_name = "nonexistent-rule" - with self.assertRaisesRegex( - ValueError, - f"Rule '{rule_id}:{rule_name}' is not supported by this tool " - f"'{self.diagnostic_tool.name} {self.diagnostic_tool.version}'.", - ): - self.context.diagnose( - infra.Rule(id=rule_id, name=rule_name, message_default_template=""), - infra.Level.WARNING, - ) - def test_diagnostics_engine_records_diagnosis_reported_in_nested_contexts( self, ): - with self.engine.create_diagnostic_context(self.diagnostic_tool) as context: + with self.engine.create_diagnostic_context("inner_test", "1.0.1") as context: context.diagnose(self.rules.rule_without_message_args, infra.Level.WARNING) sarif_log = self.engine.sarif_log() self.assertEqual(len(sarif_log.runs), 2) @@ -250,9 +236,7 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): ) with self.engine.create_diagnostic_context( - tool=infra.DiagnosticTool( - name="custom_tool", version="1.0", rules=custom_rules - ) + "custom_rules", "1.0" ) as diagnostic_context: with assert_all_diagnostics( self, @@ -269,20 +253,6 @@ def test_diagnostics_engine_records_diagnosis_with_custom_rules(self): custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined] ) - def test_diagnostic_tool_raises_type_error_when_diagnostic_type_is_invalid( - self, - ): - with self.assertRaisesRegex( - TypeError, - "Expected diagnostic_type to be a subclass of Diagnostic, but got", - ): - _ = infra.DiagnosticTool( - "custom_tool", - "1.0", - self.rules, - diagnostic_type=int, - ) - if __name__ == "__main__": common_utils.run_tests() diff --git a/tools/onnx/gen_diagnostics.py b/tools/onnx/gen_diagnostics.py index ba6fd43bee29..92960024e048 100644 --- a/tools/onnx/gen_diagnostics.py +++ b/tools/onnx/gen_diagnostics.py @@ -14,6 +14,7 @@ import argparse import os +import string import subprocess import textwrap from typing import Any, Mapping, Sequence @@ -30,19 +31,37 @@ Diagnostic rules for PyTorch ONNX export. """ -_PY_RULE_TEMPLATE = """\ -{0}: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif(**{1}), +_PY_RULE_CLASS_COMMENT = """\ +GENERATED CODE - DO NOT EDIT DIRECTLY +The purpose of generating a class for each rule is to override the `format_message` +method to provide more details in the signature about the format arguments. +""" + +_PY_RULE_CLASS_TEMPLATE = """\ +class _{pascal_case_name}(infra.Rule): + \"\"\"{short_description}\"\"\" + def format_message(self, {message_arguments}) -> str: # type: ignore[override] + \"\"\"Returns the formatted default message of this Rule. + + Message template: {message_template} + \"\"\" + return self.message_default_template.format({message_arguments_assigned}) + +""" + +_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\ +{snake_case_name}: _{pascal_case_name} = dataclasses.field( + default=_{pascal_case_name}.from_sarif(**{sarif_dict}), init=False, ) -\"\"\"{2}\"\"\" +\"\"\"{short_description}\"\"\" """ _CPP_RULE_TEMPLATE = """\ /** - * @brief {1} + * @brief {short_description} */ -{0}, +{name}, """ _RuleType = Mapping[str, Any] @@ -56,24 +75,62 @@ def _kebab_case_to_pascal_case(name: str) -> str: return "".join(word.capitalize() for word in name.split("-")) -def _format_rule_for_python(rule: _RuleType) -> str: - name = _kebab_case_to_snake_case(rule["name"]) +def _format_rule_for_python_class(rule: _RuleType) -> str: + pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) short_description = rule["short_description"]["text"] + message_template = rule["message_strings"]["default"]["text"] + field_names = [ + field_name + for _, field_name, _, _ in string.Formatter().parse(message_template) + if field_name is not None + ] + for field_name in field_names: + assert isinstance( + field_name, str + ), f"Unexpected field type {type(field_name)} from {field_name}. " + "Field name must be string.\nFull message template: {message_template}" + assert ( + not field_name.isnumeric() + ), f"Unexpected numeric field name {field_name}. " + "Only keyword name formatting is supported.\nFull message template: {message_template}" + message_arguments = ", ".join(field_names) + message_arguments_assigned = ", ".join( + [f"{field_name}={field_name}" for field_name in field_names] + ) + return _PY_RULE_CLASS_TEMPLATE.format( + pascal_case_name=pascal_case_name, + short_description=short_description, + message_template=repr(message_template), + message_arguments=message_arguments, + message_arguments_assigned=message_arguments_assigned, + ) + - return _PY_RULE_TEMPLATE.format(name, rule, short_description) +def _format_rule_for_python_field(rule: _RuleType) -> str: + snake_case_name = _kebab_case_to_snake_case(rule["name"]) + pascal_case_name = _kebab_case_to_pascal_case(rule["name"]) + short_description = rule["short_description"]["text"] + + return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format( + snake_case_name=snake_case_name, + pascal_case_name=pascal_case_name, + sarif_dict=rule, + short_description=short_description, + ) def _format_rule_for_cpp(rule: _RuleType) -> str: name = f"k{_kebab_case_to_pascal_case(rule['name'])}" short_description = rule["short_description"]["text"] - return _CPP_RULE_TEMPLATE.format(name, short_description) + return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description) def gen_diagnostics_python( rules: Sequence[_RuleType], out_py_dir: str, template_dir: str ) -> None: - rule_lines = [_format_rule_for_python(rule) for rule in rules] + rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules] + rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules] fm = torchgen_utils.FileManager( install_dir=out_py_dir, template_dir=template_dir, dry_run=False @@ -83,7 +140,9 @@ def gen_diagnostics_python( "rules.py.in", lambda: { "generated_comment": _RULES_GENERATED_COMMENT, - "rules": textwrap.indent("\n".join(rule_lines), " " * 4), + "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT, + "rule_classes": "\n".join(rule_class_lines), + "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4), }, ) _lint_file(os.path.join(out_py_dir, "_rules.py")) diff --git a/tools/onnx/templates/rules.py.in b/tools/onnx/templates/rules.py.in index e29c202dc6a7..2137119d14c2 100644 --- a/tools/onnx/templates/rules.py.in +++ b/tools/onnx/templates/rules.py.in @@ -7,10 +7,14 @@ import dataclasses # flake8: noqa from torch.onnx._internal.diagnostics import infra +""" +${generated_rule_class_comment} +""" + +${rule_classes} @dataclasses.dataclass class _POERules(infra.RuleCollection): ${rules} - rules = _POERules() diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index d2873ddf464c..f646fe77e07a 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -1897,7 +1897,7 @@ void UpdateReliable( diagnostics::Diagnose( diagnostics::Rule::kNodeMissingOnnxShapeInference, diagnostics::Level::kWarning, - {output->node()->kind().toDisplayString()}); + {{"op_name", output->node()->kind().toDisplayString()}}); } auto reliable = false; if (inferred) { diff --git a/torch/csrc/onnx/diagnostics/diagnostics.h b/torch/csrc/onnx/diagnostics/diagnostics.h index 65f59d4f1f9f..65ca626b843b 100644 --- a/torch/csrc/onnx/diagnostics/diagnostics.h +++ b/torch/csrc/onnx/diagnostics/diagnostics.h @@ -34,14 +34,6 @@ inline py::object _PyDiagnostics() { return py::module::import("torch.onnx._internal.diagnostics"); } -inline py::object _PyEngine() { - return _PyDiagnostics().attr("engine"); -} - -inline py::object _PyContext() { - return _PyDiagnostics().attr("context"); -} - inline py::object _PyRule(Rule rule) { return _PyDiagnostics().attr("rules").attr( kPyRuleNames[static_cast(rule)]); @@ -55,15 +47,15 @@ inline py::object _PyLevel(Level level) { inline void Diagnose( Rule rule, Level level, - std::vector messageArgs = {}) { + std::unordered_map messageArgs = {}) { py::object py_rule = _PyRule(rule); py::object py_level = _PyLevel(level); - py::object py_context = _PyContext(); - py::dict kwargs = py::dict(); // TODO: statically check that size of messageArgs matches with rule. - kwargs["message_args"] = messageArgs; - py_context.attr("diagnose")(py_rule, py_level, **kwargs); + py::object py_message = + py_rule.attr("format_message")(**py::cast(messageArgs)); + + _PyDiagnostics().attr("diagnose")(py_rule, py_level, py_message); } } // namespace diagnostics diff --git a/torch/onnx/_internal/diagnostics/OVERVIEW.md b/torch/onnx/_internal/diagnostics/OVERVIEW.md new file mode 100644 index 000000000000..0dffb0d20b45 --- /dev/null +++ b/torch/onnx/_internal/diagnostics/OVERVIEW.md @@ -0,0 +1,83 @@ +# PyTorch ONNX Exporter Diagnostics + +NOTE: This feature is underdevelopment and is subject to change. + +Summary of source tree: +- [OVERVIEW.md](OVERVIEW.md): Technical overview of the diagnostics infrastructure. +- [generated/](generated): Generated diagnostics rules from [rules.yaml](rules.yaml). +- [infra/](infra): Generic diagnostics infrastructure built on top of [SARIF](https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html). +- [_diagnostic.py](diagnostic.py): Python API for diagnostics. +- [rules.yaml](rules.yaml): Single source of truth for diagnostics rules. Used to generate C++ and Python interfaces, and documentation pages. +- [tools/onnx/](/tools/onnx): Scripts for generating source code and documentation for diagnostics rules. + +## Table of Contents + + + +- [Introduction](#introduction) + - [Motivation](#motivation) + - [Diagnostics as documentation](#diagnostics-as-documentation) + - [Different context and background](#different-context-and-background) + - [Machine parsable](#machine-parsable) + - [Design](#design) + - [Adopting SARIF for diagnostic structure](#adopting-sarif-for-diagnostic-structure) + - [Single source of truth for diagnostic rules](#single-source-of-truth-for-diagnostic-rules) +- [Internal Details](#internal-details) + - [Rules](#rules) + - [Infrastructure](#infrastructure) + - [Documentation](#documentation) +- [Usage](#usage) + - [Python](#python) + - [C++](#c) + + + +# Introduction + +The goal is to improve the diagnostics to help users debug and improve their model export. +* The diagnostics are emitted in machine parsable [Static Analysis Results Interchange Format (SARIF)](https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html). +* A new clearer, structured way to add new and keep track of diagnostic rules. +* Serve as foundation for more future improvements consuming the diagnostics. + +## Motivation ## + +The previous diagnostics were only scattered warning or error messages. They are not structured and are not machine parsable. This makes it hard to consume the diagnostics in a systematic way. This is a blocker for improving the diagnostics and for building tools on top of them. The diagnostics are also not very helpful for users to debug their model export. They are often not actionable and do not provide enough context to help users debug their model export. Some unsupported patterns or code are documented in the [PyTorch ONNX doc](https://pytorch.org/docs/stable/onnx.html#limitations). The information is scattered, hard to find, and hard to maintain and thus often outdated. The new diagnostics system aim to address these issues with the following key properties. + +### Diagnostics as documentation + +The diagnostics are the source of truth for the documentation of export issues. Documentations are no longer separated. Any changes are directly reflected as the diagnostic progress. The diagnostic itself serves as the means to track the history and progress of any specific issue. Linking the source code, the issues, the PRs, the fix, the docs, etc together through this single entity. + +### Different context and background + +There are two very different audiences: users and converter developers. The users care more about where the error is coming from the model, and how to resolve it for a successful export. They are not experts in the internal of exporter or JIT. The converter developers on the other hand need more info of the internal state of the converter to debug the issue. The diagnostics should be actionable for users and provide enough context for converter developers to debug and fix the issues. It should display the right information and context to the right audience, in a clean and concise way. + +### Machine parsable + +The diagnostics are emitted in machine parsable SARIF format. This opens the door for the diagnostics to be consumed by tools and systems. Future applications like auto fixing, formatted displaying, auto reporting, etc are possible. + +## Design ## + +### Adopting SARIF for diagnostic structure + +The diagnostics are emitted in machine parsable [Static Analysis Results Interchange Format (SARIF)](https://docs.oasis-open.org/sarif/sarif/v2.1.0/sarif-v2.1.0.html), with [python classes for SARIF object model](https://github.com/microsoft/sarif-python-om) as starting point. This is a standard format for the output of static analysis tools, and can be consumed by the SARIF Viewer, [VSCode extension](https://marketplace.visualstudio.com/items?itemName=MS-SarifVSCode.sarif-viewer) for example. The diagnostics are also emitted in a human readable format for users to read. The human readable format is a subset of the SARIF format. The human readable format is emitted to stdout and the SARIF format is emitted to a file. [Authoring rule metadata and result messages](https://github.com/microsoft/sarif-tutorials/blob/main/docs/Authoring-rule-metadata-and-result-messages.md) is a good starting point for understanding the SARIF format. + +### Single source of truth for diagnostic rules + +The diagnostic rules are defined in a single location, in [SARIF `reportingDescriptor` format](https://docs.oasis-open.org/sarif/sarif/v2.1.0/os/sarif-v2.1.0-os.html#_Toc34317836). From it, respective C++, python and documentation files are generated during build. With a bit of redundancy, this approach makes all the rules statically accessible under both Python and C++, while maintaining a single source of truth. + +# Internal Details + +## Rules ## + + +## Infrastructure ## + + +## Documentation ## + + +# Usage + +## Python ## + +## C++ ## diff --git a/torch/onnx/_internal/diagnostics/__init__.py b/torch/onnx/_internal/diagnostics/__init__.py index 822e6a3482e6..304978dbe22d 100644 --- a/torch/onnx/_internal/diagnostics/__init__.py +++ b/torch/onnx/_internal/diagnostics/__init__.py @@ -1,19 +1,19 @@ from ._diagnostic import ( context, create_export_diagnostic_context, + diagnose, engine, ExportDiagnostic, - ExportDiagnosticTool, ) from ._rules import rules from .infra import levels __all__ = [ "ExportDiagnostic", - "ExportDiagnosticTool", "rules", "levels", "engine", "context", "create_export_diagnostic_context", + "diagnose", ] diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 6b1c1216cd14..ae6615e831cb 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -1,10 +1,10 @@ """Diagnostic components for PyTorch ONNX export.""" import contextlib -from typing import Any, Optional, Tuple, TypeVar +from typing import Optional, TypeVar import torch -from torch.onnx._internal.diagnostics import _rules, infra +from torch.onnx._internal.diagnostics import infra # This is a workaround for mypy not supporting Self from typing_extensions. _ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic") @@ -20,12 +20,10 @@ class ExportDiagnostic(infra.Diagnostic): def __init__( self, - rule: infra.Rule, - level: infra.Level, - message_args: Optional[Tuple[Any, ...]], + *args, **kwargs, ) -> None: - super().__init__(rule, level, message_args, **kwargs) + super().__init__(*args, **kwargs) def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic: # TODO: Implement this. @@ -56,22 +54,6 @@ def with_export_source_location( return self -class ExportDiagnosticTool(infra.DiagnosticTool): - """Base class for all export diagnostic tools. - - This class is used to represent all export diagnostic tools. It is a subclass - of infra.DiagnosticTool. - """ - - def __init__(self) -> None: - super().__init__( - name="torch.onnx.export", - version=torch.__version__, - rules=_rules.rules, - diagnostic_type=ExportDiagnostic, - ) - - class ExportDiagnosticEngine(infra.DiagnosticEngine): """PyTorch ONNX Export diagnostic engine. @@ -93,7 +75,10 @@ class ExportDiagnosticEngine(infra.DiagnosticEngine): def __init__(self) -> None: super().__init__() self._background_context = infra.DiagnosticContext( - ExportDiagnosticTool(), options=None + name="torch.onnx", + version=torch.__version__, + diagnostic_type=ExportDiagnostic, + options=None, ) @property @@ -102,7 +87,7 @@ def background_context(self) -> infra.DiagnosticContext: def clear(self): super().clear() - self._background_context._diagnostics.clear() + self._background_context.diagnostics.clear() def sarif_log(self): log = super().sarif_log() @@ -122,8 +107,26 @@ def create_export_diagnostic_context(): export internals via global variable. See `ExportDiagnosticEngine` for more details. """ global context - context = engine.create_diagnostic_context(ExportDiagnosticTool()) + context = engine.create_diagnostic_context( + "torch.onnx.export", torch.__version__, diagnostic_type=ExportDiagnostic + ) try: yield context finally: context = engine.background_context + + +def diagnose( + rule: infra.Rule, + level: infra.Level, + message: Optional[str] = None, + **kwargs, +) -> ExportDiagnostic: + """Creates a diagnostic and record it in the global diagnostic context. + + This is a wrapper around `context.record` that uses the global diagnostic context. + """ + global context + diagnostic = ExportDiagnostic(rule, level, message, **kwargs) + context.add_diagnostic(diagnostic) + return diagnostic diff --git a/torch/onnx/_internal/diagnostics/_rules.py b/torch/onnx/_internal/diagnostics/_rules.py index 430fe3ea4fe5..f9948388d5da 100644 --- a/torch/onnx/_internal/diagnostics/_rules.py +++ b/torch/onnx/_internal/diagnostics/_rules.py @@ -11,22 +11,78 @@ # flake8: noqa from torch.onnx._internal.diagnostics import infra +""" +GENERATED CODE - DO NOT EDIT DIRECTLY +The purpose of generating a class for each rule is to override the `format_message` +method to provide more details in the signature about the format arguments. +""" + + +class _NodeMissingOnnxShapeInference(infra.Rule): + """Node is missing ONNX shape inference.""" + + def format_message(self, op_name) -> str: # type: ignore[override] + """Returns the formatted default message of this Rule. + + Message template: 'The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function.' + """ + return self.message_default_template.format(op_name=op_name) + + +class _MissingCustomSymbolicFunction(infra.Rule): + """Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.""" + + def format_message(self, op_name) -> str: # type: ignore[override] + """Returns the formatted default message of this Rule. + + Message template: 'ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version.' + """ + return self.message_default_template.format(op_name=op_name) + + +class _MissingStandardSymbolicFunction(infra.Rule): + """Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.""" + + def format_message(self, op_name, opset_version, issue_url) -> str: # type: ignore[override] + """Returns the formatted default message of this Rule. + + Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}." + """ + return self.message_default_template.format( + op_name=op_name, opset_version=opset_version, issue_url=issue_url + ) + + +class _OperatorSupportedInNewerOpsetVersion(infra.Rule): + """Operator is supported in newer opset version.""" + + def format_message(self, op_name, opset_version, supported_opset_version) -> str: # type: ignore[override] + """Returns the formatted default message of this Rule. + + Message template: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version." + """ + return self.message_default_template.format( + op_name=op_name, + opset_version=opset_version, + supported_opset_version=supported_opset_version, + ) + @dataclasses.dataclass class _POERules(infra.RuleCollection): - node_missing_onnx_shape_inference: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif( + node_missing_onnx_shape_inference: _NodeMissingOnnxShapeInference = dataclasses.field( + default=_NodeMissingOnnxShapeInference.from_sarif( **{ "id": "POE0001", "name": "node-missing-onnx-shape-inference", "short_description": {"text": "Node is missing ONNX shape inference."}, "full_description": { - "text": "", + "text": "Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec.", "markdown": "Node is missing ONNX shape inference.\nThis usually happens when the node is not valid under standard ONNX operator spec.\n", }, "message_strings": { "default": { - "text": "The shape inference of {0} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function." + "text": "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function." } }, "help_uri": None, @@ -37,8 +93,8 @@ class _POERules(infra.RuleCollection): ) """Node is missing ONNX shape inference.""" - missing_custom_symbolic_function: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif( + missing_custom_symbolic_function: _MissingCustomSymbolicFunction = dataclasses.field( + default=_MissingCustomSymbolicFunction.from_sarif( **{ "id": "POE0002", "name": "missing-custom-symbolic-function", @@ -46,12 +102,12 @@ class _POERules(infra.RuleCollection): "text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX." }, "full_description": { - "text": "", + "text": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.", "markdown": "Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.\n", }, "message_strings": { "default": { - "text": "ONNX export failed on an operator with unrecognized namespace {0}. If you are trying to export a custom operator, make sure you registered it with the right domain and version." + "text": "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version." } }, "help_uri": None, @@ -62,8 +118,8 @@ class _POERules(infra.RuleCollection): ) """Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX.""" - missing_standard_symbolic_function: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif( + missing_standard_symbolic_function: _MissingStandardSymbolicFunction = dataclasses.field( + default=_MissingStandardSymbolicFunction.from_sarif( **{ "id": "POE0003", "name": "missing-standard-symbolic-function", @@ -71,12 +127,12 @@ class _POERules(infra.RuleCollection): "text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX." }, "full_description": { - "text": "", + "text": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.", "markdown": "Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.\n", }, "message_strings": { "default": { - "text": "Exporting the operator '{0}' to ONNX opset version {1} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {2}." + "text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}." } }, "help_uri": None, @@ -87,8 +143,8 @@ class _POERules(infra.RuleCollection): ) """Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX.""" - operator_supported_in_newer_opset_version: infra.Rule = dataclasses.field( - default=infra.Rule.from_sarif( + operator_supported_in_newer_opset_version: _OperatorSupportedInNewerOpsetVersion = dataclasses.field( + default=_OperatorSupportedInNewerOpsetVersion.from_sarif( **{ "id": "POE0004", "name": "operator-supported-in-newer-opset-version", @@ -96,12 +152,12 @@ class _POERules(infra.RuleCollection): "text": "Operator is supported in newer opset version." }, "full_description": { - "text": "", + "text": "Operator is supported in newer opset version.", "markdown": "Operator is supported in newer opset version.\n\nExample:\n```python\ntorch.onnx.export(model, args, ..., opset_version=9)\n```\n", }, "message_strings": { "default": { - "text": "Exporting the operator '{0}' to ONNX opset version {1} is not supported. Support for this operator was added in version {2}, try exporting with this version." + "text": "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. Support for this operator was added in version {supported_opset_version}, try exporting with this version." } }, "help_uri": None, diff --git a/torch/onnx/_internal/diagnostics/infra/__init__.py b/torch/onnx/_internal/diagnostics/infra/__init__.py index 6a51350871f9..ac9e6e99a974 100644 --- a/torch/onnx/_internal/diagnostics/infra/__init__.py +++ b/torch/onnx/_internal/diagnostics/infra/__init__.py @@ -2,7 +2,6 @@ Diagnostic, DiagnosticContext, DiagnosticOptions, - DiagnosticTool, Level, levels, Location, @@ -17,7 +16,6 @@ "DiagnosticContext", "DiagnosticEngine", "DiagnosticOptions", - "DiagnosticTool", "Level", "levels", "Location", diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index 14be9d205dbb..6966ccccbb26 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -4,7 +4,7 @@ import dataclasses import enum -from typing import Any, FrozenSet, List, Optional, Sequence, Set, Tuple, Type, TypeVar +from typing import FrozenSet, List, Optional, Sequence, Tuple, Type, TypeVar from torch.onnx._internal.diagnostics.infra import formatter, sarif @@ -32,6 +32,21 @@ class Tag(enum.Enum): pass +class PatchedPropertyBag(sarif.PropertyBag): + """Key/value pairs that provide additional information about the object. + + The definition of PropertyBag via SARIF spec is "A property bag is an object (§3.6) + containing an unordered set of properties with arbitrary names." However it is not + reflected in the json file, and therefore not captured by the python representation. + This patch adds additional **kwargs to the `__init__` method to allow recording + arbitrary key/value pairs. + """ + + def __init__(self, tags: Optional[List[str]] = None, **kwargs): + super().__init__(tags=tags) + self.__dict__.update(kwargs) + + @dataclasses.dataclass(frozen=True) class Rule: id: str @@ -39,22 +54,16 @@ class Rule: message_default_template: str short_description: Optional[str] = None full_description: Optional[str] = None + full_description_markdown: Optional[str] = None help_uri: Optional[str] = None @classmethod - def from_sarif(cls, **kwargs) -> Rule: + def from_sarif(cls, **kwargs): """Returns a rule from the SARIF reporting descriptor.""" - short_description = ( - kwargs["short_description"]["text"] - if "short_description" in kwargs - else None - ) - full_description = ( - kwargs["full_description"]["markdown"] - if "full_description" in kwargs - else None - ) - help_uri = kwargs["help_uri"] if "help_uri" in kwargs else None + short_description = kwargs.get("short_description", {}).get("text") + full_description = kwargs.get("full_description", {}).get("text") + full_description_markdown = kwargs.get("full_description", {}).get("markdown") + help_uri = kwargs.get("help_uri") rule = cls( id=kwargs["id"], @@ -62,6 +71,7 @@ def from_sarif(cls, **kwargs) -> Rule: message_default_template=kwargs["message_strings"]["default"]["text"], short_description=short_description, full_description=full_description, + full_description_markdown=full_description_markdown, help_uri=help_uri, ) return rule @@ -74,7 +84,9 @@ def sarif(self) -> sarif.ReportingDescriptor: else None ) full_description = ( - sarif.MultiformatMessageString(text="", markdown=self.full_description) + sarif.MultiformatMessageString( + text=self.full_description, markdown=self.full_description_markdown + ) if self.full_description is not None else None ) @@ -86,6 +98,15 @@ def sarif(self) -> sarif.ReportingDescriptor: help_uri=self.help_uri, ) + def format_message(self, *args, **kwargs) -> str: + """Returns the formatted default message of this Rule. + + This method should be overridden (with code generation) by subclasses to reflect + the exact arguments needed by the message template. This is a helper method to + create the default message for a diagnostic. + """ + return self.message_default_template.format(*args, **kwargs) + @dataclasses.dataclass class Location: @@ -147,21 +168,40 @@ def add_frame( _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic") +@dataclasses.dataclass +class Graph: + """A graph of diagnostics. + + This class stores the string representation of a model graph. + The `nodes` and `edges` fields are unused in the current implementation. + """ + + graph_str: str + name: str + description: Optional[str] = None + + def sarif(self) -> sarif.Graph: + """Returns the SARIF representation of this graph.""" + return sarif.Graph( + description=sarif.Message(text=self.graph_str), + properties=PatchedPropertyBag(name=self.name, description=self.description), + ) + + @dataclasses.dataclass class Diagnostic: rule: Rule level: Level - message_args: Optional[Tuple[Any, ...]] + message: Optional[str] = None locations: List[Location] = dataclasses.field(default_factory=list) stacks: List[Stack] = dataclasses.field(default_factory=list) + graphs: List[Graph] = dataclasses.field(default_factory=list) additional_message: Optional[str] = None tags: List[Tag] = dataclasses.field(default_factory=list) def sarif(self) -> sarif.Result: """Returns the SARIF Result representation of this diagnostic.""" - if self.message_args is None: - self.message_args = tuple() - message = self.rule.message_default_template.format(*self.message_args) + message = self.message or self.rule.message_default_template if self.additional_message is not None: message = f"{message}\n{self.additional_message}" sarif_result = sarif.Result( @@ -171,6 +211,7 @@ def sarif(self) -> sarif.Result: ) sarif_result.locations = [location.sarif() for location in self.locations] sarif_result.stacks = [stack.sarif() for stack in self.stacks] + sarif_result.graphs = [graph.sarif() for graph in self.graphs] sarif_result.properties = sarif.PropertyBag( tags=[tag.value for tag in self.tags] ) @@ -186,6 +227,11 @@ def with_stack(self: _Diagnostic, stack: Stack) -> _Diagnostic: self.stacks.append(stack) return self + def with_graph(self: _Diagnostic, graph: Graph) -> _Diagnostic: + """Adds a graph to the diagnostic.""" + self.graphs.append(graph) + return self + def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic: """Adds an additional message to the diagnostic.""" if self.additional_message is None: @@ -231,61 +277,6 @@ def custom_collection_from_list( )() -@dataclasses.dataclass(frozen=True) -class DiagnosticTool: - name: str - version: str - rules: RuleCollection - diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic) - _triggered_rules: Set[Rule] = dataclasses.field(init=False, default_factory=set) - - def __post_init__(self) -> None: - if not issubclass(self.diagnostic_type, Diagnostic): - raise TypeError( - "Expected diagnostic_type to be a subclass of Diagnostic, " - f"but got {self.diagnostic_type}" - ) - - def sarif(self) -> sarif.Tool: - """Returns the SARIF Tool representation.""" - return sarif.Tool( - driver=sarif.ToolComponent( - name=self.name, - version=self.version, - rules=[rule.sarif() for rule in self._triggered_rules], - ) - ) - - def create_diagnostic( - self, - rule: Rule, - level: Level, - message_args: Optional[Tuple[Any, ...]], - **kwargs, - ) -> Diagnostic: - """Creates a diagnostic for the given arguments. - - Args: - rule: The rule that triggered the diagnostic. - level: The level of the diagnostic. - message_args: The arguments to format the rule's message template. - **kwargs: Additional arguments to pass to the Diagnostic constructor. - - Returns: - The created diagnostic. - - Raises: - ValueError: If the rule is not supported by the tool. - """ - if rule not in self.rules: - raise ValueError( - f"Rule '{rule.id}:{rule.name}' is not supported by this tool '{self.name} {self.version}'." - f" Supported rules are: {self.rules._rule_id_name_set}" - ) - self._triggered_rules.add(rule) - return self.diagnostic_type(rule, level, message_args, **kwargs) - - class Invocation: # TODO: Implement this. def __init__(self) -> None: @@ -301,9 +292,11 @@ class DiagnosticOptions: @dataclasses.dataclass class DiagnosticContext: - tool: DiagnosticTool + name: str + version: str options: Optional[DiagnosticOptions] = None - _diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) + diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic) + diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) _invocation: Invocation = dataclasses.field(init=False) def __enter__(self): @@ -315,15 +308,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): def sarif(self) -> sarif.Run: """Returns the SARIF Run object.""" return sarif.Run( - tool=self.tool.sarif(), - results=[diagnostic.sarif() for diagnostic in self._diagnostics], + tool=sarif.Tool( + driver=sarif.ToolComponent( + name=self.name, + version=self.version, + rules=[diagnostic.rule.sarif() for diagnostic in self.diagnostics], + ) + ), + results=[diagnostic.sarif() for diagnostic in self.diagnostics], ) + def add_diagnostic(self, diagnostic: Diagnostic) -> None: + """Adds a diagnostic to the context. + + Use this method to add diagnostics that are not created by the context. + Args: + diagnostic: The diagnostic to add. + """ + if not isinstance(diagnostic, self.diagnostic_type): + raise TypeError( + f"Expected diagnostic of type {self.diagnostic_type}, got {type(diagnostic)}" + ) + self.diagnostics.append(diagnostic) + def diagnose( self, rule: Rule, level: Level, - message_args: Optional[Tuple[Any, ...]] = None, + message: Optional[str] = None, **kwargs, ) -> Diagnostic: """Creates a diagnostic for the given arguments. @@ -331,7 +343,7 @@ def diagnose( Args: rule: The rule that triggered the diagnostic. level: The level of the diagnostic. - message_args: The arguments to format the rule's message template. + message: The message of the diagnostic. **kwargs: Additional arguments to pass to the Diagnostic constructor. Returns: @@ -340,6 +352,6 @@ def diagnose( Raises: ValueError: If the rule is not supported by the tool. """ - diagnostic = self.tool.create_diagnostic(rule, level, message_args, **kwargs) - self._diagnostics.append(diagnostic) + diagnostic = self.diagnostic_type(rule, level, message, **kwargs) + self.add_diagnostic(diagnostic) return diagnostic diff --git a/torch/onnx/_internal/diagnostics/infra/engine.py b/torch/onnx/_internal/diagnostics/infra/engine.py index 19fd846c35de..2678268fbaf9 100644 --- a/torch/onnx/_internal/diagnostics/infra/engine.py +++ b/torch/onnx/_internal/diagnostics/infra/engine.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List, Optional +from typing import List, Optional, Type from torch.onnx._internal.diagnostics import infra from torch.onnx._internal.diagnostics.infra import formatter, sarif @@ -14,9 +14,7 @@ class DiagnosticEngine: This class is the main interface for diagnostics. It manages the creation of diagnostic contexts. A DiagnosticContext provides the entry point for recording Diagnostics. - Each DiagnosticContext is powered by a DiagnosticTool, which can be customized with - custom RuleCollection and Diagnostic type. - See infra.DiagnosticContext and infra.DiagnosticTool for more details. + See infra.DiagnosticContext for more details. Examples: Step 1: Create a set of rules. @@ -31,36 +29,29 @@ class DiagnosticEngine: ... ], ... ) - Step 2: Create a diagnostic tool. - >>> tool = infra.DiagnosticTool( - ... name="tool", - ... version="1.0.0", - ... rules=rules, - ... ) - - Step 3: Create a diagnostic engine. + Step 2: Create a diagnostic engine. >>> engine = DiagnosticEngine() - Step 4: Start a new diagnostic context. - >>> with engine.start_diagnostic_context(tool) as context: + Step 3: Start a new diagnostic context. + >>> with engine.create_diagnostic_context("torch.onnx.export", version="1.0") as context: - Step 5: Add diagnostics in your code. + Step 4: Add diagnostics in your code. ... context.diagnose(rules.rule1, infra.Level.ERROR) - Step 6: Afterwards, get the SARIF log. + Step 5: Afterwards, get the SARIF log. >>> sarif_log = engine.sarif_log() """ - _contexts: List[infra.DiagnosticContext] + contexts: List[infra.DiagnosticContext] def __init__(self) -> None: - self._contexts = [] + self.contexts = [] def sarif_log(self) -> sarif.SarifLog: return sarif.SarifLog( version=sarif_version.SARIF_VERSION, schema_uri=sarif_version.SARIF_SCHEMA_LINK, - runs=[context.sarif() for context in self._contexts], + runs=[context.sarif() for context in self.contexts], ) def __str__(self) -> str: @@ -75,13 +66,27 @@ def to_json(self) -> str: def clear(self) -> None: """Clears all diagnostic contexts.""" - self._contexts.clear() + self.contexts.clear() def create_diagnostic_context( self, - tool: infra.DiagnosticTool, + name: str, + version: str, options: Optional[infra.DiagnosticOptions] = None, + diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic, ) -> infra.DiagnosticContext: - context = infra.DiagnosticContext(tool, options) - self._contexts.append(context) + """Creates a new diagnostic context. + + Args: + name: The subject name for the diagnostic context. + version: The subject version for the diagnostic context. + options: The options for the diagnostic context. + + Returns: + A new diagnostic context. + """ + context = infra.DiagnosticContext( + name, version, options, diagnostic_type=diagnostic_type + ) + self.contexts.append(context) return context diff --git a/torch/onnx/_internal/diagnostics/rules.yaml b/torch/onnx/_internal/diagnostics/rules.yaml index 717ce5e139fe..9d527bccf1e2 100644 --- a/torch/onnx/_internal/diagnostics/rules.yaml +++ b/torch/onnx/_internal/diagnostics/rules.yaml @@ -11,13 +11,14 @@ short_description: text: Node is missing ONNX shape inference. full_description: - text: "" + text: "Node is missing ONNX shape inference. + This usually happens when the node is not valid under standard ONNX operator spec." markdown: | Node is missing ONNX shape inference. This usually happens when the node is not valid under standard ONNX operator spec. message_strings: default: - text: "The shape inference of {0} type is missing, so it may result in wrong shape inference for the exported graph. + text: "The shape inference of {op_name} type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function." help_uri: properties: @@ -29,12 +30,12 @@ short_description: text: Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX. full_description: - text: "" + text: Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX. markdown: | Missing symbolic function for custom PyTorch operator, cannot translate node to ONNX. message_strings: default: - text: "ONNX export failed on an operator with unrecognized namespace {0}. + text: "ONNX export failed on an operator with unrecognized namespace {op_name}. If you are trying to export a custom operator, make sure you registered it with the right domain and version." help_uri: @@ -47,13 +48,13 @@ short_description: text: Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX. full_description: - text: "" + text: Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX. markdown: | Missing symbolic function for standard PyTorch operator, cannot translate node to ONNX. message_strings: default: - text: "Exporting the operator '{0}' to ONNX opset version {1} is not supported. - Please feel free to request support or submit a pull request on PyTorch GitHub: {2}." + text: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. + Please feel free to request support or submit a pull request on PyTorch GitHub: {issue_url}." help_uri: properties: deprecated: false @@ -65,7 +66,7 @@ short_description: text: Operator is supported in newer opset version. full_description: - text: "" + text: Operator is supported in newer opset version. markdown: | Operator is supported in newer opset version. @@ -75,8 +76,8 @@ ``` message_strings: default: - text: "Exporting the operator '{0}' to ONNX opset version {1} is not supported. - Support for this operator was added in version {2}, try exporting with this version." + text: "Exporting the operator '{op_name}' to ONNX opset version {opset_version} is not supported. + Support for this operator was added in version {supported_opset_version}, try exporting with this version." help_uri: properties: deprecated: false diff --git a/torch/onnx/errors.py b/torch/onnx/errors.py index 467494c56044..f5ad684cf168 100644 --- a/torch/onnx/errors.py +++ b/torch/onnx/errors.py @@ -46,49 +46,27 @@ class UnsupportedOperatorError(OnnxExporterError): """Raised when an operator is unsupported by the exporter.""" def __init__(self, name: str, version: int, supported_version: Optional[int]): - msg = f"Exporting the operator '{name}' to ONNX opset version {version} is not supported. " if supported_version is not None: - msg += ( - f"Support for this operator was added in version {supported_version}. " - "Please try exporting with this version." - ) - diagnostics.context.diagnose( - diagnostics.rules.operator_supported_in_newer_opset_version, - diagnostics.levels.ERROR, - message_args=( - name, - version, - supported_version, - ), + diagnostic_rule: diagnostics.infra.Rule = ( + diagnostics.rules.operator_supported_in_newer_opset_version ) + msg = diagnostic_rule.format_message(name, version, supported_version) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) else: - msg += "Please feel free to request support or submit a pull request on PyTorch GitHub: " - msg += _constants.PYTORCH_GITHUB_ISSUES_URL - if ( name.startswith("aten::") or name.startswith("prim::") or name.startswith("quantized::") ): - diagnostics.context.diagnose( - diagnostics.rules.missing_standard_symbolic_function, - diagnostics.levels.ERROR, - message_args=( - name, - version, - _constants.PYTORCH_GITHUB_ISSUES_URL, - ), + diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function + msg = diagnostic_rule.format_message( + name, version, _constants.PYTORCH_GITHUB_ISSUES_URL ) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) else: - msg += ( - "If you are trying to export a custom operator, make sure you registered " - "it with the correct domain and version." - ) - diagnostics.context.diagnose( - diagnostics.rules.missing_custom_symbolic_function, - diagnostics.levels.ERROR, - message_args=(name,), - ) + diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function + msg = diagnostic_rule.format_message(name) + diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) super().__init__(msg) From 500fd65531e77deb7784d3ac4f78c5cbe21efe41 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:31 -0800 Subject: [PATCH 035/453] [ONNX] Create common ExportTestCase base class (#88145) Refactor out a common base class `ExportTestCase`, for common things in `setUp`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88145 Approved by: https://github.com/justinchuby, https://github.com/abock, https://github.com/AllenTiTaiWang --- test/onnx/onnx_test_common.py | 14 +++------- test/onnx/pytorch_test_common.py | 26 ++++++++++++++++++ test/onnx/test_autograd_funs.py | 8 +++--- test/onnx/test_custom_ops.py | 7 ++--- test/onnx/test_export_modes.py | 4 ++- test/onnx/test_models.py | 5 ++-- test/onnx/test_models_onnxruntime.py | 5 ++-- test/onnx/test_onnx_opset.py | 3 ++- test/onnx/test_operators.py | 5 ++++ test/onnx/test_pytorch_helper.py | 3 ++- test/onnx/test_pytorch_jit_onnx.py | 3 ++- test/onnx/test_pytorch_onnx_caffe2.py | 27 +++++++------------ .../test_pytorch_onnx_caffe2_quantized.py | 3 ++- test/onnx/test_pytorch_onnx_no_runtime.py | 4 +-- .../onnx/test_pytorch_onnx_shape_inference.py | 3 ++- test/onnx/test_utility_funs.py | 11 +++----- 16 files changed, 76 insertions(+), 55 deletions(-) diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 45f90d4193ce..6963d16284ce 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -3,15 +3,13 @@ from __future__ import annotations import os -import random from typing import Any, Mapping, Type -import numpy as np import onnxruntime +import pytorch_test_common import torch from torch.onnx import _constants, verification -from torch.testing._internal import common_utils onnx_model_dir = os.path.join( os.path.dirname(os.path.realpath(__file__)), @@ -54,13 +52,7 @@ def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]) return f"{cls.__name__}_{suffix}" -def set_rng_seed(seed): - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - - -class _TestONNXRuntime(common_utils.TestCase): +class _TestONNXRuntime(pytorch_test_common.ExportTestCase): opset_version = _constants.ONNX_DEFAULT_OPSET keep_initializers_as_inputs = True # For IR version 3 type export. is_script = False @@ -68,7 +60,7 @@ class _TestONNXRuntime(common_utils.TestCase): check_dtype = True def setUp(self): - set_rng_seed(0) + super().setUp() onnxruntime.set_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py index 4a44932fb120..4e443c333f35 100644 --- a/test/onnx/pytorch_test_common.py +++ b/test/onnx/pytorch_test_common.py @@ -2,12 +2,17 @@ import functools import os +import random import sys import unittest from typing import Optional +import numpy as np + import torch from torch.autograd import function +from torch.onnx._internal import diagnostics +from torch.testing._internal import common_utils pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.insert(-1, pytorch_test_dir) @@ -188,3 +193,24 @@ def wrapper(self, *args, **kwargs): def flatten(x): return tuple(function._iter_filter(lambda o: isinstance(o, torch.Tensor))(x)) + + +def set_rng_seed(seed): + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + +class ExportTestCase(common_utils.TestCase): + """Test case for ONNX export. + + Any test case that tests functionalities under torch.onnx should inherit from this class. + """ + + def setUp(self): + super().setUp() + # TODO(#88264): Flaky test failures after changing seed. + set_rng_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + diagnostics.engine.clear() diff --git a/test/onnx/test_autograd_funs.py b/test/onnx/test_autograd_funs.py index 97f0652ecf37..a5498f39d2da 100644 --- a/test/onnx/test_autograd_funs.py +++ b/test/onnx/test_autograd_funs.py @@ -1,16 +1,16 @@ # Owner(s): ["module: onnx"] -import unittest +import pytorch_test_common import torch - from onnx_test_common import run_model_test from torch.onnx import OperatorExportTypes from torch.onnx._globals import GLOBALS from torch.onnx.utils import _model_to_graph +from torch.testing._internal import common_utils -class TestAutogradFuns(unittest.TestCase): +class TestAutogradFuns(pytorch_test_common.ExportTestCase): opset_version = GLOBALS.export_onnx_opset_version keep_initializers_as_inputs = False onnx_shape_inference = True @@ -209,4 +209,4 @@ def forward(self, input): if __name__ == "__main__": - unittest.main() + common_utils.run_tests() diff --git a/test/onnx/test_custom_ops.py b/test/onnx/test_custom_ops.py index 4242d70583ba..5609b497535e 100644 --- a/test/onnx/test_custom_ops.py +++ b/test/onnx/test_custom_ops.py @@ -4,6 +4,7 @@ import numpy as np import onnx import onnx_test_common +import pytorch_test_common import torch import torch.utils.cpp_extension from test_pytorch_onnx_caffe2 import do_export @@ -11,7 +12,7 @@ from torch.testing._internal import common_utils -class TestCustomOps(common_utils.TestCase): +class TestCustomOps(pytorch_test_common.ExportTestCase): def test_custom_add(self): op_source = """ #include @@ -56,7 +57,7 @@ def symbolic_custom_add(g, self, other): np.testing.assert_array_equal(caffe2_out[0], model(x, y).cpu().numpy()) -class TestCustomAutogradFunction(common_utils.TestCase): +class TestCustomAutogradFunction(pytorch_test_common.ExportTestCase): opset_version = 9 keep_initializers_as_inputs = False onnx_shape_inference = True @@ -130,7 +131,7 @@ def symbolic_pythonop(ctx: torch.onnx.SymbolicContext, g, *args, **kwargs): onnx_test_common.run_model_test(self, model, input_args=(x,)) -class TestExportAsContribOps(common_utils.TestCase): +class TestExportAsContribOps(pytorch_test_common.ExportTestCase): opset_version = 14 keep_initializers_as_inputs = False onnx_shape_inference = True diff --git a/test/onnx/test_export_modes.py b/test/onnx/test_export_modes.py index 0f3024a2e366..502f31b38b10 100644 --- a/test/onnx/test_export_modes.py +++ b/test/onnx/test_export_modes.py @@ -15,11 +15,13 @@ # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) +import pytorch_test_common + from torch.testing._internal import common_utils # Smoke tests for export methods -class TestExportModes(common_utils.TestCase): +class TestExportModes(pytorch_test_common.ExportTestCase): class MyModel(nn.Module): def __init__(self): super(TestExportModes.MyModel, self).__init__() diff --git a/test/onnx/test_models.py b/test/onnx/test_models.py index 7084bd75bace..15904839957e 100644 --- a/test/onnx/test_models.py +++ b/test/onnx/test_models.py @@ -2,8 +2,9 @@ import unittest -import torch +import pytorch_test_common +import torch from model_defs.dcgan import _netD, _netG, bsz, imgsz, nz, weights_init from model_defs.emb_seq import EmbeddingNetwork1, EmbeddingNetwork2 from model_defs.mnist import MNIST @@ -44,7 +45,7 @@ def toC(x): BATCH_SIZE = 2 -class TestModels(common_utils.TestCase): +class TestModels(pytorch_test_common.ExportTestCase): opset_version = 9 # Caffe2 doesn't support the default. keep_initializers_as_inputs = False diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index c84640e535e1..de1003ce449e 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -8,6 +8,7 @@ import onnx_test_common import parameterized import PIL +import pytorch_test_common import test_models import torch @@ -64,7 +65,7 @@ def exportTest( TestModels = type( "TestModels", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict( test_models.TestModels.__dict__, is_script_test_enabled=False, @@ -77,7 +78,7 @@ def exportTest( # model tests for scripting with new JIT APIs and shape inference TestModels_new_jit_API = type( "TestModels_new_jit_API", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict( TestModels.__dict__, exportTest=exportTest, diff --git a/test/onnx/test_onnx_opset.py b/test/onnx/test_onnx_opset.py index dab33bf00b09..ef79e82ee266 100644 --- a/test/onnx/test_onnx_opset.py +++ b/test/onnx/test_onnx_opset.py @@ -4,6 +4,7 @@ import itertools import onnx +import pytorch_test_common import torch import torch.onnx @@ -70,7 +71,7 @@ def check_onnx_opsets_operator( check_onnx_opset_operator(model, ops[opset_version], opset_version) -class TestONNXOpset(common_utils.TestCase): +class TestONNXOpset(pytorch_test_common.ExportTestCase): def test_opset_fallback(self): class MyModule(Module): def forward(self, x): diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 9b743a50d332..cfb36732af4d 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -31,6 +31,7 @@ ) from torch.autograd import Function, Variable from torch.nn import functional, Module +from torch.onnx._internal import diagnostics from torch.onnx.symbolic_helper import ( _get_tensor_dim_size, _get_tensor_sizes, @@ -71,6 +72,10 @@ def forward(self, *args): class TestOperators(common_utils.TestCase): + def setUp(self): + super().setUp() + diagnostics.engine.clear() + def assertONNX(self, f, args, params=None, **kwargs): if params is None: params = () diff --git a/test/onnx/test_pytorch_helper.py b/test/onnx/test_pytorch_helper.py index 362841d8bf90..7d7f3ade7f58 100644 --- a/test/onnx/test_pytorch_helper.py +++ b/test/onnx/test_pytorch_helper.py @@ -4,6 +4,7 @@ import unittest import numpy as np +import pytorch_test_common import torch.nn.init as init import torch.onnx @@ -15,7 +16,7 @@ from torch.testing._internal.common_utils import skipIfNoLapack -class TestCaffe2Backend(common_utils.TestCase): +class TestCaffe2Backend(pytorch_test_common.ExportTestCase): @skipIfNoLapack @unittest.skip("test broken because Lapack was always missing.") def test_helper(self): diff --git a/test/onnx/test_pytorch_jit_onnx.py b/test/onnx/test_pytorch_jit_onnx.py index f069251ee064..784bd0954b0a 100644 --- a/test/onnx/test_pytorch_jit_onnx.py +++ b/test/onnx/test_pytorch_jit_onnx.py @@ -1,5 +1,6 @@ # Owner(s): ["module: onnx"] import onnxruntime +import pytorch_test_common import torch from pytorch_test_common import skipIfNoCuda @@ -171,7 +172,7 @@ def MakeTestCase(opset_version: int) -> type: name = f"TestJITIRToONNX_opset{opset_version}" return type( str(name), - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(_TestJITIRToONNX.__dict__, opset_version=opset_version), ) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 141d3683171f..78440ac6ecb5 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -12,6 +12,7 @@ import model_defs.word_language_model as word_language_model import numpy as np import onnx +import pytorch_test_common import torch.onnx import torch.onnx.operators import torch.utils.model_zoo as model_zoo @@ -129,18 +130,10 @@ def do_export(model, inputs, *args, **kwargs): } -class TestCaffe2Backend_opset9(common_utils.TestCase): +class TestCaffe2Backend_opset9(pytorch_test_common.ExportTestCase): opset_version = 9 embed_params = False - def setUp(self): - # the following should ideally be super().setUp(), https://github.com/pytorch/pytorch/issues/79630 - common_utils.TestCase.setUp(self) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - np.random.seed(seed=0) - def convert_cuda(self, model, input): cuda_model = model.cuda() # input might be nested - we want to move everything to GPU @@ -3198,44 +3191,44 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9 = type( "TestCaffe2BackendEmbed_opset9", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) # opset 7 tests TestCaffe2Backend_opset7 = type( "TestCaffe2Backend_opset7", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=7), ) TestCaffe2BackendEmbed_opset7 = type( "TestCaffe2BackendEmbed_opset7", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=7), ) # opset 8 tests TestCaffe2Backend_opset8 = type( "TestCaffe2Backend_opset8", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=8), ) TestCaffe2BackendEmbed_opset8 = type( "TestCaffe2BackendEmbed_opset8", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=8), ) # opset 10 tests TestCaffe2Backend_opset10 = type( "TestCaffe2Backend_opset10", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, opset_version=10), ) TestCaffe2BackendEmbed_opset10 = type( "TestCaffe2BackendEmbed_opset10", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True, opset_version=10), ) @@ -3243,7 +3236,7 @@ def setup_rnn_tests(): # to embed_params=True TestCaffe2BackendEmbed_opset9_new_jit_API = type( "TestCaffe2BackendEmbed_opset9_new_jit_API", - (common_utils.TestCase,), + (pytorch_test_common.ExportTestCase,), dict(TestCaffe2Backend_opset9.__dict__, embed_params=True), ) diff --git a/test/onnx/test_pytorch_onnx_caffe2_quantized.py b/test/onnx/test_pytorch_onnx_caffe2_quantized.py index f6466aa0869e..92079ebbe6d9 100644 --- a/test/onnx/test_pytorch_onnx_caffe2_quantized.py +++ b/test/onnx/test_pytorch_onnx_caffe2_quantized.py @@ -6,13 +6,14 @@ import numpy as np import onnx +import pytorch_test_common import torch.ao.nn.quantized as nnq import torch.nn as nn import torch.onnx from torch.testing._internal import common_utils -class TestQuantizedOps(common_utils.TestCase): +class TestQuantizedOps(pytorch_test_common.ExportTestCase): def generic_test( self, model, sample_inputs, input_names=None, decimal=3, relaxed_check=False ): diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 1ec86ce69515..622f42effb4a 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -11,9 +11,9 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np - import onnx import onnx.numpy_helper +import pytorch_test_common import torch import torch.nn.functional as F @@ -74,7 +74,7 @@ def export_to_onnx( return onnx_model -class TestONNXExport(common_utils.TestCase): +class TestONNXExport(pytorch_test_common.ExportTestCase): def test_fuse_addmm(self): class AddmmModel(torch.nn.Module): def forward(self, x): diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index 86258fb1d0ec..cf9ef2fd893e 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -1,6 +1,7 @@ # Owner(s): ["module: onnx"] import numpy as np +import pytorch_test_common import torch from pytorch_test_common import skipIfUnsupportedMinOpsetVersion @@ -19,7 +20,7 @@ def verify(actual_type): return verify -class TestONNXShapeInference(common_utils.TestCase): +class TestONNXShapeInference(pytorch_test_common.ExportTestCase): def setUp(self): self.opset_version = _constants.ONNX_MAX_OPSET symbolic_helper._set_onnx_shape_inference(True) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 26467d54c1c6..51adaef317af 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -8,6 +8,7 @@ import onnx import parameterized +import pytorch_test_common import torch import torch.onnx @@ -27,13 +28,7 @@ from verify import verify -class _BaseTestCase(common_utils.TestCase): - def setUp(self): - super().setUp() - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - +class _BaseTestCase(pytorch_test_common.ExportTestCase): def _model_to_graph( self, model, @@ -64,7 +59,7 @@ def _model_to_graph( @common_utils.instantiate_parametrized_tests -class TestUnconvertibleOps(common_utils.TestCase): +class TestUnconvertibleOps(pytorch_test_common.ExportTestCase): """Unit tests for the `unconvertible_ops` function.""" def setUp(self): From cc04cf50bfb6110e4c1c5889ad7da626dafac384 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Nov 2022 23:37:29 +0000 Subject: [PATCH 036/453] [Inductor] Fix lowmem_dropout() missing 1 required positional argument: 'p' (#88716) Fixes error from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/master/generated/test_GuYuc_WS_DAN_PyTorch.py Error: ``` TypeError: lowmem_dropout() missing 1 required positional argument: 'p' While executing %lowmem_dropout : [#users=1] = call_function[target=torch._inductor.overrides.lowmem_dropout](args = (%avg_pool2d_9,), kwargs = {training: False}) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88716 Approved by: https://github.com/ngimel, https://github.com/jansel, https://github.com/desertfire --- test/inductor/test_torchinductor.py | 21 ++++++++++++++++----- torch/_inductor/overrides.py | 2 +- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 229f0fa83dd4..8fd4fa29bf98 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3666,13 +3666,24 @@ def test_dropout(self): torch.manual_seed(1234) @torch._dynamo.optimize("inductor") - def fn(a): - return torch.nn.functional.dropout(a, 0.5, True) + def fn1(a): + return torch.nn.functional.dropout(a) x = torch.ones(1000, device=self.device, dtype=torch.float32) - result = fn(x) - self.assertTrue(400 < result.nonzero().shape[0] < 600) - self.assertTrue(0.9 < result.mean().item() < 1.1) + result1 = fn1(x) + self.assertTrue(400 < result1.nonzero().shape[0] < 600) + self.assertTrue(0.9 < result1.mean().item() < 1.1) + + random.seed(1234) + torch.manual_seed(1234) + + @torch._dynamo.optimize("inductor") + def fn2(a): + return torch.nn.functional.dropout(a, 0.5, True) + + result2 = fn2(x) + self.assertTrue(400 < result2.nonzero().shape[0] < 600) + self.assertTrue(0.9 < result2.mean().item() < 1.1) def test_dropout_deterministic(self): @torch._dynamo.optimize("inductor") diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 581e1996a436..d89ee82674dd 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -562,7 +562,7 @@ def backward(ctx, grad_output): @torch.fx.wrap -def lowmem_dropout(input, p, training=True, inplace=False): +def lowmem_dropout(input, p=0.5, training=True, inplace=False): if isinstance(input, torch.fx.Proxy): # double check we don't FX trace this return input.tracer.create_proxy( From d9ad08ce8a07a3d17df397051b32591f4446edfa Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 10 Nov 2022 20:35:52 +0000 Subject: [PATCH 037/453] Symbolic shape: sym_floor , sym_sqrt, sym_int (#88760) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88760 Approved by: https://github.com/ezyang --- test/test_dynamic_shapes.py | 43 ++++++++++++++++++++-- torch/__init__.py | 33 +++++++++++++++-- torch/fx/experimental/symbolic_shapes.py | 45 +++++++++++++++--------- 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b23af9bbfb67..0f1f49d2e6ea 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -20,7 +20,7 @@ from torch.utils._pytree import tree_map from torch.fx.experimental import symbolic_shapes from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode +from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int from torch.utils._python_dispatch import TorchDispatchMode from torch import SymInt @@ -335,6 +335,45 @@ def test_guard_int(self): self.assertEqual(guard_int(a0), 2) self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 2)") + @skipIfNoSympy + def test_sym_int(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 5) + r = sym_int(a0) + self.assertEqual(r, 5) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertEqual(str(shape_env.guards[0][0]), "Eq(s0, 5)") + + a1 = create_symint(shape_env, 7) + r = sym_int(a1 / 2) + self.assertEqual(guard_int(r), 3) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertEqual(str(shape_env.guards[1][0]), "Eq(floor(s1/2), 3)") + + a2 = create_symint(shape_env, -3) + r = sym_int(a2 / 2) + self.assertEqual(guard_int(r), -1) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertEqual(str(shape_env.guards[2][0]), "Eq(ceiling(-s2/2), -1)") + + @skipIfNoSympy + def test_sym_sqrt(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 4) + r = sym_sqrt(a0) + self.assertEqual(r, 2) + self.assertIsInstance(r, torch.SymFloat, msg=type(r)) + self.assertEqual(str(shape_env.guards[0][0]), "Eq(sqrt(s0), 2)") + + @skipIfNoSympy + def test_sym_floor(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 5) + r = math.floor(a0 / 2) + self.assertEqual(r, 2) + self.assertIsInstance(r, torch.SymInt, msg=type(r)) + self.assertEqual(str(shape_env.guards[0][0]), "Eq(floor(s0/2), 2)") + @skipIfNoSympy def test_int_conversion(self): shape_env = ShapeEnv() @@ -526,7 +565,7 @@ def guard_fn(v): @parametrize("first_type", ["int", "float"]) @parametrize("second_type", ["int", "float"]) def test_method(self, fn, first_type, second_type): - if first_type == "float" and fn in symbolic_shapes.magic_methods_not_on_float: + if first_type == "float": self.skipTest(f"{fn} is not a float magic method") is_unary_fn = fn in symbolic_shapes.unary_magic_methods diff --git a/torch/__init__.py b/torch/__init__.py index 2abf4ba4b07d..ee271c0a975a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -218,8 +218,23 @@ def __int__(self): # Magic methods installed by torch.fx.experimental.symbolic_shapes + def __eq__(self, other: object) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + def __sym_float__(self): - ... + raise AssertionError("type stub not overridden") def __repr__(self): return self.node.str() @@ -247,8 +262,20 @@ def __bool__(self): # Magic methods installed by torch.fx.experimental.symbolic_shapes - def __sym_int__(self): - ... + def __eq__(self, other: object) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __lt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __gt__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __le__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") + + def __ge__(self, other) -> builtins.bool: + raise AssertionError("type stub not overridden") def __repr__(self): return self.node.str() diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 82e1d5107d79..d9b0a8fc2019 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1,6 +1,6 @@ import torch import torch.utils._pytree as pytree -from typing import Set, Dict, List, Type, Optional, cast +from typing import Set, Dict, List, Type, Optional, cast, Union import sys import operator import builtins @@ -24,7 +24,8 @@ __all__ = [ "has_symbolic_sizes_strides", "create_contiguous", "ShapeEnv", - "SymDispatchMode", "sym_float", "FloorDiv", "guard_int", "wrap_node" + "SymDispatchMode", "sym_int", "sym_float", "FloorDiv", "guard_int", "wrap_node", + "sym_sqrt", ] SYM_FUNCTION_MODE = None @@ -103,11 +104,26 @@ def sym_float(a): return a.__sym_float__() return float(a) +# Drop in replacement for math.sqrt +def sym_sqrt(a): + if hasattr(a, '__sym_sqrt__'): + return a.__sym_sqrt__() + return math.sqrt(a) + +# Drop in replacement for math.floor/ceil. Actually, math.floor/ceil +# directly usable, but this has a more relaxed type signature for mypy +# (mypy requires SupportFloat which is too strict) +def sym_floor(a): + return math.floor(a) # type: ignore[type] + +def sym_ceil(a): + return math.ceil(a) # type: ignore[type] + def sym_int(a): if isinstance(a, SymInt): return a - elif hasattr(a, '__sym_int__'): - return a.__sym_int__() + elif isinstance(a, SymFloat): + return sym_floor(a) if a > 0 else sym_ceil(a) return int(a) # TODO: An incomplete list @@ -255,29 +271,28 @@ def _nyi(): 'lt': lambda a, b: sympy.Lt(a, b), 'le': lambda a, b: sympy.Le(a, b), 'ge': lambda a, b: sympy.Ge(a, b), - 'sym_float': lambda a: a, # TODO: why can't I wrap with sympy.Float? - 'sym_int': lambda a: _nyi(), + 'floor': lambda a: sympy.floor(a), + 'sym_float': lambda a: a, # Cannot use sympy.Float(a) here, coz it expects python literals 'ceil': lambda a: sympy.ceiling(a), 'neg': lambda a: -a, 'min': lambda a, b: sympy.Min(a, b), 'max': lambda a, b: sympy.Max(a, b), + 'sym_sqrt': lambda a: sympy.sqrt(a), } unary_magic_methods = { 'sym_float', - 'sym_int', 'ceil', + 'floor', 'neg', + 'sym_sqrt', } -# TODO: sym_int should also work on floats -magic_methods_not_on_float = {"sym_int"} - magic_methods_on_builtins = {"min", "max"} magic_methods_on_math = {"ceil", "floor"} -magic_methods_on_submodule = {"sym_float", "sym_int"} +magic_methods_on_submodule = {"sym_float", "sym_sqrt"} -always_float_magic_methods = {"truediv", "sym_float"} +always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"} always_int_magic_methods = {"ceil", "floor"} always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"} @@ -383,10 +398,6 @@ def rbinary_magic_impl(self, other): for method, func in magic_methods.items(): _make_user_magic(method, SymInt) - -for method, func in magic_methods.items(): - if method in magic_methods_not_on_float: - continue _make_user_magic(method, SymFloat) del method @@ -479,7 +490,7 @@ def create_symbolic_sizes_strides(self, ex: torch.Tensor): assert all(x is not None for x in stride) return [self.create_symintnode(i) for i in size], [self.create_symintnode(i) for i in stride] # type: ignore[arg-type] - def create_symintnode(self, expr: "sympy.Expr"): + def create_symintnode(self, expr: Union["sympy.Expr", int]): return SymInt(SymNode(expr, self, int)) def create_symbol(self, val: int) -> "sympy.Expr": From ae01615d7558d02383efe673ec0b92e2abe40db5 Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Thu, 10 Nov 2022 23:44:49 +0000 Subject: [PATCH 038/453] Fix cupti search path in CMake (#88657) Minor fix for when cuda is installed via conda. In this case the libraries are in `lib` and not `lib64`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88657 Approved by: https://github.com/kit1980, https://github.com/malfet --- cmake/Dependencies.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index cf3c2c2caafd..104056ee0724 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1952,6 +1952,7 @@ if(USE_KINETO) find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS ${CUDA_SOURCE_DIR} ${CUDA_SOURCE_DIR}/extras/CUPTI/lib64 + ${CUDA_SOURCE_DIR}/lib ${CUDA_SOURCE_DIR}/lib64 NO_DEFAULT_PATH) From b30222e0c481f29fe0785dde518c590ac392e9a2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 10 Nov 2022 23:47:21 +0000 Subject: [PATCH 039/453] [Dynamo] Add complete support for Tensor.is_contiguous (#88407) Fixes https://github.com/pytorch/torchdynamo/issues/1783 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88407 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 12 +++++++ torch/_dynamo/variables/tensor.py | 55 ++++++++++++++++++++++++++----- 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 45433b6795cc..4df7153b8fb2 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1409,6 +1409,18 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) + def test_tensor_is_contiguous(self): + def fn(x): + input = torch.randn((1, 16, 1, 1)) + weight = torch.randn((8, 16, 3, 3)) + weight = weight.to(memory_format=x) + output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1) + return output.is_contiguous(memory_format=x) + + opt_fn = torch._dynamo.optimize("eager")(fn) + for x in [torch.contiguous_format, torch.channels_last]: + self.assertEqual(fn(x), opt_fn(x)) + def test_python_slice(self): def f1(input): y = 0 diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5a30f838e3f3..0974f24ee969 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -409,7 +409,13 @@ def specialize(value: torch.Tensor): if not config.dynamic_shapes: props["size"] = tuple(value.size()) props["stride"] = tuple(value.stride()) - props["is_contiguous"] = value.is_contiguous() + props["is_contiguous"] = tuple( + [ + x + for x in torch._prims_common._memory_formats + if value.is_contiguous(memory_format=x) + ] + ) return props def var_getattr(self, tx, name): @@ -492,13 +498,13 @@ def call_method( elif name == "is_floating_point" and self.dtype is not None: constant_result = ConstantVariable(self.dtype.is_floating_point, **options) elif name == "is_contiguous" and self.is_contiguous is not None: - if ( - "memory_format" in kwargs - and kwargs["memory_format"].as_python_constant() - == torch.contiguous_format - ): - kwargs.pop("memory_format") - constant_result = ConstantVariable(self.is_contiguous, **options) + if "memory_format" in kwargs: + memory_format = kwargs.pop("memory_format").as_python_constant() + else: + memory_format = torch.contiguous_format + constant_result = ConstantVariable( + memory_format in self.is_contiguous, **options + ) else: constant_result = None @@ -555,6 +561,39 @@ def call_method( current_tx=tx, ) return ConstantVariable(None, **options) + elif name in ("resize_", "resize_as_"): + if "memory_format" in kwargs: + memory_format = kwargs["memory_format"].as_python_constant() + else: + memory_format = torch.contiguous_format + + if name == "resize_": + self.size = args[0].as_python_constant() + self.is_contiguous = (memory_format,) + else: + assert isinstance(args[0], TensorVariable) + if self.size and args[0].size: + if ( + self.size == args[0].size + or memory_format is torch.preserve_format + ): + self.is_contiguous = args[0].is_contiguous + else: + self.size = args[0].size + self.stride = args[0].stride + self.ndim = args[0].ndim + self.is_contiguous = (memory_format,) + + return self.__class__.create( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + args, kwargs), + current_tx=tx, + ), + **options, + ) else: # Convert x.new(torch.Size) into x.new_empty(torch.Size), # as Tensor.new acts differently with a Size input versus a tuple input. From 62ef15e320f4a0aaa2f39296e9299f56926fb7c9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 10 Nov 2022 23:52:27 +0000 Subject: [PATCH 040/453] [MPS] Fix `test_embedding_dense_backward` (#88847) By copying randomly initialized weights distribution from MPS `nn.Embedding` to `cpu` Test plan: `python test_mps.py -k test_embedding_dense_backward --repeat 150` Fixes https://github.com/pytorch/pytorch/issues/88679 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88847 Approved by: https://github.com/seemethere --- test/test_mps.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_mps.py b/test/test_mps.py index 2ff5a9da71ef..30546f50fd65 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4282,8 +4282,9 @@ def helper(shape, dim, index, idx_dtype=torch.int32): def test_embedding_dense_backward(self): def helper(n, d, m, idx): embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps') + emedding_weight = embeddingMPS.weight.detach().cpu() W_MPS = torch.randn((m, d), requires_grad=True, device='mps') - idx_MPS = torch.tensor(idx).to('mps') + idx_MPS = torch.tensor(idx, device='mps') a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable a_MPS.retain_grad() b_MPS = embeddingMPS(idx_MPS) @ W_MPS.t() # modifies weight in-place @@ -4292,7 +4293,7 @@ def helper(n, d, m, idx): loss_MPS = out_MPS.sigmoid().prod() loss_MPS.backward() - embeddingCPU = nn.Embedding(n, d, max_norm=True, scale_grad_by_freq=True) + embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight) W_CPU = W_MPS.to('cpu') idx_CPU = torch.tensor(idx) a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable From 8441443132106fd673a81cd8f6728b332d16f837 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 10 Nov 2022 23:56:49 +0000 Subject: [PATCH 041/453] Revert "Add nondeterministic error for `scatter` (#88244)" This reverts commit e940a2f8e2a3aa9d98291e73b3d40fcffb6182c8. Reverted https://github.com/pytorch/pytorch/pull/88244 on behalf of https://github.com/mehtanirav due to Internal test failures --- .../ATen/native/TensorAdvancedIndexing.cpp | 4 -- test/test_torch.py | 40 ------------------- torch/__init__.py | 1 - 3 files changed, 45 deletions(-) diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index fa78b60c6684..3004dc1b31c7 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -1512,10 +1512,6 @@ TORCH_IMPL_FUNC(scatter_src_out) const Tensor& index, const Tensor& src, const Tensor& out) { - // See note [Writing Nondeterministic Operations] - // Nondeterministic when index contains duplicate entries, src is a tensor, - // and reduce=None - at::globalContext().alertNotDeterministic("scatter with src tensor and reduce=None"); scatter_impl(self, dim, index, src, out, scatter_reduce_stub, scatter_stub); diff --git a/test/test_torch.py b/test/test_torch.py index 82d0807d81a7..3ebc92676fe0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -1478,46 +1478,6 @@ def test_nondeterministic_alert_put_accumulate(self, device): 'put_', torch.device(device).type == 'cuda') - @expectedFailureMeta # expected a non-determinitic error, but it was not raised - @onlyNativeDeviceTypes - def test_nondeterministic_alert_scatter(self, device): - a = torch.randn(10, device=device) - indices = torch.tensor([0, 0], device=device) - values = torch.tensor([0., 1.], device=device) - result = torch.empty_like(a) - - error_msg = 'scatter with src tensor and reduce=None' - - error_cases = [ - lambda: torch.Tensor.scatter(a, 0, indices, values), - lambda: torch.Tensor.scatter_(a, 0, indices, values), - lambda: torch.scatter(a, 0, indices, values), - lambda: torch.scatter(a, 0, indices, values, out=result), - ] - - no_error_cases = [ - lambda: torch.Tensor.scatter(a, 0, indices, 0), - lambda: torch.Tensor.scatter_(a, 0, indices, 0), - lambda: torch.scatter(a, 0, indices, 0), - lambda: torch.scatter(a, 0, indices, 0, out=result), - - lambda: torch.Tensor.scatter(a, 0, indices, values, reduce='add'), - lambda: torch.Tensor.scatter_(a, 0, indices, values, reduce='add'), - lambda: torch.scatter(a, 0, indices, values, reduce='add'), - lambda: torch.scatter(a, 0, indices, values, out=result, reduce='add'), - ] - - for error_case in error_cases: - self.check_nondeterministic_alert( - error_case, - error_msg) - - for no_error_case in no_error_cases: - self.check_nondeterministic_alert( - no_error_case, - error_msg, - False) - @skipIfMps def test_nondeterministic_alert_histc(self, device): a = torch.tensor([], device=device) diff --git a/torch/__init__.py b/torch/__init__.py index ee271c0a975a..6049967b6f18 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -527,7 +527,6 @@ def use_deterministic_algorithms(mode, *, warn_only=False): ``mode='max'`` * :func:`torch.Tensor.put_` when ``accumulate=False`` * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor - * :func:`torch.Tensor.scatter` when ``src`` is a tensor and ``reduce=None`` * :func:`torch.histc` when called on a CUDA tensor * :func:`torch.bincount` when called on a CUDA tensor * :func:`torch.kthvalue` with called on a CUDA tensor From f9221bf53b376d1284e2356b716c2cd47fcd65f2 Mon Sep 17 00:00:00 2001 From: Ian Graves Date: Fri, 11 Nov 2022 00:19:20 +0000 Subject: [PATCH 042/453] [pytorch] Enable memory map file support for Android, Apple, and CXX (#88545) Summary: See title. Left Windows out so it still compiles. Test Plan: Add a `#fail` below [this line](https://fburl.com/code/p0mlhlw4) and build for various platforms and confirm it fails which proves the `#ifdef` was hit. ``` buck2 build xplat/langtech/tuna/cli:tuclixAndroid buck2 build xplat/langtech/tuna/cli:tuclix ``` CI/CD for the rest. Differential Revision: D41054824 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88545 Approved by: https://github.com/qihqi --- c2_defs.bzl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/c2_defs.bzl b/c2_defs.bzl index 573ba9f6ad64..0a89bb88093d 100644 --- a/c2_defs.bzl +++ b/c2_defs.bzl @@ -166,6 +166,7 @@ def get_c2_fbandroid_xplat_compiler_flags(): # T95767731 -- remove this once all builds are on at least llvm-13 "-Wno-unknown-warning-option", "-Wno-unused-but-set-variable", + "-DHAVE_MMAP", ] if get_c2_strip_glog(): @@ -392,6 +393,7 @@ def c2_cxx_library(**kwargs): args = get_c2_default_cxx_args() args.update(kwargs) args.setdefault("platforms", (ANDROID, APPLE, CXX, WINDOWS)) + fb_xplat_cxx_library( labels = [ "supermodule:android/default/caffe2", From 072834d56dada58f99216ce398fb57cce57968a9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 8 Nov 2022 07:59:12 -0800 Subject: [PATCH 043/453] [ao] qconfig_mapping.py fixing public v private (#87518) Summary: made _GLOBAL_DICT_KEY, _OBJECT_TYPE_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709278](https://our.internmc.facebook.com/intern/diff/D40709278) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87518 Approved by: https://github.com/jcaip --- test/quantization/fx/test_quantize_fx.py | 20 ++++++------ .../quantization/fx/qconfig_mapping_utils.py | 8 ++--- torch/ao/quantization/qconfig_mapping.py | 32 +++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 8c75658a04e1..6eb9246c85a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -90,11 +90,11 @@ from torch.ao.quantization.qconfig_mapping import ( _get_symmetric_qnnpack_qconfig_mapping, - GLOBAL_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, - OBJECT_TYPE_DICT_KEY, + _GLOBAL_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, QConfigMapping, ) @@ -1972,20 +1972,20 @@ def _get_qconfig_dict_for_qconfig_mapping_test(self, global_qconfig, qconfig1, q Return a dummy qconfig_dict to test QConfigMapping's to_dict and from_dict methods. """ return { - GLOBAL_DICT_KEY: global_qconfig, - OBJECT_TYPE_DICT_KEY: [ + _GLOBAL_DICT_KEY: global_qconfig, + _OBJECT_TYPE_DICT_KEY: [ (torch.nn.Linear, qconfig1), (torch.nn.ReLU, qconfig2), ], - MODULE_NAME_REGEX_DICT_KEY: [ + _MODULE_NAME_REGEX_DICT_KEY: [ ("foo.*bar", qconfig1), ("foo.*", qconfig2), ], - MODULE_NAME_DICT_KEY: [ + _MODULE_NAME_DICT_KEY: [ ("bazbaz", qconfig1), ("borbor", qconfig2), ], - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ ("bazbaz", torch.nn.Linear, 0, qconfig1), ("foofoo", torch.nn.ReLU, 1, qconfig2), ], diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 66dffd50cd00..0b0407c0b106 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -23,9 +23,9 @@ get_qconfig_dtypes, ) from ..qconfig_mapping import ( - OBJECT_TYPE_DICT_KEY, - MODULE_NAME_DICT_KEY, - MODULE_NAME_REGEX_DICT_KEY, + _OBJECT_TYPE_DICT_KEY, + _MODULE_NAME_DICT_KEY, + _MODULE_NAME_REGEX_DICT_KEY, QConfigMapping, ) from ..qconfig_mapping_utils import ( @@ -223,7 +223,7 @@ def compare_prepare_convert_qconfig_mappings( convert_qconfig_mapping.module_name_qconfigs, convert_qconfig_mapping.module_name_regex_qconfigs, ] - dict_names = [OBJECT_TYPE_DICT_KEY, MODULE_NAME_DICT_KEY, MODULE_NAME_REGEX_DICT_KEY] + dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY] for i in range(len(prepare_dicts)): for name, qconfig in prepare_dicts[i].items(): assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \ diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 418cbb334814..e3410a52a9d8 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -33,11 +33,11 @@ # TODO: replace all usages with these constants -GLOBAL_DICT_KEY = "" -OBJECT_TYPE_DICT_KEY = "object_type" -MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" -MODULE_NAME_DICT_KEY = "module_name" -MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" +_GLOBAL_DICT_KEY = "" +_OBJECT_TYPE_DICT_KEY = "object_type" +_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex" +_MODULE_NAME_DICT_KEY = "module_name" +_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" _FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, @@ -274,11 +274,11 @@ def to_dict(self) -> Dict[str, Any]: The values of this dictionary are lists of tuples. """ return { - GLOBAL_DICT_KEY: self.global_qconfig, - OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), - MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), - MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ + _GLOBAL_DICT_KEY: self.global_qconfig, + _OBJECT_TYPE_DICT_KEY: list(self.object_type_qconfigs.items()), + _MODULE_NAME_REGEX_DICT_KEY: list(self.module_name_regex_qconfigs.items()), + _MODULE_NAME_DICT_KEY: list(self.module_name_qconfigs.items()), + _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY: [ (*k, v) for k, v in self.module_name_object_type_order_qconfigs.items() ], } @@ -302,14 +302,14 @@ def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: The values of this dictionary are expected to be lists of tuples. """ conf = cls() - if GLOBAL_DICT_KEY in qconfig_dict: - conf.set_global(qconfig_dict[GLOBAL_DICT_KEY]) - for object_type, qconfig in qconfig_dict.get(OBJECT_TYPE_DICT_KEY, []): + if _GLOBAL_DICT_KEY in qconfig_dict: + conf.set_global(qconfig_dict[_GLOBAL_DICT_KEY]) + for object_type, qconfig in qconfig_dict.get(_OBJECT_TYPE_DICT_KEY, []): conf.set_object_type(object_type, qconfig) - for module_name_regex, qconfig in qconfig_dict.get(MODULE_NAME_REGEX_DICT_KEY, []): + for module_name_regex, qconfig in qconfig_dict.get(_MODULE_NAME_REGEX_DICT_KEY, []): conf.set_module_name_regex(module_name_regex, qconfig) - for module_name, qconfig in qconfig_dict.get(MODULE_NAME_DICT_KEY, []): + for module_name, qconfig in qconfig_dict.get(_MODULE_NAME_DICT_KEY, []): conf.set_module_name(module_name, qconfig) - for module_name, object_type, index, qconfig in qconfig_dict.get(MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): + for module_name, object_type, index, qconfig in qconfig_dict.get(_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) return conf From 534ae6ae4790aec1b148b7e878ae60828ae45ac0 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Fri, 11 Nov 2022 01:08:16 +0000 Subject: [PATCH 044/453] [primTorch] Implement group norm reference (#87054) Add group norm reference Split from #81191 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87054 Approved by: https://github.com/mruberry --- test/test_fx.py | 4 +- test/test_ops.py | 5 +- torch/_decomp/decompositions.py | 31 ------ torch/_refs/__init__.py | 62 ++++++++++++ torch/_refs/nn/functional/__init__.py | 40 ++++++++ torch/nn/functional.py | 2 + .../_internal/common_methods_invocations.py | 97 ++++++++++++++++--- 7 files changed, 191 insertions(+), 50 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index 0aa5b28a3de7..0aff631b8e81 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -3925,7 +3925,6 @@ def tearDown(self): "max_pool2d": PROXY_ITERABLE, "max_pool3d": PROXY_ITERABLE, - "group_norm": PROXY_ITERATED, "lp_pool2d": PROXY_ITERATED, "max_unpool1d": PROXY_ITERATED, "max_unpool2d": PROXY_ITERATED, @@ -3959,6 +3958,7 @@ def tearDown(self): "gaussian_nll_loss": CONTROL_FLOW, "glu": CONTROL_FLOW, "grid_sample": CONTROL_FLOW, + "group_norm": CONTROL_FLOW, "gumbel_softmax": CONTROL_FLOW, "hardsigmoid": CONTROL_FLOW, "hardswish": CONTROL_FLOW, @@ -4029,7 +4029,7 @@ def tearDown(self): "max_pool2d": PROXY_ITERATED, "max_pool3d": PROXY_ITERATED, - "group_norm": LEN_ERROR + "group_norm": CONTROL_FLOW } @classmethod diff --git a/test/test_ops.py b/test/test_ops.py index d0aa0906784d..73758bfc6b46 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -417,9 +417,10 @@ def test_python_ref_executor(self, device, dtype, op, executor): # skip zero-dim tensors for some composites of reduction operations and view skip_zero_dim_ops = [ - "_refs.softmax", "_refs.logsumexp", "_refs.log_softmax", + "_refs.native_group_norm", + "_refs.softmax", "_refs.sum_to_size", "ops.nvprims.view", ] @@ -1659,11 +1660,13 @@ class TestRefsOpsInfo(TestCase): '_refs.index_add_', '_refs.index_copy_', '_refs.index_fill_', + '_refs.native_group_norm', } not_in_decomp_table = { # duplicated in _decomp and _refs '_refs.nn.functional.elu', + '_refs.nn.functional.group_norm', '_refs.nn.functional.mse_loss', '_refs.rsub', # duplicated due to efficiency concerns of the ref vs the decomp diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 0e1d1cd1dd51..fe63e0db007a 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1138,37 +1138,6 @@ def normalize(input, norm_dims, eps): return out, mean, rstd -@register_decomposition(aten.native_group_norm.default) -def native_group_norm( - input: Tensor, - weight: Optional[Tensor], - bias: Optional[Tensor], - N: int, - C: int, - HxW: int, - group: int, - eps: float, -) -> Tuple[Tensor, Tensor, Tensor]: - orig_shape = input.shape - input = input.view(N, group, C // group, HxW) - reduction_dims = [2, 3] - out, mean, rstd = normalize(input, reduction_dims, eps) - mean = _squeeze_multiple(mean, reduction_dims) - rstd = _squeeze_multiple(rstd, reduction_dims) - out = out.view(orig_shape) - if weight is not None: - weight = _unsqueeze_to_dim(weight, out.dim() - 1) - out = out * weight - if bias is not None: - bias = _unsqueeze_to_dim(bias, out.dim() - 1) - out = out + bias - - out = out.to(dtype=input.dtype) - mean = mean.to(dtype=input.dtype) - rstd = rstd.to(dtype=input.dtype) - return (out, mean, rstd) - - @register_decomposition(aten.native_group_norm_backward) @pw_cast_for_opmath def native_group_norm_backward( diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index cd0344eba7a9..36fef59df375 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -238,6 +238,7 @@ "movedim", "narrow", "narrow_copy", + "native_group_norm", "native_layer_norm", "permute", "ravel", @@ -2781,6 +2782,7 @@ def _normalize( mean (Tensor): mean of the tensor along norm_dims. rstd (Tensor): 1/std of the tensor along norm_dims. """ + norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) computation_dtype = utils.get_computation_dtype(a.dtype) a_acc = _maybe_convert_to_dtype(a, computation_dtype) assert isinstance(a_acc, TensorLike) # to avoid mypy error for var_mean @@ -2792,6 +2794,66 @@ def _normalize( return out, mean, rstd +# add all specified dimensions +def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in sorted(dimensions): + x = torch.unsqueeze(x, dim) + return x + + +@register_decomposition(torch.ops.aten.native_group_norm.default) +def native_group_norm( + input: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + batch_size: int, + num_channels: int, + flattened_inner_size: int, + num_groups: int, + eps: float, +) -> Tuple[Tensor, Tensor, Tensor]: + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # num_channels / num_groups and flattened inner dimension are the reduction axes + reduction_dims = [2, 3] + input_reshaped = torch.reshape( + input, + [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], + ) + out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) + out = out.view(input.shape) + + broadcast_dims = [0] + list(dim for dim in range(2, input.ndim)) + unsqueeze_bias = None + if bias is not None: + unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) + unsqueeze_weight = None + if weight is not None: + unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + + if unsqueeze_weight is not None: + out = out * unsqueeze_weight + if unsqueeze_bias is not None: + out = out + unsqueeze_bias + + out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] + mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment] + rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] + + # remove broadcast dimensions from mean and rstd + mean = prims.squeeze(mean, reduction_dims) + rstd = prims.squeeze(rstd, reduction_dims) + return (out, mean, rstd) + + @register_decomposition(torch.ops.aten.native_layer_norm) def native_layer_norm( input: Tensor, diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 3cde67844947..dcd86d8952d2 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -171,6 +171,46 @@ def relu(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: return torch.where(torch.le(a, 0), 0, a) +def group_norm( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +) -> Tensor: + """ + Reference implementation of :func:`torch.nn.functional.group_norm`. + """ + utils.check( + input.ndim >= 2, + lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}", + ) + + batch_size = input.shape[0] + num_channels = input.shape[1] + utils.check( + num_channels % num_groups == 0, + lambda: "Expected number of channels in input to be divisible by num_groups, " + + f"but got input of shape {input.shape} and num_groups = {num_groups}", + ) + + # input shape is (N, C, *), so we flatten all inner dimensions except (N, C) + flattened_inner_size = 1 + for dim_length in input.shape[2:]: + flattened_inner_size *= dim_length + + return torch.native_group_norm( + input, + weight, + bias, + batch_size, + num_channels, + flattened_inner_size, + num_groups, + eps, + )[0] + + def layer_norm( input: Tensor, normalized_shape: ShapeType, diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 79bf6297e587..961dd83f57b2 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2524,6 +2524,8 @@ def group_norm( """ if has_torch_function_variadic(input, weight, bias): return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps) + if input.dim() < 2: + raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}") _verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:])) return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 731dc008ccce..b702c1161860 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3334,27 +3334,72 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample= def sample_inputs_group_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - # Ordered as input shape, num groups, and eps + # Ordered as input shape, num groups, and kwargs for eps cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] - ((1, 6, 3), 2, 0.5), - ((2, 6, 3), 2, -0.5), - ((1, 2), 1, None), - ((0, 2), 1, None), + ((1, 6, 3), 2, {'eps' : 0.5}), + ((2, 6, 3), 2, {'eps' : -0.5}), + ((1, 3), 1, {'eps' : 1e-5}), + ((0, 2), 1, {'eps' : 1e-5}), + ((S, S, S), 1, {'eps' : 0.5}), ) - for input_shape, num_groups, eps in cases: + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: # Shape of weight and bias should be the same as num_channels - weight = make_arg(input_shape[1]) - bias = make_arg(input_shape[1]) - kwargs = {'weight': weight, 'bias': bias} if eps is None else {'weight': weight, 'bias': bias, 'eps': eps} - yield SampleInput( - make_arg(input_shape), - args=(num_groups,), - kwargs=kwargs - ) + channels = input_shape[1] if len(input_shape) > 1 else 0 + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(make_arg(input_shape), num_groups, **kwargs) + # Without any optional args yield SampleInput(make_arg((1, 2)), args=(1,)) +def reference_inputs_group_norm(op_info, device, dtype, requires_grad, **kwargs): + yield from sample_inputs_group_norm( + op_info, device, dtype, requires_grad, **kwargs) + + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # Ordered as input shape, num groups, and kwargs for eps + cases: Tuple[Tuple[int], int, float] = ( # type: ignore[assignment] + ((20, 6, 10, 10), 3, {'eps' : 1e-5}), + # equivalent with InstanceNorm + # GroupNorm(C, num_groups=C) == InstanceNorm(num_features=C) + ((20, 6, 10, 10), 6, {'eps' : 1e-5}), + # equivalent with LayerNorm + # GroupNorm(C, num_groups=1, affine=False) == LayerNorm(normalized_shape=[C, H, W], elementwise_affine=False) + ((20, 6, 10, 10), 1, {'eps' : 1e-5}), + ) + + # num_channels is inferred to be input.shape[1] dimension + for input_shape, num_groups, kwargs in cases: + # Shape of weight and bias should be the same as num_channels + channels = input_shape[1] if len(input_shape) > 1 else 0 + input_tensor = make_arg(input_shape) + weight_tensor = make_arg(channels) + bias_tensor = make_arg(channels) + + # Checking for permutations of weights and biases as `None` + weights = [weight_tensor, None] + biases = [bias_tensor, None] + for weight, bias in itertools.product(weights, biases): + kwargs = { + 'weight': weight, + 'bias': bias, + **kwargs + } + yield SampleInput(input_tensor, num_groups, **kwargs) + def sample_inputs_instance_norm(opinfo, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3481,6 +3526,18 @@ def sample_inputs_native_layer_norm(opinfo, device, dtype, requires_grad, **kwar args=(normalized_shape, None, None, eps), ) +def error_inputs_group_norm(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) + + # check that input has minimum number of dimensions + err_msg1 = "Expected at least 2 dimensions for input tensor but received" + s1 = SampleInput(make_arg((1)), args=(1,)) + yield ErrorInput(s1, error_regex=err_msg1) + + # check that the channels dimension is compatible with number of groups + err_msg2 = "Expected number of channels in input to be divisible by num_groups, but got input of shape" + s2 = SampleInput(make_arg((2, 7, 4)), args=(2,)) + yield ErrorInput(s2, error_regex=err_msg2) def error_inputs_native_layer_norm(opinfo, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32, requires_grad=False) @@ -7747,12 +7804,12 @@ def reference_group_norm(inp: np.ndarray, num_groups: int, weight=None, bias=Non if weight is not None: # weight is a vector of length equal to the channel if len(Y.shape) > 2: - weight = np.tile(np.expand_dims(weight, 1), [1] + list(inp.shape[2:])) + weight = np.expand_dims(weight, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) Y = Y * weight if bias is not None: # bias is a vector of length equal to the channel if len(Y.shape) > 2: - bias = np.tile(np.expand_dims(bias, 1), [1] + list(inp.shape[2:])) + bias = np.expand_dims(bias, [0] + [idx + 2 for idx in range(inp.ndim - 2)]) Y = Y + bias return Y @@ -10921,12 +10978,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_group_norm, decorators=[ # RuntimeError: Cannot insert a Tensor that requires grad as a constant. # Consider making it a parameter or input, or detaching the gradient DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)) ], sample_inputs_func=sample_inputs_group_norm, + reference_inputs_func=reference_inputs_group_norm, supports_expanded_weight=True,), OpInfo('nn.functional.instance_norm', # no ref because instance_norm will often have numerical instability (large numbers or nan) @@ -17941,6 +18000,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), ) ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + supports_nvfuser=False, + validate_view_consistency=False, + ), PythonRefInfo( "_refs.narrow_copy", torch_opinfo_name="narrow_copy", From c961e45ee559a61bfb4f1e8a548e574ef89d3102 Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 10 Nov 2022 12:21:50 -0800 Subject: [PATCH 045/453] handle zero dims in reductions (#88280) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88280 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 21 +++++++++++++++++ torch/_inductor/ir.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 8fd4fa29bf98..121f3d31f39c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4224,6 +4224,27 @@ def forward(x): ] self.common(forward, args) + def test_zero_dim_reductions(self): + for kd in [True, False]: + inps0 = (torch.zeros(2, 0, device=self.device, dtype=torch.float16), 1, kd) + failed_ops = [aten.argmin, aten.argmax, aten.max, aten.min] + for fo in failed_ops: + with self.assertRaisesRegex( + IndexError, "Expected reduction dim 1 to have non-zero size" + ): + mod = make_fx(fo)(*inps0) + _ = compile_fx_inner(mod, inps0) + + pass_ops = [ + lambda *x: fn(*x) for fn in [aten.sum, aten.prod, aten.any, aten.all] + ] + for po in pass_ops: + compiled = torch._dynamo.optimize("inductor")(po) + expected = po(*inps0) + actual = compiled(*inps0) + + self.assertTrue(torch.allclose(actual, expected, atol=1e-3, rtol=1e-3)) + @requires_cuda() def test_unspec_inputs(self): def fn(x, y): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 924ec7aaa7b2..448c057ecb0e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -729,6 +729,42 @@ def create( reduction_hint: ReductionHint = ReductionHint.DEFAULT, ): reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) + + if reduction_numel == 0: + + # N.B. This is a hack to generate the literal of the given type + # Ideally, we should be fixing `def constant` in triton.py + # but it breaks due to hardcoded dtypes in other places + def py_cnst(val): + return ( + bool(val) + if dst_dtype == torch.bool + else float(val) + if dst_dtype.is_floating_point + else int(val) + ) + + rtypes_to_inits = { + "sum": py_cnst(0), + "prod": py_cnst(1), + "any": py_cnst(0), + # "all" is desugared to `!any(!val)` + } + + assert ( + reduction_type in rtypes_to_inits.keys() + ), f"{reduction_type} not supported for zero-dimension tensors!" + + def const_fn(index): + return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) + + return Pointwise.create( + device=device, + dtype=src_dtype, + inner_fn=const_fn, + ranges=list(ranges), + ) + if reduction_numel == 1: # this reduction is actually a pointwise op if reduction_type in ("argmin", "argmax"): From fc9e36dd426d4747bb7c71ee93bcbaa700bda01d Mon Sep 17 00:00:00 2001 From: anjali411 Date: Thu, 10 Nov 2022 22:41:47 +0000 Subject: [PATCH 046/453] Add meta support for scalar_tensor and argmax (#88590) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88590 Approved by: https://github.com/albanD --- test/functorch/test_vmap.py | 1 + test/test_proxy_tensor.py | 6 +-- torch/_meta_registrations.py | 42 +++++++++++++++++++ .../_internal/common_methods_invocations.py | 32 ++++++++++++-- 4 files changed, 74 insertions(+), 7 deletions(-) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 3acab4172fce..5ba35de21b8b 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3229,6 +3229,7 @@ def test(): xfail('linspace', ''), # test runner can't handle factory functions xfail('arange', ''), # test runner can't handle factory functions xfail('logspace', ''), # test runner can't handle factory functions + xfail('scalar_tensor'), # test runner can't handle factory functions xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fbeaa04aa65d..72c7249f4f14 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1116,8 +1116,8 @@ def f(a, b, c, d, e): skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition + xfail('masked.argmax', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... + xfail('masked.argmin', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition @@ -1134,8 +1134,6 @@ def f(a, b, c, d, e): xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition - xfail('argmax', ''), # aten.argmax.default - couldn't find symbolic meta function/decomposition - xfail('argmin', ''), # aten.argmin.default - couldn't find symbolic meta function/decomposition xfail('argwhere', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5035eadf84a4..04c522ab9e3b 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1735,6 +1735,48 @@ def meta_sort(self, stable=None, dim=-1, descending=False): return torch.empty_like(self), torch.empty_like(self, dtype=torch.int64) +def zero_numel_check_dims(self, dim, fn_name): + if self.ndim == 0: + check( + dim == 0 or dim == -1, + lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}", + IndexError, + ) + else: + check( + self.size(dim) != 0, + lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.", + IndexError, + ) + + +# From aten/src/ATen/native/ReduceOps.cpp +def check_argmax_argmin(name, self, dim): + if dim is not None: + dim = maybe_wrap_dim(dim, self.dim()) + zero_numel_check_dims(self, dim, name) + else: + check( + self.numel() != 0, + lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.", + ) + + +@register_meta([aten.argmax.default, aten.argmin.default]) +def argmax_argmin_meta(self, dim=None, keepdim=False): + check_argmax_argmin("argmax", self, dim) + dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None) + shape = _compute_reduction_shape(self, dims, keepdim) + return self.new_empty(shape, dtype=torch.int64) + + +@register_meta(aten.scalar_tensor.default) +def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): + return torch.empty( + (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b702c1161860..b41e74a24c10 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1372,6 +1372,15 @@ def sample_inputs_empty(op, device, dtype, requires_grad, **kwargs): for case in cases: yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad) +def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs): + # Not including a scalar tensor in vals because meta tests start failing due to + # lack of meta support for _local_scalar_dense + # torch.tensor(2, device=device) + vals = (-5, 0, 1) + + for item in vals: + yield SampleInput(item, device=device, dtype=dtype, requires_grad=requires_grad) + def sample_inputs_eye(op, device, dtype, requires_grad, **kwargs): # only ints >= 0 are allowed for both arguments, unless m is omitted sizes = (None, 0, 1, 2, 3, 4, 7, L, M, S) @@ -9287,9 +9296,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): error_inputs_func=error_inputs_diag), OpInfo('diag_embed', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), - # TODO: this is very questionable, because we do have - # diag_embed.out but it's not bound to Python somehow - # https://github.com/pytorch/pytorch/issues/88598 supports_out=False, # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, @@ -10546,6 +10552,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): assert_jit_shape_analysis=True, sample_inputs_func=sample_inputs_native_batch_norm, skips=( + # NotImplementedError: Could not run + # 'aten::native_batch_norm.out' with arguments from the 'CPU' backend. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type="cpu"), # RuntimeError: out_invstd.dim() == 1 && out_invstd.is_contiguous() && out_invstd.sizes()[0] DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type="cuda"), @@ -14511,6 +14519,24 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), )), + OpInfo('scalar_tensor', + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + sample_inputs_func=sample_inputs_scalar_tensor, + supports_autograd=False, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'), + # fails to match any schemas despite working in the interpreter + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), + # skip these tests since we have non tensor input + DecorateInfo(unittest.skip('Skipped!'), "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'), + )), OpInfo('new_full', op=lambda x, *args, **kwargs: x.new_full(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), From 3fbf748f2109de408bd47efb1a43e3897d7a775c Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Fri, 11 Nov 2022 02:30:29 +0000 Subject: [PATCH 047/453] Assert we have triton before scheduling on triton (#88849) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88849 Approved by: https://github.com/wconstab, https://github.com/ngimel, https://github.com/jansel --- torch/_inductor/scheduler.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2f1c4b7c2e64..cb71a4443804 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -16,7 +16,7 @@ from . import config, dependencies, ir from .dependencies import MemoryDep, StarDep from .sizevars import SimplifyIndexing -from .utils import cache_on_self, cmp, dynamo_utils +from .utils import cache_on_self, cmp, dynamo_utils, has_triton from .virtualized import V log = logging.getLogger(__name__) @@ -1078,6 +1078,16 @@ def create_backend(self, device: torch.device): return CppScheduling(self) else: + if not has_triton(): + device_props = torch.cuda.get_device_properties(device) + if device_props.major < 6: + raise RuntimeError( + f"Found {device_props.name} which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 6.0, but your device is of CUDA capability {device_props.major}.{device_props.minor}" # noqa: B950 + ) + else: + raise RuntimeError( + "Cannot find a working triton installation. More information on installing Triton can be found at https://github.com/openai/triton" # noqa: B950 + ) from .codegen.triton import TritonScheduling return TritonScheduling(self) From 495e7b1c729e64693e794ea22640b4552816f0ef Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 10 Nov 2022 21:22:29 +0000 Subject: [PATCH 048/453] Ref for aten.full; symint changes in prim (#88762) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88762 Approved by: https://github.com/ezyang --- test/functorch/test_vmap.py | 1 + test/test_ops.py | 1 - torch/_prims_common/__init__.py | 5 ++- torch/_refs/__init__.py | 17 +++++--- .../_internal/common_methods_invocations.py | 40 +++++++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 5ba35de21b8b..6d95077b627e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3233,6 +3233,7 @@ def test(): xfail('empty', ''), # test runner can't handle factory functions xfail('ones', ''), # test runner can't handle factory functions xfail('zeros', ''), # test runner can't handle factory functions + xfail('full', ''), # test runner can't handle factory functions xfail('eye', ''), # non-tensor input xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse diff --git a/test/test_ops.py b/test/test_ops.py index 73758bfc6b46..c688f6521af1 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1743,7 +1743,6 @@ class TestRefsOpsInfo(TestCase): '_refs.unflatten', '_refs.sum_to_size', # ref implementation missing kwargs - '_refs.full', # missing "layout" '_refs.full_like', # missing "layout" '_refs.ones_like', # missing "layout" '_refs.round', # missing "decimals" diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 90777ed6601a..128796dfa3d0 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -837,10 +837,11 @@ def type_to_dtype(typ: type) -> torch.dtype: if typ is bool: return torch.bool - if typ is int: + if typ in [int, torch.SymInt]: return torch.long - if typ is float: + if typ in [float, torch.SymFloat]: return torch.get_default_dtype() + # TODO: sym_complex_float? if typ is complex: return corresponding_complex_dtype(torch.get_default_dtype()) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 36fef59df375..43b0c74192de 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -322,7 +322,7 @@ def _broadcast_shapes(*_shapes): common_shape = [ 1, ] * reduce(max, (len(shape) for shape in shapes)) - for shape in shapes: + for arg_idx, shape in enumerate(shapes): for idx in range(-1, -1 - len(shape), -1): if common_shape[idx] == 1: if shape[idx] < 0: @@ -333,9 +333,9 @@ def _broadcast_shapes(*_shapes): elif shape[idx] != 1: if common_shape[idx] != shape[idx]: raise RuntimeError( - "Attempting to broadcast a dimension of length ", - str(shape[idx]), - "!", + f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}" ) return common_shape @@ -4495,6 +4495,7 @@ def eye( # result.requires_grad_(requires_grad) +@register_decomposition(torch.ops.aten.full) @out_wrapper() def full( shape: ShapeType, @@ -4506,6 +4507,12 @@ def full( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: + utils.check_layout(layout) + utils.check_pin_memory(pin_memory) + + dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value)) + device = device if device is not None else torch.device("cpu") + e = empty( shape, dtype=dtype, @@ -4514,7 +4521,7 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return fill(e, fill_value) + return torch.fill(e, fill_value) # type: ignore[arg-type] def full_like( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index b41e74a24c10..5178ec978bd1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -772,6 +772,20 @@ def sample_inputs_ones_zeros(op, device, dtype, requires_grad, **kwargs): for size in sizes: yield SampleInput(size, kwargs={'dtype': dtype, 'device': device}) +def sample_inputs_full(op, device, dtype, requires_grad, **kwargs): + def get_val(dtype): + return make_tensor([], dtype=dtype, device="cpu").item() + + sizes = ( + (M,), + (S, S), + ) + fill_values = [get_val(dtype), get_val(torch.int)] + + for size, fill_value in product(sizes, fill_values): + yield SampleInput(size, fill_value, dtype=dtype, device=device) + + def error_inputs_uniform(op, device, **kwargs): t = torch.zeros([10], device=device) yield ErrorInput( @@ -14373,6 +14387,32 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), )), + OpInfo('full', + op=torch.full, + supports_autograd=False, + is_factory_function=True, + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), + supports_out=True, + sample_inputs_func=sample_inputs_full, + skips=( + # Tests that assume input is a tensor or sequence of tensors + DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), + DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), + # Same failure as arange: cannot find linspace in captured graph + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), + # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # boolean alpha not handled properly + DecorateInfo(unittest.expectedFailure, + 'TestCudaFuserOpInfo', + 'test_nvfuser_correctness', + dtypes=(torch.bool,)), + # RuntimeError: UNSUPPORTED DTYPE: bool + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', dtypes=(torch.bool,)), + )), OpInfo('new_empty', op=lambda x, *args, **kwargs: x.new_empty(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), From 3082378701605884ff07f7ba7984864340b19b34 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 03:33:55 +0000 Subject: [PATCH 049/453] [vision hash update] update the pinned vision hash (#88853) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88853 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index d8180093d885..48685938a146 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -ffd5a567eb90abf6b5555063da434d3c130d540f +d72e90640ec8514e0369b5419d7f3b74a387b1d7 From 9d09968bbe05fc6d7d7c3d8b1acfbe1b1b1413a8 Mon Sep 17 00:00:00 2001 From: Emil Lynegaard Date: Fri, 11 Nov 2022 03:34:54 +0000 Subject: [PATCH 050/453] Disable check for dropout in MultiheadAttention fast_path (#88831) Since we already enforce eval mode for the fast_path, we do not need to also check for a falsy dropout value, as a model trained with dropout will have a non-zero dropout during eval mode, even though it won't be applied. Fixes #88806 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88831 Approved by: https://github.com/drisspg --- torch/nn/modules/activation.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 5f5615b496d7..7b0e7e3effaa 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -904,7 +904,6 @@ class MultiheadAttention(Module): - inputs are batched (3D) with ``batch_first==True`` - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` - training is disabled (using ``.eval()``) - - dropout is 0 - ``add_bias_kv`` is ``False`` - ``add_zero_attn`` is ``False`` - ``batch_first`` is ``True`` and the input is batched @@ -1088,8 +1087,6 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "self.bias_k was not None" elif self.bias_v is not None: why_not_fast_path = "self.bias_v was not None" - elif self.dropout: - why_not_fast_path = f"dropout was {self.dropout}, required zero" elif self.add_zero_attn: why_not_fast_path = "add_zero_attn was enabled" elif not self._qkv_same_embed_dim: From c4fc5d372f3db37380fe213b5726403cb1330d5d Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 7 Nov 2022 23:46:29 +0000 Subject: [PATCH 051/453] [FSDP][state_dict][1/N] Moving state_dict logic to pre_state_dict_hook (#87900) This is one step toward the ultimate goal: remove the overwritten state_dict in FSDP. All the logic should be either in `pre_state_dict_hook` or `post_state_dict_hook`. Since current `nn.Module` does not support `pre_state_dict_hook`, this PR mimic `pre_state_dict_hook` by calling the pre hook inside post the hook, effectively ditching all the work done by `nn.Module.state_dict`. Once `pre_state_dict_hook` is supported by `nn.Module`, these pre hook calls can be moved out from the post hooks and be registered to `nn.Module.pre_state_dict_hook`. The major issue of this temporary solution is that `post_state_dict_hook` is called from the leaf node to the root node. This makes the `module._lazy_init()` invalid as FSDP assumes `_lazy_init()` to be called from the root. As a result, `FSDP.state_dict` currently contains only one logic -- calling `module._lazy_init()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87900 Approved by: https://github.com/rohan-varma --- test/distributed/fsdp/test_fsdp_state_dict.py | 2 +- torch/distributed/fsdp/_runtime_utils.py | 19 +- torch/distributed/fsdp/_state_dict_utils.py | 388 +++++++++++++----- .../fsdp/fully_sharded_data_parallel.py | 101 +---- 4 files changed, 288 insertions(+), 222 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 133405033730..48dad3118db7 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -447,7 +447,7 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): ) @parametrize("fp16", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) - @parametrize("use_orig_params", [False, True]) + @parametrize("use_orig_params", [True, False]) def test_basic_save_and_load_state_dict( self, state_dict_type: StateDictType, diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 9aee15a016c4..e0986d300a65 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -1113,28 +1113,23 @@ def _get_buffers_and_dtypes_for_computation( @no_type_check -def _get_buffers_and_dtypes_for_checkpoint( +def _get_buffer_dtypes( state: _FSDPState, - root_module: nn.Module, -) -> Tuple[List[torch.Tensor], List[torch.dtype]]: + buffer_names: List[str], +) -> List[torch.dtype]: """ - Returns all buffers in the module tree rooted at ``root_module`` and a - corresponding list of the buffer dtypes for checkpointing. Each buffer - dtype is the original buffer dtype ignoring any buffer mixed precision. + Returns the original buffer types of the given buffer names. """ - p_assert(state._is_root, "Expects the root to cast buffers") - buffers: List[torch.Tensor] = [] - buffer_dtypes: List[Optional[torch.dtype]] = [] - for buffer_name, buffer in root_module.named_buffers(): + buffer_dtypes: List[torch.dtype] = [] + for buffer_name in buffer_names: p_assert( buffer_name in state._buffer_name_to_orig_dtype, f"{buffer_name} is missing from pre-computed dict on rank " f"{state.rank}, which only has keys " f"{state._buffer_name_to_orig_dtype.keys()}", ) - buffers.append(buffer) buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) - return buffers, buffer_dtypes + return buffer_dtypes def _cast_buffers_to_dtype_and_device( diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 0169aa8f10eb..1109f1e88150 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, cast, Dict +from typing import Any, Callable, cast, Dict import torch import torch.distributed as dist @@ -11,15 +11,22 @@ import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn import torch.nn.functional as F + from torch.distributed._shard.sharded_tensor import ( init_from_local_shards, Shard, ShardedTensor, ) -from torch.distributed.fsdp._common_utils import clean_tensor_name +from torch.distributed.fsdp._common_utils import ( + clean_tensor_name, + FSDP_PREFIX, + TrainingState, +) from torch.distributed.fsdp._runtime_utils import ( _cast_buffers_to_dtype_and_device, - _get_buffers_and_dtypes_for_computation, + _clear_grads_if_needed, + _get_buffer_dtypes, + _lazy_init, ) from torch.distributed.utils import _replace_by_prefix @@ -31,49 +38,218 @@ from .flat_param import FlatParamHandle -def _full_post_state_dict_hook( +def _enter_full_param_ctx( + module, + recurse: bool = False, + writeback: bool = False, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +) -> None: + """ + state_dict hooks cannot use the pure context call as the checkpoint flow + requires to enter the context in the pre-hook but leave the context in the + post-hook. This API enters the context of ``summon_full_params``. + """ + assert module._full_param_ctx is None, ( + "Entering the ``summon_full_params`` context but module._full_param_ctx " + "is not None." + ) + assert module.training_state != TrainingState.SUMMON_FULL_PARAMS, ( + "Entering the summon_full_params context but the state is already " + "SUMMON_FULL_PARAMS." + ) + module._full_param_ctx = module._summon_full_params( + recurse=recurse, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + module._full_param_ctx.__enter__() + + +def _exit_full_param_ctx(module) -> None: + """A helper function to exit ``summon_full_params`` context.""" + module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert module._full_param_ctx is not None + module._full_param_ctx.__exit__(None, None, None) + module._full_param_ctx = None + + +def _common_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """Performs the pre-state_dict tasks shared by all state_dict types.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + _lazy_init(module, module) + # TODO: change to this call after pre_state_dict_hook is in `nn.Module`. + # if module.is_root: + # _clear_grads_if_needed(module._fsdp_handles(module)) + if module._has_params: + _clear_grads_if_needed([module._handles[0]]) + + +def _common_summon_pre_state_dict_hook( + module, + offload_to_cpu: bool, + rank0_only: bool, +) -> None: + """ + Performs the pre-state_dict tasks shared by all state_dict types that require + ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + """ + _enter_full_param_ctx( + module, + recurse=False, + writeback=False, + offload_to_cpu=offload_to_cpu, + rank0_only=rank0_only, + ) + + +# TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. +def _common_summon_post_state_dict_hook( module, state_dict: Dict[str, Any], prefix: str, + param_hook: Callable, ) -> Dict[str, Any]: """ - Hook that runs after model.state_dict() is called before returning result to - user. For FSDP, we may have to clone the tensors in state_dict as params go - back to sharded version after _summon_full_params ends, and also remove - the ``FSDP_WRAPPED_MODULE`` prefix. + The post-state_dict flow that shared by all state_dict types that require + ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + hook. """ - _replace_by_prefix(state_dict, prefix + f"{fsdp_file.FSDP_PREFIX}", prefix) - module._assert_state([fsdp_file.TrainingState.SUMMON_FULL_PARAMS]) + _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) + module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) # Return early for trivial cases if not state_dict or not module._has_params: + _exit_full_param_ctx(module) return state_dict - # If a rank has already exited the `summon_full_params()` context here - # (e.g. when `rank0_only=True` and `rank != 0`), then the rank only - # needed to participate in the all-gather and does not need to save the - # state dict. For `use_orig_params=False`, we can check this via - # `FlatParameter` registration. - # TODO: For `use_orig_params=True`, we check for the reshard upon - # exiting `summon_full_params()` via the parameter shape. However, for - # `NO_SHARD`, we cannot tell from the shape, so we do not return early. - if ( - not module._use_orig_params - and fsdp_file.FLAT_PARAM in module.module._parameters - ) or ( - module._use_orig_params - and module._handles - and module._handles[0].uses_sharded_strategy - and module._handles[0].is_sharded(module._handles[0].flat_param) - ): - return state_dict + # TODO: Once pre_state_dict hook is supported, this pop should be removed. + # For `use_orig_params=True`, the `FlatParameter` is not registered, so + # there is no entry in the state dict for it to pop. + if not module._use_orig_params: + state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") - offload_to_cpu = module._state_dict_config.offload_to_cpu - cpu_device = torch.device("cpu") + # If a rank does not have unsharded parameters(when `rank0_only=True` + # and `rank != 0`), then the rank only needed to participate in the + # all-gather and does not need to save the # state dict. We simply check + # rank0_only to ensure this issue. + rank0_only = ( + module._state_dict_type == fsdp_file.StateDictType.FULL_STATE_DICT + and cast(fsdp_file.FullStateDictConfig, module._state_dict_config).rank0_only + ) + # no_fsdp_return means the state_dict returned by this rank should contain + # only non-FSDP controlled parameters and buffers. + no_fsdp_return = rank0_only and module.rank != 0 + if no_fsdp_return and not module._use_orig_params: + for clean_key in module._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + state_dict.pop(f"{prefix}{clean_key}", None) + _exit_full_param_ctx(module) + return state_dict # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. for fqn, param_name, module_name in module._param_fqns: + # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. + param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" + if no_fsdp_return: + state_dict.pop(fqn) + continue + state_dict[fqn] = param + assert fqn in state_dict, ( + f"FSDP assumes {fqn} is in the state_dict but the state_dict only " + f"has {state_dict.keys()}. " + f"prefix={prefix}, module_name={module_name}, " + f"param_name={param_name} rank={module.rank}." + ) + + param_hook(module, state_dict, prefix, fqn) + _exit_full_param_ctx(module) + + cpu_device = torch.device("cpu") + buffer_clean_fqns = [] + buffers = [] + for clean_key in module._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_tensor_name(clean_key) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if no_fsdp_return: + state_dict.pop(fqn) + else: + buffer = state_dict[fqn] + if module._state_dict_config.offload_to_cpu and buffer.device != cpu_device: + state_dict[fqn] = buffer.to(cpu_device) + # TODO: for composable FSDP, this should be clean_tensor_name(clean_key), + buffer_clean_fqns.append(clean_key) + buffers.append(state_dict[fqn]) + if buffers and module._mixed_precision_enabled_for_buffers(): + buffer_dtypes = _get_buffer_dtypes(module, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, module.compute_device) + for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): + fqn = f"{prefix}{clean_fqn}" + state_dict[fqn] = buffer.clone() + return state_dict + + +def _full_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. pre-state_dict hook is + not actually supported by ``nn.Module``. As a result, this API is called + from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict + is supported in ``nn.Module``, this hook will be registered as a hook in + ``nn.Module``. + + TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported + in ``nn.Module``. + """ + _common_pre_state_dict_hook(module, state_dict, prefix) + _common_summon_pre_state_dict_hook( + module, + offload_to_cpu=module._state_dict_config.offload_to_cpu, + rank0_only=cast( + fsdp_file.FullStateDictConfig, module._state_dict_config + ).rank0_only, + ) + + +def _full_post_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> Dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _summon_full_params ends, and also remove + the ``FSDP_WRAPPED_MODULE`` prefix. + """ + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _full_pre_state_dict_hook(module, state_dict, prefix) + + def param_hook( + module, + state_dict: Dict[str, Any], + prefix: str, + fqn: str, + ) -> None: clean_key = fqn clean_prefix = clean_tensor_name(prefix) # Strip prefix out of key if needed as buffer names and param names @@ -84,11 +260,6 @@ def _full_post_state_dict_hook( # Clone non-ignored parameters before exiting the # `_summon_full_params()` context - assert fqn in state_dict, ( - f"FSDP assumes {fqn} is in the state_dict but the state_dict " - f"only has {state_dict.keys()}. prefix={prefix}, " - f"module_name={module_name} param_name={param_name} rank={module.rank}." - ) if clean_key not in module._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): @@ -104,24 +275,7 @@ def _full_post_state_dict_hook( f"implementation of {fqn}. Error: {str(e)}" ) - # Offload the buffer to CPU if needed -- we do not do this in - # `_summon_full_params()` since without care, that would free - # the original buffer's GPU memory and require reallocating - # that memory later; this only affects the state dict's buffer - # variable and leaves the original buffer's GPU memory intact - if offload_to_cpu: - for clean_key in module._buffer_names: - # This is a hack to support activation checkpoint. - clean_key = clean_key.replace( - f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" - ) - fqn = f"{prefix}{clean_key}" - if fqn not in state_dict: - # A buffer can be registered as non-persistent. - continue - if state_dict[fqn].device != cpu_device: - state_dict[fqn] = state_dict[fqn].to(cpu_device) - return state_dict + return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) def _full_pre_load_state_dict_hook( @@ -129,21 +283,30 @@ def _full_pre_load_state_dict_hook( state_dict: Dict[str, Any], prefix: str, ) -> None: - # We do not expect to be calling pre-hooks twice without post-hook - # call in between. - assert getattr(module, "_full_param_ctx", None) is None - # Note that it needs writeback=True to persist. - module._full_param_ctx = module._summon_full_params(recurse=False, writeback=True) - module._full_param_ctx.__enter__() - _replace_by_prefix(state_dict, prefix, prefix + f"{fsdp_file.FSDP_PREFIX}") + _enter_full_param_ctx(module, recurse=False, writeback=True) + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None: - # We should exit summon_full_params context. - module._assert_state([fsdp_file.TrainingState.SUMMON_FULL_PARAMS]) - assert getattr(module, "_full_param_ctx", None) is not None - module._full_param_ctx.__exit__(None, None, None) - module._full_param_ctx = None + _exit_full_param_ctx(module) + + +def _local_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. Right now, pre-state_dict + hook is not supported by the PyTorch core. So this API is called from + `_local_post_state_dict_hook()` to simulate the case. + """ + if module._has_params and not module._handles[0].uses_sharded_strategy: + raise RuntimeError( + "``local_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, state_dict, prefix) def _local_post_state_dict_hook( @@ -156,7 +319,10 @@ def _local_post_state_dict_hook( the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy will happen. The underlying storage is the same. """ - _replace_by_prefix(state_dict, f"{prefix}{fsdp_file.FSDP_PREFIX}", prefix) + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _local_pre_state_dict_hook(module, state_dict, prefix) + + _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) if not module._has_params: return state_dict @@ -198,8 +364,8 @@ def _local_pre_load_state_dict_hook( state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """ - _replace_by_prefix(state_dict, prefix, f"{prefix}{fsdp_file.FSDP_PREFIX}") - fqn = f"{prefix}{fsdp_file.FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") + fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" if fqn not in state_dict: assert not module._has_params, ( "No `FlatParameter` in `state_dict` for this FSDP instance " @@ -229,6 +395,30 @@ def _local_pre_load_state_dict_hook( state_dict[fqn] = load_tensor +def _sharded_pre_state_dict_hook( + module, + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """ + Hook that runs before model.state_dict() is called. Check + ``_full_pre_load_state_dict_hook`` for the detail. + """ + if module._has_params and not module._handles[0].uses_sharded_strategy: + raise RuntimeError( + "``sharded_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, state_dict, prefix) + # Setting offload_to_cpu here does not work even if offload_to_cpu is True. + # We have to create ShardedTensor first then move it to CPU. + _common_summon_pre_state_dict_hook( + module, + offload_to_cpu=False, + rank0_only=False, + ) + + def _sharded_post_state_dict_hook( module, state_dict: Dict[str, Any], @@ -238,33 +428,24 @@ def _sharded_post_state_dict_hook( The hook replaces the unflattened, unsharded parameter in the state_dict with a unflattened, sharded parameter (a ShardedTensor). """ - _replace_by_prefix(state_dict, f"{prefix}{fsdp_file.FSDP_PREFIX}", prefix) - if not module._has_params: - return state_dict - assert module.training_state != fsdp_file.TrainingState.SUMMON_FULL_PARAMS, ( - "Inside _sharded_post_state_dict_hook, the training_state must " - "not be SUMMON_FULL_PARAMS." - ) - with module._summon_full_params(recurse=False, writeback=False): - for fqn, _, _ in module._param_fqns: - # Create a ShardedTensor for the unflattened, non-sharded parameter. - param = functools.reduce(getattr, fqn.split("."), module.module) - sharded_tensor = _ext_chunk_tensor( - tensor=param, - rank=module.rank, - world_size=module.world_size, - num_devices_per_node=torch.cuda.device_count(), - pg=module.process_group, - ) - if module._state_dict_config.offload_to_cpu: - sharded_tensor = sharded_tensor.cpu() - state_dict[f"{prefix}{fqn}"] = sharded_tensor - # For `use_orig_params=True`, the `FlatParameter` is not registered, so - # there is no entry in the state dict for it to pop. - if not module._use_orig_params: - state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") - return state_dict + # TODO: remove the hack. See ``_full_pre_state_dict_hook``. + _sharded_pre_state_dict_hook(module, state_dict, prefix) + + def param_hook(module, state_dict: Dict[str, Any], prefix: str, fqn: str): + param = state_dict[fqn] + sharded_tensor = _ext_chunk_tensor( + tensor=param, + rank=module.rank, + world_size=module.world_size, + num_devices_per_node=torch.cuda.device_count(), + pg=module.process_group, + ) + if module._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[fqn] = sharded_tensor + + return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) def _sharded_post_load_state_dict_hook(module, *args, **kwargs) -> None: @@ -281,7 +462,7 @@ def _sharded_pre_load_state_dict_hook( The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """ - _replace_by_prefix(state_dict, prefix, prefix + f"{fsdp_file.FSDP_PREFIX}") + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") if not module._has_params: return @@ -295,7 +476,7 @@ def _sharded_pre_load_state_dict_hook( shared_fqns = [fqn for fqn, _, _ in module._shared_param_fqns] loaded_shapes = [] for fqn, _, _ in module._param_fqns: - full_fqn = f"{prefix}{fsdp_file.FSDP_PREFIX}{fqn}" + full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: continue @@ -353,9 +534,7 @@ def _sharded_pre_load_state_dict_hook( f"The loaded local chunk has different padding({num_to_pad}) " f"from the local chunk {flat_param._shard_numel_padded}." ) - state_dict[ - f"{prefix}{fsdp_file.FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" - ] = loaded_flat_tensor + state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor if module._use_orig_params: module._deregister_orig_params() @@ -381,17 +560,6 @@ def _post_state_dict_hook( processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( fsdp_module, state_dict, prefix ) - # Restore buffers, which currently are in their full precision type, - # back to their mixed precision type. This is because buffers are cast - # during lazy_init() and stay at their mixed precision type before/after - # forward/backward. As a result state_dict() should maintain this. - if fsdp_module._is_root and fsdp_module._mixed_precision_enabled_for_buffers(): - buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation( - fsdp_module, fsdp_module - ) - _cast_buffers_to_dtype_and_device( - buffers, buffer_dtypes, fsdp_module.compute_device - ) return processed_state_dict diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 6f5537aad520..9934e7189342 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -50,9 +50,7 @@ _init_state_dict_state, ) from torch.distributed.fsdp._runtime_utils import ( - _cast_buffers_to_dtype_and_device, _clear_grads_if_needed, - _get_buffers_and_dtypes_for_checkpoint, _lazy_init, _post_forward, _post_forward_reshard, @@ -512,6 +510,7 @@ def __init__( _pre_load_state_dict_hook, with_module=True ) self.register_load_state_dict_post_hook(_post_load_state_dict_hook) + self._full_param_ctx: Optional[Generator] = None @property def module(self) -> nn.Module: @@ -813,104 +812,8 @@ def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]: yield fqn, param_name, module_name def state_dict(self, *args, **kwargs): - """ - This is the entry point of all three FSDP ``state_dict`` APIs: full, - local, and sharded. For the full state dict - (``StateDictType.FULL_STATE_DICT``), FSDP attempts to unshard the model - on all ranks, which may result in an OOM error if the full model cannot - fit on a single GPU. In that case, users may pass in a - :class:`FullStateDictConfig` to only save the checkpoint on rank 0 and/ - or to offload it to CPU memory layer by layer, enabling much larger - checkpoints. If the full model cannot fit in CPU memory, then users may - instead take a local state dict (``StateDictType.LOCAL_STATE_DICT``) - that only saves the local shard of the model. The sharded state dict - (``StateDictType.SHARDED_STATE_DICT``) saves the model parameters as - ``ShardedTensor`` s. The ``state_dict`` type can be configured using - the :meth:`state_dict_type` context manager. - - Example:: - - >>> # xdoctest: +SKIP("undefined variables") - >>> import torch - >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - >>> from torch.distributed.fsdp import StateDictType - >>> torch.cuda.set_device(device_id) - >>> my_module = nn.Linear(...) - >>> sharded_module = FSDP(my_module) - >>> full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - >>> with FSDP.state_dict_type(sharded_module, StateDictType.FULL_STATE_DICT, full_state_dict_config): - >>> full_dict = sharded_module.state_dict() - >>> full_dict.keys() - >>> odict_keys(['weight', 'bias']) - >>> # using local state dict - >>> with FSDP.state_dict_type(sharded_module, StateDictType.LOCAL_STATE_DICT): - >>> local_dict = sharded_module.state_dict() - >>> local_dict.keys() - >>> odict_keys(['flat_param', 'inner.flat_param']) - - .. warning:: This needs to be called on all ranks since it uses - collective communications. - """ - # TODO (rohan-varma): separate these out once a state_dict pre-hook - # is available. - if torch.cuda.is_available(): - torch.cuda.synchronize() _lazy_init(self, self) - if self._is_root: - _clear_grads_if_needed(self._fsdp_handles(self)) - if self._state_dict_type == StateDictType.FULL_STATE_DICT: - # Get config args - full_state_dict_config = ( - self._state_dict_config - if self._state_dict_config is not None - else FullStateDictConfig() - ) - rank0_only = full_state_dict_config.rank0_only - offload_to_cpu = full_state_dict_config.offload_to_cpu - summon_ctx = ( - self._summon_full_params( - recurse=False, - writeback=False, - offload_to_cpu=offload_to_cpu, - rank0_only=rank0_only, - ) - if self.training_state != TrainingState.SUMMON_FULL_PARAMS - else contextlib.suppress() - ) - with summon_ctx: - # Since buffers stay in their low precision throughout runtime, - # we must explicitly restore them to their original dtypes for - # model checkpointing. We have the root module cast for all - # submodules. - # TODO: Investigate if this can and should be refactored into - # `summon_full_params()`. - if self._is_root and self._mixed_precision_enabled_for_buffers(): - buffers, buffer_dtypes = _get_buffers_and_dtypes_for_checkpoint( - self, self - ) - _cast_buffers_to_dtype_and_device( - buffers, buffer_dtypes, self.compute_device - ) - state_dict = super().state_dict(*args, **kwargs) - - # TODO: support offload to CPU in post state dict hook. - if not rank0_only or self.rank == 0: - return state_dict - else: - return {} - - elif ( - self._state_dict_type == StateDictType.LOCAL_STATE_DICT - or self._state_dict_type == StateDictType.SHARDED_STATE_DICT - ): - if self._has_params and not self._handles[0].uses_sharded_strategy: - raise RuntimeError( - "sharded_state_dict/local_state_dict can only be called " - "when parameters are flatten and sharded." - ) - return super().state_dict(*args, **kwargs) - else: - raise ValueError(f"Unknown StateDictType {self._state_dict_type}.") + return super().state_dict(*args, **kwargs) def forward(self, *args: Any, **kwargs: Any) -> Any: """ From 86b7aa26f0bb8878d925a625af45d16d4bb2f2af Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Fri, 11 Nov 2022 03:49:27 +0000 Subject: [PATCH 052/453] Fix FakeTensorProp on Module with Parameters or Buffers (#88700) In `FakeTensorMode.__torch_dispatch__`, the output is now always computed by meta kernels in ```python try: with in_kernel_invocation_manager(self): r = func(*args, **kwargs) # <----- "r" can be a real tensor. except NotImplementedError as not_implemented_error: # no meta kernel registered, fallback to kernel for the device if not self.allow_fallback_kernels: raise not_implemented_error return run_fallback_kernel(self, func, args, kwargs, not_implemented_error) return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs) ``` For example, I observed a CPU tensor is generated when executing `aten.addmm` when running `FakeTensorProp`. Therefore, I'd like to allow `FakeTensorMode` to wrap real tensor as `FakeTensor` during the computation. Does this PR look a good direction to fix this problem? If yes, I can go ahead and add some tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88700 Approved by: https://github.com/eellison, https://github.com/ezyang --- test/test_fake_tensor.py | 59 +++++++++++++++++++++++++++++ torch/fx/passes/fake_tensor_prop.py | 12 +++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index ad9042196bff..3d47cc8ea0e5 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -2,6 +2,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfCrossRef, skipIfRocm import torch +import torch._dynamo import itertools import numpy as np from torch.testing._internal.jit_utils import RUN_CUDA @@ -11,6 +12,7 @@ FakeTensorConverter, DynamicOutputShapeException, ) +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.testing import FileCheck from torch import nn import unittest @@ -663,5 +665,62 @@ def test_like_ops(self): op = self.get_aten_op(schema) self.assertIn(op, torch._subclasses.fake_tensor._like_tensor_constructors) +class FakeTensorPropTest(TestCase): + def test_fake_tensor_prop_on_nn_module(self): + class ToyNnModuleWithParameters(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(4, 3) + self.layer2 = torch.nn.Linear(3, 2) + + def forward(self, value): + value = self.layer1(value) + value = torch.relu(value) + value = self.layer2(value) + return value + + model = ToyNnModuleWithParameters() + value = torch.randn(5, 4) + # Convert nn.Module to GraphModule so that FakeTensorProp runs. + graph_model = torch.fx.symbolic_trace(model, (value,)) + # The following block runs FakeTensorProp on graph_module w/to the same FakeTensorMode + # + # TODO(wschin): there should be an API to run FakeTensorProp for GraphModule + # with parameters and buffers. + with FakeTensorMode() as fake_tensor_mode: + + def to_fake_tensor(x): + if isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor): + return fake_tensor_mode.from_tensor(x) + return x + + fake_parameters_and_buffers = { + k: to_fake_tensor(v) + for k, v in itertools.chain( + graph_model.named_parameters(), graph_model.named_buffers() + ) + } + with torch.nn.utils.stateless._reparametrize_module( + graph_model, fake_parameters_and_buffers + ): + # This case uses the **same** fake tensor mode to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The result should be correct. + result = FakeTensorProp(graph_model, fake_tensor_mode).propagate(value) + self.assertTrue(isinstance(result, FakeTensor)) + self.assertEqual(result.shape, (5, 2)) + # This case uses the **different** fake tensor modes to + # 1. create fake parameters and fake buffers, and + # 2. run FakeTensorProp + # The following code should fail. + failed = False + try: + FakeTensorProp(graph_model).propagate(value) + except AssertionError: + # AssertionError: tensor's device must be `meta`, got cpu instead + failed = True + self.assertTrue(failed) + if __name__ == "__main__": run_tests() diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index b034b5341b06..403db5b9a009 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch.fx from torch.fx import Node from torch.fx._compatibility import compatibility @@ -17,7 +19,13 @@ class FakeTensorProp(torch.fx.Interpreter): Args: module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. """ + def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode def run_node(self, n: Node): result = super().run_node(n) @@ -25,6 +33,6 @@ def run_node(self, n: Node): return result def propagate(self, *args): - with FakeTensorMode.push() as mode: - fake_args = [mode.from_tensor(a) for a in args] + with self._mode: + fake_args = [self._mode.from_tensor(a) for a in args] return super().run(*fake_args) From 310335de48ab9d8bcd33b98f3f71ef88ae4bd45c Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Fri, 11 Nov 2022 04:02:44 +0000 Subject: [PATCH 053/453] Update lr_scheduler.pyi to match lr_scheduler.py (#88818) Following #88503, we should also update the pyi file Pull Request resolved: https://github.com/pytorch/pytorch/pull/88818 Approved by: https://github.com/soulitzer --- torch/optim/lr_scheduler.pyi | 37 +++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 97603e064a70..00d9eb512ae1 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -1,7 +1,7 @@ from typing import Iterable, Any, Optional, Callable, Union, List from .optimizer import Optimizer -class _LRScheduler: +class LRScheduler: optimizer: Optimizer = ... base_lrs: List[float] = ... last_epoch: int = ... @@ -14,46 +14,49 @@ class _LRScheduler: def step(self, epoch: Optional[int] = ...) -> None: ... def print_lr(self, is_verbose: bool, group: dict, lr: float, epoch: Optional[int] = ...) -> None: ... -class LambdaLR(_LRScheduler): +class _LRScheduler(LRScheduler): + ... + +class LambdaLR(LRScheduler): lr_lambdas: List[Callable[[int], float]] = ... def __init__(self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = ..., verbose: bool = ...) -> None: ... -class MultiplicativeLR(_LRScheduler): +class MultiplicativeLR(LRScheduler): lr_lambdas: List[Callable[[int], float]] = ... def __init__(self, optimizer: Optimizer, lr_lambda: Union[Callable[[int], float], List[Callable[[int], float]]], last_epoch: int = ..., verbose: bool = ...) -> None: ... -class StepLR(_LRScheduler): +class StepLR(LRScheduler): step_size: int = ... gamma: float = ... def __init__(self, optimizer: Optimizer, step_size: int, gamma: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class MultiStepLR(_LRScheduler): +class MultiStepLR(LRScheduler): milestones: Iterable[int] = ... gamma: float = ... def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class ConstantLR(_LRScheduler): +class ConstantLR(LRScheduler): factor: float = ... total_iters: int = ... def __init__(self, optimizer: Optimizer, factor: float=..., total_iters: int=..., last_epoch: int=..., verbose: bool = ...) -> None: ... -class LinearLR(_LRScheduler): +class LinearLR(LRScheduler): start_factor: float = ... end_factor: float = ... total_iters: int = ... def __init__(self, optimizer: Optimizer, start_factor: float=..., end_factor: float= ..., total_iters: int= ..., last_epoch: int= ..., verbose: bool = ...) -> None: ... -class ExponentialLR(_LRScheduler): +class ExponentialLR(LRScheduler): gamma: float = ... def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int = ..., verbose: bool = ...) -> None: ... -class ChainedScheduler(_LRScheduler): - def __init__(self, schedulers: List[_LRScheduler]) -> None: ... +class ChainedScheduler(LRScheduler): + def __init__(self, schedulers: List[LRScheduler]) -> None: ... -class SequentialLR(_LRScheduler): - def __init__(self, optimizer: Optimizer, schedulers: List[_LRScheduler], milestones: List[int], last_epoch: int=..., verbose: bool=...) -> None: ... +class SequentialLR(LRScheduler): + def __init__(self, optimizer: Optimizer, schedulers: List[LRScheduler], milestones: List[int], last_epoch: int=..., verbose: bool=...) -> None: ... -class CosineAnnealingLR(_LRScheduler): +class CosineAnnealingLR(LRScheduler): T_max: int = ... eta_min: float = ... def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... @@ -82,7 +85,7 @@ class ReduceLROnPlateau: def state_dict(self) -> dict: ... def load_state_dict(self, state_dict: dict) -> None: ... -class CyclicLR(_LRScheduler): +class CyclicLR(LRScheduler): max_lrs: List[float] = ... total_size: float = ... step_ratio: float = ... @@ -95,7 +98,7 @@ class CyclicLR(_LRScheduler): def __init__(self, optimizer: Optimizer, base_lr: Union[float, List[float]], max_lr: Union[float, List[float]], step_size_up: int = ..., step_size_down: Optional[int] = ..., mode: str = ..., gamma: float = ..., scale_fn: Optional[Callable[[float], float]] = ..., scale_mode: str = ..., cycle_momentum: bool = ..., base_momentum: float = ..., max_momentum: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... def scale_fn(self, x: Any) -> float: ... -class CosineAnnealingWarmRestarts(_LRScheduler): +class CosineAnnealingWarmRestarts(LRScheduler): T_0: int = ... T_i: int = ... T_mult: Optional[int] = ... @@ -104,14 +107,14 @@ class CosineAnnealingWarmRestarts(_LRScheduler): def __init__(self, optimizer: Optimizer, T_0: int, T_mult: int = ..., eta_min: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... def step(self, epoch: Optional[Any] = ...): ... -class OneCycleLR(_LRScheduler): +class OneCycleLR(LRScheduler): total_steps: int = ... anneal_func: Callable[[float, float, float], float] = ... cycle_momentum: bool = ... use_beta1: bool = ... def __init__(self, optimizer: Optimizer, max_lr: Union[float, List[float]], total_steps: int = ..., epochs: int = ..., steps_per_epoch: int = ..., pct_start: float = ..., anneal_strategy: str = ..., cycle_momentum: bool = ..., base_momentum: Union[float, List[float]] = ..., max_momentum: Union[float, List[float]] = ..., div_factor: float = ..., final_div_factor: float = ..., three_phase: bool = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... -class PolynomialLR(_LRScheduler): +class PolynomialLR(LRScheduler): total_iters: int = ... power: float = ... def __init__(self, optimizer: Optimizer, total_iters: int = ..., power: float = ..., last_epoch: int = ..., verbose: bool = ...) -> None: ... From 0de8f047c1cc950c59b0448b9b78dafc0202c43f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 04:19:08 +0000 Subject: [PATCH 054/453] Revert "[dynamo] fixes dict changed during runtime error (#87526)" This reverts commit cf04b36ce8f531730210b03eaa347977a1c2d75c. Reverted https://github.com/pytorch/pytorch/pull/87526 on behalf of https://github.com/anijain2305 due to error reported --- test/dynamo/test_aot_cudagraphs.py | 3 +++ torch/_dynamo/convert_frame.py | 15 +++------------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index fdb7c88762b8..cb1d2a0e601f 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -71,6 +71,7 @@ def fn(x, y): y = torch.randn(3, device="cuda") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_dtoh(self): def model(x, y): @@ -104,6 +105,7 @@ def fn(x, y): y = torch.randn((), device="cpu") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch("functorch._src.config.use_functionalize", True) @patch_all(ok=False) # input mutation not supported yet def test_mutate_input(self): @@ -143,6 +145,7 @@ def fn(x, y): y = torch.randn(1, device="cuda") fn(x, y) + @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_factory(self): def model(y): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index db9b23f2da7e..f1ce83727a19 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -156,11 +156,7 @@ def has_tensor(obj): seen_ids[obj_id] = any([has_tensor(v) for v in obj]) return seen_ids[obj_id] elif istype(obj, dict): - # Some packages like pytest can be updated during runtime. So, make a - # copy of values to avoid issues like "RuntimeError: dictionary - # changed size during iteration" - values = list(obj.values()) - seen_ids[obj_id] = any([has_tensor(v) for v in values]) + seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) return seen_ids[obj_id] elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False @@ -168,13 +164,8 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif ( - not is_allowed(obj) - and not hasattr(obj, "__get__") # overridden get can mutate the object - and hasattr(obj, "__dict__") - and istype(obj.__dict__, dict) - ): - seen_ids[obj_id] = has_tensor(obj.__dict__) + elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): + seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) return seen_ids[obj_id] else: # if config.debug: From a6d72f44a4e8b6e9d2e878f30fd8b1d3e1197f0e Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 9 Nov 2022 17:27:22 +0000 Subject: [PATCH 055/453] [ONNX] Add onnx::Max into standard Op for scalar type alignment (#88750) Easy fix for onnx::Max ScalarType Pull Request resolved: https://github.com/pytorch/pytorch/pull/88750 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- aten/src/ATen/core/interned_strings.h | 1 + torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 80919e52b58f..2abc6217516d 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -239,6 +239,7 @@ namespace c10 { _(onnx, LSTM) \ _(onnx, MatMul) \ _(onnx, Min) \ + _(onnx, Max) \ _(onnx, Mul) \ _(onnx, Pow) \ _(onnx, RNN) \ diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 657c27f70c7d..3af0360b7e01 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -48,6 +48,7 @@ static const std::unordered_set standardOps = { onnx::Div, onnx::Gemm, onnx::Min, + onnx::Max, onnx::Mod, onnx::Mul, onnx::Pow, From 396c3b1d88d7624938a2bb0b287f2a19f1e89bb4 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 11 Nov 2022 05:23:48 +0000 Subject: [PATCH 056/453] Use `atomicAdd` for `bfloat16` in Ampere and above (#84981) WIP to fix extremely slow `scatter_add` issue vs. fp16. The current changes seem to improve performance, but it still appears to lag behind the fp16 equivalent. CC @ngimel @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/84981 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/Atomic.cuh | 17 ++++++-- aten/src/ATen/native/cuda/KernelUtils.cuh | 48 ++++++++++++++++++++++- 2 files changed, 60 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/cuda/Atomic.cuh b/aten/src/ATen/cuda/Atomic.cuh index 42975411e841..3d60b672e972 100644 --- a/aten/src/ATen/cuda/Atomic.cuh +++ b/aten/src/ATen/cuda/Atomic.cuh @@ -6,6 +6,10 @@ #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + template struct AtomicFPOp; @@ -219,10 +223,15 @@ static inline __device__ at::Half gpuAtomicAdd(at::Half *address, at::Half val) } static inline __device__ at::BFloat16 gpuAtomicAdd(at::BFloat16 *address, at::BFloat16 val) { - return AtomicFPOp()(address, val, - [](at::BFloat16 bsum, at::BFloat16 val) { - return bsum + val; - }); +#if defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) +return AtomicFPOp()(address, val, + [](at::BFloat16 bsum, at::BFloat16 val) { + return bsum + val; + }); +#else + __nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val)); + return *reinterpret_cast(&r); +#endif } #if defined(CUDA_VERSION) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000) diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 1e36e2db74d5..d2e956d1a3e4 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -1,6 +1,10 @@ #pragma once #include +#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))) +#include +#endif + namespace at { namespace native { @@ -66,7 +70,49 @@ __device__ __forceinline__ void fastSpecializedAtomicAdd( template < typename scalar_t, typename index_t, - typename std::enable_if::value>::type* = + typename std::enable_if::value>::type* = + nullptr> +__device__ __forceinline__ void fastSpecializedAtomicAdd( + scalar_t* tensor, + index_t index, + const index_t numel, + scalar_t value) { +#if ( \ + (defined(USE_ROCM)) || \ + (defined(CUDA_VERSION) && (CUDA_VERSION < 11000)) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))) + gpuAtomicAddNoReturn( + reinterpret_cast(tensor) + index, + static_cast(value)); +#else + // Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned) + __nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index); + bool low_byte = (reinterpret_cast(target_addr) % sizeof(__nv_bfloat162) == 0); + + if (low_byte && index < (numel - 1)) { + __nv_bfloat162 value2; + value2.x = *reinterpret_cast<__nv_bfloat16*>(&value); + value2.y = __int2bfloat16_rz(0); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2); + + } else if (!low_byte && index > 0) { + __nv_bfloat162 value2; + value2.x = __int2bfloat16_rz(0); + value2.y = *reinterpret_cast<__nv_bfloat16*>(&value); + atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2); + + } else { + atomicAdd( + reinterpret_cast<__nv_bfloat16*>(tensor) + index, *reinterpret_cast<__nv_bfloat16*>(&value)); + } +#endif +} + + +template < + typename scalar_t, + typename index_t, + typename std::enable_if::value && !std::is_same::value >::type* = nullptr> __device__ __forceinline__ void fastSpecializedAtomicAdd( scalar_t* tensor, From b843f4db0a26aae6536e6b971f73bcc5af21c90a Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 9 Nov 2022 17:41:10 +0000 Subject: [PATCH 057/453] [ONNX] Add test case for onnx::Max scalar type (#88751) Referenced by minimum cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/88751 Approved by: https://github.com/wschin, https://github.com/BowenBao --- test/onnx/test_pytorch_onnx_onnxruntime.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 1e36163d0394..e4fc3f83b288 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -8728,6 +8728,28 @@ def forward(self, x, y): y = torch.full_like(x, True) self.run_test(MinimumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(12) + def test_maximum_dtypes(self): + class MaximumModel(torch.nn.Module): + def forward(self, x, y): + return torch.maximum(x, y) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randn((5, 5), dtype=torch.float) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randn((5, 5), dtype=torch.float16) + y = torch.randint(10, (5, 5), dtype=torch.int16) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int16) + y = torch.randint(10, (5, 5), dtype=torch.int32) + self.run_test(MaximumModel(), (x, y)) + + x = torch.randint(10, (5, 5), dtype=torch.int) + y = torch.full_like(x, True) + self.run_test(MaximumModel(), (x, y)) + @skipIfUnsupportedMinOpsetVersion(9) def test_any(self): class M(torch.nn.Module): From d15a6b0c975b9e1e90ed4e951071e5269c10ac5b Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 11 Nov 2022 08:51:26 +0000 Subject: [PATCH 058/453] Error on ZeroTensor serialization (#88803) Follow-up : https://github.com/pytorch/pytorch/pull/88182#issuecomment-1308628415 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88803 Approved by: https://github.com/anjali411 --- test/cpp/api/serialize.cpp | 8 ++++++++ test/test_serialization.py | 22 ++++++++++++++++++++++ torch/csrc/jit/serialization/pickler.h | 6 ++++++ 3 files changed, 36 insertions(+) diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 05bb0f941d40..20d572853d3a 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -288,6 +288,14 @@ TEST(SerializeTest, MathBits) { ASSERT_EQ(actual.sizes().vec(), expected.sizes().vec()); ASSERT_TRUE(actual.allclose(expected)); } + + { + // We don't support serializing `ZeroTensor` as it is not public facing yet. + // If in future, `ZeroTensor` serialization is supported, this test should + // start failing! + auto t = torch::_efficientzerotensor({5, 5}); + ASSERT_THROWS_WITH(save_and_load(t), "ZeroTensor is not serializable,"); + } } TEST(SerializeTest, BasicToFile) { diff --git a/test/test_serialization.py b/test/test_serialization.py index af0317e87a14..779d6fb5c20c 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -931,6 +931,28 @@ def _save_load_check(t): t_n_c = torch._neg_view(torch.conj(t)) _save_load_check(t_n_c) + @parametrize('weights_only', (False, True)) + def test_serialization_efficient_zerotensor(self, weights_only): + # We don't support serializing `ZeroTensor` as it is not public + # facing yet. + # If in future, `ZeroTensor` serialization is supported, this test + # should start failing! + t = torch._efficientzerotensor((4, 5)) + + def _save_load_check(t): + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + # Unsafe load should work + self.assertEqual(torch.load(f, weights_only=weights_only), t) + + # NOTE: `torch.save` fails before we hit the TORCH_CHECK in `getTensoMetadata` + # as nullptr storage is disabled. + err_msg = (r'python bindings to nullptr storage \(e.g., from torch.Tensor._make_wrapper_subclass\)' + ' are currently unsafe and thus disabled') + with self.assertRaisesRegex(RuntimeError, err_msg): + _save_load_check(t) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super(TestSerialization, self).run(*args, **kwargs) diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index c289cae12b64..26f9fcf42396 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -300,6 +300,12 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls); // For now, it only takes care of `conj` and `neg` bit. inline std::unordered_map getTensorMetadata( const at::Tensor& t) { + // We don't support serializing `ZeroTensor` as it is not public + // facing yet. + TORCH_CHECK( + !t._is_zerotensor(), + "ZeroTensor is not serializable,", + " please file an issue if required."); std::unordered_map metadata{}; // Only add meta-data if the value is not default. From ee91c328da5739ce03b3127cd7c542ce505212b8 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Fri, 11 Nov 2022 12:19:31 +0000 Subject: [PATCH 059/453] Fix cuda/cpu check on NoneType (#88854) Summary: Fix cuda/cpu check on NoneType Test Plan: sabdcastle/ github CI/CD Differential Revision: D41203955 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88854 Approved by: https://github.com/drisspg, https://github.com/ngimel --- torch/nn/modules/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 7b0e7e3effaa..e6b3b778e5fb 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1111,7 +1111,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): + elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " From 324ac93a43a93f671bb34b835926b22d13442735 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 8 Nov 2022 00:16:14 +0000 Subject: [PATCH 060/453] [FSDP][state_dict][2/N] Move state_dict related enums/dataclasses/states to state_dict_utils.py, api.py and init_state_dict() (#88481) **Motivation**: Several Enums, Dataclasses and states defined in fully_sharded_data_paralle.py should be moved to a place where the composable FSDP can access. This PR does the move. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88481 Approved by: https://github.com/rohan-varma, https://github.com/awgu --- torch/distributed/fsdp/_init_utils.py | 12 +- torch/distributed/fsdp/_state_dict_utils.py | 72 +++++++--- torch/distributed/fsdp/api.py | 96 ++++++++++++- .../fsdp/fully_sharded_data_parallel.py | 127 +----------------- 4 files changed, 164 insertions(+), 143 deletions(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index c89f65c3a5b8..966e61f7fe12 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -3,6 +3,7 @@ from typing import ( Callable, Dict, + Generator, Iterable, Iterator, List, @@ -33,8 +34,11 @@ from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, + FullStateDictConfig, MixedPrecision, ShardingStrategy, + StateDictConfig, + StateDictType, ) from torch.distributed.fsdp.flat_param import ( _HandlesKey, @@ -206,7 +210,13 @@ def _init_prefetching_state( def _init_state_dict_state(state: _FSDPState) -> _FSDPState: - # TODO: after rebase + state._state_dict_type = StateDictType.FULL_STATE_DICT + state_dict_config: StateDictConfig = FullStateDictConfig() + state._state_dict_config = state_dict_config + full_param_ctx: Optional[Generator] = None + # TODO: For composable API, this should be a dict that maps from a module to + # handles. + state._full_param_ctx = full_param_ctx return state diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 1109f1e88150..c90bd4d409b1 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, Callable, cast, Dict +from typing import Any, Callable, cast, Dict, Iterator, Tuple import torch import torch.distributed as dist @@ -20,6 +20,7 @@ from torch.distributed.fsdp._common_utils import ( clean_tensor_name, FSDP_PREFIX, + FSDP_WRAPPED_MODULE, TrainingState, ) from torch.distributed.fsdp._runtime_utils import ( @@ -28,6 +29,7 @@ _get_buffer_dtypes, _lazy_init, ) +from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType from torch.distributed.utils import _replace_by_prefix from ._fsdp_extensions import ( @@ -38,6 +40,33 @@ from .flat_param import FlatParamHandle +def _convert_to_wrapped_module_name(module_name: str) -> str: + module_name = module_name.replace(f"{FSDP_PREFIX}", "") + module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # Activation checkpoint adds a prefix that has to be + # removed as well. + module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") + return module_name + + +def _param_fqns(module) -> Iterator[Tuple[str, str, str]]: + if not module._has_params: + return + for param_name, module_name in module._handles[0].parameter_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +def _shared_param_fqns(module) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in module._handles[0].shared_parameter_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + def _enter_full_param_ctx( module, recurse: bool = False, @@ -71,7 +100,10 @@ def _enter_full_param_ctx( def _exit_full_param_ctx(module) -> None: """A helper function to exit ``summon_full_params`` context.""" - module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert module.training_state == TrainingState.SUMMON_FULL_PARAMS, ( + "Exiting the summon_full_params context but the state is not " + "SUMMON_FULL_PARAMS." + ) assert module._full_param_ctx is not None module._full_param_ctx.__exit__(None, None, None) module._full_param_ctx = None @@ -124,7 +156,9 @@ def _common_summon_post_state_dict_hook( hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) - module._assert_state([TrainingState.SUMMON_FULL_PARAMS]) + assert ( + module.training_state == TrainingState.SUMMON_FULL_PARAMS + ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases if not state_dict or not module._has_params: _exit_full_param_ctx(module) @@ -141,8 +175,8 @@ def _common_summon_post_state_dict_hook( # all-gather and does not need to save the # state dict. We simply check # rank0_only to ensure this issue. rank0_only = ( - module._state_dict_type == fsdp_file.StateDictType.FULL_STATE_DICT - and cast(fsdp_file.FullStateDictConfig, module._state_dict_config).rank0_only + module._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, module._state_dict_config).rank0_only ) # no_fsdp_return means the state_dict returned by this rank should contain # only non-FSDP controlled parameters and buffers. @@ -159,7 +193,7 @@ def _common_summon_post_state_dict_hook( # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. - for fqn, param_name, module_name in module._param_fqns: + for fqn, param_name, module_name in _param_fqns(module): # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" @@ -224,9 +258,7 @@ def _full_pre_state_dict_hook( _common_summon_pre_state_dict_hook( module, offload_to_cpu=module._state_dict_config.offload_to_cpu, - rank0_only=cast( - fsdp_file.FullStateDictConfig, module._state_dict_config - ).rank0_only, + rank0_only=cast(FullStateDictConfig, module._state_dict_config).rank0_only, ) @@ -473,9 +505,9 @@ def _sharded_pre_load_state_dict_hook( ) nonsharded_tensors = [] - shared_fqns = [fqn for fqn, _, _ in module._shared_param_fqns] + shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module)] loaded_shapes = [] - for fqn, _, _ in module._param_fqns: + for fqn, _, _ in _param_fqns(module): full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: @@ -552,9 +584,9 @@ def _post_state_dict_hook( what postprocessing will be done. """ _post_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( @@ -576,9 +608,9 @@ def _pre_load_state_dict_hook( will be done. """ _pre_load_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, } # Code that is common for all state_dict impls fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) @@ -593,9 +625,9 @@ def _pre_load_state_dict_hook( @torch.no_grad() def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: _post_load_state_dict_hook_fn = { - fsdp_file.StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, - fsdp_file.StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, - fsdp_file.StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, + StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, } # Code that is common for all state_dict impls fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) diff --git a/torch/distributed/fsdp/api.py b/torch/distributed/fsdp/api.py index 9e1327c80633..18f3cd3069dd 100644 --- a/torch/distributed/fsdp/api.py +++ b/torch/distributed/fsdp/api.py @@ -10,7 +10,17 @@ import torch -__all__ = ["ShardingStrategy", "BackwardPrefetch", "MixedPrecision", "CPUOffload"] +__all__ = [ + "ShardingStrategy", + "BackwardPrefetch", + "MixedPrecision", + "CPUOffload", + "StateDictType", + "StateDictConfig", + "FullStateDictConfig", + "LocalStateDictConfig", + "ShardedStateDictConfig", +] class ShardingStrategy(Enum): @@ -149,3 +159,87 @@ class CPUOffload: """ offload_params: bool = False + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + ..note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all state_dict configuration classes. + Users should instantiate a child version (i.e. ``FullStateDictConfig``) in + order to configure settings for the particular type of ``state_dict`` + implementation FSDP will use. + """ + + offload_to_cpu: bool = False + + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, + ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload + the full ``state_dict`` to CPU and to materialize the ``state_dict`` on + rank 0 only. When used, it is recommended to enable both of these flags + together to optimize memory savings when taking checkpoints. Note that + this config class is meant for user via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + """ + + rank0_only: bool = False + + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + pass diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 9934e7189342..773686081a4d 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -5,7 +5,6 @@ import traceback import warnings from contextlib import contextmanager -from dataclasses import dataclass from enum import auto, Enum from typing import ( Any, @@ -25,7 +24,6 @@ import torch.nn as nn from torch.distributed import ProcessGroup from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - _CHECKPOINT_PREFIX, _CHECKPOINT_WRAPPED_MODULE, ActivationWrapper, ) @@ -68,8 +66,13 @@ from torch.distributed.fsdp.api import ( BackwardPrefetch, CPUOffload, + FullStateDictConfig, + LocalStateDictConfig, MixedPrecision, + ShardedStateDictConfig, ShardingStrategy, + StateDictConfig, + StateDictType, ) from ._optim_utils import ( @@ -103,11 +106,6 @@ __all__ = [ "FullyShardedDataParallel", - "StateDictType", - "StateDictConfig", - "FullStateDictConfig", - "LocalStateDictConfig", - "ShardedStateDictConfig", "OptimStateKeyType", ] @@ -115,90 +113,6 @@ FLAT_PARAM = "_flat_param" -class StateDictType(Enum): - """ - This enum indicates that which type of ``state_dict`` the FSDP module is - currently processing (returning or loading). - The default value is FULL_STATE_DICT to comply the PyTorch convention. - ..note:: - FSDP currently supports three types of ``state_dict``: - 1. ``state_dict/load_state_dict`: this pair of APIs return and load - the non-sharded, unflattened parameters. The semantics is the - same as using DDP. - 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return - and load local sharded, flattened parameters. The values returned - by ``_local_state_dict`` can be directly used by FSDP and is only - meaningful to FSDP (because parameters are flattened). Note that - these APIs are meant for use via the :func:`state_dict_type` - context manager as follows: - >>> # xdoctest: +SKIP("undefined variables") - >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): - ... state = fsdp.state_dict() # loads local state dict - 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs - return and load sharded, unflattened parameters. The ``state_dict`` - return by ``sharded_state_dict`` can be used by all other parallel - schemes (resharding may be required). - """ - - FULL_STATE_DICT = auto() - LOCAL_STATE_DICT = auto() - SHARDED_STATE_DICT = auto() - - -@dataclass -class StateDictConfig: - """ - ``StateDictConfig`` is the base class for all state_dict configuration classes. - Users should instantiate a child version (i.e. ``FullStateDictConfig``) in - order to configure settings for the particular type of ``state_dict`` - implementation FSDP will use. - """ - - offload_to_cpu: bool = False - - -@dataclass -class FullStateDictConfig(StateDictConfig): - """ - ``FullStateDictConfig`` is a config class meant to be used with - ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters, - ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload - the full ``state_dict`` to CPU and to materialize the ``state_dict`` on - rank 0 only. When used, it is recommended to enable both of these flags - together to optimize memory savings when taking checkpoints. Note that - this config class is meant for user via the :func:`state_dict_type` - context manager as follows: - >>> # xdoctest: +SKIP("undefined variables") - >>> fsdp = FSDP(model, auto_wrap_policy=...) - >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): - >>> state = fsdp.state_dict() - >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0. - >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: - >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP - >>> if dist.get_rank() == 0: - >>> # Load checkpoint only on rank 0 to avoid memory redundancy - >>> state_dict = torch.load("my_checkpoint.pt") - >>> model.load_state_dict(state_dict) - >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument - >>> # communicates loaded checkpoint states from rank 0 to rest of the world. - >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) - >>> # After this point, all ranks have FSDP model with loaded checkpoint. - """ - - rank0_only: bool = False - - -@dataclass -class LocalStateDictConfig(StateDictConfig): - pass - - -@dataclass -class ShardedStateDictConfig(StateDictConfig): - pass - - class OptimStateKeyType(Enum): PARAM_NAME = auto() PARAM_ID = auto() @@ -502,15 +416,12 @@ def __init__( # `_state_dict_type` controls the `state_dict()` behavior, which is # implemented using post-save and pre-load hooks - _init_state_dict_state(self) # TODO: currently a no-op; need to refactor below - self._state_dict_type = StateDictType.FULL_STATE_DICT - self._state_dict_config = FullStateDictConfig() + _init_state_dict_state(self) self._register_state_dict_hook(_post_state_dict_hook) self._register_load_state_dict_pre_hook( _pre_load_state_dict_hook, with_module=True ) self.register_load_state_dict_post_hook(_post_load_state_dict_hook) - self._full_param_ctx: Optional[Generator] = None @property def module(self) -> nn.Module: @@ -785,32 +696,6 @@ def state_dict_type( module, prev_state_dict_type, prev_state_dict_config ) - def _convert_to_wrapped_module_name(self, module_name: str) -> str: - module_name = module_name.replace(f"{FSDP_PREFIX}", "") - module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") - if module_name: - module_name = f"{module_name}." - # Activation checkpoint adds a prefix that has to be - # removed as well. - module_name = module_name.replace(_CHECKPOINT_PREFIX, "") - return module_name - - @property - def _param_fqns(self) -> Iterator[Tuple[str, str, str]]: - if not self._has_params: - return - for param_name, module_name in self._handles[0].parameter_module_names(): - module_name = self._convert_to_wrapped_module_name(module_name) - fqn = f"{module_name}{param_name}" - yield fqn, param_name, module_name - - @property - def _shared_param_fqns(self) -> Iterator[Tuple[str, str, str]]: - for param_name, module_name in self._handles[0].shared_parameter_module_names(): - module_name = self._convert_to_wrapped_module_name(module_name) - fqn = f"{module_name}{param_name}" - yield fqn, param_name, module_name - def state_dict(self, *args, **kwargs): _lazy_init(self, self) return super().state_dict(*args, **kwargs) From 91b71cdbe4f31006fad91f9dd460123677a7c625 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 9 Nov 2022 20:39:50 +0000 Subject: [PATCH 061/453] [dynamo] Add torch.device to is_safe_constant (#88766) Test Plan: ``` PYTORCH_TEST_WITH_DYNAMO=1 python test/test_torch.py -k test_advancedindex_mixed_cpu_devices_cuda ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88766 Approved by: https://github.com/jansel --- torch/_dynamo/utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index ef2c1c38ea8b..067a80807374 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -583,7 +583,19 @@ def is_safe_constant(v): if istype(v, (tuple, frozenset)): return all(map(is_safe_constant, v)) return istype( - v, (types.CodeType, int, float, bool, str, bytes, type(None), slice, type(type)) + v, + ( + types.CodeType, + int, + float, + bool, + str, + bytes, + type(None), + slice, + type(type), + torch.device, + ), ) From b92acee8f83c7852194d6979362aea0c240709da Mon Sep 17 00:00:00 2001 From: soulitzer Date: Thu, 10 Nov 2022 19:08:42 -0500 Subject: [PATCH 062/453] Add context manager to allow mutation on saved tensors (#79056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/79056 Approved by: https://github.com/albanD --- test/test_autograd.py | 178 ++++++++++++++++++++++++++++++++++++++++ torch/autograd/graph.py | 163 +++++++++++++++++++++++++++++++++++- 2 files changed, 338 insertions(+), 3 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index f5d890fad2d7..e08047860e42 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -8778,6 +8778,184 @@ def test_warning_in_backward(self, device): with self.assertWarnsRegex(UserWarning, "Warn from backward"): b.backward() +class TestAllowMutationOnSaved(TestCase): + def assertClonedLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.cloned.items())), n) + + def assertTIDMapLenEqual(self, ctx, n): + self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n) + + def test_basic(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + out = (b**2).sum() + b.sin_() + out.sum().backward() + return a.grad + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertTrue(torch.allclose(a * 2, da)) + self.assertClonedLenEqual(ctx, 0) + + def test_views(self): + a = torch.rand(2, 3, requires_grad=True) + + def fn(a): + b = a.clone() + c = b.view_as(b) + out = (b**2).sum() # How does this work? + c.sin_() + out.sum().backward() + return a.grad + + msg = "variables needed for gradient computation has been modified by an inplace" + with self.assertRaisesRegex(RuntimeError, msg): + fn(a) + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + da = fn(a) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, da)) + + def test_save_base_and_modify_view(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:1] + out = b**2 + # modify the view + c *= 10 + # self.assertClonedLenEqual(ctx, 1) + out.sum().backward() + self.assertClonedLenEqual(ctx, 0) + + self.assertClonedLenEqual(ctx, 0) + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_save_view_modify_base(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + c = b[:] + out = (c**2).sum() + b *= 2 + out.backward() + self.assertTrue(torch.allclose(a * 2, a.grad)) + + def test_double_backward(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + b = a.clone() + out = (b**2).sum() + b.sin_() + torch.autograd.grad(out, a, create_graph=True) + da, = torch.autograd.grad(out, a, create_graph=True) + d2a, = torch.autograd.grad(da.sum(), a) + + self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a)) + self.assertClonedLenEqual(ctx, 0) + + def test_saved_but_not_anymore(self): + # Make sure we don't clone if the tensor was once saved, but + # by the time we do in-place, it is no longer saved + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + out = (a**2).sum() + self.assertTIDMapLenEqual(ctx, 1) + self.assertClonedLenEqual(ctx, 0) + out.backward() + a.sin_() + self.assertClonedLenEqual(ctx, 0) + out = (a**2).sum() + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del out + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_many_times(self): + # We should only clone once + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 1) + del b, c + self.assertClonedLenEqual(ctx, 0) + + def test_saved_same_tensor_different_versions(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.randn(2, 3, requires_grad=True).clone() + b = a**2 + a.sin_() + c = a**2 + a.sin_() + self.assertClonedLenEqual(ctx, 2) + del b + self.assertClonedLenEqual(ctx, 1) + del c + self.assertClonedLenEqual(ctx, 0) + + def test_with_math_views(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + a.sin_() + out.backward() + + a = torch.tensor([1 + 1j], requires_grad=True).clone() + b = a.conj() + out = (b**2).sum() + # in this case, it is no longer a view it seems + b.sin_() + out.backward() + + def test_with_out_variant(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.tensor([1.], requires_grad=True) + b = torch.tensor([1.]) + c = torch.tensor([2.]) + out = a * b + self.assertTIDMapLenEqual(ctx, 1) + torch.sin(c, out=b) + self.assertClonedLenEqual(ctx, 1) + out.backward() + self.assertClonedLenEqual(ctx, 0) + + def test_backward_out_of_context(self): + # Out of context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + # Different context + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + a = torch.rand(2, 3, requires_grad=True) + out = (a**2).sum() + + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + with self.assertRaisesRegex(RuntimeError, msg): + out.backward() + + def test_disallow_nesting(self): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + msg = "allow_mutation_on_saved_tensors contexts cannot be nested" + with self.assertRaisesRegex(RuntimeError, msg): + with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: + pass class TestAutogradInferenceMode(TestCase): def _is_inference_tensor(self, tensor): diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 9c333c70bcf2..fc490a9d8e31 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,15 +1,17 @@ import torch import contextlib -from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List +from typing import Callable, Any, Dict, Tuple, Optional, Sequence, List, Set from torch.utils.hooks import RemovableHandle - -__all__ = ["saved_tensors_hooks", "save_on_cpu"] +from torch.utils._python_dispatch import TorchDispatchMode +from collections import defaultdict +import weakref __all__ = [ "saved_tensors_hooks", "save_on_cpu", "disable_saved_tensors_hooks", "register_multi_grad_hook", + "allow_mutation_on_saved_tensors", ] class saved_tensors_hooks(): @@ -270,3 +272,158 @@ def __setstate__(self, state): handles.append(t.register_hook(get_inner_hook(i))) return Handle(tuple(handles)) + + +# NOTE [Allow mutation on tensors saved for backward] +# +# 1. Tensor gets saved for backward +# - remember the python object id and the version of the tensor +# - remember aliasing information (data_ptr of base + version) +# - save the original so we control its lifetime +# 2. Any time a tensor gets in-placed +# - for each tensor aliased to it: +# - check using its object id and version to see if it has been saved +# - if it has been saved, clone it +# - delete the reference to the original +# 3. during backward +# - if the clone exists, the tensor must've been modified in-place +_allow_mutation_on_saved_tensors_enabled = False + +def _get_tid(t) -> Tuple[int, int, int]: + return (id(t), t.data_ptr(), t._version) + +def _get_sid(t) -> Tuple[int, int]: + return (t.data_ptr(), t._version) + +class _Handle(): + pass + +class _swap_with_cloned(saved_tensors_hooks): + def __init__(self, ctx): + def pack_hook(t): + tid = _get_tid(t) + sid = _get_sid(t) + # Tensors saved for backward have an entry in _tid_to_weakhandle + handle: Optional[_Handle] = None + + # Save aliasing information + ctx.sid_to_tid[sid].add(tid) + + # NB: The same tensor (of the same version) can be saved multiple times + if tid not in ctx.tid_to_weakhandle: + handle = _Handle() + ctx.tid_to_weakhandle[tid] = handle + ctx.original[handle] = t + else: + # Store an additional strong reference to the handle + handle = ctx.tid_to_weakhandle[tid] + return handle + + def unpack_hook(tup): + handle = tup + error_msg = ( + "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" + "in which the graph was originally recorded.") + assert _allow_mutation_on_saved_tensors_enabled, error_msg + if handle in ctx.cloned: + res = ctx.cloned[handle] + else: + assert handle in ctx.original, error_msg + res = ctx.original[handle] + return res + + super().__init__(pack_hook, unpack_hook) + +class _CloneArgBeforeMutateMode(TorchDispatchMode): + def __init__(self, ctx): + self.ctx = ctx + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + for idx, arg in enumerate(func._schema.arguments): + if arg.alias_info is not None and arg.alias_info.is_write: + t = kwargs["out"] if arg.is_out else args[idx] + tid = _get_tid(t) + sid = _get_sid(t) + ctx = self.ctx + if sid in ctx.sid_to_tid: + for tid in ctx.sid_to_tid[sid]: + if tid not in ctx.tid_to_weakhandle: + # We know that if tid is in sid_to_tid, then it must also be in + # tid_to_weakhandle. However, it is possible for the tensor to be + # saved at one point, but cleared by backward before it is modified + # in-place. Consider the following example: + # + # >>> a = torch.randn(2, 3, requires_grad=True).clone() + # >>> out = (a**2).sum() + # >>> out.backward() + # >>> a.sin_() + continue + handle = ctx.tid_to_weakhandle[tid] + if handle in ctx.cloned: + # The same exact tensor has been cloned already + continue + ctx.cloned[handle] = ctx.original[handle].clone() + del ctx.original[handle] + + rs = func(*args, **kwargs) + return rs + +class _AllowMutationOnSavedContext(): + def __init__(self): + self.cloned: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.original: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self.tid_to_weakhandle: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + self.sid_to_tid: Dict[Tuple[int, int], Set[Tuple[int, int, int]]] = defaultdict(set) + + def clear(self): + self.cloned.clear() + self.original.clear() + self.tid_to_weakhandle.clear() + self.sid_to_tid.clear() + +@contextlib.contextmanager +def allow_mutation_on_saved_tensors(): + """Context manager under which mutating tensors saved for backward is allowed + + Under this context manager, tensors saved for backward are cloned on mutation, + so the original version can still be used during backward. Normally, mutating a tensor + saved for backward will result in an error raised when it's used during backward. + + To ensure the correct behavior, both the forward and backward should be run under + the same context manager. + + returns: + An _AllowMutationOnSavedContext object storing the state managed by this + context manager. This object can be useful for debugging purposes. The state + managed by the context manager is automatically cleared upon exiting. + + Example:: + + >>> import torch + >>> with torch.autograd.graph.allow_mutation_on_saved_tensors(): + ... # forward + ... a = torch.ones(2, 3, requires_grad=True) + ... b = a.clone() + ... out = (b**2).sum() + ... b.sin_() + ... # backward + ... out.sum().backward() + ... + tensor([[0.8415, 0.8415, 0.8415], + [0.8415, 0.8415, 0.8415]], grad_fn=) + """ + global _allow_mutation_on_saved_tensors_enabled + + ctx = _AllowMutationOnSavedContext() + + with _swap_with_cloned(ctx), _CloneArgBeforeMutateMode(ctx): + try: + if _allow_mutation_on_saved_tensors_enabled: + raise RuntimeError("allow_mutation_on_saved_tensors contexts cannot be nested") + _allow_mutation_on_saved_tensors_enabled = True + yield ctx + finally: + ctx.clear() + _allow_mutation_on_saved_tensors_enabled = False From 3c7f96665e784a793d2d1a120ea8fe370b3f6d81 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 19:54:56 +0000 Subject: [PATCH 063/453] [FSDP][state_dict][3/N] Change how state_dict utils access attributes in _FSDPState (#88635) **What This PR Does** _state_dict_utils currently accesses the FSDP states through module. To enable composable FSDP state_dict, these accesses need to go through _FSDPState. module is still required for most APIs as state_dict has to access per-module information. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88635 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_common_utils.py | 18 ++ torch/distributed/fsdp/_state_dict_utils.py | 260 ++++++++++++-------- 2 files changed, 177 insertions(+), 101 deletions(-) diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index c93c8abb5ebd..f6ccc3e9243f 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -61,6 +61,24 @@ def _all_handles(state: _FSDPState) -> List: ) +@no_type_check +def _module_handles(state: _FSDPState, module: nn.Module) -> List: + """ + Given a module and returns the flat handles that map to this module. If the + module is FullyShardedDataParallel, the module._handles will be returned. + """ + if _is_composable(state): + return state._module_to_handles[module][:] + else: + return module._handles[:] + + +@no_type_check +def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: + """Given a module and returns if this module has parameters sharded by FSDP.""" + return len(_module_handles(state, module)) > 0 + + def clean_tensor_name(tensor_name: str) -> str: """ Cleans the parameter or buffer name by removing any module wrapper diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index c90bd4d409b1..0bfd149b0112 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -1,7 +1,7 @@ import functools import math import warnings -from typing import Any, Callable, cast, Dict, Iterator, Tuple +from typing import Any, Callable, cast, Dict, Iterator, no_type_check, Tuple import torch import torch.distributed as dist @@ -18,6 +18,9 @@ ShardedTensor, ) from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _has_fsdp_params, + _module_handles, clean_tensor_name, FSDP_PREFIX, FSDP_WRAPPED_MODULE, @@ -51,24 +54,28 @@ def _convert_to_wrapped_module_name(module_name: str) -> str: return module_name -def _param_fqns(module) -> Iterator[Tuple[str, str, str]]: - if not module._has_params: +def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str]]: + if not _has_fsdp_params(fsdp_state, module): return - for param_name, module_name in module._handles[0].parameter_module_names(): + for param_name, module_name in _module_handles(fsdp_state, module)[ + 0 + ].parameter_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name -def _shared_param_fqns(module) -> Iterator[Tuple[str, str, str]]: - for param_name, module_name in module._handles[0].shared_parameter_module_names(): +def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: + for param_name, module_name in _module_handles(fsdp_state, module)[ + 0 + ].shared_parameter_module_names(): module_name = _convert_to_wrapped_module_name(module_name) fqn = f"{module_name}{param_name}" yield fqn, param_name, module_name def _enter_full_param_ctx( - module, + fsdp_state: _FSDPState, recurse: bool = False, writeback: bool = False, rank0_only: bool = False, @@ -80,53 +87,56 @@ def _enter_full_param_ctx( requires to enter the context in the pre-hook but leave the context in the post-hook. This API enters the context of ``summon_full_params``. """ - assert module._full_param_ctx is None, ( - "Entering the ``summon_full_params`` context but module._full_param_ctx " + assert fsdp_state._full_param_ctx is None, ( + "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " "is not None." ) - assert module.training_state != TrainingState.SUMMON_FULL_PARAMS, ( + assert fsdp_state.training_state != TrainingState.SUMMON_FULL_PARAMS, ( "Entering the summon_full_params context but the state is already " "SUMMON_FULL_PARAMS." ) - module._full_param_ctx = module._summon_full_params( + fsdp_state._full_param_ctx = fsdp_state._summon_full_params( recurse=recurse, writeback=writeback, rank0_only=rank0_only, offload_to_cpu=offload_to_cpu, with_grads=with_grads, ) - module._full_param_ctx.__enter__() + fsdp_state._full_param_ctx.__enter__() -def _exit_full_param_ctx(module) -> None: +@no_type_check +def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: """A helper function to exit ``summon_full_params`` context.""" - assert module.training_state == TrainingState.SUMMON_FULL_PARAMS, ( + assert fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS, ( "Exiting the summon_full_params context but the state is not " "SUMMON_FULL_PARAMS." ) - assert module._full_param_ctx is not None - module._full_param_ctx.__exit__(None, None, None) - module._full_param_ctx = None + assert fsdp_state._full_param_ctx is not None + fsdp_state._full_param_ctx.__exit__(None, None, None) + fsdp_state._full_param_ctx = None def _common_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: """Performs the pre-state_dict tasks shared by all state_dict types.""" if torch.cuda.is_available(): torch.cuda.synchronize() - _lazy_init(module, module) + # TODO: need to check if this is always correct for composable FSDP. + _lazy_init(fsdp_state, module) # TODO: change to this call after pre_state_dict_hook is in `nn.Module`. - # if module.is_root: - # _clear_grads_if_needed(module._fsdp_handles(module)) - if module._has_params: - _clear_grads_if_needed([module._handles[0]]) + # if fsdp_state.is_root: + # _clear_grads_if_needed(_all_handles(fsdp_state)) + if _has_fsdp_params(fsdp_state, module): + _clear_grads_if_needed([_module_handles(fsdp_state, module)[0]]) def _common_summon_pre_state_dict_hook( - module, + fsdp_state: _FSDPState, offload_to_cpu: bool, rank0_only: bool, ) -> None: @@ -135,7 +145,7 @@ def _common_summon_pre_state_dict_hook( ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ _enter_full_param_ctx( - module, + fsdp_state, recurse=False, writeback=False, offload_to_cpu=offload_to_cpu, @@ -144,8 +154,10 @@ def _common_summon_pre_state_dict_hook( # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. +@no_type_check def _common_summon_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, param_hook: Callable, @@ -157,17 +169,17 @@ def _common_summon_post_state_dict_hook( """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) assert ( - module.training_state == TrainingState.SUMMON_FULL_PARAMS + fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases - if not state_dict or not module._has_params: - _exit_full_param_ctx(module) + if not state_dict or not _has_fsdp_params(fsdp_state, module): + _exit_full_param_ctx(fsdp_state) return state_dict # TODO: Once pre_state_dict hook is supported, this pop should be removed. # For `use_orig_params=True`, the `FlatParameter` is not registered, so # there is no entry in the state dict for it to pop. - if not module._use_orig_params: + if not fsdp_state._use_orig_params: state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") # If a rank does not have unsharded parameters(when `rank0_only=True` @@ -175,25 +187,25 @@ def _common_summon_post_state_dict_hook( # all-gather and does not need to save the # state dict. We simply check # rank0_only to ensure this issue. rank0_only = ( - module._state_dict_type == StateDictType.FULL_STATE_DICT - and cast(FullStateDictConfig, module._state_dict_config).rank0_only + fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only ) # no_fsdp_return means the state_dict returned by this rank should contain # only non-FSDP controlled parameters and buffers. - no_fsdp_return = rank0_only and module.rank != 0 - if no_fsdp_return and not module._use_orig_params: - for clean_key in module._buffer_names: + no_fsdp_return = rank0_only and fsdp_state.rank != 0 + if no_fsdp_return and not fsdp_state._use_orig_params: + for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_key.replace( f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" ) state_dict.pop(f"{prefix}{clean_key}", None) - _exit_full_param_ctx(module) + _exit_full_param_ctx(fsdp_state) return state_dict # Loop only the parameters saved in this instance's wrapped module to # avoid processing buffers. - for fqn, param_name, module_name in _param_fqns(module): + for fqn, param_name, module_name in _param_fqns(module, fsdp_state): # TODO: remove the parameter retrieval. See ``_full_pre_state_dict_hook``. param = functools.reduce(getattr, fqn.split("."), module.module) fqn = f"{prefix}{fqn}" @@ -205,16 +217,16 @@ def _common_summon_post_state_dict_hook( f"FSDP assumes {fqn} is in the state_dict but the state_dict only " f"has {state_dict.keys()}. " f"prefix={prefix}, module_name={module_name}, " - f"param_name={param_name} rank={module.rank}." + f"param_name={param_name} rank={fsdp_state.rank}." ) - param_hook(module, state_dict, prefix, fqn) - _exit_full_param_ctx(module) + param_hook(state_dict, prefix, fqn) + _exit_full_param_ctx(fsdp_state) cpu_device = torch.device("cpu") buffer_clean_fqns = [] buffers = [] - for clean_key in module._buffer_names: + for clean_key in fsdp_state._buffer_names: # This is a hack to support activation checkpoint. clean_key = clean_tensor_name(clean_key) fqn = f"{prefix}{clean_key}" @@ -225,22 +237,29 @@ def _common_summon_post_state_dict_hook( state_dict.pop(fqn) else: buffer = state_dict[fqn] - if module._state_dict_config.offload_to_cpu and buffer.device != cpu_device: + if ( + fsdp_state._state_dict_config.offload_to_cpu + and buffer.device != cpu_device + ): state_dict[fqn] = buffer.to(cpu_device) # TODO: for composable FSDP, this should be clean_tensor_name(clean_key), buffer_clean_fqns.append(clean_key) buffers.append(state_dict[fqn]) - if buffers and module._mixed_precision_enabled_for_buffers(): - buffer_dtypes = _get_buffer_dtypes(module, buffer_clean_fqns) - _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, module.compute_device) + if buffers and fsdp_state._mixed_precision_enabled_for_buffers(): + buffer_dtypes = _get_buffer_dtypes(fsdp_state, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes, fsdp_state.compute_device + ) for buffers, clean_fqn in zip(buffers, buffer_clean_fqns): fqn = f"{prefix}{clean_fqn}" state_dict[fqn] = buffer.clone() return state_dict +@no_type_check def _full_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -254,16 +273,18 @@ def _full_pre_state_dict_hook( TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported in ``nn.Module``. """ - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) _common_summon_pre_state_dict_hook( - module, - offload_to_cpu=module._state_dict_config.offload_to_cpu, - rank0_only=cast(FullStateDictConfig, module._state_dict_config).rank0_only, + fsdp_state, + offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, + rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, ) +@no_type_check def _full_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -274,10 +295,9 @@ def _full_post_state_dict_hook( the ``FSDP_WRAPPED_MODULE`` prefix. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _full_pre_state_dict_hook(module, state_dict, prefix) + _full_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) def param_hook( - module, state_dict: Dict[str, Any], prefix: str, fqn: str, @@ -292,7 +312,7 @@ def param_hook( # Clone non-ignored parameters before exiting the # `_summon_full_params()` context - if clean_key not in module._ignored_param_names and not getattr( + if clean_key not in fsdp_state._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): try: @@ -300,31 +320,37 @@ def param_hook( state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] except BaseException as e: warnings.warn( - f"Failed to clone() tensor with name {fqn} on rank {module.rank}. " + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " "memory regions after returning from state_dict() call if this " "parameter is managed by FSDP. Please check clone " f"implementation of {fqn}. Error: {str(e)}" ) - return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) + return _common_summon_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) def _full_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: - _enter_full_param_ctx(module, recurse=False, writeback=True) + _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") -def _full_post_load_state_dict_hook(module, *args, **kwargs) -> None: - _exit_full_param_ctx(module) +def _full_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + _exit_full_param_ctx(fsdp_state) def _local_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -333,16 +359,21 @@ def _local_pre_state_dict_hook( hook is not supported by the PyTorch core. So this API is called from `_local_post_state_dict_hook()` to simulate the case. """ - if module._has_params and not module._handles[0].uses_sharded_strategy: + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy + ): raise RuntimeError( "``local_state_dict`` can only be used when parameters are flatten " "and sharded." ) - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) +@no_type_check def _local_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -352,42 +383,45 @@ def _local_post_state_dict_hook( will happen. The underlying storage is the same. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _local_pre_state_dict_hook(module, state_dict, prefix) + _local_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) - if not module._has_params: + if not _has_fsdp_params(fsdp_state, module): return state_dict # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor # value as the flat_param but it is a pure Tensor because # nn.Module.state_dict() will detach the parameter. Therefore, we need # to get flat_param to get the metadata. - assert module._handles, "Should have returned early" - flat_param = module._handles[0].flat_param + assert _module_handles(fsdp_state, module), "Should have returned early" + flat_param = _module_handles(fsdp_state, module)[0].flat_param # Construct a ShardedTensor from the flat_param. full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] - shard_offset = flat_param.numel() * module.rank + shard_offset = flat_param.numel() * fsdp_state.rank valid_data_size = flat_param.numel() - flat_param._shard_numel_padded if valid_data_size > 0 and flat_param._shard_numel_padded > 0: flat_param = flat_param.narrow(0, 0, valid_data_size) local_shards = [ - Shard.from_tensor_and_offsets(flat_param, [shard_offset], module.rank) + Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) ] sharded_tensor = init_from_local_shards( - local_shards, full_numel, process_group=module.process_group + local_shards, full_numel, process_group=fsdp_state.process_group ) # type: ignore[assignment] - if module._state_dict_config.offload_to_cpu: + if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[f"{prefix}{fsdp_file.FLAT_PARAM}"] = sharded_tensor return state_dict -def _local_post_load_state_dict_hook(module, *args, **kwargs) -> None: +def _local_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: pass def _local_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -399,7 +433,7 @@ def _local_pre_load_state_dict_hook( _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" if fqn not in state_dict: - assert not module._has_params, ( + assert not _has_fsdp_params(fsdp_state, module), ( "No `FlatParameter` in `state_dict` for this FSDP instance " "but it has parameters" ) @@ -416,7 +450,7 @@ def _local_pre_load_state_dict_hook( # Get the metadata of the flat_param to decide whether to pad the loaded # tensor. - flat_param = module._handles[0].flat_param + flat_param = _module_handles(fsdp_state, module)[0].flat_param assert flat_param is not None if flat_param._shard_numel_padded not in (0, flat_param.numel()): assert load_tensor.numel() < flat_param.numel(), ( @@ -429,6 +463,7 @@ def _local_pre_load_state_dict_hook( def _sharded_pre_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -436,23 +471,28 @@ def _sharded_pre_state_dict_hook( Hook that runs before model.state_dict() is called. Check ``_full_pre_load_state_dict_hook`` for the detail. """ - if module._has_params and not module._handles[0].uses_sharded_strategy: + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy + ): raise RuntimeError( "``sharded_state_dict`` can only be used when parameters are flatten " "and sharded." ) - _common_pre_state_dict_hook(module, state_dict, prefix) + _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) # Setting offload_to_cpu here does not work even if offload_to_cpu is True. # We have to create ShardedTensor first then move it to CPU. _common_summon_pre_state_dict_hook( - module, + fsdp_state, offload_to_cpu=False, rank0_only=False, ) +@no_type_check def _sharded_post_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> Dict[str, Any]: @@ -462,31 +502,38 @@ def _sharded_post_state_dict_hook( """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. - _sharded_pre_state_dict_hook(module, state_dict, prefix) + _sharded_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) - def param_hook(module, state_dict: Dict[str, Any], prefix: str, fqn: str): + def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): param = state_dict[fqn] sharded_tensor = _ext_chunk_tensor( tensor=param, - rank=module.rank, - world_size=module.world_size, + rank=fsdp_state.rank, + world_size=fsdp_state.world_size, num_devices_per_node=torch.cuda.device_count(), - pg=module.process_group, + pg=fsdp_state.process_group, ) - if module._state_dict_config.offload_to_cpu: + if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() state_dict[fqn] = sharded_tensor - return _common_summon_post_state_dict_hook(module, state_dict, prefix, param_hook) + return _common_summon_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) -def _sharded_post_load_state_dict_hook(module, *args, **kwargs) -> None: - if module._use_orig_params: - module._register_orig_params() +@no_type_check +def _sharded_post_load_state_dict_hook( + module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if fsdp_state._use_orig_params: + fsdp_state._register_orig_params() +@no_type_check def _sharded_pre_load_state_dict_hook( module, + fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: @@ -495,19 +542,19 @@ def _sharded_pre_load_state_dict_hook( a new FlatParameter and shards the new FlatParameter to the local chunk. """ _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") - if not module._has_params: + if not _has_fsdp_params(fsdp_state, module): return - if not module._handles[0].uses_sharded_strategy: + if not _module_handles(fsdp_state, module)[0].uses_sharded_strategy: raise RuntimeError( "load_sharded_state_dict can only be called when parameters " "are flatten and sharded." ) nonsharded_tensors = [] - shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module)] + shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module, fsdp_state)] loaded_shapes = [] - for fqn, _, _ in _param_fqns(module): + for fqn, _, _ in _param_fqns(module, fsdp_state): full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}" param = state_dict.pop(full_fqn) if fqn in shared_fqns: @@ -517,12 +564,12 @@ def _sharded_pre_load_state_dict_hook( loaded_shapes.append(param.size()) assert len(shards) < 2, ( "Expects 0 or 1 shard per rank " - f"but got {len(shards)} shards on rank {module.rank}." + f"but got {len(shards)} shards on rank {fsdp_state.rank}." ) param_numel = param.size().numel() dim_0_size = param.size()[0] chunk_size = ( - math.ceil(dim_0_size / module.world_size) * param_numel // dim_0_size + math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size ) if len(shards) == 1: local_tensor = shards[0].tensor.flatten() @@ -534,14 +581,16 @@ def _sharded_pre_load_state_dict_hook( else: local_tensor = torch.zeros(chunk_size, dtype=param.dtype).cuda() tensor = torch.empty( - chunk_size * module.world_size, dtype=local_tensor.dtype + chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype ).cuda() - dist.all_gather_into_tensor(tensor, local_tensor, group=module.process_group) + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group + ) tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) nonsharded_tensors.append(tensor) # Create a new flat_param from the loaded, non-sharded tensors. - flat_param = module._handles[0].flat_param + flat_param = _module_handles(fsdp_state, module)[0].flat_param loaded_flat_param = FlatParamHandle.flatten_params( nonsharded_tensors, requires_grad=False ) @@ -549,8 +598,8 @@ def _sharded_pre_load_state_dict_hook( # Get the chunk from the loaded flat_param for the local rank. loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard( loaded_flat_param, - module.rank, - module.world_size, + fsdp_state.rank, + fsdp_state.world_size, ) loaded_flat_tensor.to(flat_param.device) assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), ( @@ -567,10 +616,11 @@ def _sharded_pre_load_state_dict_hook( f"from the local chunk {flat_param._shard_numel_padded}." ) state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor - if module._use_orig_params: - module._deregister_orig_params() + if fsdp_state._use_orig_params: + fsdp_state._deregister_orig_params() +@no_type_check @torch.no_grad() def _post_state_dict_hook( module: nn.Module, @@ -580,21 +630,24 @@ def _post_state_dict_hook( ) -> Dict[str, Any]: """ _post_state_dict_hook() is called after the state_dict() of this - FSDP module is executed. ``module._state_dict_type`` is used to decide + FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide what postprocessing will be done. """ + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _post_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) - processed_state_dict = _post_state_dict_hook_fn[fsdp_module._state_dict_type]( - fsdp_module, state_dict, prefix + processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_module, fsdp_state, state_dict, prefix ) return processed_state_dict +@no_type_check @torch.no_grad() def _pre_load_state_dict_hook( module: nn.Module, @@ -604,9 +657,11 @@ def _pre_load_state_dict_hook( ) -> None: """ ``_pre_state_dict_hook` is called before ``module._load_from_state_dict()`` - is called. ``module._state_dict_type`` is used to decide what preprocessing + is called. ``fsdp_state._state_dict_type`` is used to decide what preprocessing will be done. """ + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _pre_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, @@ -617,13 +672,16 @@ def _pre_load_state_dict_hook( if torch.cuda.is_available(): torch.cuda.synchronize() # Dispatch into state_dict specific implementation of pre-hook. - _pre_load_state_dict_hook_fn[fsdp_module._state_dict_type]( - fsdp_module, state_dict, prefix + _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_module, fsdp_state, state_dict, prefix ) +@no_type_check @torch.no_grad() def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: + # TODO: get the composable state from module + fsdp_state: _FSDPState = module _post_load_state_dict_hook_fn = { StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, @@ -633,4 +691,4 @@ def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) # Dispatch into state_dict type specific implementation of post-hook for # loading state_dict. - _post_load_state_dict_hook_fn[fsdp_module._state_dict_type](fsdp_module) + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](fsdp_module, fsdp_state) From d615d1228932eaa5e026f5399e099f2036d2379b Mon Sep 17 00:00:00 2001 From: anjali411 Date: Fri, 11 Nov 2022 15:24:28 +0000 Subject: [PATCH 064/453] Add meta impl for topk (#88694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88694 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 1 - test/test_proxy_tensor.py | 1 - torch/_meta_registrations.py | 17 +++++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 09b65a32bfee..4da39210343e 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1214,7 +1214,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition xfail('tensordot', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('topk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('trapz', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 72c7249f4f14..d1a5c9498bca 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1337,7 +1337,6 @@ def f(a, b, c, d, e): xfail('take_along_dim', ''), # dtype of indices should be Long but got Float xfail('take', ''), # aten.take.default - couldn't find symbolic meta function/decomposition xfail('tensordot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('topk', ''), # aten.topk.default - couldn't find symbolic meta function/decomposition xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 04c522ab9e3b..5d583de67d19 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1777,6 +1777,23 @@ def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None): ) +@register_meta(aten.topk.default) +def topk_meta(self, k, dim=-1, largest=True, sorted=True): + # From aten/src/ATen/native/Sorting.cpp + dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True) + check( + k >= 0 and k <= (self.size(dim) if self.dim() > 0 else 1), + lambda: "selected index k out of range", + ) + sliceSize = 1 if self.dim() == 0 else self.size(dim) + check(k >= 0 and k <= sliceSize, lambda: "k not in range for dimension") + + topKSize = list(self.shape) + if len(topKSize) > 0: + topKSize[dim] = k + return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs From 1e8f95ace16cb617d71f8f8254c1d5bafd9f586c Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Fri, 11 Nov 2022 13:51:18 +0100 Subject: [PATCH 065/453] Symintify `broadcast_to` (#88776) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88776 Approved by: https://github.com/ezyang --- .../ATen/functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/TensorShape.cpp | 4 ++-- aten/src/ATen/native/native_functions.yaml | 4 +++- test/functorch/test_aotdispatch.py | 11 ++--------- test/test_proxy_tensor.py | 15 --------------- 5 files changed, 8 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 66aaa53bfcc1..e31b36d11241 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -63,7 +63,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(bitwise_or, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); OP_DECOMPOSE(broadcast_tensors); - OP_DECOMPOSE(broadcast_to); + m.impl("broadcast_to", native::broadcast_to_symint); OP_DECOMPOSE(cartesian_prod); OP_DECOMPOSE(cdist); OP_DECOMPOSE(clip); diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 31b4011c1281..deb9b949aa5d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -537,8 +537,8 @@ Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) { return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced); } -Tensor broadcast_to(const Tensor& self, IntArrayRef size) { - return self.expand(size); +Tensor broadcast_to_symint(const Tensor& self, SymIntArrayRef size) { + return self.expand_symint(size); } std::vector broadcast_tensors(TensorList tensors) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0ea606f5e1fb..de087c0b8a89 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1195,8 +1195,10 @@ device_check: NoCheck device_guard: False -- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) +- func: broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) variants: function, method + dispatch: + CompositeImplicitAutograd: broadcast_to_symint - func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) variants: function diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 4da39210343e..f4782b8a595d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1093,20 +1093,13 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked.cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('masked.cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('masked_fill', ''), # could not find kernel - xfail('masked.log_softmax', ''), # argument 'size' (position 2) must be tuple of ints, not ... xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=t... - xfail('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + # Seems flaky: https://github.com/pytorch/pytorch/issues/88883 + skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... - xfail('masked.softmax', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.softmin', ''), # argument 'size' (position 2) must be tuple of ints, not torc... - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... - xfail('masked.sum', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=to... xfail('matmul', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decompo... xfail('median', ''), # could not find kernel diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index d1a5c9498bca..86beb651cb2d 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1114,23 +1114,8 @@ def f(a, b, c, d, e): xfail('linalg.eig'), xfail('linalg.eigvals'), skip('masked.logsumexp', ''), # Tensors of type TensorImpl do not have numel - xfail('masked.amax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.amin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.argmax', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... - xfail('masked.argmin', ''), # broadcast_to(): argument 'size' (position 2) must be tuple of ints, but found ... xfail('masked.cumprod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.cumsum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.log_softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition - xfail('masked.mean', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, ... - xfail('masked.median', ''), # aten.nanmedian.dim - couldn't find symbolic meta function/decomposition - xfail('masked.norm', ''), # aten.linalg_vector_norm.default - couldn't find symbolic meta function/decomposition - xfail('masked.prod', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmax', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.softmin', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.std', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... - xfail('masked.sum', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition - xfail('masked.var', ''), # ones() received an invalid combination of arguments - got (torch.Size, device=torch.device, d... xfail('addmv', ''), # aten.addmv.default - couldn't find symbolic meta function/decomposition xfail('addr', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('aminmax', ''), # aten.aminmax.default - couldn't find symbolic meta function/decomposition From a6832b08a3f6c1b425a075fe204a1f21361f33d9 Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 8 Nov 2022 19:23:21 +0000 Subject: [PATCH 066/453] Regularize bernouilli_ with bernouilli decomp (#88349) Fix for https://github.com/pytorch/torchdynamo/issues/1796. Just like the other [bernouilli decomp](https://github.com/pytorch/pytorch/blob/master/torch/_inductor/decomposition.py#L302) we need to pass `dtype=float32` to avoid `"check_uniform_bounds" not implemented` errors. Are we planning on enabling `TEST_WITH_TORCHINDUCTOR` ? Do I need to change anything with the tests ? Pull Request resolved: https://github.com/pytorch/pytorch/pull/88349 Approved by: https://github.com/desertfire --- torch/_inductor/decomposition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index d7aa5e35f501..e8a20c0dbd26 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -325,7 +325,7 @@ def bernoulli_p(self, p=0.5, *, generator=None): @register_extra_random_decomp([aten.bernoulli_]) def bernoulli_(self, p=0.5): - return self.copy_(torch.rand_like(self) < p) + return self.copy_(torch.rand_like(self, dtype=torch.float32) < p) @functools.lru_cache(None) From 89a326ff7ea56a1d735d26800b07a10e35c2dff4 Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Fri, 11 Nov 2022 16:57:05 +0000 Subject: [PATCH 067/453] Explicitly check filelike arg of `torch.save` (#88867) Fixes #88793 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88867 Approved by: https://github.com/ezyang --- test/test_serialization.py | 9 +++++++++ torch/serialization.py | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/test/test_serialization.py b/test/test_serialization.py index 779d6fb5c20c..5ccc6f47b4c5 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -585,6 +585,15 @@ def test_serialization_filelike_exceptions(self): with self.assertRaises(TypeError): # Tries to serialize str into tensor with wrong callable write property torch.save('foo', x) + s_data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + s = torch.CharStorage(s_data) + with self.assertRaises(AttributeError): + # Tries to serialize list into CharStorage + torch.save(s_data, s) + x = torch.randint(10, (3, 3), dtype=torch.float).cpu().numpy() + with self.assertRaises(AttributeError): + # Tries to serialize ndarray into ndarray + torch.save(x, x) def test_serialization_storage_slice(self): diff --git a/torch/serialization.py b/torch/serialization.py index d123a955ad96..3078e57587be 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -375,6 +375,12 @@ def _check_dill_version(pickle_module) -> None: pickle_module.__version__ )) +def _check_save_filelike(f): + if not isinstance(f, (str, os.PathLike)) and not hasattr(f, 'write'): + raise AttributeError(( + "expected 'f' to be string, path, or a file-like object with " + "a 'write' attribute")) + def save( obj: object, f: FILE_LIKE, @@ -422,6 +428,7 @@ def save( >>> torch.save(x, buffer) """ _check_dill_version(pickle_module) + _check_save_filelike(f) if _use_new_zipfile_serialization: with _open_zipfile_writer(f) as opened_zipfile: From adfbd831cf59111c3d3a4a50ba6372bba94b63d1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 17:03:25 +0000 Subject: [PATCH 068/453] Revert "[Autograd] Use in-place input accumulation fast path for dense Tensors. (#88339)" This reverts commit 8f66ae413f8c9d7f2418d7f0b9f69d409c455b46. Reverted https://github.com/pytorch/pytorch/pull/88339 on behalf of https://github.com/mehtanirav due to Internal test failures --- torch/csrc/autograd/input_buffer.cpp | 54 ++++++++-------------------- 1 file changed, 14 insertions(+), 40 deletions(-) diff --git a/torch/csrc/autograd/input_buffer.cpp b/torch/csrc/autograd/input_buffer.cpp index 7e6df0cea8da..6cc6acefc9d4 100644 --- a/torch/csrc/autograd/input_buffer.cpp +++ b/torch/csrc/autograd/input_buffer.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -67,18 +66,6 @@ void record_stream_any_impl(Variable& var, c10::Stream& stream) { } } } - -bool can_accumulate_inplace(const Variable& v) { - return ( - // `v` is a "vanilla" Tensor - !(at::isTensorSubclassLike(v) || v._is_zerotensor() || v.is_nested()) && - - // with a favorable memory layout - v.is_non_overlapping_and_dense() && - - // and we hold the last reference - v.use_count() == 1 && v.storage().use_count() == 1); -} } // anonymous namespace static void accumulate( @@ -87,38 +74,25 @@ static void accumulate( Variable&& var) { TORCH_INTERNAL_ASSERT(pos < buffer.size()); auto& old_var = buffer[pos]; - // If we hold the last reference to `old_var` AND its storage we will try to - // repurpose it to store the output. (Or, if `old_var` is sparse then `var` - // becomes the candidate output Tensor.) We only do this if: - // 1) GradMode is disabled since Autograd has special handling for inplace - // mutation which we don't want to trigger. - // - // 2) We hold the last reference. - // (Both `.use_count` and `.storage().use_count()` are one) - // - // 3) The candidate tensor is a contiguous, non-overlapping, dense, and - // otherwise stock standard Tensor. - // - // 4) The candidate is mutable. Currently only ZeroTensors are immutable. - // - // 5) The other Tensor is not a Tensor subclass (except sparse), since - // it's hard to predict the semantics of arbitrary subclass behavior. - - if (at::GradMode::is_enabled()) { - buffer[pos] = old_var + var; - } else if ( - // ATen doesn't route sparse additions correctly... - old_var.is_sparse() || old_var.is_sparse_csr()) { - if (can_accumulate_inplace(var)) { + // ATen doesn't route sparse additions correctly... + // do dense + sparse in-place if possible + if (old_var.is_sparse()) { + // It is safe to change the Tensor inplace if the Tensor is only used in + // this buffer (this could be the gradient passed by the user) and that no + // other Tensor is using the same storage. + if (!var.is_sparse() && var.is_contiguous() && var.use_count() == 1 && + var.storage().use_count() == 1) { buffer[pos] = var.add_(old_var); } else { buffer[pos] = var + old_var; } - } else if ( - can_accumulate_inplace(old_var) && !at::isTensorSubclassLike(var)) { - buffer[pos] = old_var.add_(var); } else { - buffer[pos] = old_var + var; + if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && + old_var.use_count() == 1 && old_var.storage().use_count() == 1) { + buffer[pos] = old_var.add_(var); + } else { + buffer[pos] = old_var + var; + } } } From 8ff2e34ca6905404aba35a432acf667ee6a13c6e Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Fri, 11 Nov 2022 04:25:11 +0000 Subject: [PATCH 069/453] Take input striding for conv forward based on eager output (#88706) From discussion with @Chillee and @ngimel we'll likely need further fixes to ensure that we hit channels last kernels but this is still worth landing in its own right. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88706 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 26 +++++++++++ torch/_inductor/ir.py | 72 +++++++++++++++++------------ 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 121f3d31f39c..aea8013bdfac 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4601,6 +4601,8 @@ def fn(a): CommonTemplate.install(CudaTests, "cuda") class CudaReproTests(TestCase): + common = check_model_cuda + def test_index_put_issue(self): def forward( self, @@ -4637,6 +4639,30 @@ def forward( compiled = compile_fx_inner(mod, inps) compiled(inps) + @requires_cuda() + def test_input_channels_last(self): + m = torch.nn.Sequential( + torch.nn.Conv2d(3, 3, 1, 1), + ToTuple(), + ).cuda() + inp = ( + torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() + ) + + self.common( + m, + (inp,), + check_lowp=False, + ) + + @torch._dynamo.optimize() + def foo(m, inp): + return m(inp) + + self.assertTrue( + foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last) + ) + # https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527 @requires_cuda() def test_unspec_inputs_interop(self): diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 448c057ecb0e..240c196a73b6 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -19,7 +19,12 @@ import torch.fx import torch.utils._pytree as pytree -from torch._prims_common import is_boolean_dtype, is_float_dtype +from torch._prims_common import ( + is_boolean_dtype, + is_float_dtype, + make_channels_last_strides_for, + make_contiguous_strides_for, +) from torch._subclasses.fake_tensor import FakeTensorMode from . import config, dependencies @@ -133,7 +138,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] else: - stride = torch._prims_common.make_contiguous_strides_for(size) + stride = make_contiguous_strides_for(size) dtype = x.get_dtype() device = x.get_device() t = torch.empty_strided( @@ -2462,6 +2467,9 @@ def require_stride_order(cls, x, order): x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x + # TODO - Storage to InputBuffer + if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): + return x x = cls.copy_input(x) as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) assert is_stride_order_storage_and_layout(x, order) @@ -3052,9 +3060,32 @@ def create( output_padding_: List[int], groups: int, ): + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, + stride_, + padding_, + dilation_, + transposed, + output_padding_, + groups, + ) + req_stride_order = get_stride_order(output.stride()) + + if config.triton.convolution == "aten": + weight = cls.require_stride_order(weight, req_stride_order) + x = cls.require_stride_order(x, req_stride_order) + else: + x = cls.require_stride1(cls.realize_input(x)) + weight = cls.require_stride1(cls.realize_input(weight)) - weight = cls.require_stride1(cls.realize_input(weight)) - x = cls.require_stride_order(x, get_stride_order(weight.get_stride())) stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) @@ -3062,22 +3093,6 @@ def create( output_padding = tuple(output_padding_) assert isinstance(groups, int) - # TODO - enable FakeTensorMode for propagation more globally. incorrect stride metas for fallback - # kernels will lead to runtime failures - with FakeTensorMode(): - output, *_ = cls.process_kernel( - torch.ops.aten.convolution, - x, - weight, - bias, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - ) - output_size = output.shape weight_shape = [ @@ -3122,6 +3137,7 @@ def create( # for conv2d or conv3d, prefer channels last format if kernel == "triton_ops.conv": output_layout_str = "torch.channels_last" + elif config.tune_layout and len(x.get_size()) == 4: from .codegen.autotuner import tuned_conv_layout @@ -3151,14 +3167,19 @@ def create( if len(stride_order) < len(output_size): # add batch dim if it exists stride_order = [len(stride_order)] + stride_order + strides = make_channels_last_strides_for(output_size) else: stride_order = list(reversed(range(len(output_size)))) + strides = make_contiguous_strides_for(output_size) - output_layout = FlexibleLayout( + if config.triton.convolution != "aten": + x = cls.require_stride_order(x, stride_order) + + output_layout = FixedLayout( x.get_device(), x.get_dtype(), output_size, - stride_order, + strides, ) if bias is not None: @@ -3178,13 +3199,6 @@ def create( kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - def map_args(self): # x, w, bias in_args = [x.codegen_reference() for x in self.inputs] From 5f0783bd6d27a0a239263b943d626c533b8b9a90 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 11 Nov 2022 17:43:46 +0000 Subject: [PATCH 070/453] Fix ATen Fallback for BUILD_CAFFE2=0 for ONNX-only ops (#88504) Follow-up for #87735 Once again, because BUILD_CAFFE2=0 is not tested for ONNX exporter, one scenario slipped through. A use case where the model can be exported without aten fallback when operator_export_type=ONNX_ATEN_FALLBACK and BUILD_CAFFE2=0 A new unit test has been added, but it won't prevent regressions if BUILD_CAFFE2=0 is not executed on CI again Fixes #87313 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88504 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- test/onnx/test_pytorch_onnx_no_runtime.py | 220 +++++++++++++--------- torch/onnx/utils.py | 19 +- 2 files changed, 149 insertions(+), 90 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 622f42effb4a..89526c71ca38 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -18,7 +18,7 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch.onnx import symbolic_helper, utils +from torch.onnx import OperatorExportTypes, symbolic_helper, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils @@ -935,6 +935,139 @@ def forward(self, x, w): torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + @common_utils.skipIfNoCaffe2 + def test_caffe2_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN, + OperatorExportTypes.ONNX_ATEN_FALLBACK, + ): + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfNoCaffe2 + def test_caffe2_onnx_aten_must_not_fallback(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + # TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize + for operator_export_type in ( + OperatorExportTypes.ONNX_ATEN_FALLBACK, + OperatorExportTypes.ONNX_ATEN, + ): + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=operator_export_type, + opset_version=10, # or higher + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + assert onnx_model.graph.node[0].op_type == "Mod" + + @common_utils.skipIfCaffe2 + def test_aten_fallback_must_fallback(self): + class ModelWithAtenNotONNXOp(torch.nn.Module): + def forward(self, x, y): + abcd = x + y + defg = torch.linalg.qr(abcd) + return defg + + x = torch.rand(3, 4) + y = torch.rand(3, 4) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenNotONNXOp(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + # support for linalg.qr was added in later op set versions. + opset_version=9, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "linalg_qr") + + @common_utils.skipIfCaffe2 + def test_onnx_aten(self): + class ModelWithAtenFmod(torch.nn.Module): + def forward(self, x, y): + return torch.fmod(x, y) + + x = torch.randn(3, 4, dtype=torch.float32) + y = torch.randn(3, 4, dtype=torch.float32) + f = io.BytesIO() + torch.onnx.export( + ModelWithAtenFmod(), + (x, y), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + self.assertAtenOp(onnx_model, "fmod", "Tensor") + + @common_utils.skipIfCaffe2 + def test_onnx_aten_fallback_must_not_fallback(self): + # For BUILD_CAFFE2=0, aten fallback only when not exportable + class ONNXExportable(torch.nn.Module): + def __init__(self): + super(ONNXExportable, self).__init__() + self.quant = torch.quantization.QuantStub() + self.fc1 = torch.nn.Linear(12, 8) + self.fc2 = torch.nn.Linear(8, 4) + self.fc3 = torch.nn.Linear(4, 6) + self.dequant = torch.quantization.DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = x.view((-1, 12)) + h = F.relu(self.fc1(x)) + h = F.relu(self.fc2(h)) + h = F.relu(self.fc3(h)) + h = self.dequant(h) + return h + + dummy_input = torch.randn(12) + f = io.BytesIO() + torch.onnx.export( + ONNXExportable(), + (dummy_input,), + f, + do_constant_folding=False, + operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + ) + onnx_model = onnx.load(io.BytesIO(f.getvalue())) + all_aten_nodes = [ + p + for p in onnx_model.graph.node + if p.op_type == "ATen" and p.domain == "org.pytorch.aten" + ] + self.assertEqual(len(all_aten_nodes), 0) + class TestQuantizeEagerONNXExport(common_utils.TestCase): def _test_lower_graph_impl(self, model, data): @@ -997,91 +1130,6 @@ def test_lower_graph_conv3d(self): data = torch.from_numpy(data_numpy).to(dtype=torch.float) self._test_lower_graph_impl(model, data) - @common_utils.skipIfNoCaffe2 - def test_caffe2_aten_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfNoCaffe2 - def test_caffe2_onnx_aten(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, - opset_version=10, # or higher - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - assert onnx_model.graph.node[0].op_type == "Mod" - - @common_utils.skipIfCaffe2 - def test_aten_fallback(self): - class ModelWithAtenNotONNXOp(torch.nn.Module): - def forward(self, x, y): - abcd = x + y - defg = torch.linalg.qr(abcd) - return defg - - x = torch.rand(3, 4) - y = torch.rand(3, 4) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenNotONNXOp(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, - # support for linalg.qr was added in later op set versions. - opset_version=9, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "linalg_qr") - - @common_utils.skipIfCaffe2 - def test_onnx_aten(self): - class ModelWithAtenFmod(torch.nn.Module): - def forward(self, x, y): - return torch.fmod(x, y) - - x = torch.randn(3, 4, dtype=torch.float32) - y = torch.randn(3, 4, dtype=torch.float32) - f = io.BytesIO() - torch.onnx.export( - ModelWithAtenFmod(), - (x, y), - f, - do_constant_folding=False, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN, - ) - onnx_model = onnx.load(io.BytesIO(f.getvalue())) - self.assertAtenOp(onnx_model, "fmod", "Tensor") - if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index ff0ef755968d..b30b71812aae 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1752,10 +1752,21 @@ def _should_aten_fallback( ) is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK - return name.startswith("aten::") and ( - ((is_onnx_aten_export or is_aten_fallback_export) and not is_caffe2_build) - or (not is_exportable_aten_op and is_aten_fallback_export) - ) + if not name.startswith("aten::"): + return False + + if is_caffe2_build: + if ( + is_onnx_aten_export or is_aten_fallback_export + ) and not is_exportable_aten_op: + return True + else: + if is_onnx_aten_export or ( + is_aten_fallback_export and not is_exportable_aten_op + ): + return True + + return False @_beartype.beartype From 3d1c5c89ed27ff16601aecf7834a6bd06f578c45 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:21 +0000 Subject: [PATCH 071/453] [FSDP][state_dict][4/N] Move the core logic of summon full parameters to _unshard_params_utils.py (#88636) **What** `_summon_full_parameters` is required for state_dict. To enable composable FSDP state_dict, `_summon_full_params` must be accessible without FullyShardedDataParall. This PR move the core logic of `_summon_full_params` to `_unshard_params_utils`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88636 Approved by: https://github.com/awgu --- test/distributed/fsdp/test_fsdp_state_dict.py | 2 +- .../fsdp/test_fsdp_summon_full_params.py | 4 +- torch/distributed/fsdp/_state_dict_utils.py | 34 ++- .../distributed/fsdp/_unshard_param_utils.py | 254 ++++++++++++++++++ .../fsdp/fully_sharded_data_parallel.py | 201 ++------------ 5 files changed, 290 insertions(+), 205 deletions(-) create mode 100644 torch/distributed/fsdp/_unshard_param_utils.py diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 48dad3118db7..ba51ae66ed1b 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -25,7 +25,7 @@ StateDictType, ) from torch.distributed.fsdp._shard_utils import _gather_state_dict -from torch.distributed.fsdp.fully_sharded_data_parallel import FLAT_PARAM +from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel diff --git a/test/distributed/fsdp/test_fsdp_summon_full_params.py b/test/distributed/fsdp/test_fsdp_summon_full_params.py index 0d4e98069117..18055dbebffb 100644 --- a/test/distributed/fsdp/test_fsdp_summon_full_params.py +++ b/test/distributed/fsdp/test_fsdp_summon_full_params.py @@ -212,7 +212,7 @@ def forward(self, fsdp_module): model = FSDP(MyModule()).cuda(self.rank) with self.assertRaisesRegex( - ValueError, "current state is TrainingState.FORWARD" + ValueError, "Current handle state is HandleTrainingState.FORWARD" ): model(model) @@ -231,7 +231,7 @@ def bad_backwards_hook(tensor): output.register_hook(bad_backwards_hook) with self.assertRaisesRegex( - ValueError, "current state is TrainingState.FORWARD_BACKWARD" + ValueError, "Current handle state is HandleTrainingState.BACKWARD_PRE" ): output.backward() diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 0bfd149b0112..eee5522340b4 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -24,7 +24,6 @@ clean_tensor_name, FSDP_PREFIX, FSDP_WRAPPED_MODULE, - TrainingState, ) from torch.distributed.fsdp._runtime_utils import ( _cast_buffers_to_dtype_and_device, @@ -40,6 +39,11 @@ _ext_pre_load_state_dict_transform, _extensions as _user_extensions, ) +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_orig_params, + FLAT_PARAM, +) from .flat_param import FlatParamHandle @@ -91,10 +95,6 @@ def _enter_full_param_ctx( "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " "is not None." ) - assert fsdp_state.training_state != TrainingState.SUMMON_FULL_PARAMS, ( - "Entering the summon_full_params context but the state is already " - "SUMMON_FULL_PARAMS." - ) fsdp_state._full_param_ctx = fsdp_state._summon_full_params( recurse=recurse, writeback=writeback, @@ -108,10 +108,6 @@ def _enter_full_param_ctx( @no_type_check def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: """A helper function to exit ``summon_full_params`` context.""" - assert fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS, ( - "Exiting the summon_full_params context but the state is not " - "SUMMON_FULL_PARAMS." - ) assert fsdp_state._full_param_ctx is not None fsdp_state._full_param_ctx.__exit__(None, None, None) fsdp_state._full_param_ctx = None @@ -168,9 +164,6 @@ def _common_summon_post_state_dict_hook( hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) - assert ( - fsdp_state.training_state == TrainingState.SUMMON_FULL_PARAMS - ), "Inside the post_state_dict_hook but the state is not SUMMON_FULL_PARAMS." # Return early for trivial cases if not state_dict or not _has_fsdp_params(fsdp_state, module): _exit_full_param_ctx(fsdp_state) @@ -180,7 +173,7 @@ def _common_summon_post_state_dict_hook( # For `use_orig_params=True`, the `FlatParameter` is not registered, so # there is no entry in the state dict for it to pop. if not fsdp_state._use_orig_params: - state_dict.pop(f"{prefix}{fsdp_file.FLAT_PARAM}") + state_dict.pop(f"{prefix}{FLAT_PARAM}") # If a rank does not have unsharded parameters(when `rank0_only=True` # and `rank != 0`), then the rank only needed to participate in the @@ -338,6 +331,7 @@ def _full_pre_load_state_dict_hook( state_dict: Dict[str, Any], prefix: str, ) -> None: + _lazy_init(fsdp_state, module) _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") @@ -409,7 +403,7 @@ def _local_post_state_dict_hook( ) # type: ignore[assignment] if fsdp_state._state_dict_config.offload_to_cpu: sharded_tensor = sharded_tensor.cpu() - state_dict[f"{prefix}{fsdp_file.FLAT_PARAM}"] = sharded_tensor + state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor return state_dict @@ -430,8 +424,9 @@ def _local_pre_load_state_dict_hook( state_dict. The flat_param should be a ShardedTensor. This hook converts the ShardedTensor to a tensor. No copy happen unless padding is required. """ + _lazy_init(fsdp_state, module) _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") - fqn = f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}" + fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" if fqn not in state_dict: assert not _has_fsdp_params(fsdp_state, module), ( "No `FlatParameter` in `state_dict` for this FSDP instance " @@ -527,7 +522,7 @@ def _sharded_post_load_state_dict_hook( module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if fsdp_state._use_orig_params: - fsdp_state._register_orig_params() + _register_orig_params(module, fsdp_state) @no_type_check @@ -541,6 +536,7 @@ def _sharded_pre_load_state_dict_hook( The hook combines the unflattened, sharded parameters (ShardedTensor) to a new FlatParameter and shards the new FlatParameter to the local chunk. """ + _lazy_init(fsdp_state, module) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") if not _has_fsdp_params(fsdp_state, module): return @@ -605,7 +601,7 @@ def _sharded_pre_load_state_dict_hook( assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), ( f"The original shapes in FSDP are {flat_param._shapes}. " f"The loaded shapes are {loaded_shapes}. " - f"FSDP extension is {'NOT' if _user_extensions is None else ''} None." + f"FSDP extension is {'NOT' if _user_extensions is not None else ''} None." ) assert flat_param.numel() == loaded_flat_tensor.numel(), ( f"The loaded local chunk has different numel({loaded_flat_tensor.numel()}) " @@ -615,9 +611,9 @@ def _sharded_pre_load_state_dict_hook( f"The loaded local chunk has different padding({num_to_pad}) " f"from the local chunk {flat_param._shard_numel_padded}." ) - state_dict[f"{prefix}{FSDP_PREFIX}{fsdp_file.FLAT_PARAM}"] = loaded_flat_tensor + state_dict[f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"] = loaded_flat_tensor if fsdp_state._use_orig_params: - fsdp_state._deregister_orig_params() + _deregister_orig_params(module, fsdp_state) @no_type_check diff --git a/torch/distributed/fsdp/_unshard_param_utils.py b/torch/distributed/fsdp/_unshard_param_utils.py new file mode 100644 index 000000000000..950841850b62 --- /dev/null +++ b/torch/distributed/fsdp/_unshard_param_utils.py @@ -0,0 +1,254 @@ +import contextlib +import warnings +from typing import cast, Generator, List + +import torch +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _has_fsdp_params, + _module_handles, + HandleTrainingState, +) +from torch.distributed.fsdp._runtime_utils import ( + _clear_grads_if_needed, + _reshard, + _reshard_grads, + _unshard, + _unshard_grads, +) +from ._utils import p_assert +from .flat_param import FlatParamHandle + +FLAT_PARAM = "_flat_param" + + +@torch.no_grad() +def _writeback_to_local_shard( + handles: List[FlatParamHandle], + writeback_grad: bool, +): + """ + For each handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. If + ``writeback_grad=True``, then writes back to the sharded gradient as + well. + + Precondition: Each handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + for handle in handles: + # For `NO_SHARD`, `_local_shard` is the unsharded flattened + # parameter and `grad` is the unsharded gradient, so there is no + # need to writeback for either + if not handle.uses_sharded_strategy: + continue + assert ( + handle.flat_param.ndim == 1 + ), f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}" + + # Get the unpadded shard instead of the padded shard to persist + # user changes to the padding (though FSDP does not explicitly + # support this) + param_shard, _ = FlatParamHandle._get_unpadded_shard( + handle.flat_param, + handle.rank, + handle.world_size, + ) + handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] + if writeback_grad: + existing_grad = handle.sharded_grad + if existing_grad is not None: + assert handle.flat_param.grad is not None + grad_shard, _ = FlatParamHandle._get_unpadded_shard( + handle.flat_param.grad, + handle.rank, + handle.world_size, + ) + existing_grad[: grad_shard.numel()].copy_(grad_shard) + + +def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + De-registers the flattened parameter from the wrapped module, hiding it + from ``nn.Module`` methods. + + We do not use ``del`` because we want ``FLAT_PARAM`` to always be an + attribute but dynamically change whether it is visible to ``nn.Module`` + methods. + """ + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) + + +def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + Registers the flattened parameter to the wrapped module, making it + visible to ``nn.Module`` methods. + + We do not use :meth:`nn.Module.register_parameter` because we want + ``FLAT_PARAM`` to always be an attribute but dynamically change whether + it is visible to ``nn.Module`` methods. + """ + handles = _module_handles(state, module) + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handles[0].flat_param + + +@contextlib.contextmanager +def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + de-registers the flattened parameter and unflattens the original + parameters as ``nn.Parameter`` views into the flattened parameter. + After the context, re-registers the flattened parameter and restores + the original parameters as ``Tensor`` views into the flattened + parameter. + """ + handles = _module_handles(state, module) + if not handles: + yield + else: + _deregister_flat_param(state, module) + try: + with handles[0].unflatten_as_params(): + yield + finally: + if not handles[0]._use_orig_params: + _register_flat_param(state, module) + + +@contextlib.contextmanager +def _unshard_params( + module: nn.Module, + state: _FSDPState, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +): + if with_grads and (offload_to_cpu or not state._use_orig_params): + raise NotImplementedError( + f"with_grads={with_grads} " + f"use_orig_params={state._use_orig_params} " + f"offload_to_cpu={offload_to_cpu} " + f"is not supported yet" + ) + if writeback and rank0_only: + raise ValueError( + "writeback=True and rank0_only=True is not supported, as model " + "parameter shapes will be different across ranks, and writing " + "to them can lead to inconsistencies across ranks when the " + "context is exited." + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu and rank0_only=False will result in " + "full parameters being redundantly copied to CPU memory for " + "GPUs that reside on the same machine, which may incur the risk of " + "CPU OOM. It is recommended to use ``offload_to_cpu`` with " + "rank0_only=True." + ) + + torch.cuda.synchronize() + # If handles are shared by other module(s), the handle may be already unsharded. + handles = [ + handle + for handle in _module_handles(state, module) + if handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS + ] + if not handles: + yield + return + + for handle in handles: + if handle._training_state != HandleTrainingState.IDLE: + raise ValueError(f"Current handle state is {handle._training_state}") + + for handle in handles: + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + _clear_grads_if_needed(handles) + free_unsharded_flat_params = [handle.needs_unshard() for handle in handles] + # No need to call `wait_stream()` since we unshard in the computation + # stream directly + computation_stream = torch.cuda.current_stream() + _unshard(state, handles, computation_stream, computation_stream) + if with_grads: + _unshard_grads(handles) + + if rank0_only and state.rank != 0: + # Free the unsharded flattened parameter early + _reshard(state, handles, free_unsharded_flat_params) + if with_grads: + _reshard_grads(handles) + try: + yield + finally: + for handle in handles: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + for handle in handles: + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # TODO (awgu): Since PyTorch enforces that a parameter + # and its gradients need to match metadata (e.g. + # device), we must move gradients to CPU *after* we + # move parameters. + # TODO (awgu): This FPW call assumes 1 `FlatParameter` + if not state._use_orig_params: + stack.enter_context(_unflatten_as_params(state, module)) + try: + yield + finally: + stack.close() + if writeback: + _writeback_to_local_shard(handles, with_grads) + _reshard(state, handles, free_unsharded_flat_params) + if with_grads: + _reshard_grads(handles) + for handle in handles: + handle._training_state = HandleTrainingState.IDLE + + +def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the original parameters; registers the ``FlatParameter``. + """ + handles = _module_handles(state, module) + p_assert( + len(handles) <= 1, + "Expects <=1 handle per FSDP instance; needs to be refactored " + "for >1 handle (e.g. non-recursive wrapping)", + ) + if not handles: + return + handle = handles[0] + p_assert( + handle._use_orig_params, + f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " + f"handle: {handle._use_orig_params}", + ) + handle._deregister_orig_params() + _register_flat_param(state, module) + + +def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the ``FlatParameter``; registers the original parameters. + """ + handles = _module_handles(state, module) + if not handles: + return + handle = handles[0] + _deregister_flat_param(state, module) + if handle.is_sharded(handle.flat_param): + handle._use_sharded_views() + handle._use_sharded_grad_views() + else: + handle._use_unsharded_views(as_params=True) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 773686081a4d..510f90de2023 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -48,18 +48,14 @@ _init_state_dict_state, ) from torch.distributed.fsdp._runtime_utils import ( - _clear_grads_if_needed, _lazy_init, _post_forward, _post_forward_reshard, _pre_forward, _pre_forward_unshard, _reshard, - _reshard_grads, _root_pre_forward, _should_free_in_backward, - _unshard, - _unshard_grads, _wait_for_computation_stream, ) from torch.distributed.fsdp._wrap_utils import _auto_wrap @@ -92,6 +88,12 @@ _post_state_dict_hook, _pre_load_state_dict_hook, ) +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_flat_param, + _register_orig_params, + _unshard_params, +) from ._utils import p_assert from .flat_param import FlatParameter, FlatParamHandle from .wrap import ParamExecOrderWrapPolicy @@ -409,7 +411,7 @@ def __init__( self._fsdp_wrapped_module = module if not use_orig_params: _check_orig_params_flattened(self, self._ignored_params) - self._register_flat_param() + _register_flat_param(self, self) # Delete to avoid keeping references after the constructor delattr(self, "_ignored_params") @@ -864,153 +866,20 @@ def _summon_full_params( yield return - torch.cuda.synchronize() _lazy_init(self, self) - self._assert_state([TrainingState.IDLE]) - for handle in self._handles: - assert handle._training_state == HandleTrainingState.IDLE - self.training_state = TrainingState.SUMMON_FULL_PARAMS - for handle in self._handles: - handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS - - if self._is_root: - _clear_grads_if_needed(self._fsdp_handles(self)) - free_unsharded_flat_params = [ - handle.needs_unshard() for handle in self._handles - ] - # No need to call `wait_stream()` since we unshard in the computation - # stream directly - computation_stream = torch.cuda.current_stream() - _unshard(self, self._handles, computation_stream, computation_stream) - if with_grads: - _unshard_grads(self._handles) - - if rank0_only and self.rank != 0: - # Free the unsharded flattened parameter early - _reshard(self, self._handles, free_unsharded_flat_params) - if with_grads: - _reshard_grads(self._handles) + with _unshard_params( + module=self, + state=self, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ): try: + self.training_state = TrainingState.SUMMON_FULL_PARAMS yield finally: self.training_state = TrainingState.IDLE - for handle in self._handles: - handle._training_state = HandleTrainingState.IDLE - else: - # Unflatten the unsharded flattened parameters - with contextlib.ExitStack() as stack: - # Invariant: rank == 0 or !rank0_only - for handle in self._handles: - if offload_to_cpu and handle.uses_sharded_strategy: - stack.enter_context(handle.to_cpu()) - # TODO (awgu): Since PyTorch enforces that a parameter - # and its gradients need to match metadata (e.g. - # device), we must move gradients to CPU *after* we - # move parameters. - # TODO (awgu): This FPW call assumes 1 `FlatParameter` - if not self._use_orig_params: - stack.enter_context(self._unflatten_as_params()) - try: - yield - finally: - stack.close() - if writeback: - self._writeback_to_local_shard(self._handles, with_grads) - _reshard(self, self._handles, free_unsharded_flat_params) - if with_grads: - _reshard_grads(self._handles) - self.training_state = TrainingState.IDLE - for handle in self._handles: - handle._training_state = HandleTrainingState.IDLE - - @torch.no_grad() - def _writeback_to_local_shard( - self, - handles: List[FlatParamHandle], - writeback_grad: bool, - ): - """ - For each handle, writes back the this rank's shard of the unsharded - flattened parameter to the sharded flattened parameter. If - ``writeback_grad=True``, then writes back to the sharded gradient as - well. - - Precondition: Each handle's ``FlatParameter`` 's data points to the - padded unsharded flattened parameter. - """ - for handle in handles: - # For `NO_SHARD`, `_local_shard` is the unsharded flattened - # parameter and `grad` is the unsharded gradient, so there is no - # need to writeback for either - if not handle.uses_sharded_strategy: - continue - assert ( - handle.flat_param.ndim == 1 - ), f"Expects `flat_param` to be flattened but got {handle.flat_param.shape}" - - # Get the unpadded shard instead of the padded shard to persist - # user changes to the padding (though FSDP does not explicitly - # support this) - param_shard, _ = FlatParamHandle._get_unpadded_shard( - handle.flat_param, - handle.rank, - handle.world_size, - ) - handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) - if writeback_grad: - existing_grad = handle.sharded_grad - if existing_grad is not None: - grad_shard, _ = FlatParamHandle._get_unpadded_shard( - handle.flat_param.grad, - handle.rank, - handle.world_size, - ) - existing_grad[: grad_shard.numel()].copy_(grad_shard) - - @contextlib.contextmanager - def _unflatten_as_params(self) -> Generator: - """ - Assumes that the flattened parameter is unsharded. When in the context, - de-registers the flattened parameter and unflattens the original - parameters as ``nn.Parameter`` views into the flattened parameter. - After the context, re-registers the flattened parameter and restores - the original parameters as ``Tensor`` views into the flattened - parameter. - """ - if not self._handles: - yield - else: - self._deregister_flat_param() - try: - with self._handles[0].unflatten_as_params(): - yield - finally: - if not self._handles[0]._use_orig_params: - self._register_flat_param() - - def _register_flat_param(self): - """ - Registers the flattened parameter to the wrapped module, making it - visible to ``nn.Module`` methods. - - We do not use :meth:`nn.Module.register_parameter` because we want - ``FLAT_PARAM`` to always be an attribute but dynamically change whether - it is visible to ``nn.Module`` methods. - """ - if self._has_params: - self.module._parameters[FLAT_PARAM] = self._handles[0].flat_param - - def _deregister_flat_param(self): - """ - De-registers the flattened parameter from the wrapped module, hiding it - from ``nn.Module`` methods. - - We do not use ``del`` because we want ``FLAT_PARAM`` to always be an - attribute but dynamically change whether it is visible to ``nn.Module`` - methods. - """ - if self._has_params: - self.module._parameters.pop(FLAT_PARAM, None) @contextlib.contextmanager def _deregister_orig_params_ctx(self): @@ -1026,46 +895,12 @@ def _deregister_orig_params_ctx(self): "`_use_orig_params=True`", ) for fsdp_module in self.fsdp_modules(self): - fsdp_module._deregister_orig_params() + _deregister_orig_params(fsdp_module, fsdp_module) try: yield finally: for fsdp_module in self.fsdp_modules(self): - fsdp_module._register_orig_params() - - def _deregister_orig_params(self): - """ - Deregisters the original parameters; registers the ``FlatParameter``. - """ - p_assert( - len(self._handles) <= 1, - "Expects <=1 handle per FSDP instance; needs to be refactored " - "for >1 handle (e.g. non-recursive wrapping)", - ) - if not self._handles: - return - handle = self._handles[0] - p_assert( - handle._use_orig_params, - f"Inconsistent `_use_orig_params` -- FSDP: {self._use_orig_params} " - f"handle: {handle._use_orig_params}", - ) - handle._deregister_orig_params() - self._register_flat_param() - - def _register_orig_params(self): - """ - Deregisters the ``FlatParameter``; registers the original parameters. - """ - if not self._handles: - return - handle = self._handles[0] - self._deregister_flat_param() - if handle.is_sharded(handle.flat_param): - handle._use_sharded_views() - handle._use_sharded_grad_views() - else: - handle._use_unsharded_views(as_params=True) + _register_orig_params(fsdp_module, fsdp_module) def _apply(self, *args, **kwargs): """ From 9d7d21f5691979f728f42a709e1a47ab3e905342 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:31 -0800 Subject: [PATCH 072/453] [ONNX] Add stack info to diagnostics (#87258) ~~Investigating strange bug releasing 'graph' right when returning from `_C._jit_pass_onnx`.~~ ~~Can be repro-ed locally via `test_cpp_diagnose`, with changes in this PR.~~ Resolved by https://github.com/pytorch/pytorch/pull/87829. This PR adds methods to record stack backtrace information to diagnostics. * #87830 Pull Request resolved: https://github.com/pytorch/pytorch/pull/87258 Approved by: https://github.com/abock --- test/onnx/internal/test_diagnostics.py | 77 +++++++++++++++---- .../onnx/_internal/diagnostics/_diagnostic.py | 61 ++++++++++++--- .../_internal/diagnostics/infra/__init__.py | 2 + .../_internal/diagnostics/infra/_infra.py | 49 ++++++------ .../onnx/_internal/diagnostics/infra/utils.py | 35 +++++++++ 5 files changed, 169 insertions(+), 55 deletions(-) create mode 100644 torch/onnx/_internal/diagnostics/infra/utils.py diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index fbd888329a50..ea9a789e91c1 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -3,6 +3,7 @@ import contextlib import dataclasses import io +import typing import unittest from typing import AbstractSet, Tuple @@ -110,23 +111,15 @@ class TestOnnxDiagnostics(common_utils.TestCase): def setUp(self): engine = diagnostics.engine engine.clear() + self._sample_rule = diagnostics.rules.missing_custom_symbolic_function super().setUp() - def test_assert_diagnostic_raises_when_diagnostic_not_found(self): - with self.assertRaises(AssertionError): - with assert_diagnostic( - self, - diagnostics.engine, - diagnostics.rules.node_missing_onnx_shape_inference, - diagnostics.levels.WARNING, - ): - pass - - def test_cpp_diagnose_emits_warning(self): + def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp( + self, + ) -> diagnostics.ExportDiagnostic: class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): - ctx.save_for_backward(x, y) return x + y @staticmethod @@ -137,6 +130,30 @@ class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) + # trigger warning for missing shape inference. + rule = diagnostics.rules.node_missing_onnx_shape_inference + torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + + context = diagnostics.engine.contexts[-1] + for diagnostic in context.diagnostics: + if ( + diagnostic.rule == rule + and diagnostic.level == diagnostics.levels.WARNING + ): + return typing.cast(diagnostics.ExportDiagnostic, diagnostic) + raise AssertionError("No diagnostic found.") + + def test_assert_diagnostic_raises_when_diagnostic_not_found(self): + with self.assertRaises(AssertionError): + with assert_diagnostic( + self, + diagnostics.engine, + diagnostics.rules.node_missing_onnx_shape_inference, + diagnostics.levels.WARNING, + ): + pass + + def test_cpp_diagnose_emits_warning(self): with assert_diagnostic( self, diagnostics.engine, @@ -144,7 +161,7 @@ def forward(self, x): diagnostics.levels.WARNING, ): # trigger warning for missing shape inference. - torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() def test_py_diagnose_emits_error(self): class M(torch.nn.Module): @@ -168,15 +185,43 @@ def forward(self, x): def test_diagnostics_engine_records_diagnosis_reported_outside_of_export( self, ): - sample_rule = diagnostics.rules.missing_custom_symbolic_function sample_level = diagnostics.levels.ERROR with assert_diagnostic( self, diagnostics.engine, - sample_rule, + self._sample_rule, sample_level, ): - diagnostics.context.diagnose(sample_rule, sample_level) + diagnostics.context.diagnose(self._sample_rule, sample_level) + + def test_diagnostics_records_python_call_stack(self): + diagnostic = diagnostics.ExportDiagnostic( + self._sample_rule, diagnostics.levels.NOTE + ) + stack = diagnostic.python_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame = stack.frames[0] + assert frame.location.snippet is not None # for mypy + self.assertIn("self._sample_rule", frame.location.snippet) + assert frame.location.uri is not None # for mypy + self.assertIn("test_diagnostics.py", frame.location.uri) + + def test_diagnostics_records_cpp_call_stack(self): + diagnostic = ( + self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp() + ) + stack = diagnostic.cpp_call_stack + assert stack is not None # for mypy + self.assertGreater(len(stack.frames), 0) + frame_messages = [frame.location.message for frame in stack.frames] + self.assertTrue( + any( + isinstance(message, str) + and "torch::jit::ONNXShapeTypeInference" in message + for message in frame_messages + ) + ) @dataclasses.dataclass diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index ae6615e831cb..21e44f2b4467 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -5,11 +5,38 @@ import torch from torch.onnx._internal.diagnostics import infra +from torch.onnx._internal.diagnostics.infra import utils as infra_utils +from torch.utils import cpp_backtrace # This is a workaround for mypy not supporting Self from typing_extensions. _ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic") +def _cpp_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32): + """Returns the current C++ call stack. + + This function utilizes `torch.utils.cpp_backtrace` to get the current C++ call stack. + The returned C++ call stack is a concatenated string of the C++ call stack frames. + Each frame is separated by a newline character, in the same format of + r"frame #[0-9]+: (?P.*)". More info at `c10/util/Backtrace.cpp`. + + """ + frames = cpp_backtrace.get_cpp_backtrace(frames_to_skip, frames_to_log).split("\n") + frame_messages = [] + for frame in frames: + segments = frame.split(":", 1) + if len(segments) == 2: + frame_messages.append(segments[1].strip()) + else: + frame_messages.append("") + return infra.Stack( + frames=[ + infra.StackFrame(location=infra.Location(message=message)) + for message in frame_messages + ] + ) + + class ExportDiagnostic(infra.Diagnostic): """Base class for all export diagnostics. @@ -18,24 +45,34 @@ class ExportDiagnostic(infra.Diagnostic): diagnostic. """ + python_call_stack: Optional[infra.Stack] = None + cpp_call_stack: Optional[infra.Stack] = None + def __init__( self, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) - - def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic: - # TODO: Implement this. - # self.stacks.append(...) - raise NotImplementedError() - return self - - def with_python_stack(self: _ExportDiagnostic) -> _ExportDiagnostic: - # TODO: Implement this. - # self.stacks.append(...) - raise NotImplementedError() - return self + self.record_python_call_stack(frames_to_skip=1) + self.record_cpp_call_stack(frames_to_skip=1) + + def record_python_call_stack(self, frames_to_skip) -> None: + """Records the current Python call stack in the diagnostic.""" + frames_to_skip += 1 # Skip this function. + stack = infra_utils.python_call_stack(frames_to_skip=frames_to_skip) + stack.message = "Python call stack" + self.with_stack(stack) + self.python_call_stack = stack + + def record_cpp_call_stack(self, frames_to_skip) -> None: + """Records the current C++ call stack in the diagnostic.""" + # No need to skip this function because python frame is not recorded + # in cpp call stack. + stack = _cpp_call_stack(frames_to_skip=frames_to_skip) + stack.message = "C++ call stack" + self.with_stack(stack) + self.cpp_call_stack = stack def with_model_source_location( self: _ExportDiagnostic, diff --git a/torch/onnx/_internal/diagnostics/infra/__init__.py b/torch/onnx/_internal/diagnostics/infra/__init__.py index ac9e6e99a974..4f9dd9e5fa0b 100644 --- a/torch/onnx/_internal/diagnostics/infra/__init__.py +++ b/torch/onnx/_internal/diagnostics/infra/__init__.py @@ -8,6 +8,7 @@ Rule, RuleCollection, Stack, + StackFrame, ) from .engine import DiagnosticEngine @@ -22,4 +23,5 @@ "Rule", "RuleCollection", "Stack", + "StackFrame", ] diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index 6966ccccbb26..b8a4c5032f52 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -110,11 +110,12 @@ def format_message(self, *args, **kwargs) -> str: @dataclasses.dataclass class Location: - uri: str - message: str + uri: Optional[str] = None line: Optional[int] = None + message: Optional[str] = None start_column: Optional[int] = None end_column: Optional[int] = None + snippet: Optional[str] = None def sarif(self) -> sarif.Location: """Returns the SARIF representation of this location.""" @@ -124,43 +125,37 @@ def sarif(self) -> sarif.Location: region=sarif.Region( start_line=self.line, start_column=self.start_column, - end_line=self.line, end_column=self.end_column, + snippet=sarif.ArtifactContent(text=self.snippet), ), ), - message=sarif.Message(text=self.message), + message=sarif.Message(text=self.message) + if self.message is not None + else None, ) +@dataclasses.dataclass +class StackFrame: + location: Location + + def sarif(self) -> sarif.StackFrame: + """Returns the SARIF representation of this stack frame.""" + return sarif.StackFrame(location=self.location.sarif()) + + @dataclasses.dataclass class Stack: - frame_locations: List[Location] = dataclasses.field(default_factory=list) + frames: List[StackFrame] = dataclasses.field(default_factory=list) + message: Optional[str] = None def sarif(self) -> sarif.Stack: """Returns the SARIF representation of this stack.""" return sarif.Stack( - frames=[ - sarif.StackFrame(location=loc.sarif()) for loc in self.frame_locations - ] - ) - - def add_frame( - self, - uri: str, - message: str, - line: Optional[int] = None, - start_column: Optional[int] = None, - end_column: Optional[int] = None, - ) -> None: - """Adds a frame to the stack.""" - self.frame_locations.append( - Location( - uri=uri, - message=message, - line=line, - start_column=start_column, - end_column=end_column, - ) + frames=[frame.sarif() for frame in self.frames], + message=sarif.Message(text=self.message) + if self.message is not None + else None, ) diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py new file mode 100644 index 000000000000..c32de1c6b8ad --- /dev/null +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -0,0 +1,35 @@ +import inspect + +from torch.onnx._internal.diagnostics.infra import _infra + + +def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame: + """Returns a StackFrame for the given inspect.FrameInfo.""" + snippet = ( + frame.code_context[frame.index] + if frame.code_context is not None and frame.index is not None + else None + ) + + return _infra.StackFrame( + location=_infra.Location( + uri=frame.filename, + line=frame.lineno, + snippet=snippet, + ) + ) + + +def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 32) -> _infra.Stack: + """Returns the current Python call stack.""" + if frames_to_skip < 0: + raise ValueError("frames_to_skip must be non-negative") + if frames_to_log < 0: + raise ValueError("frames_to_log must be non-negative") + frames_to_skip += 1 # Skip this function. + stack = _infra.Stack() + stack.frames = [ + python_frame(frame) + for frame in inspect.stack()[frames_to_skip : frames_to_skip + frames_to_log] + ] + return stack From 4e5d7afe84c01ed730f0f43395d7fa0542e81f3a Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 19:08:30 +0000 Subject: [PATCH 073/453] Revert "add DisableTorchFunction that matches DisableTorchDispatch (#88219)" This reverts commit c0ecce15b5a54ff0185f9976e6bfb6f3a7de698d. Reverted https://github.com/pytorch/pytorch/pull/88219 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901 --- aten/src/ATen/PythonTorchFunctionTLS.cpp | 11 +-- aten/src/ATen/PythonTorchFunctionTLS.h | 12 +-- test/allowlist_for_publicAPI.json | 1 - test/test_overrides.py | 21 ---- test/test_public_bindings.py | 1 - torch/_C/__init__.pyi.in | 1 - torch/__init__.py | 2 +- torch/csrc/Module.cpp | 4 - torch/csrc/autograd/init.cpp | 9 +- torch/csrc/utils/disable_torch_function.cpp | 100 ++------------------ torch/csrc/utils/disable_torch_function.h | 1 - 11 files changed, 24 insertions(+), 139 deletions(-) diff --git a/aten/src/ATen/PythonTorchFunctionTLS.cpp b/aten/src/ATen/PythonTorchFunctionTLS.cpp index 00f372f370e6..c9487c6958cb 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.cpp +++ b/aten/src/ATen/PythonTorchFunctionTLS.cpp @@ -26,12 +26,12 @@ int64_t PythonTorchFunctionTLS::stack_len() { return pythonTorchFunctionState.stack_.size(); } -void PythonTorchFunctionTLS::set_disabled_state(TorchFunctionDisabledState disabled_state) { - pythonTorchFunctionState.disabled_state_ = disabled_state; +void PythonTorchFunctionTLS::set_disabled(bool disabled) { + pythonTorchFunctionState.disabled_ = disabled; } -TorchFunctionDisabledState PythonTorchFunctionTLS::get_disabled_state() { - return pythonTorchFunctionState.disabled_state_; +bool PythonTorchFunctionTLS::is_disabled() { + return pythonTorchFunctionState.disabled_; } void PythonTorchFunctionTLS::set_state(const PythonTorchFunctionTLS& state) { @@ -43,8 +43,7 @@ const PythonTorchFunctionTLS& PythonTorchFunctionTLS::get_state() { } bool torch_function_mode_enabled() { - return PythonTorchFunctionTLS::get_disabled_state() != TorchFunctionDisabledState::ALL_DISABLED && - PythonTorchFunctionTLS::stack_len() > 0; + return PythonTorchFunctionTLS::stack_len() > 0; } } // namespace impl diff --git a/aten/src/ATen/PythonTorchFunctionTLS.h b/aten/src/ATen/PythonTorchFunctionTLS.h index a1e3a61ea202..5940fb6f2dee 100644 --- a/aten/src/ATen/PythonTorchFunctionTLS.h +++ b/aten/src/ATen/PythonTorchFunctionTLS.h @@ -6,11 +6,9 @@ namespace at { namespace impl { -enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; - struct TORCH_API PythonTorchFunctionTLS { - static void set_disabled_state(TorchFunctionDisabledState disabled_state_); - static TorchFunctionDisabledState get_disabled_state(); + static void set_disabled(bool); + static bool is_disabled(); static void push_onto_stack(std::shared_ptr mode); static const std::shared_ptr pop_stack(); @@ -22,11 +20,11 @@ struct TORCH_API PythonTorchFunctionTLS { private: // The mode TLS is split into - // - disabled_state, which says which part of torch function are disabled + // - disabled_, which says whether or not to disable all torch function + // modes // - stack_, which is a vector of modes representing the stack of user // defined modes - TorchFunctionDisabledState disabled_state_ = - TorchFunctionDisabledState::ENABLED; + bool disabled_; std::vector> stack_; }; diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 45ba9ae94676..8a66dc12d4b6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,7 +1128,6 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", - "DisableTorchFunction", "DisableTorchFunctionSubclass", "Generator", "HalfStorage", diff --git a/test/test_overrides.py b/test/test_overrides.py index 3b3a5ed063c7..01c763a548fc 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1453,27 +1453,6 @@ class B(torch.Tensor): self.assertTrue(called) - def test_disable_subclass_mode(self): - called = False - - class A(TorchFunctionMode): - def __torch_function__(self, func, types, args=(), kwargs=None): - nonlocal called - if kwargs is None: - kwargs = {} - called = True - return func(*args, **kwargs) - - class B(torch.Tensor): - pass - - x = B(torch.randn(5)) - with A(): - with torch._C.DisableTorchFunction(): - self.assertNotIsInstance(torch.sum(x), B) - - self.assertFalse(called) - def test_disable_enable_subclass(self): called = False diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 46c7396b9b07..6897c3102df6 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,7 +99,6 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", - "DisableTorchFunction", "DisableTorchFunctionSubclass", "DispatchKey", "DispatchKeySet", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index bc4bf03d8161..79dd6386c378 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,7 +108,6 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunction(): ... def DisableTorchFunctionSubclass(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp diff --git a/torch/__init__.py b/torch/__init__.py index 6049967b6f18..ec23499dce65 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -315,7 +315,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'DisableTorchFunction', 'Generator']: + if name not in ['DisableTorchFunctionSubclass', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 0a9aa53a0bbc..efe6c18ea0cd 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1597,10 +1597,6 @@ Call this whenever a new thread is created in order to propagate values from "DisableTorchFunctionSubclass", (PyObject*)THPModule_DisableTorchFunctionSubclassType(), /* incref= */ false)); - ASSERT_TRUE(set_module_attr( - "DisableTorchFunction", - (PyObject*)THPModule_DisableTorchFunctionType(), - /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6271cfd5cb99..d26db95f1295 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -60,14 +60,13 @@ struct DisableAutocast { struct EnableTorchFunction { EnableTorchFunction() - : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::ENABLED); + : old_(at::impl::PythonTorchFunctionTLS::is_disabled()) { + at::impl::PythonTorchFunctionTLS::set_disabled(false); } ~EnableTorchFunction() { - at::impl::PythonTorchFunctionTLS::set_disabled_state(old_); + at::impl::PythonTorchFunctionTLS::set_disabled(old_); } - at::impl::TorchFunctionDisabledState old_; + bool old_; }; struct EnablePythonDispatcher { diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 589b069250a3..516e6b89d43a 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -11,8 +11,7 @@ PyObject* disabled_torch_function = nullptr; PyObject* disabled_torch_dispatch = nullptr; bool torch_function_enabled() { - return at::impl::PythonTorchFunctionTLS::get_disabled_state() == - at::impl::TorchFunctionDisabledState::ENABLED; + return !at::impl::PythonTorchFunctionTLS::is_disabled(); } PyObject* disabled_torch_function_impl() { @@ -35,23 +34,20 @@ void set_disabled_torch_dispatch_impl(PyObject* value) { typedef struct { PyObject_HEAD /* Type-specific fields go here. */ - at::impl::TorchFunctionDisabledState old_state; + bool old_state; } DisableTorchFunctionSubclass; PyObject* DisableTorchFunctionSubclass__enter( PyObject* self, PyObject* unused) { - const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state(); - ((DisableTorchFunctionSubclass*)self)->old_state = old_state; - if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); - } + ((DisableTorchFunctionSubclass*)self)->old_state = + at::impl::PythonTorchFunctionTLS::is_disabled(); + at::impl::PythonTorchFunctionTLS::set_disabled(true); Py_RETURN_NONE; } PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::PythonTorchFunctionTLS::set_disabled( ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } @@ -119,81 +115,6 @@ PyObject* THPModule_DisableTorchFunctionSubclassType() { return (PyObject*)(&DisableTorchFunctionSubclassType); } -typedef struct { - PyObject_HEAD - /* Type-specific fields go here. */ - at::impl::TorchFunctionDisabledState old_state; -} DisableTorchFunction; - -PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { - ((DisableTorchFunctionSubclass*)self)->old_state = - at::impl::PythonTorchFunctionTLS::get_disabled_state(); - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::ALL_DISABLED); - Py_RETURN_NONE; -} - -PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - ((DisableTorchFunctionSubclass*)self)->old_state); - Py_RETURN_NONE; -} - -static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT - {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, - {nullptr, nullptr, 0, nullptr}}; - -PyTypeObject DisableTorchFunctionType = { - PyVarObject_HEAD_INIT( - nullptr, - 0) "torch._C.DisableTorchFunction", /* tp_name */ - sizeof(DisableTorchFunction), /* tp_basicsize */ - 0, /* tp_itemsize */ - nullptr, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - nullptr, /* tp_getattr */ - nullptr, /* tp_setattr */ - nullptr, /* tp_reserved */ - nullptr, /* tp_repr */ - nullptr, /* tp_as_number */ - nullptr, /* tp_as_sequence */ - nullptr, /* tp_as_mapping */ - nullptr, /* tp_hash */ - nullptr, /* tp_call */ - nullptr, /* tp_str */ - nullptr, /* tp_getattro */ - nullptr, /* tp_setattro */ - nullptr, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - nullptr, /* tp_doc */ - nullptr, /* tp_traverse */ - nullptr, /* tp_clear */ - nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - nullptr, /* tp_iter */ - nullptr, /* tp_iternext */ - DisableTorchFunction_methods, /* tp_methods */ - nullptr, /* tp_members */ - nullptr, /* tp_getset */ - nullptr, /* tp_base */ - nullptr, /* tp_dict */ - nullptr, /* tp_descr_get */ - nullptr, /* tp_descr_set */ - 0, /* tp_dictoffset */ - nullptr, /* tp_init */ - PyType_GenericAlloc, /* tp_alloc */ - PyType_GenericNew, /* tp_new */ -}; - -PyObject* THPModule_DisableTorchFunctionType() { - if (PyType_Ready(&DisableTorchFunctionType) < 0) { - return nullptr; - } - - return (PyObject*)(&DisableTorchFunctionType); -} - PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS PyObject *func = nullptr, *types = nullptr, *args = nullptr, @@ -216,14 +137,11 @@ PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { // These are all C-API calls so no exceptions will be raised // and therefore no need for RAII approach to storing // the old value. - auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state(); - if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) { - at::impl::PythonTorchFunctionTLS::set_disabled_state( - at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); - } + bool old_value = at::impl::PythonTorchFunctionTLS::is_disabled(); + at::impl::PythonTorchFunctionTLS::set_disabled(true); // kwargs can safely be nullptr here. PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs); - at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value); + at::impl::PythonTorchFunctionTLS::set_disabled(old_value); return result; END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 8fc5118830eb..881a7adb13eb 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,7 +29,6 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); -PyObject* THPModule_DisableTorchFunctionType(); PyObject* THPModule_DisableTorchFunctionSubclassType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); From ba4d5aae06bde7c0ad045e54b7ad86f4542efb86 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 11 Nov 2022 19:13:05 +0000 Subject: [PATCH 074/453] Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)" This reverts commit 7f28be10e5e71efda37800384fa897785499bed1. Reverted https://github.com/pytorch/pytorch/pull/88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901 --- test/allowlist_for_publicAPI.json | 2 +- test/profiler/test_profiler_tree.py | 2 +- test/test_overrides.py | 4 +-- test/test_public_bindings.py | 2 +- torch/_C/__init__.pyi.in | 2 +- torch/__init__.py | 2 +- torch/_dynamo/variables/builder.py | 2 +- torch/_dynamo/variables/misc.py | 2 +- torch/_dynamo/variables/tensor.py | 2 +- torch/_subclasses/fake_tensor.py | 2 +- torch/_tensor.py | 2 +- torch/csrc/Module.cpp | 4 +-- torch/csrc/autograd/init.cpp | 1 + torch/csrc/utils/disable_torch_function.cpp | 32 +++++++++---------- torch/csrc/utils/disable_torch_function.h | 2 +- torch/distributed/_shard/common_op_utils.py | 4 +-- torch/distributed/_shard/partial_tensor.py | 2 +- torch/distributed/_shard/replicated_tensor.py | 4 +-- .../_shard/sharded_tensor/_ops/tensor_ops.py | 2 +- torch/masked/maskedtensor/core.py | 2 +- 20 files changed, 38 insertions(+), 39 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 8a66dc12d4b6..ba4a2e96df21 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -1128,7 +1128,7 @@ "BFloat16Tensor", "ComplexDoubleStorage", "ComplexFloatStorage", - "DisableTorchFunctionSubclass", + "DisableTorchFunction", "Generator", "HalfStorage", "HalfTensor", diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 210530250f92..d4a31c645613 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -26,7 +26,7 @@ "torch/profiler/profiler.py(...): start": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): stop_trace": KEEP_ELLIPSES, "torch/profiler/profiler.py(...): _transit_action": KEEP_ELLIPSES, - "": PRUNE_ALL, + "": PRUNE_ALL, "cudaStreamIsCapturing": PRUNE_ALL, "cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PRUNE_ALL, } diff --git a/test/test_overrides.py b/test/test_overrides.py index 01c763a548fc..7082f75a2141 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1448,7 +1448,7 @@ class B(torch.Tensor): x = B(torch.randn(5)) with A(): - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self.assertNotIsInstance(torch.sum(x), B) self.assertTrue(called) @@ -1460,7 +1460,7 @@ class A(torch.Tensor): pass x = A(torch.randn(5)) - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): g = torch._C._EnableTorchFunction() try: self.assertIsInstance(torch.sum(x), A) diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 6897c3102df6..4d2df6512698 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -99,7 +99,7 @@ def test_no_new_bindings(self): "device", "DeviceObjType", "DictType", - "DisableTorchFunctionSubclass", + "DisableTorchFunction", "DispatchKey", "DispatchKeySet", "dtype", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 79dd6386c378..2d20da2a04f3 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -108,7 +108,7 @@ class layout: ... # Defined in torch/csrc/utils/disable_torch_function.cpp -def DisableTorchFunctionSubclass(): ... +def DisableTorchFunction(): ... # Defined in torch/csrc/utils/tensor_layouts.cpp strided : layout = ... diff --git a/torch/__init__.py b/torch/__init__.py index ec23499dce65..19be59282cca 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -315,7 +315,7 @@ def get_pyobj(self): if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type] if (obj.__module__ != 'torch'): # TODO: fix their module from C++ side - if name not in ['DisableTorchFunctionSubclass', 'Generator']: + if name not in ['DisableTorchFunction', 'Generator']: obj.__module__ = 'torch' if not TYPE_CHECKING: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9d8789746855..d3c5140fa4a9 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -506,7 +506,7 @@ def wrap_tensor(self, value: torch.Tensor): ) # Disable __torch_function__ to prevent cloning of `value` to hit # us - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): if is_constant_source(self.get_source()): return self.tx.output.register_attr_or_module( value, diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 6e4325b6c0f4..da327122a6a7 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -538,7 +538,7 @@ def call_function( options = VariableTracker.propagate(self, new_args, new_kwargs.values()) # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): if isinstance(args[0], TorchVariable): return TensorVariable.create( tx=tx, diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 0974f24ee969..e87b1d87bac9 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -743,7 +743,7 @@ def inline_torch_function_unwrapped( # Disable __torch_function__ here to prevent the clone of the # example tensor from going into the override. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return tx.inline_user_function_return(tf_func_var, tf_args, {}) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 79af51efc5b8..14f5cd2de0a7 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1093,5 +1093,5 @@ def __torch_function__(self, func, types, args=(), kwargs=None): memo[id(tensor)] = out return out else: - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return func(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 41b6569c06d8..793034bb64ed 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -1297,7 +1297,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with _C.DisableTorchFunctionSubclass(): + with _C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index efe6c18ea0cd..b8693a484ed9 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1594,8 +1594,8 @@ Call this whenever a new thread is created in order to propagate values from (PyObject*)THPDefaultCPUGenerator, /* incref= */ false)); ASSERT_TRUE(set_module_attr( - "DisableTorchFunctionSubclass", - (PyObject*)THPModule_DisableTorchFunctionSubclassType(), + "DisableTorchFunction", + (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false)); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index d26db95f1295..ee963232d316 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -343,6 +343,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { _C_m, "_RestorePythonTLSSnapshot") .def(py::init<>()); + // TODO: line up this binding with DisableTorchFunction py::class_(_C_m, "_DisableTorchDispatch") .def(py::init<>()); py::class_(_C_m, "_EnableTorchFunction") diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index 516e6b89d43a..682120d7e622 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -35,20 +35,18 @@ typedef struct { PyObject_HEAD /* Type-specific fields go here. */ bool old_state; -} DisableTorchFunctionSubclass; +} DisableTorchFunction; -PyObject* DisableTorchFunctionSubclass__enter( - PyObject* self, - PyObject* unused) { - ((DisableTorchFunctionSubclass*)self)->old_state = +PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { + ((DisableTorchFunction*)self)->old_state = at::impl::PythonTorchFunctionTLS::is_disabled(); at::impl::PythonTorchFunctionTLS::set_disabled(true); Py_RETURN_NONE; } -PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { +PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { at::impl::PythonTorchFunctionTLS::set_disabled( - ((DisableTorchFunctionSubclass*)self)->old_state); + ((DisableTorchFunction*)self)->old_state); Py_RETURN_NONE; } @@ -60,16 +58,16 @@ PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) { } } -static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT - {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr}, - {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr}, +static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT + {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, + {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; -PyTypeObject DisableTorchFunctionSubclassType = { +PyTypeObject DisableTorchFunctionType = { PyVarObject_HEAD_INIT( nullptr, - 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */ - sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */ + 0) "torch._C.DisableTorchFunction", /* tp_name */ + sizeof(DisableTorchFunction), /* tp_basicsize */ 0, /* tp_itemsize */ nullptr, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ @@ -94,7 +92,7 @@ PyTypeObject DisableTorchFunctionSubclassType = { 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ - DisableTorchFunctionSubclass_methods, /* tp_methods */ + DisableTorchFunction_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ @@ -107,12 +105,12 @@ PyTypeObject DisableTorchFunctionSubclassType = { PyType_GenericNew, /* tp_new */ }; -PyObject* THPModule_DisableTorchFunctionSubclassType() { - if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) { +PyObject* THPModule_DisableTorchFunctionType() { + if (PyType_Ready(&DisableTorchFunctionType) < 0) { return nullptr; } - return (PyObject*)(&DisableTorchFunctionSubclassType); + return (PyObject*)(&DisableTorchFunctionType); } PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index 881a7adb13eb..3cdc33e90681 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -29,7 +29,7 @@ struct DisableTorchDispatch { } // namespace torch PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused); -PyObject* THPModule_DisableTorchFunctionSubclassType(); +PyObject* THPModule_DisableTorchFunctionType(); PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args); PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args); PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg); diff --git a/torch/distributed/_shard/common_op_utils.py b/torch/distributed/_shard/common_op_utils.py index 42d65923a536..08aa13282abc 100644 --- a/torch/distributed/_shard/common_op_utils.py +++ b/torch/distributed/_shard/common_op_utils.py @@ -53,11 +53,11 @@ def tensor_default_op(types, args=(), kwargs=None, pg=None): Handles ``__torch_function__`` dispatch for the default tensor ops that behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or ``torch.Tensor.dtype``. We simply lower to the real op call with - DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` + DisableTorchFunction context like ``torch.Tensor.__torch_function__`` to avoid recursions. """ if kwargs is None: kwargs = {} - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): return op(*args, **kwargs) diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py index 6a48163082c5..dc8d09bdd7f3 100644 --- a/torch/distributed/_shard/partial_tensor.py +++ b/torch/distributed/_shard/partial_tensor.py @@ -236,7 +236,7 @@ def find_process_group(e): # Need to disable all dispatch to print args and kwargs appropriately. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] try: - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): raise RuntimeError( f"torch function '{func.__name__}', with args: {args} and " f"kwargs: {kwargs} not supported for PartialTensor!") diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py index e3db6b0fac66..1327f89e00aa 100644 --- a/torch/distributed/_shard/replicated_tensor.py +++ b/torch/distributed/_shard/replicated_tensor.py @@ -109,7 +109,7 @@ def dispatch_arg(arg): # We cann't do super().__torch_function__() as it implicitly convert the result # back to tensor subclasses, where in our case, we need to control the output type # base on the inter-op rules we defined. - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs @@ -157,7 +157,7 @@ def validate(self) -> bool: return True def __setstate__(self, state): - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self.data = state self.requires_grad = state.requires_grad from torch.distributed._shard.api import _get_current_process_group diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index 9ed83ee33f61..e52c29238a62 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -203,7 +203,7 @@ def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): local_shard.tensor.requires_grad_(requires_grad) # update the wrapper class property - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): self_st.requires_grad_(requires_grad) # update the metadata in the meanwhile self_st._metadata.tensor_properties.requires_grad = requires_grad diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 0459f24587bd..3274ef2ef956 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -270,7 +270,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if not all(issubclass(cls, t) for t in types): return NotImplemented - with torch._C.DisableTorchFunctionSubclass(): + with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret From f74946324e794d2332251d0497dc8ff4f831caa9 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 11 Nov 2022 21:11:12 +0000 Subject: [PATCH 075/453] [fix] allow saving python attr on Tensor and Parameter via torch.save (#81616) Fixes: https://github.com/pytorch/pytorch/issues/72129 TODO: * [x] Fix for Parameter Benchmark (Measurable diff for small tensors) ``` [-------------- Save and Load --------------] | After PR | Before PR 1 threads: ---------------------------------- () | 111.7 | 106.9 (4, 4) | 114.4 | 109.2 (128, 128) | 135.2 | 128.3 (1024, 1024) | 1431.9 | 1431.3 Times are in microseconds (us). ```
Benchmark Script ```python import torch from torch.testing._internal.common_utils import BytesIOContext from torch.utils import benchmark import pickle shapes = ((), (4, 4), (128, 128), (1024, 1024)) sizes = [1, 64, 1024, 10000] results = [] def save_load_fn(t): with BytesIOContext() as f: torch.save(t, f) f.seek(0) torch.load(f) for shape in shapes: t = torch.randn(shape) label = 'Save and Load' sub_label = f'{shape}' results.append(benchmark.Timer( stmt='save_load_fn(t)', globals={'t': t, 'save_load_fn':save_load_fn}, label=label, sub_label=sub_label, description='Before PR', ).blocked_autorange(min_run_time=2)) compare = benchmark.Compare(results) compare.print() with open('before_pr.pkl', 'wb') as f: pickle.dump(results, f) # with open('after_pr.pkl', 'rb') as f: # after_pr = pickle.load(f) # with open('before_pr.pkl', 'rb') as f: # before_pr = pickle.load(f) # compare = benchmark.Compare(after_pr + before_pr) # compare.print() ```
NOTE : **BC-Breaking** : After this PR, all tensors (also regular tensors) will be serialised using `_rebuild_from_type_v2`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81616 Approved by: https://github.com/albanD, https://github.com/kurtamohler --- test/test_serialization.py | 22 +++++++ torch/_tensor.py | 43 ++----------- torch/_utils.py | 59 ++++++++++++++++++ torch/_weights_only_unpickler.py | 4 ++ torch/csrc/jit/serialization/unpickler.cpp | 71 ++++++++++++++++++++++ torch/csrc/jit/serialization/unpickler.h | 4 ++ torch/nn/parameter.py | 1 + 7 files changed, 165 insertions(+), 39 deletions(-) diff --git a/test/test_serialization.py b/test/test_serialization.py index 5ccc6f47b4c5..dca926be60e7 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -905,6 +905,28 @@ def test_meta_serialization(self, weights_only): self.assertEqual(state['weight'].size(), big_model.weight.size()) + def test_serialization_python_attr(self): + def _test_save_load_attr(t): + t.foo = 'foo' + t.pi = 3.14 + + with BytesIOContext() as f: + torch.save(t, f) + f.seek(0) + loaded_t = torch.load(f) + + self.assertEqual(t, loaded_t) + self.assertEqual(t.foo, loaded_t.foo) + self.assertEqual(t.pi, loaded_t.pi) + + t = torch.zeros(3, 3) + _test_save_load_attr(t) + # This should start failing once Parameter + # supports saving Python Attribute. + err_msg = "'Parameter' object has no attribute" + with self.assertRaisesRegex(AttributeError, err_msg): + _test_save_load_attr(torch.nn.Parameter(t)) + def test_weights_only_assert(self): class HelloWorld: def __reduce__(self): diff --git a/torch/_tensor.py b/torch/_tensor.py index 793034bb64ed..39fc56452f5a 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -55,9 +55,6 @@ def _rebuild_from_type(func, type, args, dict): def _rebuild_from_type_v2(func, new_type, args, state): - if new_type is Tensor: - return func(*args) - ret = func(*args) if type(ret) is not new_type: ret = ret.as_subclass(new_type) @@ -70,21 +67,7 @@ def _rebuild_from_type_v2(func, new_type, args, state): ): ret.__setstate__(state) else: - if isinstance(state, tuple): - if not len(state) == 2: - raise RuntimeError(f"Invalid serialized state: {state}") - dict_state = state[0] - slots_state = state[1] - else: - dict_state = state - slots_state = None - - for k, v in dict_state.items(): - setattr(ret, k, v) - - if slots_state: - for k, v in slots_state.items(): - setattr(ret, k, v) + ret = torch._utils._set_obj_state(ret, state) return ret @@ -223,31 +206,13 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - if type(self) is Tensor: + state = torch._utils._get_obj_state(self) + if type(self) is Tensor and not state: + # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): return handle_torch_function(Tensor.__reduce_ex__, (self,), self, proto) func, args = self._reduce_ex_internal(proto) - # Get the state of the python subclass - # This loosely mimicks the function on the object class but since Tensor do not inherit - # from it, we cannot call that function directly - # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 - getstate_fn = getattr(self, "__getstate__", None) - if getstate_fn: - state = getstate_fn() - else: - slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] - if slots_to_save: - state = ( - self.__dict__, - { - name: getattr(self, name) - for name in slots_to_save - if hasattr(self, name) - }, - ) - else: - state = self.__dict__ return (_rebuild_from_type_v2, (func, type(self), args, state)) def storage(self): diff --git a/torch/_utils.py b/torch/_utils.py index 3bc8a749b3e6..9c646a2f85e0 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -1,3 +1,4 @@ +import copyreg import sys import traceback import warnings @@ -335,6 +336,64 @@ def _rebuild_parameter(data, requires_grad, backward_hooks): return param +# TODO(kshitij12345): Support serializing nn.Parameter with Python Attributes. +# NOTE: We are just defining it here now for future use. +def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state): + param = torch.nn.Parameter(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + param._backward_hooks = backward_hooks + + # Restore state on Parameter like python attr. + param = _set_obj_state(param, state) + return param + + +def _get_obj_state(obj): + # Get the state of the python subclass + # This loosely mimicks the function on the object class but since Tensor do not inherit + # from it, we cannot call that function directly + # https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891 + getstate_fn = getattr(obj, "__getstate__", None) + if getstate_fn: + state = getstate_fn() + else: + slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined] + if slots_to_save: + state = ( + obj.__dict__, + { + name: getattr(obj, name) + for name in slots_to_save + if hasattr(obj, name) + }, + ) + else: + state = obj.__dict__ + + return state + + +def _set_obj_state(obj, state): + if isinstance(state, tuple): + if not len(state) == 2: + raise RuntimeError(f"Invalid serialized state: {state}") + dict_state = state[0] + slots_state = state[1] + else: + dict_state = state + slots_state = None + + for k, v in dict_state.items(): + setattr(obj, k, v) + + if slots_state: + for k, v in slots_state.items(): + setattr(obj, k, v) + return obj + + def _import_dotted_name(name): components = name.split(".") obj = __import__(components[0]) diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index ee00db937fc3..acc3554768b0 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -103,6 +103,10 @@ def _get_allowed_globals(): torch._utils._rebuild_sparse_csr_tensor, ]: rc[f"torch._utils.{f.__name__}"] = f + + # Handles Tensor Subclasses, Tensor's with attributes. + # NOTE: It calls into above rebuild functions for regular Tensor types. + rc["torch._tensor._rebuild_from_type_v2"] = torch._tensor._rebuild_from_type_v2 return rc diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index f7e974919f03..4bbf7a783a23 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -532,6 +532,21 @@ PickleOpCode Unpickler::readInstruction() { } stack_.emplace_back(std::move(tensor)); } break; + case PickleOpCode::SETITEM: { + // At this OpCode, stack looks like + // | Stack Bottom | + // | ...... | + // | Dict | -> (stack_size - 3) + // | Key | -> (stack_size - 2) + // | Value | -> (stack_size - 1) + auto stack_size = stack_.size(); + auto dict_pos = stack_size - 3; + auto key_pos = stack_size - 2; + auto val_pos = stack_size - 1; + auto dict = stack_.at(dict_pos).toGenericDict(); + dict.insert_or_assign(stack_.at(key_pos), stack_.at(val_pos)); + stack_.erase(stack_.begin() + (key_pos), stack_.end()); + } break; default: { AT_ERROR( "Unknown opcode for unpickling at ", @@ -546,6 +561,23 @@ PickleOpCode Unpickler::readInstruction() { void Unpickler::readGlobal( const std::string& module_name, const std::string& class_name) { + if (this->skip_next_read_global) { + // See [NOTE] skip_next_read_global + this->skip_next_read_global--; + if (this->skip_next_read_global == 1) { + // Pass through to the correct handler + } else if (this->skip_next_read_global == 0) { + // Corresponds to the type of `Tensor` being unpickled + if (module_name != "torch" || class_name != "Tensor") { + TORCH_WARN( + "Trying to load a Subclassed Tensor, it will be converted to at::Tensor in C++"); + } + stack_.emplace_back(int64_t(globals_.size() - 1)); + return; + } else { + TORCH_CHECK(false, "INVALID VALUES") + } + } // TODO [unpickler refactor] __main__ isn't used by the pickler anymore, this // is only here for bc-compatibility reasons if (module_name == "__main__") { @@ -631,6 +663,12 @@ void Unpickler::readGlobal( // Unpickle a tensor bool quantized = class_name == "_rebuild_qtensor"; rebuildTensor(quantized); + } else if ( + module_name == "torch._tensor" && + (class_name == "_rebuild_from_type_v2")) { + // Unpickle a Tensor with Python attributes or + // a Subclassed Tensor. + rebuildTensorFromTypeV2(); } else if ( module_name == "torch._utils" && class_name == "_rebuild_sparse_tensor") { rebuildSparseTensor(); @@ -849,6 +887,39 @@ void Unpickler::rebuildTensor(bool quantized) { }); } +void Unpickler::rebuildTensorFromTypeV2() { + // [NOTE] skip_next_read_global + // When rebuilding Tensor with Python Attr or Subclassed Tensor, + // we receive `(func, type(self), args, state)` on stack for + // `rebuildTensorFromTypeV2`. + // Thus next call to readGlobal corresponds to `func` which is + // the function to rebuild the base tensor. + // The call after `func` to readGlobal corresponds to `type` of the + // Tensor where we raise warning if the type is not `torch.Tensor`. + this->skip_next_read_global = 2; + auto curr_globals_idx = globals_.size(); + globals_.emplace_back([this, curr_globals_idx] { + // args is a tuple with following data + // (function to rebuild base tensor, type of tensor, + // arguments to construct base tensor, Python State (as dict)) + auto args = pop(stack_).toTuple(); + size_t tup_idx = 0; + const auto args_elems = args->elements(); + auto base_tensor_args = args_elems.at(tup_idx + 2).toTuple(); + auto py_state = args_elems.at(tup_idx + 3).toGenericDict(); + if (py_state.size() > 0) { + TORCH_WARN( + "Loading Tensor with Python attributes will return at::Tensor with Python attributes being discarded"); + } + // This calls the function to rebuild the + // base tensor. + // Eg. `rebuildTensor`, `rebuildSpareTensor`. + stack_.emplace_back(base_tensor_args); + globals_[curr_globals_idx + 1](); + stack_.emplace_back(pop(stack_)); + }); +} + #ifdef USE_RPC void Unpickler::rebuildRRef() { globals_.emplace_back([this] { diff --git a/torch/csrc/jit/serialization/unpickler.h b/torch/csrc/jit/serialization/unpickler.h index 5411d421a0c5..de00e7eacff2 100644 --- a/torch/csrc/jit/serialization/unpickler.h +++ b/torch/csrc/jit/serialization/unpickler.h @@ -120,6 +120,7 @@ class TORCH_API Unpickler { const std::string& module_name, const std::string& class_name); void rebuildTensor(bool quantized); + void rebuildTensorFromTypeV2(); void rebuildSparseTensor(); #ifdef USE_DISTRIBUTED void rebuildRRef(); @@ -176,6 +177,9 @@ class TORCH_API Unpickler { // See [type tag serialization] uint64_t version_; + + // See [NOTE] skip_next_read_global + uint8_t skip_next_read_global = 0; }; void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index e0f400f2642b..68908001238e 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -60,6 +60,7 @@ def __repr__(self): return 'Parameter containing:\n' + super(Parameter, self).__repr__() def __reduce_ex__(self, proto): + # TODO(kshitij12345): Support saving Python Attribute # See Note [Don't serialize hooks] return ( torch._utils._rebuild_parameter, From 575e02df5357ef6216b2d2db2424d10432679df2 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 11 Nov 2022 21:19:26 +0000 Subject: [PATCH 076/453] Fix CUDNN_PATH handling on Windows (#88898) Fixes https://github.com/pytorch/pytorch/issues/88873 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88898 Approved by: https://github.com/kit1980 --- torch/utils/cpp_extension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index aa03da23b38d..720935296504 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1686,7 +1686,7 @@ def _prepare_ldflags(extra_ldflags, with_cuda, verbose, is_standalone): extra_ldflags.append(f'/LIBPATH:{_join_cuda_home("lib", "x64")}') extra_ldflags.append('cudart.lib') if CUDNN_HOME is not None: - extra_ldflags.append(os.path.join(CUDNN_HOME, "lib", "x64")) + extra_ldflags.append(f'/LIBPATH:{os.path.join(CUDNN_HOME, "lib", "x64")}') elif not IS_HIP_EXTENSION: extra_ldflags.append(f'-L{_join_cuda_home("lib64")}') extra_ldflags.append('-lcudart') From 7aa144ac54808419f7a702ef0c5a4445dba4c587 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:21 +0000 Subject: [PATCH 077/453] [FSDP][state_dict][5/N] Remove the FSDP module dependency from _state_dict_utils (#88637) **What** This PR completely removes the `FullyShardedDataParallel` dependency from `_state_dict_utils` -- `_state_dict_utils` now depends only on `_FSDPState` and all the utils modules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88637 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_init_utils.py | 6 +- torch/distributed/fsdp/_state_dict_utils.py | 108 ++++++++++---------- 2 files changed, 58 insertions(+), 56 deletions(-) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 966e61f7fe12..1265ee3578ed 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -213,10 +213,8 @@ def _init_state_dict_state(state: _FSDPState) -> _FSDPState: state._state_dict_type = StateDictType.FULL_STATE_DICT state_dict_config: StateDictConfig = FullStateDictConfig() state._state_dict_config = state_dict_config - full_param_ctx: Optional[Generator] = None - # TODO: For composable API, this should be a dict that maps from a module to - # handles. - state._full_param_ctx = full_param_ctx + unshard_params_ctx: Dict[nn.Module, Generator] = {} + state._unshard_params_ctx = unshard_params_ctx return state diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index eee5522340b4..54191cb55ece 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -8,7 +8,6 @@ import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper # Import the entire FSDP file to avoid circular imports -import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn import torch.nn.functional as F @@ -42,6 +41,7 @@ from ._unshard_param_utils import ( _deregister_orig_params, _register_orig_params, + _unshard_params, FLAT_PARAM, ) from .flat_param import FlatParamHandle @@ -58,7 +58,9 @@ def _convert_to_wrapped_module_name(module_name: str) -> str: return module_name -def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str]]: +def _param_fqns( + module: nn.Module, fsdp_state: _FSDPState +) -> Iterator[Tuple[str, str, str]]: if not _has_fsdp_params(fsdp_state, module): return for param_name, module_name in _module_handles(fsdp_state, module)[ @@ -69,7 +71,7 @@ def _param_fqns(module, fsdp_state: _FSDPState) -> Iterator[Tuple[str, str, str] yield fqn, param_name, module_name -def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: +def _shared_param_fqns(module: nn.Module, fsdp_state) -> Iterator[Tuple[str, str, str]]: for param_name, module_name in _module_handles(fsdp_state, module)[ 0 ].shared_parameter_module_names(): @@ -78,7 +80,9 @@ def _shared_param_fqns(module, fsdp_state) -> Iterator[Tuple[str, str, str]]: yield fqn, param_name, module_name -def _enter_full_param_ctx( +@no_type_check +def _enter_unshard_params_ctx( + module: nn.Module, fsdp_state: _FSDPState, recurse: bool = False, writeback: bool = False, @@ -89,32 +93,32 @@ def _enter_full_param_ctx( """ state_dict hooks cannot use the pure context call as the checkpoint flow requires to enter the context in the pre-hook but leave the context in the - post-hook. This API enters the context of ``summon_full_params``. + post-hook. This API enters the context of ``_unshard_params``. """ - assert fsdp_state._full_param_ctx is None, ( - "Entering the ``summon_full_params`` context but fsdp_state._full_param_ctx " + assert module not in fsdp_state._unshard_params_ctx, ( + "Entering the ``_unshard_params`` context but _unshard_params_ctx[module] " "is not None." ) - fsdp_state._full_param_ctx = fsdp_state._summon_full_params( - recurse=recurse, + fsdp_state._unshard_params_ctx[module] = _unshard_params( + module, + fsdp_state, writeback=writeback, rank0_only=rank0_only, offload_to_cpu=offload_to_cpu, with_grads=with_grads, ) - fsdp_state._full_param_ctx.__enter__() + fsdp_state._unshard_params_ctx[module].__enter__() @no_type_check -def _exit_full_param_ctx(fsdp_state: _FSDPState) -> None: - """A helper function to exit ``summon_full_params`` context.""" - assert fsdp_state._full_param_ctx is not None - fsdp_state._full_param_ctx.__exit__(None, None, None) - fsdp_state._full_param_ctx = None +def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: + """A helper function to exit ``_unshard_params`` context.""" + fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) + fsdp_state._unshard_params_ctx.pop(module) def _common_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -131,16 +135,18 @@ def _common_pre_state_dict_hook( _clear_grads_if_needed([_module_handles(fsdp_state, module)[0]]) -def _common_summon_pre_state_dict_hook( +def _common_unshard_pre_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, offload_to_cpu: bool, rank0_only: bool, ) -> None: """ Performs the pre-state_dict tasks shared by all state_dict types that require - ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + ``_unshard_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ - _enter_full_param_ctx( + _enter_unshard_params_ctx( + module, fsdp_state, recurse=False, writeback=False, @@ -151,8 +157,8 @@ def _common_summon_pre_state_dict_hook( # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``. @no_type_check -def _common_summon_post_state_dict_hook( - module, +def _common_unshard_post_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -160,13 +166,13 @@ def _common_summon_post_state_dict_hook( ) -> Dict[str, Any]: """ The post-state_dict flow that shared by all state_dict types that require - ``summon_full_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + ``_unshard_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. """ _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) # Return early for trivial cases if not state_dict or not _has_fsdp_params(fsdp_state, module): - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) return state_dict # TODO: Once pre_state_dict hook is supported, this pop should be removed. @@ -193,7 +199,7 @@ def _common_summon_post_state_dict_hook( f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" ) state_dict.pop(f"{prefix}{clean_key}", None) - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) return state_dict # Loop only the parameters saved in this instance's wrapped module to @@ -214,7 +220,7 @@ def _common_summon_post_state_dict_hook( ) param_hook(state_dict, prefix, fqn) - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) cpu_device = torch.device("cpu") buffer_clean_fqns = [] @@ -251,7 +257,7 @@ def _common_summon_post_state_dict_hook( @no_type_check def _full_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -267,7 +273,8 @@ def _full_pre_state_dict_hook( in ``nn.Module``. """ _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) - _common_summon_pre_state_dict_hook( + _common_unshard_pre_state_dict_hook( + module, fsdp_state, offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, @@ -276,7 +283,7 @@ def _full_pre_state_dict_hook( @no_type_check def _full_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -284,7 +291,7 @@ def _full_post_state_dict_hook( """ Hook that runs after model.state_dict() is called before returning result to user. For FSDP, we may have to clone the tensors in state_dict as params go - back to sharded version after _summon_full_params ends, and also remove + back to sharded version after _unshard_params ends, and also remove the ``FSDP_WRAPPED_MODULE`` prefix. """ # TODO: remove the hack. See ``_full_pre_state_dict_hook``. @@ -303,8 +310,7 @@ def param_hook( if clean_key.startswith(clean_prefix): clean_key = clean_key[len(clean_prefix) :] - # Clone non-ignored parameters before exiting the - # `_summon_full_params()` context + # Clone non-ignored parameters before exiting the `_unshard_params()` context. if clean_key not in fsdp_state._ignored_param_names and not getattr( state_dict[fqn], "_has_been_cloned", False ): @@ -320,30 +326,30 @@ def param_hook( f"implementation of {fqn}. Error: {str(e)}" ) - return _common_summon_post_state_dict_hook( + return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) def _full_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, ) -> None: _lazy_init(fsdp_state, module) - _enter_full_param_ctx(fsdp_state, recurse=False, writeback=True) + _enter_unshard_params_ctx(module, fsdp_state, recurse=False, writeback=True) _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") def _full_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: - _exit_full_param_ctx(fsdp_state) + _exit_unshard_params_ctx(module, fsdp_state) def _local_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -366,7 +372,7 @@ def _local_pre_state_dict_hook( @no_type_check def _local_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -408,13 +414,13 @@ def _local_post_state_dict_hook( def _local_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: pass def _local_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -457,7 +463,7 @@ def _local_pre_load_state_dict_hook( def _sharded_pre_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -477,7 +483,8 @@ def _sharded_pre_state_dict_hook( _common_pre_state_dict_hook(module, fsdp_state, state_dict, prefix) # Setting offload_to_cpu here does not work even if offload_to_cpu is True. # We have to create ShardedTensor first then move it to CPU. - _common_summon_pre_state_dict_hook( + _common_unshard_pre_state_dict_hook( + module, fsdp_state, offload_to_cpu=False, rank0_only=False, @@ -486,7 +493,7 @@ def _sharded_pre_state_dict_hook( @no_type_check def _sharded_post_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -512,14 +519,14 @@ def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str): sharded_tensor = sharded_tensor.cpu() state_dict[fqn] = sharded_tensor - return _common_summon_post_state_dict_hook( + return _common_unshard_post_state_dict_hook( module, fsdp_state, state_dict, prefix, param_hook ) @no_type_check def _sharded_post_load_state_dict_hook( - module, fsdp_state: _FSDPState, *args, **kwargs + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs ) -> None: if fsdp_state._use_orig_params: _register_orig_params(module, fsdp_state) @@ -527,7 +534,7 @@ def _sharded_post_load_state_dict_hook( @no_type_check def _sharded_pre_load_state_dict_hook( - module, + module: nn.Module, fsdp_state: _FSDPState, state_dict: Dict[str, Any], prefix: str, @@ -636,9 +643,8 @@ def _post_state_dict_hook( StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, } - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( - fsdp_module, fsdp_state, state_dict, prefix + module, fsdp_state, state_dict, prefix ) return processed_state_dict @@ -664,12 +670,11 @@ def _pre_load_state_dict_hook( StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, } # Code that is common for all state_dict impls - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) if torch.cuda.is_available(): torch.cuda.synchronize() # Dispatch into state_dict specific implementation of pre-hook. _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( - fsdp_module, fsdp_state, state_dict, prefix + module, fsdp_state, state_dict, prefix ) @@ -684,7 +689,6 @@ def _post_load_state_dict_hook(module: nn.Module, *args: Any) -> None: StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, } # Code that is common for all state_dict impls - fsdp_module = cast(fsdp_file.FullyShardedDataParallel, module) # Dispatch into state_dict type specific implementation of post-hook for # loading state_dict. - _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](fsdp_module, fsdp_state) + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) From dfb4b73e45896851d734e34a9902fd8b151797fe Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Fri, 11 Nov 2022 21:51:10 +0000 Subject: [PATCH 078/453] Fix unused variable 'options' warning in RNN.cpp (#88753) Fixes ``` /home/rbarnes/pytorch/aten/src/ATen/native/cudnn/RNN.cpp:73:17: warning: unused variable 'options' [-Wunused-variable] TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); ^ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88753 Approved by: https://github.com/soumith --- aten/src/ATen/native/cudnn/RNN.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index c08c5d26b63c..426243392b6f 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -70,7 +70,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see c10::optional device, c10::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] - TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } From ea0ec9d71ca5428bedfcaf74990c109af8cb9a64 Mon Sep 17 00:00:00 2001 From: efiks <5167930+efiks@users.noreply.github.com> Date: Fri, 11 Nov 2022 21:58:23 +0000 Subject: [PATCH 079/453] [tourch] BatchBoxCox - fix numerical issue in vectorized code (#88875) Summary: Usage of fast math in BatchBoxCox kernel provided different math results between dev and optimized versions which cause few internal test to fail. For now disabling the compiler optimized version and relying on ATEN vectors Differential Revision: D41211784 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88875 Approved by: https://github.com/hyuen --- caffe2/perfkernels/batch_box_cox_avx2.cc | 93 ++++++++++++++---------- 1 file changed, 53 insertions(+), 40 deletions(-) diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc index 8b93293646db..6171b5bfd032 100644 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ b/caffe2/perfkernels/batch_box_cox_avx2.cc @@ -1,3 +1,4 @@ +#include #ifdef CAFFE2_PERF_USE_MKL #include #include @@ -5,30 +6,68 @@ #include "vectorizer.h" -#ifndef VECTORIZED_KERNEL +// Enable compiler vectorized version only if numerical consistency is not +// required between dev and opt versions - disabled for now +#ifndef FAST_VECTORIZED_KERNEL #define CPU_CAPABILITY_AVX2 #include namespace at::vec { +// Implements the vectorized version of std::max() operation, +// which DOESNOT propagates NaN for second argument template Vectorized max(const Vectorized& a, const Vectorized& b); -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_pd(b, a); } - template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_ps(b, a); } +// Implements recieprocal method based on newton-rapson method +// 1. user RCP approximiation +// 2. update with RCP = RCP * (2 - X * RCP) +template +Vectorized fast_recieprocal(const Vectorized& b); +template +scalar_t fast_recieprocal(scalar_t b); + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + auto minus2 = _mm256_set1_ps(-2.f); + auto rcp = _mm256_rcp_ps(b); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + return rcp; +} + +template <> +float fast_recieprocal(float b) { + auto minus2 = _mm_set_ss(-2.f); + auto b_reg = _mm_set_ss(b); + auto rcp = _mm_rcp_ss(b_reg); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + return _mm_cvtss_f32(rcp); +} + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + return b.reciprocal(); +} + +template <> +double fast_recieprocal(double b) { + return 1./b; +} + } #endif @@ -45,14 +84,6 @@ template void PackV(const int N, const T* a, const int* ia, T* y); template void UnpackV(const int N, const T* a, T* y, const int* iy); -template -void Pow(const int N, const T* a, const T* b, T* y); -template -void Add(const int N, const T* a, const T* b, T* y); -template -void Div(const int N, const T* a, const T* b, T* y); -template -void Ln(const int N, const T* a, T* y); #define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ template <> \ @@ -72,29 +103,7 @@ DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) #undef DELEGATE_UNPACKV_FUNCTION -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, const T* b, T* y) { \ - OriginalFunc(N, a, b, y); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Pow, vsPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Pow, vdPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, T* y) { \ - OriginalFunc(N, a, y); \ - } -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Ln, vsLn) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Ln, vdLn) -#undef DELEGATE_SIMPLE_UNARY_FUNCTION - -#ifndef VECTORIZED_KERNEL +#ifndef FAST_VECTORIZED_KERNEL template void box_cox_zero_lambda( size_t D, @@ -140,7 +149,7 @@ void box_cox_nonzero_lambda( auto sum = data + lambda2; auto max = at::vec::max(sum, k_eps_vec); auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = lambda1.reciprocal(); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); auto pow = max.pow(lambda1); auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); res.store(out + j); @@ -148,7 +157,7 @@ void box_cox_nonzero_lambda( for ( ;j < D; ++j) { auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); auto pow = std::pow(max, lambda1_ptr[j]); out[j] = pow * lambda_over_1 - lambda_over_1; } @@ -181,12 +190,16 @@ void box_cox_nonzero_lambda( FAST_MATH auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; - auto pow = std::pow(max, lambda1_ptr[j]); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); out[j] = pow * lambda_over_1 - lambda_over_1; } } - #endif template From fbc1878265374a159639993269d40a6e08503278 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Tue, 8 Nov 2022 10:22:32 -0800 Subject: [PATCH 080/453] [ONNX] Pretty print diagnostic logging (#88261) Adds pretty print diagnostic logging. For example ```python import io import torch from torch.onnx._internal import diagnostics class CustomAdd(torch.autograd.Function): @staticmethod def forward(ctx, x, y): return x + y @staticmethod def symbolic(g, x, y): return g.op("custom::CustomAdd", x, y) class M(torch.nn.Module): def forward(self, x): return CustomAdd.apply(x, x) # trigger warning for missing shape inference. # rule = diagnostics.rules.node_missing_onnx_shape_inference torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) ``` By default, observe minimum summary of diagnostics ``` ========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 ========= verbose: False, log level: Level.ERROR ======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ======================== 3 WARNING were not printed due to the log level. ``` Adjusting the `verbose` and `level` argument. ```python diagnostics.engine.pretty_print(verbose=True, level=diagnostics.levels.WARNING) ``` Prints full log. ``` =============================== 1 Diagnostic Run =============================== ========= Diagnostic Run torch.onnx.export version 1.14.0a0+git90a69c5 ========= verbose: True, log level: Level.WARNING ======================= 0 NONE 0 NOTE 3 WARNING 0 ERROR ======================== WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: n, utils._params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/_patch_torch.py:82 frame: <@beartype(torch.onnx._patch_torch._graph_op) at 0x7f62184b6710>:78 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: return function(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_deprecation.py:30 frame: return g.op("custom::CustomAdd", x, y) test_pretty_print.py:14 frame: return symbolic_fn(g, *args) /home/bowbao/pytorch_dev/torch/onnx/utils.py:1716 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: graph = _C._jit_pass_onnx(graph, operator_export_type) /home/bowbao/pytorch_dev/torch/onnx/utils.py:663 frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabd026 (0x7f625b599026 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:688 frame: <@beartype(torch.onnx.utils._optimize_graph) at 0x7f62180e05f0>:85 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: module=module, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1123 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr&, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () WARNING: node-missing-onnx-shape-inference ========================================== The shape inference of custom::CustomAdd type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. --------------------------- Stack: Python call stack --------------------------- frame: diagnostic = ExportDiagnostic(rule, level, message, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/diagnostics/_diagnostic.py:151 frame: graph, params_dict, GLOBALS.export_onnx_opset_version /home/bowbao/pytorch_dev/torch/onnx/utils.py:1179 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: dynamic_axes=dynamic_axes, /home/bowbao/pytorch_dev/torch/onnx/utils.py:1539 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: export_modules_as_functions=export_modules_as_functions, /home/bowbao/pytorch_dev/torch/onnx/utils.py:519 frame: <@beartype(torch.onnx.utils.export) at 0x7f62180e0170>:347 frame: return beartyped(*args, **kwargs) /home/bowbao/pytorch_dev/torch/onnx/_internal/_beartype.py:81 frame: torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO()) test_pretty_print.py:22 ---------------------------- Stack: C++ call stack ----------------------------- frame: () frame: ( + 0x88411b (0x7f625b36011b in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Value*, std::pair const&) + 0x7d3 (0x7f625b351743 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::UpdateReliable(torch::jit::Node*) + 0x4f (0x7f625b35198f in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0xac9 (0x7f625b357179 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x87d6d1 (0x7f625b3596d1 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: (torch::jit::ONNXShapeTypeInference(std::shared_ptr&, std::map, std::allocator >, c10::IValue, std::less, std::allocator > >, std::allocator, std::allocator > const, c10::IValue> > > const&, int) + 0x33 (0x7f625b359cf3 in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0xabdbae (0x7f625b599bae in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: ( + 0x3c0fda (0x7f625ae9cfda in /home/bowbao/pytorch_dev/torch/lib/libtorch_python.so)) frame: () ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88261 Approved by: https://github.com/abock, https://github.com/justinchuby --- test/onnx/internal/test_diagnostics.py | 2 +- .../onnx/_internal/diagnostics/_diagnostic.py | 18 +-- .../_internal/diagnostics/infra/_infra.py | 110 +++++++++++++++++- .../_internal/diagnostics/infra/engine.py | 15 +++ .../_internal/diagnostics/infra/formatter.py | 18 +++ .../onnx/_internal/diagnostics/infra/utils.py | 2 +- 6 files changed, 140 insertions(+), 25 deletions(-) diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index ea9a789e91c1..884b7cb1c388 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -19,7 +19,7 @@ def _assert_has_diagnostics( rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): sarif_log = engine.sarif_log() - unseen_pairs = {(rule.id, level.value) for rule, level in rule_level_pairs} + unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} actual_results = [] for run in sarif_log.runs: if run.results is None: diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 21e44f2b4467..efe5c0e34911 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -74,22 +74,6 @@ def record_cpp_call_stack(self, frames_to_skip) -> None: self.with_stack(stack) self.cpp_call_stack = stack - def with_model_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - - def with_export_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - class ExportDiagnosticEngine(infra.DiagnosticEngine): """PyTorch ONNX Export diagnostic engine. @@ -115,7 +99,6 @@ def __init__(self) -> None: name="torch.onnx", version=torch.__version__, diagnostic_type=ExportDiagnostic, - options=None, ) @property @@ -150,6 +133,7 @@ def create_export_diagnostic_context(): try: yield context finally: + context.pretty_print(context.options.log_verbose, context.options.log_level) context = engine.background_context diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index b8a4c5032f52..3414574cce73 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -17,10 +17,10 @@ class Level(enum.Enum): please use infra.Tag instead. """ - NONE = "none" - NOTE = "note" - WARNING = "warning" - ERROR = "error" + NONE = enum.auto() + NOTE = enum.auto() + WARNING = enum.auto() + ERROR = enum.auto() levels = Level @@ -107,6 +107,9 @@ def format_message(self, *args, **kwargs) -> str: """ return self.message_default_template.format(*args, **kwargs) + def pretty_print(self): + pass + @dataclasses.dataclass class Location: @@ -134,6 +137,25 @@ def sarif(self) -> sarif.Location: else None, ) + def pretty_print(self): + """Prints the location in a human-readable format.""" + location_strs = ["frame:"] + if self.snippet is not None: + location_strs.append(self.snippet) + if self.uri is not None: + line_strs = [self.uri] + line_strs.append(str(self.line)) if self.line is not None else "-1" + line_strs.append( + str(self.start_column) + ) if self.start_column is not None else "-1" + line_strs.append( + str(self.end_column) + ) if self.end_column is not None else "-1" + location_strs.append(":".join(line_strs)) + if self.message is not None: + location_strs.append(f"({self.message})") + print(" ".join(location_strs)) + @dataclasses.dataclass class StackFrame: @@ -143,6 +165,10 @@ def sarif(self) -> sarif.StackFrame: """Returns the SARIF representation of this stack frame.""" return sarif.StackFrame(location=self.location.sarif()) + def pretty_print(self): + """Prints the stack frame in a human-readable format.""" + self.location.pretty_print() + @dataclasses.dataclass class Stack: @@ -158,6 +184,12 @@ def sarif(self) -> sarif.Stack: else None, ) + def pretty_print(self): + """Prints the stack in a human-readable format.""" + formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-") + for frame in self.frames: + frame.pretty_print() + # This is a workaround for mypy not supporting Self from typing_extensions. _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic") @@ -182,6 +214,9 @@ def sarif(self) -> sarif.Graph: properties=PatchedPropertyBag(name=self.name, description=self.description), ) + def pretty_print(self): + pass + @dataclasses.dataclass class Diagnostic: @@ -201,7 +236,7 @@ def sarif(self) -> sarif.Result: message = f"{message}\n{self.additional_message}" sarif_result = sarif.Result( message=sarif.Message(text=message), - level=self.level.value, + level=self.level.name.lower(), # type: ignore[arg-type] rule_id=self.rule.id, ) sarif_result.locations = [location.sarif() for location in self.locations] @@ -235,6 +270,31 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic: self.additional_message = f"{self.additional_message}\n{message}" return self + def pretty_print(self, verbose: bool = False, log_level: Level = Level.ERROR): + """Prints the diagnostics in a human-readable format. + + Args: + verbose: If True, prints all information. E.g. stack frames, graphs, etc. + Otherwise, only prints compact information. E.g., rule name and display message. + level: The minimum level of diagnostics to print. + """ + if self.level.value < log_level.value: + return + formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}") + print(self.message) + + if not verbose: + print("\n") + return + + for location in self.locations: + location.pretty_print() + for stack in self.stacks: + stack.pretty_print() + for graph in self.graphs: + graph.pretty_print() + print() + @dataclasses.dataclass class RuleCollection: @@ -284,12 +344,15 @@ class DiagnosticOptions: Options for diagnostic context. """ + log_verbose: bool = dataclasses.field(default=False) + log_level: Level = dataclasses.field(default=Level.ERROR) + @dataclasses.dataclass class DiagnosticContext: name: str version: str - options: Optional[DiagnosticOptions] = None + options: DiagnosticOptions = dataclasses.field(default_factory=DiagnosticOptions) diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic) diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) _invocation: Invocation = dataclasses.field(init=False) @@ -350,3 +413,38 @@ def diagnose( diagnostic = self.diagnostic_type(rule, level, message, **kwargs) self.add_diagnostic(diagnostic) return diagnostic + + def pretty_print( + self, verbose: bool = False, log_level: Level = Level.ERROR + ) -> None: + """Prints the diagnostics in a human-readable format. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title( + f"Diagnostic Run {self.name} version {self.version}" + ) + print(f"verbose: {verbose}, log level: {log_level}") + diagnostic_stats = {level: 0 for level in Level} + for diagnostic in self.diagnostics: + diagnostic_stats[diagnostic.level] += 1 + formatter.pretty_print_title( + " ".join(f"{diagnostic_stats[level]} {level.name}" for level in Level) + ) + + for diagnostic in self.diagnostics: + diagnostic.pretty_print(verbose, log_level) + + unprinted_diagnostic_stats = [ + (level, count) + for level, count in diagnostic_stats.items() + if count > 0 and level.value < log_level.value + ] + if unprinted_diagnostic_stats: + print( + f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} " + "were not printed due to the log level." + ) + print() diff --git a/torch/onnx/_internal/diagnostics/infra/engine.py b/torch/onnx/_internal/diagnostics/infra/engine.py index 2678268fbaf9..51a6057565bb 100644 --- a/torch/onnx/_internal/diagnostics/infra/engine.py +++ b/torch/onnx/_internal/diagnostics/infra/engine.py @@ -85,8 +85,23 @@ def create_diagnostic_context( Returns: A new diagnostic context. """ + if options is None: + options = infra.DiagnosticOptions() context = infra.DiagnosticContext( name, version, options, diagnostic_type=diagnostic_type ) self.contexts.append(context) return context + + def pretty_print( + self, verbose: bool = False, level: infra.Level = infra.Level.ERROR + ) -> None: + """Pretty prints all diagnostics in the diagnostic contexts. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title(f"{len(self.contexts)} Diagnostic Run") + for context in self.contexts: + context.pretty_print(verbose, level) diff --git a/torch/onnx/_internal/diagnostics/infra/formatter.py b/torch/onnx/_internal/diagnostics/infra/formatter.py index 2f35489f8d45..292a2b6a47a5 100644 --- a/torch/onnx/_internal/diagnostics/infra/formatter.py +++ b/torch/onnx/_internal/diagnostics/infra/formatter.py @@ -57,3 +57,21 @@ def sarif_to_json(attr_cls_obj: _SarifClass) -> str: dict = dataclasses.asdict(attr_cls_obj) dict = _convert_key(dict, _camel_case_to_snake_case) return json.dumps(dict, indent=4) + + +def pretty_print_title(title: str, width: int = 80, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + ==================== title ==================== + """ + print(f" {title} ".center(width, fill_char)) + + +def pretty_print_item_title(title: str, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + title + ===== + """ + print(title) + print(fill_char * len(title)) diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py index c32de1c6b8ad..6a85df910463 100644 --- a/torch/onnx/_internal/diagnostics/infra/utils.py +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -6,7 +6,7 @@ def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame: """Returns a StackFrame for the given inspect.FrameInfo.""" snippet = ( - frame.code_context[frame.index] + frame.code_context[frame.index].strip() if frame.code_context is not None and frame.index is not None else None ) From f39cad50b765b6fd2f4927a4d1552fff5928c61e Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 11 Nov 2022 22:07:34 +0000 Subject: [PATCH 081/453] Make InductorCPU usable in internally (#88870) Test Plan: `buck2 test mode/opt //caffe2/test:test_inductor -- --exact 'caffe2/test:test_inductor - test_dtype_mismatch_issue_cuda (caffe2.test.inductor.test_torchinductor.CudaTests)'` Differential Revision: D41206109 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88870 Approved by: https://github.com/izaitsevfb --- torch/_inductor/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 87e2793782be..8f9f2c4f461d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -92,6 +92,7 @@ class cpp: "g++-10", "clang++", "g++", + "g++.par", ) From be8d88f8d0c6825b1b19354ffbaa4466aae0d3b8 Mon Sep 17 00:00:00 2001 From: Kevin Tse Date: Thu, 10 Nov 2022 18:33:09 -0500 Subject: [PATCH 082/453] [DataLoader] Removing DataLoader2 related code (#88848) Removing these lines of code as `DataLoader2` has been added to [TorchData](https://github.com/pytorch/data). I'm importing this to confirm it will not impact internal codes. Differential Revision: [D41201578](https://our.internmc.facebook.com/intern/diff/D41201578) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88848 Approved by: https://github.com/ejguan --- docs/source/data.rst | 3 - test/test_dataloader.py | 111 ----------- torch/utils/data/__init__.py | 5 - torch/utils/data/communication/__init__.py | 6 - torch/utils/data/communication/eventloop.py | 70 ------- torch/utils/data/communication/iter.py | 181 ----------------- torch/utils/data/communication/map.py | 159 --------------- torch/utils/data/communication/messages.py | 75 ------- torch/utils/data/communication/protocol.py | 205 -------------------- torch/utils/data/communication/queue.py | 51 ----- torch/utils/data/dataloader_experimental.py | 150 -------------- 11 files changed, 1016 deletions(-) delete mode 100644 torch/utils/data/communication/__init__.py delete mode 100644 torch/utils/data/communication/eventloop.py delete mode 100644 torch/utils/data/communication/iter.py delete mode 100644 torch/utils/data/communication/map.py delete mode 100644 torch/utils/data/communication/messages.py delete mode 100644 torch/utils/data/communication/protocol.py delete mode 100644 torch/utils/data/communication/queue.py delete mode 100644 torch/utils/data/dataloader_experimental.py diff --git a/docs/source/data.rst b/docs/source/data.rst index de2d44920f57..b44096d10196 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -441,9 +441,6 @@ Example:: .. autoclass:: torch.utils.data.distributed.DistributedSampler -.. This module is experimental and should be private, adding it here for now -.. py:module:: torch.utils.data.communication - .. These modules are documented as part of torch/data listing them here for .. now until we have a clearer fix .. py:module:: torch.utils.data.datapipes diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 270ca89764ed..6a7ff90527d3 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -20,19 +20,16 @@ ChainDataset, ConcatDataset, DataLoader, - DataLoader2, Dataset, IterableDataset, IterDataPipe, Subset, TensorDataset, - communication, _utils ) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split from torch.utils.data.datapipes.iter import IterableWrapper -from torch.utils.data.datapipes.map import SequenceWrapper from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, @@ -2222,114 +2219,6 @@ def test_excessive_thread_creation_warning(self): r"excessive worker creation might get DataLoader running slow or even freeze"): dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) -# Define a global function for testing purposes since local functions cannot be pickled -def identity(x): - return x - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2(TestCase): - @skipIfNoDill - def test_basics(self): - # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order - # of traversing workers - dp = IterableWrapper(list(range(1000))).sharding_filter() - dl = DataLoader(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2 = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2, parallelism_mode='thread') - self.assertEqual(list(dl), list(dl2)) - self.assertEqual(list(dl), list(dl2_threading)) - - class Sorter(IterDataPipe): - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - return iter(sorted(self.datapipe)) - - def test_shuffle(self): - items = list(range(1000)) - dp = IterableWrapper(items).sharding_filter().shuffle() - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=False) - self.assertEqual(items, list(dl)) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2_EventLoop(TestCase): - @skipIfNoDill - def test_basic_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - it = list(range(100)) - numbers_dp = IterableWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) - - process.start() - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - - actual = list(local_datapipe) - clean_me(process, req_queue, res_queue) - - self.assertEqual(list(range(100)), actual) - - @skipIfNoDill - def test_basic_mapdatapipe_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - input_len = 100 - it = list(range(input_len)) - numbers_dp = SequenceWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline( - numbers_dp) - - process.start() - - # Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - actual = list(local_datapipe) - self.assertEqual([(x, x) for x in range(100)], actual) - - # Functional Test: raise Error when input - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - with self.assertRaisesRegex(IndexError, "out of bound"): - local_datapipe[1000] - - # __len__ Test: Ensure that the correct length is returned - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - self.assertEqual(input_len, len(local_datapipe)) - - clean_me(process, req_queue, res_queue) - class IntegrationTestDataLoaderDataPipe(TestCase): r""" diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 6fe6147ddc54..bc054a947069 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -39,8 +39,6 @@ runtime_validation, runtime_validation_disabled, ) -from torch.utils.data.dataloader_experimental import DataLoader2 -from torch.utils.data import communication __all__ = ['BatchSampler', 'ChainDataset', @@ -48,7 +46,6 @@ 'DFIterDataPipe', 'DataChunk', 'DataLoader', - 'DataLoader2', 'Dataset', 'DistributedSampler', 'IterDataPipe', @@ -63,8 +60,6 @@ 'WeightedRandomSampler', '_DatasetKind', 'argument_validation', - 'collate', - 'communication', 'default_collate', 'default_convert', 'functional_datapipe', diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py deleted file mode 100644 index 1b9cae401189..000000000000 --- a/torch/utils/data/communication/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import eventloop -from . import iter -from . import map -from . import messages -from . import protocol -from . import queue diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py deleted file mode 100644 index 9bf241d334df..000000000000 --- a/torch/utils/data/communication/eventloop.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import threading -import pickle - -from torch.utils.data import IterDataPipe, communication, MapDataPipe - -try: - import dill - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -__all__ = [ - "DataPipeToQueuesLoop", - "SpawnProcessForDataPipeline", - "SpawnThreadForDataPipeline", -] - -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): - if isinstance(source_datapipe, IterDataPipe): - pipe_type = communication.iter - protocol_type = communication.protocol.IterDataPipeQueueProtocolServer - elif isinstance(source_datapipe, MapDataPipe): - pipe_type = communication.map # type: ignore[misc] - protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] - else: - raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe) - - torch.set_num_threads(1) - for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), - blocking_request_get=True): - pass - - -def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): - req_queue = multiprocessing_ctx.Queue() - res_queue = multiprocessing_ctx.Queue() - process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) - return process, req_queue, res_queue - - -def SpawnThreadForDataPipeline(datapipe): - r""" - Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target, - and return the process, req_queue, res_queue, thread_local_datapipe. - """ - req_queue = communication.queue.ThreadingQueue() - res_queue = communication.queue.ThreadingQueue() - - try: - new_datapipe = pickle.loads(pickle.dumps(datapipe)) - except Exception as pe: - if HAS_DILL: - try: - new_datapipe = dill.loads(dill.dumps(datapipe)) - except Exception as de: - raise Exception('Unable to dill DataPipe to make thread local copy', de) - - else: - raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe) - - process = threading.Thread(target=DataPipeToQueuesLoop, args=( - new_datapipe, req_queue, res_queue), daemon=True) - return process, req_queue, res_queue, new_datapipe diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py deleted file mode 100644 index 94f7cd2ec703..000000000000 --- a/torch/utils/data/communication/iter.py +++ /dev/null @@ -1,181 +0,0 @@ -import time -import types - -from torch.utils.data import IterDataPipe, communication - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingDataPipe", - "InvalidStateResetRequired", - "NonBlocking", - "NotAvailable", - "QueueWrapper", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class InvalidStateResetRequired(Exception): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass - - -class NonBlocking(IterDataPipe): - not_available_hook = default_not_available_hook - - def __iter__(self): - self.reset_iterator() - return self - - def __next__(self): - while True: - try: - return self.nonblocking_next() - except StopIteration: - raise StopIteration - except NotAvailable: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - raise NotImplementedError( - "nonblocking_next is not implemented for %s" % self.__class__) - - def reset_iterator(self): - raise NotImplementedError( - "reset_iterator is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlocking.not_available_hook = hook_function - - -def EnsureNonBlockingDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, IterDataPipe): - raise Exception('Not Iterable DataPipe ' + - str(validated_datapipe.__class__)) - if isinstance(validated_datapipe, NonBlocking): - return validated_datapipe - if not hasattr(validated_datapipe, '_as_iterator'): - validated_datapipe._as_iterator = None # type: ignore[attr-defined] - if not hasattr(validated_datapipe, 'nonblocking_next'): - def nonblocking_next(self): - if self._as_iterator is None: - self._as_iterator = iter(self) - return next(self._as_iterator) - validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] - nonblocking_next, validated_datapipe) - if not hasattr(validated_datapipe, 'reset_iterator'): - def reset_iterator(self): - self._as_iterator = None - validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] - reset_iterator, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): - raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.ResetIteratorRequest): - source_datapipe.reset_iterator() - protocol.response_reset_iterator() - - elif isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.GetNextRequest): - while forever: - try: - value = source_datapipe.nonblocking_next() - except NotAvailable: - yield True - continue - except StopIteration: - protocol.response_stop_iteration() - if full_stop: - forever = False - else: - yield True - break - except InvalidStateResetRequired: - protocol.response_invalid_state() - if full_stop: - forever = False - else: - yield True - break - protocol.response_next(value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapper(NonBlocking): - """ - Creates iter.DataPipe which reads data from the DataLoader.Queue - """ - - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def reset_iterator(self): - self._stop_iteration = False - self.counter = 0 - self.protocol.request_reset_iterator() - while True: - try: - self.protocol.get_response_reset_iterator() - break - except communication.protocol.EmptyQueue: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - if self._stop_iteration: - raise Exception( - '`next` or `nonblocking_next` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_next() - try: - response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise StopIteration - if isinstance(response, communication.messages.InvalidStateResponse): - raise NotAvailable - return response.value diff --git a/torch/utils/data/communication/map.py b/torch/utils/data/communication/map.py deleted file mode 100644 index 8af63bf0c73e..000000000000 --- a/torch/utils/data/communication/map.py +++ /dev/null @@ -1,159 +0,0 @@ -import time -import types - -from torch.utils.data import communication, MapDataPipe - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingMapDataPipe", - "NonBlockingMap", - "NotAvailable", - "QueueWrapperForMap", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class NonBlockingMap(MapDataPipe): - not_available_hook = default_not_available_hook - - def __getitem__(self, index): - while True: - try: - return self.nonblocking_getitem(index) - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def __len__(self): - try: - return self.nonblocking_len() - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def nonblocking_len(self): - raise NotImplementedError( - "nonblocking_len is not implemented for %s" % self.__class__) - - def nonblocking_getitem(self, index): - raise NotImplementedError( - "nonblocking_getitem is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlockingMap.not_available_hook = hook_function - - -def EnsureNonBlockingMapDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, MapDataPipe): - raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}') - if isinstance(validated_datapipe, NonBlockingMap): - return validated_datapipe - if not hasattr(validated_datapipe, 'nonblocking_len'): - def nonblocking_len(self): - return self.__len__() - validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] - nonblocking_len, validated_datapipe) - if not hasattr(validated_datapipe, 'nonblocking_getitem'): - def nonblocking_getitem(self, index): - return self.__getitem__(index) - validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] - nonblocking_getitem, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): - raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.LenRequest): - size = source_datapipe.nonblocking_len() - protocol.response_len(size) - - elif isinstance(request, communication.messages.GetItemRequest): - while forever: - try: - value = source_datapipe.nonblocking_getitem(request.key) - except NotAvailable: - yield True - continue - except IndexError as e: - # Alternatively, we can just allow the underlying DataPipe to throw an exception? - protocol.response_index_out_of_bound() - if full_stop: - forever = False - else: - yield True - break - protocol.response_item(request.key, value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapperForMap(NonBlockingMap): - """ - Creates map.DataPipe which reads data from the DataLoader.Queue - """ - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def nonblocking_getitem(self, index): - if self._stop_iteration: - raise Exception( - '`getitem` or `nonblocking_getitem` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_item(index) - try: - response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise IndexError(f"Index {index} is out of bound.") - return response.key, response.value - - def nonblocking_len(self): - if self._stop_iteration: - raise Exception( - '`len` or `nonblocking_len` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_len() - try: - response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - return response.len diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py deleted file mode 100644 index 449cf23cfc01..000000000000 --- a/torch/utils/data/communication/messages.py +++ /dev/null @@ -1,75 +0,0 @@ -class DataLoaderQueueMessage(object): - pass - - -class Request(DataLoaderQueueMessage): - pass - - -class Response(DataLoaderQueueMessage): - pass - - -class ResetIteratorRequest(Request): - pass - - -class ResetIteratorResponse(Response): - pass - - -class TerminateRequest(Request): - pass - - -class TerminateResponse(Response): - pass - - -class LenRequest(Request): - pass - - -class LenResponse(Response): - __slots__ = ('len') - - def __init__(self, len): - self.len = len - - -class GetItemRequest(Request): - __slots__ = ('key') - - def __init__(self, key): - self.key = key - - -class GetItemResponse(Response): - __slots__ = ('key', 'value') - - def __init__(self, key, value): - self.key = key - self.value = value - - -class GetNextRequest(Request): - pass - - -class GetNextResponse(Response): - __slots__ = ('value') - - def __init__(self, value): - self.value = value - - -class StopIterationResponse(Response): - pass - - -class InvalidStateResponse(Response): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py deleted file mode 100644 index 5bf5fe1af062..000000000000 --- a/torch/utils/data/communication/protocol.py +++ /dev/null @@ -1,205 +0,0 @@ -from torch.utils.data import communication - - -class Protocol(object): - __slots__ = ('request_queue', 'response_queue') - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - - -class ProtocolClient(Protocol): - """ - ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. - """ - _req_sent = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_sent = None - - def can_take_request(self): - return self._req_sent is None - - def waiting_for_response(self): - return self._req_sent is not None - - def request_sent(self, request=True): - if not self.can_take_request(): - raise Exception('Protocol only supports one request in the Queue') - self._req_sent = request - - def request_served(self, result=None): - if not self.waiting_for_response(): - raise Exception( - 'Expected no peding requests, but something got served', result) - self._req_sent = None - - -class ProtocolServer(Protocol): - """ - ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. - """ - _req_received = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_received = None - - def have_pending_request(self): - return self._req_received is not None - - def get_new_request(self, block=False): - if self.have_pending_request(): - raise Exception( - 'Trying to get next request, while having one unserved') - try: - response = self.request_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self._req_received = response - return response - # TODO: Validate supported requests - - def response_terminate(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.TerminateRequest): - raise Exception( - "Replaying with terminate status to other type of message") - self.response_queue.put(communication.messages.TerminateResponse()) - self._req_received = None - - -class MapDataPipeQueueProtocolServer(ProtocolServer): - def response_item(self, key, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetItemResponse(key, value)) - self._req_received = None - - def response_len(self, size): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.LenResponse(size)) - self._req_received = None - - def response_index_out_of_bound(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - -class MapDataPipeQueueProtocolClient(ProtocolClient): - def request_len(self): - if not self.can_take_request(): - raise Exception('Can not request len while we are still waiting response for previous request') - request = communication.messages.LenRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_item(self, index): - if not self.can_take_request(): - raise Exception('Can not request item while we are still waiting response for previous request') - request = communication.messages.GetItemRequest(index) - self.request_queue.put(request) - self.request_sent(request) - - def get_response_len(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - if not isinstance(response, communication.messages.LenResponse): - raise Exception('Invalid response received') - return response - - def get_response_item(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - # if not isinstance(response, communication.messages.GetItemResponse): - # raise Exception('Invalid response received') - return response - - -class EmptyQueue(Exception): - pass - - -class IterDataPipeQueueProtocolServer(ProtocolServer): - def response_reset_iterator(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.ResetIteratorRequest): - raise Exception( - "Replaying with reset status to other type of message") - self.response_queue.put(communication.messages.ResetIteratorResponse()) - self._req_received = None - - def response_next(self, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetNextResponse(value)) - self._req_received = None - - def response_stop_iteration(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - - def response_invalid_state(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.InvalidStateResponse()) - self._req_received = None - - -class IterDataPipeQueueProtocolClient(ProtocolClient): - def request_reset_iterator(self): - if not self.can_take_request(): - raise Exception('Can not reset while we are still waiting response for previous request') - request = communication.messages.ResetIteratorRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_next(self): - if not self.can_take_request(): - raise Exception('Can not request next item while we are still waiting response for previous request') - request = communication.messages.GetNextRequest() - self.request_queue.put(request) - self.request_sent(request) - - def get_response_reset_iterator(self, block=False): - try: - response = self.response_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - if not isinstance(response, communication.messages.ResetIteratorResponse): - raise Exception('Invalid response received') - - def get_response_next(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception( - 'Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - # TODO(VitalyFedyunin): Add possible response types validation here - return response diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py deleted file mode 100644 index 85c33d4799cd..000000000000 --- a/torch/utils/data/communication/queue.py +++ /dev/null @@ -1,51 +0,0 @@ -import threading -import time - - -class LocalQueue(): - ops = 0 - stored = 0 - uid = 0 - empty = 0 - - def __init__(self, name='unnamed'): - self.items = [] - self.name = name - self.uid = LocalQueue.uid - LocalQueue.uid += 1 - - def put(self, item, block=True): - LocalQueue.ops += 1 - LocalQueue.stored += 1 - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - LocalQueue.ops += 1 - if not len(self.items): - LocalQueue.empty += 1 - raise Exception('LocalQueue is empty') - LocalQueue.stored -= 1 - return self.items.pop() - - -class ThreadingQueue(): - def __init__(self, name='unnamed'): - self.lock = threading.Lock() - self.items = [] - self.name = name - - def put(self, item, block=True): - with self.lock: - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - while True: - with self.lock: - if len(self.items) > 0: - return self.items.pop() - if not block: - raise Exception("Not available") - # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue - time.sleep(0.000001) diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py deleted file mode 100644 index 8a8d536b7985..000000000000 --- a/torch/utils/data/dataloader_experimental.py +++ /dev/null @@ -1,150 +0,0 @@ -import time - -from typing import Any, List - -import torch.utils.data.backward_compatibility - -import torch.utils.data.graph_settings -from torch.utils.data import DataLoader, IterDataPipe, communication -from torch.utils.data.datapipes.iter import IterableWrapper - -__all__ = [ - "DataLoader2", -] - - -class _ThreadingDataLoader2: - - def __init__(self, datapipe, num_workers=0, collate_fn=None): - self.threads = [] - self.datapipes = [] - self.collate_fn = collate_fn - for worker_id in range(num_workers): - (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) - torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id) - thread.start() - self.threads.append((thread, req_queue, res_queue)) # These queues are independent - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - self.datapipes.append(local_datapipe) - - def __iter__(self): - not_available = False - forever = True - exclude_datapipes: List[Any] = [] - while len(exclude_datapipes) < len(self.datapipes): - for dp in self.datapipes: - if dp not in exclude_datapipes: - try: - value = dp.nonblocking_next() - yield value - except StopIteration: - exclude_datapipes.append(dp) - except communication.iter.NotAvailable: - not_available = True - if not_available: - time.sleep(0.001) - - def __del__(self): - self._cleanup_all_threads() - - def _cleanup_all_threads(self): - def clean_me(thread, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - thread.join() - - for thread, req_queue, res_queue in self.threads: - clean_me(thread, req_queue, res_queue) - -class DataLoader2: - def __new__(cls, - dataset, - batch_size=1, - shuffle=None, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - *, - prefetch_factor=2, - persistent_workers=False, - batch_outside_worker=False, - parallelism_mode='mp'): - if isinstance(dataset, IterDataPipe): - data_loader: Any = None - if batch_sampler is not None: - raise Exception( - 'batch_sampler is not yet supported by DataPipes') - if sampler is not None: - raise Exception( - 'sampler is not yet supported by DataPipes') - datapipe = dataset - datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment] - if batch_outside_worker and pin_memory: - raise Exception( - 'pin_memory is not yet compatible with batch_outside_worker') - if not batch_outside_worker: - if batch_size is not None: - datapipe = datapipe.batch(batch_size, drop_last=drop_last) - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - - # Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing - # for Iterable, but required to set Pipes correctly. - data_loader = DataLoader(datapipe, - batch_size=None, # Replaced by .batch DataPipe - shuffle=shuffle, - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=False, # Replaced by .batch DataPipe - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) - elif parallelism_mode == 'thread': - if collate_fn is not None and not batch_outside_worker: - datapipe = datapipe.map(collate_fn) - if pin_memory: - raise Exception( - 'pin_memory is not yet supported by DataPipes with Threading') - if worker_init_fn is not None: - raise Exception( - 'worker_init_fn is not yet supported by DataPipes with Threading') - data_loader = _ThreadingDataLoader2(datapipe, - num_workers=num_workers, - collate_fn=collate_fn) - else: - raise Exception('Unsupported parallelism mode', parallelism_mode) - if not batch_outside_worker: - return data_loader - else: - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - datapipe = IterableWrapper(data_loader).batch( - batch_size, drop_last=drop_last).map(collate_fn) - return datapipe - else: - if parallelism_mode == 'thread': - raise Exception( - 'thread parallelism mode is not supported for old DataSets') - return DataLoader(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) From 6fe47b682fe1ba2dd2c7da02ff1bb06f8670e3a7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 11 Nov 2022 22:31:32 +0000 Subject: [PATCH 083/453] [Dynamo] Fix str(Guard.obj_weakref) bug to re-ennable support overriding __getattr__ (#88564) See my inline comments! Pull Request resolved: https://github.com/pytorch/pytorch/pull/88564 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- test/dynamo/test_misc.py | 2 -- torch/_dynamo/guards.py | 27 ++++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 4df7153b8fb2..a8bf86e46411 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -579,8 +579,6 @@ def fn(count): self.assertEqual(cnts.frame_count, 0) self.assertEqual(cnts.op_count, 0) - # KeyError: '__name__' - @patch.object(torch._dynamo.config, "suppress_errors", True) def test_user_getattr1(self): class MyConfig(dict): def __getattr__(self, name): diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9edd6f60560d..382734412b2b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -101,13 +101,38 @@ def sort_key(self): def __lt__(self, other): return self.sort_key() < other.sort_key() + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + def __str__(self): s = f""" {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__} {{ 'guard_types': {self.guard_types}, 'code': {self.code_list}, - 'obj_weakref': {self.obj_weakref} + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 'guarded_class': {self.guarded_class_weakref} }} """ From a7fa423f48af8af220e9286a6b4c374d533f77e0 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Fri, 11 Nov 2022 14:41:35 +0000 Subject: [PATCH 084/453] copy_: Short-circuit when self and src view the same data (#88884) This comes up if you use inplace operators on a slice, e.g. ```python import torch a = torch.rand(1000000, device="cuda") a[::2] *= 2 ``` The last line looks as if it should be fully inplace, but is actually equivalent to: ```python tmp = a[::2] tmp *= 2 a[::2] = tmp ``` Which results in `mul_` and `copy_` being called. With this PR, the redundant copy becomes a no-op and the above example is 2x faster. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88884 Approved by: https://github.com/ngimel --- aten/src/ATen/native/Copy.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index a44f39c5bb2e..c6b82426d3bf 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -220,6 +220,18 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return at::metal::metal_copy_(self, src); } + // Exit early if self and src are views of the same data + const bool is_same_data = ( + self.is_alias_of(src) && + self.storage_offset() == src.storage_offset() && + self.strides().equals(src.strides()) && + self.sizes().equals(src.sizes()) && + self.scalar_type() == src.scalar_type() + ); + if (is_same_data) { + return self; + } + auto iter = TensorIteratorConfig() .add_output(self) From 7c3adddd6c3fe1bda4a9e5bfb9f992a802329551 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Wed, 9 Nov 2022 12:20:16 -0800 Subject: [PATCH 085/453] [functorch] delete some unused files (#88763) Some post-merge cleanup. - packaging/ was for building standalone windows binaries - our flake8 config got superceded by PyTorch's. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88763 Approved by: https://github.com/samdow --- functorch/.flake8 | 20 - functorch/packaging/build_wheel.sh | 19 - functorch/packaging/pkg_helpers.bash | 414 ------------------ .../windows/internal/cuda_install.bat | 264 ----------- .../windows/internal/driver_update.bat | 25 -- .../windows/internal/vc_env_helper.bat | 43 -- .../windows/internal/vc_install_helper.sh | 16 - 7 files changed, 801 deletions(-) delete mode 100644 functorch/.flake8 delete mode 100644 functorch/packaging/build_wheel.sh delete mode 100644 functorch/packaging/pkg_helpers.bash delete mode 100644 functorch/packaging/windows/internal/cuda_install.bat delete mode 100644 functorch/packaging/windows/internal/driver_update.bat delete mode 100644 functorch/packaging/windows/internal/vc_env_helper.bat delete mode 100644 functorch/packaging/windows/internal/vc_install_helper.sh diff --git a/functorch/.flake8 b/functorch/.flake8 deleted file mode 100644 index a6d73773e3b5..000000000000 --- a/functorch/.flake8 +++ /dev/null @@ -1,20 +0,0 @@ -[flake8] -select = B,C,E,F,P,T4,W,B9 -max-line-length = 120 -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # shebang has extra meaning in fbcode lints, so I think it's not worth trying - # to line this up with executable bit - EXE001, - # these ignores are from flake8-bugbear; please fix! - B007,B008, - # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -exclude = - ./.git, - ./benchmarks, - ./docs, - ./examples, - ./notebooks diff --git a/functorch/packaging/build_wheel.sh b/functorch/packaging/build_wheel.sh deleted file mode 100644 index 074e7dde7714..000000000000 --- a/functorch/packaging/build_wheel.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -set -ex - -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "$script_dir/pkg_helpers.bash" - -export BUILD_TYPE=wheel -setup_env 0.2.0 -setup_wheel_python -pip_install numpy pyyaml future ninja -pip_install --upgrade setuptools -setup_pip_pytorch_version -python setup.py clean - -if [[ "$OSTYPE" == "msys" ]]; then - "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel -else - python setup.py bdist_wheel -fi diff --git a/functorch/packaging/pkg_helpers.bash b/functorch/packaging/pkg_helpers.bash deleted file mode 100644 index 329891a07216..000000000000 --- a/functorch/packaging/pkg_helpers.bash +++ /dev/null @@ -1,414 +0,0 @@ -# A set of useful bash functions for common functionality we need to do in -# many build scripts - - -# Setup CUDA environment variables, based on CU_VERSION -# -# Inputs: -# CU_VERSION (cpu, cu92, cu100) -# NO_CUDA_PACKAGE (bool) -# BUILD_TYPE (conda, wheel) -# -# Outputs: -# VERSION_SUFFIX (e.g., "") -# PYTORCH_VERSION_SUFFIX (e.g., +cpu) -# WHEEL_DIR (e.g., cu100/) -# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) -# FORCE_CUDA (respected by torchvision setup.py) -# NVCC_FLAGS (respected by torchvision setup.py) -# -# Precondition: CUDA versions are installed in their conventional locations in -# /usr/local/cuda-* -# -# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building -# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == -# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a -# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always -# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU -# version of a Python package. But that doesn't apply if you're on OS X, -# since the default CU_VERSION on OS X is cpu. -setup_cuda() { - - # First, compute version suffixes. By default, assume no version suffixes - export VERSION_SUFFIX="" - export PYTORCH_VERSION_SUFFIX="" - export WHEEL_DIR="" - # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) - if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then - export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" - # Match the suffix scheme of pytorch, unless this package does not have - # CUDA builds (in which case, use default) - if [[ -z "$NO_CUDA_PACKAGE" ]]; then - export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" - export WHEEL_DIR="$CU_VERSION/" - fi - fi - - # Now work out the CUDA settings - case "$CU_VERSION" in - cu115) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5" - else - export CUDA_HOME=/usr/local/cuda-11.5/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu113) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" - else - export CUDA_HOME=/usr/local/cuda-11.3/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu112) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" - else - export CUDA_HOME=/usr/local/cuda-11.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu111) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" - else - export CUDA_HOME=/usr/local/cuda-11.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu110) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" - else - export CUDA_HOME=/usr/local/cuda-11.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" - ;; - cu102) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" - else - export CUDA_HOME=/usr/local/cuda-10.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu101) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" - else - export CUDA_HOME=/usr/local/cuda-10.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu100) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0" - else - export CUDA_HOME=/usr/local/cuda-10.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu92) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" - else - export CUDA_HOME=/usr/local/cuda-9.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" - ;; - cpu) - ;; - rocm*) - export FORCE_CUDA=1 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - if [[ -n "$CUDA_HOME" ]]; then - # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one - export PATH="$CUDA_HOME/bin:$PATH" - export FORCE_CUDA=1 - fi -} - -# Populate build version if necessary, and add version suffix -# -# Inputs: -# BUILD_VERSION (e.g., 0.2.0 or empty) -# VERSION_SUFFIX (e.g., +cpu) -# -# Outputs: -# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) -# -# Fill BUILD_VERSION if it doesn't exist already with a nightly string -# Usage: setup_build_version 0.2.0 -setup_build_version() { - if [[ -z "$BUILD_VERSION" ]]; then - export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" - else - export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" - fi - - # Set build version based on tag if on tag - if [[ -n "${CIRCLE_TAG}" ]]; then - # Strip tag - export BUILD_VERSION="$(echo "${CIRCLE_TAG}" | sed -e 's/^v//' -e 's/-.*$//')${VERSION_SUFFIX}" - fi -} - -# Set some useful variables for OS X, if applicable -setup_macos() { - if [[ "$(uname)" == Darwin ]]; then - export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ - fi -} - - -# Top-level entry point for things every package will need to do -# -# Usage: setup_env 0.2.0 -setup_env() { - setup_cuda - setup_build_version "$1" - setup_macos -} - -# Function to retry functions that sometimes timeout or have flaky failures -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Inputs: -# PYTHON_VERSION (3.7, 3.8, 3.9) -# UNICODE_ABI (bool) -# -# Outputs: -# PATH modified to put correct Python version in PATH -# -# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image -setup_wheel_python() { - if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then - eval "$(conda shell.bash hook)" - conda env remove -n "env$PYTHON_VERSION" || true - conda create ${CONDA_CHANNEL_FLAGS} -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" - conda activate "env$PYTHON_VERSION" - # Install libpng from Anaconda (defaults) - conda install ${CONDA_CHANNEL_FLAGS} libpng "jpeg<=9b" -y - else - # Install native CentOS libJPEG, freetype and GnuTLS - yum install -y libjpeg-turbo-devel freetype gnutls - case "$PYTHON_VERSION" in - 3.7) python_abi=cp37-cp37m ;; - 3.8) python_abi=cp38-cp38 ;; - 3.9) python_abi=cp39-cp39 ;; - 3.10) python_abi=cp310-cp310 ;; - *) - echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - ;; - esac - # Download all the dependencies required to compile image and video_reader - # extensions - - mkdir -p ext_libraries - pushd ext_libraries - popd - export PATH="/opt/python/$python_abi/bin:$(pwd)/ext_libraries/bin:$PATH" - fi -} - -# Install with pip a bit more robustly than the default -pip_install() { - retry pip install --progress-bar off "$@" -} - -# Install torch with pip, respecting PYTORCH_VERSION, and record the installed -# version into PYTORCH_VERSION, if applicable -setup_pip_pytorch_version() { - if [[ -z "$PYTORCH_VERSION" ]]; then - # Install latest prerelease version of torch, per our nightlies, consistent - # with the requested cuda version - pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # CUDA and CPU are ABI compatible on the CPU-only parts, so strip - # in this case - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" - else - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//')" - fi - else - pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ - -f "https://download.pytorch.org/whl/${CU_VERSION}/torch_stable.html" \ - -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" - fi -} - -# Fill PYTORCH_VERSION with the latest conda nightly version, and -# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions -# -# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. -setup_conda_pytorch_constraint() { - if [[ -z "$PYTORCH_VERSION" ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly -c pytorch" - export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | \ - python -c "import os, sys, json, re; cuver = os.environ.get('CU_VERSION'); \ - cuver_1 = cuver.replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - cuver_2 = (cuver[:-1] + '.' + cuver[-1]).replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - print(re.sub(r'\\+.*$', '', \ - [x['version'] for x in json.load(sys.stdin)['pytorch'] \ - if (x['platform'] == 'darwin' or cuver_1 in x['fn'] or cuver_2 in x['fn']) \ - and 'py' + os.environ['PYTHON_VERSION'] in x['fn']][-1]))")" - if [[ -z "$PYTORCH_VERSION" ]]; then - echo "PyTorch version auto detection failed" - echo "No package found for CU_VERSION=$CU_VERSION and PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - fi - else - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}" - fi - if [[ "$CU_VERSION" == cpu ]]; then - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" - else - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - fi - if [[ "$OSTYPE" == msys && "$CU_VERSION" == cu92 ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev" - fi -} - -# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT -setup_conda_cudatoolkit_constraint() { - export CONDA_BUILD_VARIANT="cuda" - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" - ;; - cu110) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -setup_conda_cudatoolkit_plain_constraint() { - export CONDA_BUILD_VARIANT="cuda" - export CMAKE_USE_CUDA=1 - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.5" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.2" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.1" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.2" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.1" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.0" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=9.2" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -# Build the proper compiler package before building the final package -setup_visual_studio_constraint() { - if [[ "$OSTYPE" == "msys" ]]; then - export VSTOOLCHAIN_PACKAGE=vs$VC_YEAR - conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE - cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchvision/conda_build_config.yaml - fi -} - -setup_junit_results_folder() { - if [[ "$CI" == "true" ]]; then - export CONDA_PYTORCH_BUILD_RESULTS_DIRECTORY="${SOURCE_ROOT_DIR}/build_results/results.xml" - fi -} - - -download_copy_ffmpeg() { - if [[ "$OSTYPE" == "msys" ]]; then - # conda install -yq ffmpeg=4.2 -c pytorch - # curl -L -q https://anaconda.org/pytorch/ffmpeg/4.3/download/win-64/ffmpeg-4.3-ha925a31_0.tar.bz2 --output ffmpeg-4.3-ha925a31_0.tar.bz2 - # bzip2 --decompress --stdout ffmpeg-4.3-ha925a31_0.tar.bz2 | tar -x --file=- - # cp Library/bin/*.dll ../torchvision - echo "FFmpeg is disabled currently on Windows" - else - if [[ "$(uname)" == Darwin ]]; then - conda install -yq ffmpeg=4.2 -c pytorch - conda install -yq wget - else - # pushd ext_libraries - # wget -q https://anaconda.org/pytorch/ffmpeg/4.2/download/linux-64/ffmpeg-4.2-hf484d3e_0.tar.bz2 - # tar -xjvf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # rm -rf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # ldconfig - # which ffmpeg - # popd - echo "FFmpeg is disabled currently on Linux" - fi - fi -} diff --git a/functorch/packaging/windows/internal/cuda_install.bat b/functorch/packaging/windows/internal/cuda_install.bat deleted file mode 100644 index 41960224ebae..000000000000 --- a/functorch/packaging/windows/internal/cuda_install.bat +++ /dev/null @@ -1,264 +0,0 @@ -@echo on - -if "%CU_VERSION%" == "cpu" ( - echo Skipping for CPU builds - exit /b 0 -) - -set SRC_DIR=%~dp0\.. - -if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" - -rem in unit test workflow, we get CUDA_VERSION, for example 11.1 -if defined CUDA_VERSION ( - set CUDA_VER=%CUDA_VERSION:.=% -) else ( - set CUDA_VER=%CU_VERSION:cu=% -) - -set /a CUDA_VER=%CU_VERSION:cu=% -set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% -set CUDA_VER_MINOR=%CUDA_VER:~-1,1% -set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% - - -if %CUDA_VER% EQU 92 goto cuda92 -if %CUDA_VER% EQU 100 goto cuda100 -if %CUDA_VER% EQU 101 goto cuda101 -if %CUDA_VER% EQU 102 goto cuda102 -if %CUDA_VER% EQU 110 goto cuda110 -if %CUDA_VER% EQU 111 goto cuda111 -if %CUDA_VER% EQU 112 goto cuda112 -if %CUDA_VER% EQU 113 goto cuda113 -if %CUDA_VER% EQU 115 goto cuda115 - - -echo CUDA %CUDA_VERSION_STR% is not supported -exit /b 1 - -:cuda92 -if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" -) - -goto cuda_common - -:cuda100 - -if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" -) - -goto cuda_common - -:cuda101 - -if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" -) - -goto cuda_common - -:cuda102 - -if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" -) - -rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. -if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( - curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" - if errorlevel 1 exit /b 1 -) - -echo Installing GPU driver DLLs -7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" - -goto cuda_common - -:cuda110 - -if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" -) - -goto cuda_common - -:cuda111 - -if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" -) - -goto cuda_common - -:cuda112 - -if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( - curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" -) - -goto cuda_common - -:cuda113 - -set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" - -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda115 - -set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda_common - -if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( - curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" - if errorlevel 1 exit /b 1 -) - -echo Installing CUDA toolkit... -7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" -pushd "%SRC_DIR%\temp_build\cuda" -sc config wuauserv start= disabled -sc stop wuauserv -sc query wuauserv - -start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" -echo %errorlevel% - -popd - -echo Installing VS integration... -rem It's for VS 2019 -if "%CUDA_VER_MAJOR%" == "10" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) -if "%CUDA_VER_MAJOR%" == "11" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) - -echo Installing NvToolsExt... -7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" - -echo Setting up environment... -set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" -set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" - -if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( - echo CUDA %CUDA_VERSION_STR% installed failed. - echo --------- RunDll32.exe.log - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" - echo --------- setup.exe.log ------- - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" - exit /b 1 -) - -echo Installing cuDNN... -7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" - -echo Cleaning temp files -rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/functorch/packaging/windows/internal/driver_update.bat b/functorch/packaging/windows/internal/driver_update.bat deleted file mode 100644 index 00b43affc01c..000000000000 --- a/functorch/packaging/windows/internal/driver_update.bat +++ /dev/null @@ -1,25 +0,0 @@ -set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -if errorlevel 1 exit /b 1 - -start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot -if errorlevel 1 exit /b 1 - -del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL - -setlocal EnableDelayedExpansion -set NVIDIA_GPU_EXISTS=0 -for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( - set GPUS=%%i - if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( - SET NVIDIA_GPU_EXISTS=1 - goto gpu_check_end - ) -) -:gpu_check_end -endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% - -if "%NVIDIA_GPU_EXISTS%" == "0" ( - echo "CUDA Driver installation Failed" - exit /b 1 -) diff --git a/functorch/packaging/windows/internal/vc_env_helper.bat b/functorch/packaging/windows/internal/vc_env_helper.bat deleted file mode 100644 index e85a372f93d5..000000000000 --- a/functorch/packaging/windows/internal/vc_env_helper.bat +++ /dev/null @@ -1,43 +0,0 @@ -@echo on - -set VC_VERSION_LOWER=16 -set VC_VERSION_UPPER=17 -if "%VC_YEAR%" == "2017" ( - set VC_VERSION_LOWER=15 - set VC_VERSION_UPPER=16 -) - -for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( - if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( - set "VS15INSTALLDIR=%%i" - set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" - goto vswhere - ) -) - -:vswhere -if "%VSDEVCMD_ARGS%" == "" ( - call "%VS15VCVARSALL%" x64 || exit /b 1 -) else ( - call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 -) - -@echo on - -set DISTUTILS_USE_SDK=1 - -set args=%1 -shift -:start -if [%1] == [] goto done -set args=%args% %1 -shift -goto start - -:done -if "%args%" == "" ( - echo Usage: vc_env_helper.bat [command] [args] - echo e.g. vc_env_helper.bat cl /c test.cpp -) - -%args% || exit /b 1 diff --git a/functorch/packaging/windows/internal/vc_install_helper.sh b/functorch/packaging/windows/internal/vc_install_helper.sh deleted file mode 100644 index cdae18065b9f..000000000000 --- a/functorch/packaging/windows/internal/vc_install_helper.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -set -ex - -if [[ "$CU_VERSION" == "cu92" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="-vcvars_ver=14.13" - powershell packaging/windows/internal/vs2017_install.ps1 -elif [[ "$CU_VERSION" == "cu100" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="" - powershell packaging/windows/internal/vs2017_install.ps1 -else - export VC_YEAR=2019 - export VSDEVCMD_ARGS="" -fi From 37c5b42fa6597ebf7dbfb6db4ada2c7803950555 Mon Sep 17 00:00:00 2001 From: Horace He Date: Fri, 11 Nov 2022 19:17:47 +0000 Subject: [PATCH 086/453] Fix matmul decomp to use reshape instead of contiguous().view() (#88832) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88832 Approved by: https://github.com/bertmaher, https://github.com/ngimel --- torch/_decomp/decompositions.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index fe63e0db007a..1a2d332e99fd 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2261,9 +2261,7 @@ def matmul(tensor1, tensor2): t2_is_matrix = t2.dim() == 2 if t2_is_matrix: output_shape.append(t2.shape[1]) - # HACK: We need reshape with symint support - t1 = t1.contiguous() - t1_folded = t1.view(folded_dim1, sizes_1[-1]) + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) if t2_is_matrix: # FIXME This path always does an unnecessary copy when transpose == True as the returned # result from BLAS is already C-transposed @@ -2296,15 +2294,11 @@ def matmul(tensor1, tensor2): expand_batch_product = prod(expand_batch_portion) # HACK: We need reshape with symint support - tensor1_expanded = ( - tensor1.expand(tensor1_expand_size) - .contiguous() - .view(expand_batch_product, n, m1) + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 ) - tensor2_expanded = ( - tensor2.expand(tensor2_expand_size) - .contiguous() - .view(expand_batch_product, m2, p) + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p ) output_shape = expand_batch_portion From 5ff600aa6e40c6b4d426594bbb1f446f005b7fb3 Mon Sep 17 00:00:00 2001 From: William Wen Date: Sat, 12 Nov 2022 00:22:25 +0000 Subject: [PATCH 087/453] Add comprehensive minifier tests (#88022) Adds tests for https://github.com/pytorch/torchdynamo/issues/1241. To run: `pytest test/dynamo/test_minifier.py`. Actually runs minifier launcher script and repro scripts, rather than just checking for existence of the minifier launcher script. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88022 Approved by: https://github.com/mlazos, https://github.com/anijain2305 --- test/dynamo/test_minifier.py | 630 +++++++++++++++++++++++++++++++---- torch/_dynamo/debug_utils.py | 78 ++++- 2 files changed, 632 insertions(+), 76 deletions(-) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 0cec7d202a9d..51b79a5e7511 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,27 +1,138 @@ # Owner(s): ["module: dynamo"] +import functools import os +import re import shutil +import subprocess +import textwrap import unittest -from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.optimizations.backends import create_backend +import torch._inductor.utils +from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT +_HAS_TRITON = torch._inductor.utils.has_triton() +requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() +RELU_COMPILE_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend - def forward(self, x): - for _ in range(10): - x = torch.sin(x) - x = torch._foobar(x) - for _ in range(10): - x = torch.cos(x) - return x +class DynamoCompileError(Exception): + pass + +@register_backend +def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise DynamoCompileError("relu found") + return gm +""" + +RELU_RUNTIME_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "DynamoRuntimeError") + gm.recompile() + return gm +""" + +RELU_ACCURACY_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm +""" + +RELU_CUSTOM_ERROR_BACKEND = """\ +class CustomError(Exception): + pass + +def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise CustomError("relu found") + return gm +""" + +CPP_COMPILE_ERROR = """\ +def cpp_compile_error(x): + return "compile error!" +""" + +CPP_RUNTIME_ERROR = """\ +def cpp_runtime_error(x): + return f"{x}; throw 1" +""" + +CPP_ACCURACY_ERROR = """\ +def cpp_accuracy_error(x): + return f"{x} + 1" +""" + +TRITON_COMPILE_ERROR = """\ +def triton_compile_error(x): + return "compile error!" +""" + +# NOTE: there is currently not an easy way to cause a triton runtime error. +TRITON_RUNTIME_ERROR = """\ +def triton_runtime_error(x): + return f"{x}; assert?" +""" + +TRITON_ACCURACY_ERROR = """\ +def triton_accuracy_error(x): + return f"{x} + 1" +""" + +DEBUG_DIR = "/tmp/_torchdynamo_debug_/" + +# Search for the name of the first function defined in a code string. +def get_fn_name(code): + fn_name_match = re.search(r"def (\w+)\(", code) + if fn_name_match is not None: + return fn_name_match.group(1) + return None + + +# Generates code that patches CppOverrides/TritonOverrides. +def gen_codegen_fn_patch_code(old_fn_name, new_fn_code, device): + new_fn_name = get_fn_name(new_fn_code) + if new_fn_name is not None: + patch_code = f"""\ +import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen +overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} +{new_fn_code} +overrides.{old_fn_name} = staticmethod({new_fn_name}) +""" + return f"""\ +{patch_code} +isolate_fails_code_str = \"\"\"\\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +\"\"\" +""" + + return None class MinfierTests(torch._dynamo.test_case.TestCase): @@ -32,9 +143,10 @@ def setUpClass(cls): unittest.mock.patch.object( torch._dynamo.config, "debug_dir_root", - "/tmp/_torchdynamo_debug_/", + DEBUG_DIR, ) ) + os.makedirs(DEBUG_DIR, exist_ok=True) @classmethod def tearDownClass(cls): @@ -47,65 +159,455 @@ def setUp(self): def tearDown(self): super().tearDown() - def test_after_dynamo(self): - @create_backend - def bad_dynamo_backend(subgraph): - import sys - - def f(*args): - # Shifted the forced exception to runtime as this is more common - # in JIT compilers. - for node in subgraph.model.graph.nodes: - if node.op == "call_function" and node.target is torch._foobar: - sys.stdout.write("Dynamo compiled failed\n") - raise NotImplementedError("foobar is not implemented") - return subgraph.model(*args) - - return f - - mod = MockModule() - opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - - @patch.object(torch._dynamo.config, "repro_after", "dynamo") - def inner(): - x = torch.randn(4) - try: - opt_mod(x) - except Exception: - pass - - inner() - self.assertTrue(os.path.exists(repro_file)) + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code): + proc = subprocess.run( + ["python3", "-c", code], capture_output=True, cwd=DEBUG_DIR + ) + + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + # Print repro directory for debugging generated code. + # Make sure to comment out `shutil.rmtree...` above as well. + print("repro dir:", repro_dir_match.group(1)) + return proc, repro_dir_match.group(1) + return proc, None - # If error_at_aot is True, an error will be produced when AOTAutograd - # attempts to generate the backward graph. - # If error_after_aot is False, an error will be produced in inductor. - def _test_around_aot(self, error_at_aot): - mod = MockModule() - opt_mod = torch._dynamo.optimize("inductor")(mod) + # Patch generated files with testing patches + def _inject_code(self, patch_code, filename): + patch_code = f"""\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +""" + with open(filename, "r") as f: + code = f.read() + code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + with open(filename, "w") as f: + f.write(code) + return code - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - repro_after = "dynamo" if error_at_aot else "aot" + # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. + def _run_minifier_launcher(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + launch_file = os.path.join(repro_dir, "minifier_launcher.py") + self.assertTrue(os.path.exists(launch_file)) + launch_code = self._inject_code(patch_code, launch_file) - @patch.object(torch._dynamo.config, "repro_after", repro_after) - def inner(): - x = torch.randn(4) - x.requires_grad = error_at_aot - try: - opt_mod(x) - except Exception: - pass + launch_proc = subprocess.run( + ["python3", launch_file], + capture_output=True, + cwd=repro_dir, + ) - inner() + return launch_proc, launch_code + # Runs the repro script in `repro_dir`, patched with `patch_code` + def _run_repro(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + repro_file = os.path.join(repro_dir, "repro.py") self.assertTrue(os.path.exists(repro_file)) + repro_code = self._inject_code(patch_code, repro_file) + + repro_proc = subprocess.run( + ["python3", repro_file], capture_output=True, cwd=repro_dir + ) + + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file. + def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): + return f"""\ +import torch +import torch._dynamo +{patch_code} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code (in a separate process since it could segfault) + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + def _run_full_test(self, run_code, repro_after, repro_level, patch_code): + test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) + test_proc, repro_dir = self._run_test_code(test_code) + self.assertIsNotNone(repro_dir) + launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) + repro_proc, repro_code = self._run_repro(patch_code, repro_dir) + return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) + + # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) + def _test_after_dynamo(self, device, repro_level, backend_code, error_name): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "dynamo", repro_level, backend_code + ) + + self.assertIn(error_name, test_proc.stderr.decode("utf-8")) + self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) + + def test_after_dynamo_cpu_compile_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + def test_after_dynamo_cpu_runtime_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + def test_after_dynamo_cpu_accuracy_error(self): + self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + @requires_cuda() + def test_after_dynamo_cuda_compile_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_error(self): + self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + # Ensure that the testing backends pass when relu is not present. + def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_dynamo_cpu_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) + + def test_after_dynamo_cpu_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) + + def test_after_dynamo_cpu_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + + # Ensure that generated code with a custom backends generates a runnable minifier + # launcher script that results in a RuntimeError + def test_after_dynamo_custom_backend(self): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize({get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + test_code = self._gen_test_code( + run_code, "dynamo", 2, RELU_CUSTOM_ERROR_BACKEND + ) + _, repro_dir = self._run_test_code(test_code) + launch_proc, launch_code = self._run_minifier_launcher("", repro_dir) + self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) + + # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @requires_cuda() + def test_cpu_cuda_module_after_dynamo(self): + backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + class CpuCudaModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.m_x = torch.nn.Linear(20, 20).cuda() + self.m_y = torch.nn.Linear(20, 20) + self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) + self.p_y = torch.nn.Parameter(torch.randn(20, 20)) + self.register_buffer("b_x", torch.ones(20, 20).cuda()) + self.register_buffer("b_y", torch.ones(20, 20)) + + def forward(self, x, y): + return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y + + mod = CpuCudaModule() + + @torch._dynamo.optimize("{backend_name}") + def inner(x1, y1): + x2 = torch.randn(20, 20).cuda() + y2 = torch.randn(20, 20) + x3, y3 = mod(x1 + x2, y1 + y2) + return torch.relu(x3.cpu() + y3) + + inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + # Check if generated minifier code covers all cpu/cuda cases + self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) + self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) + # search for Linear(...).cuda() + self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) + # search for Linear(...) + self.assertIsNotNone( + re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + # search for + # = torch.randn(...) + # ... = .cuda() + self.assertIsNotNone( + re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) + ) + # search for + # = torch.randn(...) + # no followup call to .cuda() + self.assertIsNotNone( + re.search( + r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL + ) + ) + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # Test if we can actually get a minified graph + def test_if_graph_minified(self): + backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{backend_name}") + def inner(x): + for _ in range(20): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(20): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # compare the length of the forward functions + match = re.search(r"def forward.*return", launch_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertGreater(match.group(0).count("\n"), 40) + + match = re.search(r"def forward.*return", repro_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertLess(match.group(0).count("\n"), 5) + + # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) + def _test_after_aot(self, device, backend_code, repro_level): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", repro_level, patch_code + ) + return ( + (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), + (test_proc.returncode, repro_proc.returncode), + ) + + def test_after_aot_cpu_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) + self.assertIn("CppCompileError", tb1) + self.assertIn("CppCompileError", tb2) + + def test_after_aot_cpu_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + @requires_cuda() + def test_after_aot_cuda_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) + self.assertIn("SyntaxError", tb1) + self.assertIn("SyntaxError", tb2) + + @requires_cuda() + def test_after_aot_cuda_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + # Test that runtime errors after aot can be repro'd (CPU only for now) + def _test_after_aot_runtime_error(self, device, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", 3, patch_code + ) + + self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) + + self.assertEqual(test_proc.returncode, repro_proc.returncode) + self.assertNotEqual(test_proc.returncode, 0) + + def test_after_aot_cpu_runtime_error(self): + self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_error(self): + self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) + + # Ensure that inductor codegen patches pass when relu is not present. + def _test_after_aot_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_aot_cpu_compile_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) + + def test_after_aot_cpu_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) + + def test_after_aot_cpu_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) + + @requires_cuda() + def test_after_aot_cuda_compile_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) - def test_at_aot(self): - self._test_around_aot(True) + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) - def test_after_aot(self): - self._test_around_aot(False) + @requires_cuda() + def test_after_aot_cuda_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) if __name__ == "__main__": diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index f09991f9bf34..98a269fe8c9e 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -84,6 +84,11 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -95,12 +100,16 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" + if param.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -145,6 +154,9 @@ def _cuda_system_info_comment(): return model_str +TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" + + def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -155,6 +167,8 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx + {TEST_REPLACEABLE_COMMENT} + """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -170,7 +184,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' + model_str += "mod = make_fx(Repro())(*args)\n" return model_str @@ -197,7 +211,8 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - repro_path = os.path.join(config.base_dir, "repro.py") + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -216,7 +231,10 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" + class AccuracyError(Exception): + pass + if not same_two_models(mod, compiled, args, only_fwd=True): + raise AccuracyError("Bad accuracy detected") """ ) ) @@ -231,7 +249,7 @@ def save_graph_repro(fd, gm, args, compiler_name): ) -def isolate_fails(fx_g, args, compiler_name: str, env=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -239,7 +257,10 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - fd.write(generate_compiler_repro_string(fx_g, args)) + repro_code = generate_compiler_repro_string(fx_g, args) + if patch_code is not None: + repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + fd.write(repro_code) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -263,6 +284,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], + cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -329,6 +351,8 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" +isolate_fails_code_str = None + {generate_compiler_repro_string(gm, args)} from functools import partial @@ -343,7 +367,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -351,6 +375,10 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) +class AccuracyError(Exception): + pass + + def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -410,7 +438,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise ValueError("Bad accuracy detected") + raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -435,7 +463,8 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - raise e + log.error("CompilerError") + raise if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -544,9 +573,14 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() + +class AccuracyError(Exception): + pass + with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - assert same_two_models(mod, opt_mod, args), "Dynamo failed" + if not same_two_models(mod, opt_mod, args): + raise AccuracyError("Dynamo failed") """ ) @@ -561,12 +595,14 @@ def generate_dynamo_fx_repro_string( from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -705,6 +741,21 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" + custom_compiler_error = ( + textwrap.dedent( + """\ + raise RuntimeError( + 'Compiler name is None - this likely means that a custom compiler ' + 'was called by torchdynamo. Please remove this error, import your ' + 'custom compiler function, and replace the compiler_name="None" ' + 'line below to compiler_name=' + ) + """ + ) + if compiler_name is None + else "" + ) + contents = textwrap.dedent( f""" import os @@ -718,14 +769,17 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] +{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -769,7 +823,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc = ValueError("Bad accuracy detected.") + exc = AccuracyError("Bad accuracy detected.") exc.minifier_path = os.path.join( minifier_dir(), "minifier_launcher.py" ) From a3f3ec8fac98151f31373ba3bcfe2d601584a840 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 11 Nov 2022 21:22:49 +0000 Subject: [PATCH 088/453] [FSDP+dynamo]: forward treats parameter-views as params (#88781) Dynamo+AotAutograd needs a way to wrap all tensors (whether inputs or params/buffers) in FakeTensor wrappers, and FSDP's mangling of parameters hides them from this wrapping. This PR unblocks running hf_bert and hf_T5 with FSDP under dynamo, whether using recursive wrapping around transformer layers or only applying FSDP around the whole model. Perf/memory validation and possibly optimization is the next step. `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_Bert --fsdp --dynamo aot_eager --fsdp_wrap` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager` `python benchmarks/dynamo/distributed.py --torchbench_model hf_T5 --fsdp --dynamo aot_eager --fsdp_wrap` The problem: Dynamo (Actually aot_autograd) trips up with FSDP becuase it must wrap all input tensors in FakeTensor wrappers, and it only knows to wrap graph inputs or named_(parameters, buffers). FSDP's pre_forward hook sets views (which are not nn.param) into the flatparam as attrs on the module with the same name as the original param, but they will not show up in named_parameters. - in use_orig_params mode, FSDP still de-registers params during pre-forward hook, then re-registers them post-forward - during forward (between the hooks), the params are setattr'd on the module as regular view tensors, not nn.Parameters - note: use_orig_params is the recommended way to use FSDP, and use_orig_params=False is being deprecated. So i only consider use_orig_params=True for this enablement The solution: - adding them to named_buffers is not possible because it interferes with how FSDP's `_apply` works - since they are not actual nn.parameters, register_parameter will complain about registering them - simply seting `module._parameters[name] = view` seems to be a viable workaround, despite being hacky, and FSDP code does modify _parameters directly already. Note: Manual checkpointing still isn't working with FSDP+dynamo, so that will have to be addressed in a follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88781 Approved by: https://github.com/ezyang, https://github.com/awgu --- benchmarks/dynamo/dist_util.py | 20 +-- benchmarks/dynamo/distributed.py | 5 +- test/distributed/test_dynamo_distributed.py | 131 ++++++++++++++++---- torch/distributed/fsdp/flat_param.py | 4 + 4 files changed, 124 insertions(+), 36 deletions(-) diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index 9e2f086ca8b7..d30b5a63cfe5 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -20,6 +20,9 @@ except ImportError: from torchbench import setup_torchbench_cwd +from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead +from transformers.models.t5.modeling_t5 import T5Block + def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -122,26 +125,25 @@ def check_fn(submodule): ) -# from transformers.models.t5.modeling_t5 import T5Block - MODEL_FSDP_WRAP = { - ToyModel: (MyModule,) - # TODO T5: (T5Block,) + "toy_model": (MyModule,), + "hf_Bert": (BertLayer, BertLMPredictionHead), + "hf_T5": (T5Block,), } -def apply_fsdp(model, use_checkpointing=False, use_wrap_policy=True): - blocks = MODEL_FSDP_WRAP[model.__class__] - +def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): wrap_policy = None + blocks = MODEL_FSDP_WRAP[ + "toy_model" if model.__class__ is ToyModel else args.torchbench_model + ] if use_wrap_policy: # transformer policy is really a generic policy that wraps modules of specified classes wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=blocks ) - model = FSDP(model, auto_wrap_policy=wrap_policy) + model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) if use_checkpointing: fsdp_checkpointing_base(model, blocks) - return model diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py index c2db15563348..32e3b544d87d 100644 --- a/benchmarks/dynamo/distributed.py +++ b/benchmarks/dynamo/distributed.py @@ -50,6 +50,7 @@ def move_tensor(maybe_tensor): if args.fsdp: model = apply_fsdp( + args, model, use_checkpointing=args.fsdp_checkpoint, use_wrap_policy=args.fsdp_wrap, @@ -160,7 +161,9 @@ def experiment(fn, key, world_size, results): ) args = parser.parse_args() - model_name = "ToyModel" if args.toy_model else args.torchbench_model + model_name = args.torchbench_model + if args.toy_model: + model_name = "ToyModel" model, inputs = get_model(args) fn = partial(run_model, args, model, inputs) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 3dd3c5de7725..b6bc16edb941 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1,4 +1,6 @@ # Owner(s): ["module: dynamo"] +import copy +import functools import logging import os import random @@ -16,7 +18,9 @@ from torch._dynamo.utils import same from torch._dynamo.testing import collect_results from torch._inductor.utils import has_triton +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import ( MultiProcessTestCase, import_transformers_or_skip, @@ -175,6 +179,7 @@ def test_ddp_baseline_aot_eager_multiprocess(self): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp(self): @@ -199,6 +204,106 @@ def test_hf_bert_ddp(self): opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) + + @skip_if_lt_x_gpu(1) + # TODO(whc) delete aot_eager test, if inductor test lands stably + def test_fsdp_aot_eager(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @skip_if_lt_x_gpu(1) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_fsdp_inductor(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert + @patch.object(torch._inductor.config.triton, "cudagraphs", False) + @patch.object(torch._inductor.config, "fallback_random", True) + def test_hf_bert_fsdp(self): + from transformers.models.bert.modeling_bert import BertLayer + + def apply_fsdp(model, wrap_policy): + model = FSDP( + copy.deepcopy(model), + auto_wrap_policy=wrap_policy, + use_orig_params=True + ) + return model + + with _per_rank_init(self.rank, self.world_size): + for (wrap_policy, test_instance) in ( + ( + None, + "FSDP without recursive wrapping" + ), + ( + functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, ) + ), + "FSDP with recursive wrapping BertLayer instances" + ) + ): + print(f"Running hf_bert test for {test_instance}") + model, inputs = get_hf_bert(self.rank) + reset_rng_state() + eager_model = apply_fsdp(model, wrap_policy) + correct_outputs = eager_model(**inputs) + correct_loss = correct_outputs.loss + correct_loss.backward() + + reset_rng_state() + opt_model = apply_fsdp(model, wrap_policy) + + opt_model = torch._dynamo.optimize("inductor")(opt_model) + opt_outputs = opt_model(**inputs) + opt_loss = opt_outputs.loss + opt_loss.backward() + + inputs_flat = [inputs[k] for k in inputs] + correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat) + opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) + self.assertTrue(same(correct_results, opt_results)) + + @requires_nccl() class TestDistributed(torch._dynamo.test_case.TestCase): """ @@ -257,32 +362,6 @@ def test_ddp_baseline_inductor(self): outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - # TODO(whc) move these tests to 'distributed' shard to get nccl, or see if it's available already in pytorch CI? - @unittest.skip( - "can't run with gloo (no support for _allgather_base) and nccl not available in CI" - ) - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_aot_eager(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @unittest.skip("hangs/crashes with inductor currently") - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_inductor(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - @patch.object(config, "optimize_ddp", True) def test_graph_split(self): """ diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 0978f0875a28..b790590c7943 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -1306,6 +1306,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: assert tensor is not None # mypy param_var = tensor setattr(module, param_name, param_var) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = param_var # type: ignore[assignment] for i, ( param_name, module, @@ -1336,6 +1338,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: module.register_parameter(param_name, prim_param) else: setattr(module, param_name, prim_param) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = prim_param # type: ignore[assignment] def _use_unsharded_grad_views(self) -> None: """ From 2cd05a2818bacbc2e252052b6b71085e4de16b0d Mon Sep 17 00:00:00 2001 From: Jiaxu Zhu Date: Sat, 12 Nov 2022 01:20:52 +0000 Subject: [PATCH 089/453] Support torch.qint32 in Convert (#88871) Enable the `torch.qint32` when creating `quantize_per_tensor` function call in `convert_fx` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88871 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/fx/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 61bb2cdc1b03..a5a989ec2148 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -183,7 +183,7 @@ def get_quantize_node_info( if hasattr(activation_post_process, "compute_dtype"): compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] quantize_op : Optional[Union[Callable, str]] = None - if dtype in [torch.quint8, torch.qint8] and \ + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ not hasattr(activation_post_process, 'compute_dtype'): node_type = "call_function" scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] From 2b166532f7ac280232daf6c44620e96e258867cf Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 11 Nov 2022 09:00:55 -0500 Subject: [PATCH 090/453] Remove incorrect assert about hermetic state. (#88885) I'm not sure why I thought this assert was valid in the first place, and there's no comment about it. The assert is tantamount to saying, "no tensor objects should become dead via SafePyObject when hermetic mode is on." But suppose we run a Python GC while we're inside hermetic mode. This could result in us disposing non-hermetic tensors, which would hit decref. So the assert seems invalid. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88885 Approved by: https://github.com/anjali411, https://github.com/malfet --- torch/csrc/autograd/python_variable.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 920d0e7344b5..002b904d4072 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -305,10 +305,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) // THPVariable_clear). // 2. We are decref-ing some other Python object. We don't do // PyObject resurrection on non-Tensors, so we just carry on as usual - if (is_tensor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !c10::impl::HermeticPyObjectTLS::get_state()); - } if (is_tensor && Py_REFCNT(pyobj) > 1) { // It's still alive! This can happen if a weak ref resurrected // the PyObject without flipping ownership. At this point it is From 66736ff425d7163df0eed48e9944c8539e92b577 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 11 Nov 2022 09:33:41 -0500 Subject: [PATCH 091/453] Fix bug in OptionalTensorList (#88887) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88887 Approved by: https://github.com/anjali411 --- aten/src/ATen/core/PythonFallbackKernel.cpp | 5 ++++- test/test_python_dispatch.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index e16874a83f96..2d8834afe59e 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -74,10 +74,13 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { (*interpreter)->dispatch(op, stack); return; } - } else if (ivalue.isTensorList() || (ivalue.isOptionalTensorList() && !ivalue.isNone())) { + } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef // is not a thing) for (const auto& nv : ivalue.toListRef()) { + if (nv.isNone()) { + continue; + } auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { (*interpreter)->dispatch(op, stack); diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 380f85f568f7..33465217bbbc 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -390,6 +390,24 @@ def test_produce_real_type(self) -> None: $4 = torch._ops.aten.select.int($3, 1, 1) $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') + def test_optional_tensor_list(self) -> None: + def weird(xs): + print("woof") + return torch.empty(()) + + my_lib = Library("my_lib", "DEF") + my_lib.define("weird(Tensor?[] self) -> Tensor") + my_lib.impl("weird", weird, "CPU") + with capture_logs() as logs: + x = LoggingTensor(torch.ones(2, 2)) + log_input("x", x) + torch.ops.my_lib.weird.default([None, x]) + + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], + [1., 1.]]))])''') + def test_list_ret(self) -> None: # test all sequence types are permissible returns for list_type in (list, tuple): From 1e2327baf7a2d9c63bef08e5f996ef983e199429 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Sat, 12 Nov 2022 02:23:48 +0000 Subject: [PATCH 092/453] fix fx tests (#88886) Summary: Some source files are missing and TPX couldn't handle the default test names. Test Plan: Rely on CI. Differential Revision: D41218564 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88886 Approved by: https://github.com/zou3519 --- test/fx/test_common_passes.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py index 9c59abce4da6..407e707db879 100644 --- a/test/fx/test_common_passes.py +++ b/test/fx/test_common_passes.py @@ -73,10 +73,15 @@ def MutationMetadata(x): if torch.cuda.is_available(): Devices.append("cuda") + +def name_fn(common_pass, f, device): + """Names parameterized test cases.""" + return f'{type(common_pass()).__name__}_{f.__name__}_{device}' + @instantiate_parametrized_tests class TestCommonPass(TestCase): - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn) def test_correctness(self, common_pass, f, device): inp = torch.randn(10, device=device) @@ -94,7 +99,7 @@ def test_correctness(self, common_pass, f, device): self.assertEqual(result, expected) - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn) def test_correctness_factory(self, common_pass, f, device): inp = torch.randn(10, device=device) traced_m = make_fx(f)(inp, device) From 4108367123c1b44289b5c731c3bb7022976b816d Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 11 Nov 2022 20:41:36 +0000 Subject: [PATCH 093/453] Exclude poolformer_m36 from the inductor model test (#88908) Summary: The root cause is still to be investigated. Issue tracked at https://github.com/pytorch/torchdynamo/issues/1856 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88908 Approved by: https://github.com/malfet --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 758f4396b5b1..198877e0313d 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -156,6 +156,7 @@ "hrnet_w18", # accuracy "lcnet_0500", # accuracy "levit_128", # levit_128 + "poolformer_m36", "rexnet_100", # accuracy "swin_base_patch4_window7_224", "twins_pcpvt_base", # time out From ae4074669ecbf2a6d8bf99db745d29dce98d0c10 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 10 Nov 2022 21:19:22 +0000 Subject: [PATCH 094/453] [FSDP][state_dict][6/N] Remove most FSDP module dependency from _optim_utils (#88638) **What** This PR removes most `FullyShardedDataParallel` dependencies from `optim_utils`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88638 Approved by: https://github.com/awgu --- torch/distributed/fsdp/_optim_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 530a8480d552..70fb4156d537 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -22,9 +22,11 @@ import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp._common_utils import _get_param_to_fqns from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor -from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed +from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init from torch.distributed.fsdp._shard_utils import _gather_state_dict +from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle @@ -185,7 +187,7 @@ def _communicate_optim_state( # we take the target rank's value if ( fsdp_module.world_size == 1 - or fsdp_module.sharding_strategy == fsdp_file.ShardingStrategy.NO_SHARD + or fsdp_module.sharding_strategy == ShardingStrategy.NO_SHARD ): tensor_state[state_name] = value continue @@ -293,7 +295,7 @@ def _flatten_optim_state_dict( '"param_groups" to be a valid optimizer state dict' ) flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # Construct the "state" part flat_osd_state: Dict[_OptimStateKey, Any] = {} @@ -897,7 +899,7 @@ def _rekey_sharded_optim_state_dict( if using_optim_input else _get_param_to_param_id(optim) ) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # All parameter keys in `param_to_flat_param_id` should be in # `param_to_fqns` -- strict inequality follows when not all parameters are # passed to the optimizer @@ -951,7 +953,7 @@ def _get_flat_param_to_fsdp_module(model: torch.nn.Module): flat_param_to_fsdp_module = {} for module in model.modules(): if isinstance(module, fsdp_file.FullyShardedDataParallel): - fsdp_file._lazy_init(module, module) + _lazy_init(module, module) for param in module.params: # may have none flat_param_to_fsdp_module[param] = module return flat_param_to_fsdp_module @@ -1165,9 +1167,7 @@ def _optim_state_dict( # Construct the local mapping between unflattened parameter names # (`_OptimStateKey`s) and parameter IDs and broadcast rank 0's mapping - param_to_fqns: Dict[torch.nn.Parameter, List[str]] = fsdp_file._get_param_to_fqns( - model - ) + param_to_fqns: Dict[torch.nn.Parameter, List[str]] = _get_param_to_fqns(model) flat_param_id_to_param: List[torch.nn.Parameter] = ( _get_param_id_to_param_from_optim_input(model, optim_input) if using_optim_input From b2b0a0d3baf6258fbf728572719937810fd890ce Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 03:21:06 +0000 Subject: [PATCH 095/453] [vision hash update] update the pinned vision hash (#88920) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88920 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 48685938a146..b9eda365de0c 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -d72e90640ec8514e0369b5419d7f3b74a387b1d7 +deba056203d009fec6b58afb9fa211f6ee3328c8 From d01bf1d1f11ab1fb9ae21a007138e2c4ecc31b63 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Sat, 12 Nov 2022 01:05:46 +0000 Subject: [PATCH 096/453] [FSDP] Introduce `ModuleWrapPolicy` for simplicity (#88450) **BC Breaking Change** This renames `unwrapped_params` to `nonwrapped_numel`. I prefer `nonwrapped` over `unwrapped` because "unwrap" suggests that some wrapping has been undone. I prefer `numel` over `params` because that is unit of measurement; I think we should keep "params" to refer to `nn.Parameter`s themselves. This only breaks anything that passes `unwrapped_params` as a keyword argument, but I did not see anything that did that (except the one internal benchmark file but that does not actually depend on our `pytorch` code). In a follow-up, I want to rename `min_num_params` to `min_nonwrapped_numel` in `size_based_auto_wrap_policy`, which is also BC breaking. Again, this is to differentiate between "params" being `nn.Parameter`s and "numel" being the unit for `param.numel()`. **Overview** This PR introduces `ModuleWrapPolicy` as a lightweight layer over the existing `transformer_auto_wrap_policy`. The most common auto wrapping paradigm is: ``` module_classes: Set[Type[nn.Module]] = ... auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=module_classes, ) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` Now, users can instead write: ``` auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_model = FSDP(model, auto_wrap_policy=auto_wrap_policy, ...) ``` This hides the unused arguments expected from the callable (`recurse` and `unwrapped_params`/`nonwrapped_numel`). `ModuleWrapPolicy` inherits from an abstract base class `FSDPPolicy` that expects a `policy` property. This decouples the construct of such `FSDPPolicy` classes and their actual `policy`, which must abide by the `_recursive_wrap` interface. Any existing auto wrap policy can be rewritten as a class that inherits from `FSDPPolicy`, so this approach is fully backward compatible from a functionality perspective. I call this base class `FSDPPolicy` to generalize over the cases where we may not want to actually perform any nested wrapping. In reality, the policy is meant for constructing `FlatParameter`s, which just happened to be induced by a nested wrapping before. Given this, I am changing the constructor argument in `fully_shard()` to simply `policy` instead of `auto_wrap_policy`. This PR migrates usages of `transformer_auto_wrap_policy` within our unit test suite to `ModuleWrapPolicy` as much as possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88450 Approved by: https://github.com/zhaojuanmao --- .../_composable/test_fully_shard.py | 27 +- .../fsdp/test_fsdp_clip_grad_norm.py | 10 +- test/distributed/fsdp/test_fsdp_misc.py | 22 +- test/distributed/fsdp/test_fsdp_state_dict.py | 12 +- .../fsdp/test_fsdp_use_orig_params.py | 9 +- test/distributed/fsdp/test_utils.py | 7 +- test/distributed/fsdp/test_wrap.py | 16 + torch/distributed/_composable/fully_shard.py | 8 +- torch/distributed/fsdp/__init__.py | 1 - torch/distributed/fsdp/_init_utils.py | 5 +- torch/distributed/fsdp/_wrap_utils.py | 17 +- torch/distributed/fsdp/flat_param.py | 3 +- .../fsdp/fully_sharded_data_parallel.py | 155 ++-------- torch/distributed/fsdp/wrap.py | 288 ++++++++---------- torch/testing/_internal/common_fsdp.py | 20 +- 15 files changed, 244 insertions(+), 356 deletions(-) diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py index 27e0fb855fba..ba08deeafcdf 100644 --- a/test/distributed/_composable/test_fully_shard.py +++ b/test/distributed/_composable/test_fully_shard.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import copy -import functools import sys from typing import Any, Tuple @@ -12,7 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from torch.distributed.fsdp._runtime_utils import _root_pre_forward -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -62,10 +61,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return z @staticmethod - def auto_wrap_policy(): - return functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={SubModel} - ) + def policy(): + return ModuleWrapPolicy({SubModel}) def get_input(self, device=torch.device) -> Tuple[Any, ...]: return (torch.randn((8, 5), device=device),) @@ -85,13 +82,13 @@ def test_auto_wrap_policy(self): local_model = Model(device=torch.device("cuda")) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) # Check that the composable module has the same names as the local @@ -138,7 +135,7 @@ def test_device_id(self): assert param.device == cpu_device fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), device_id=self.rank, ) for param in composable_module.parameters(): @@ -157,12 +154,12 @@ def test_sync_module_states(self): param.zero_() fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), sync_module_states=True, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -197,13 +194,13 @@ def _param_init_fn(module: nn.Module): composable_module = Model(device="meta") fsdp_wrapped_model = FSDP( Model(device="meta"), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), param_init_fn=_param_init_fn, use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), param_init_fn=_param_init_fn, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -227,13 +224,13 @@ def test_training(self): local_model = Model(device=device) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) del local_model # not needed anymore LR = 1e-2 diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index ddba50a9e456..e587065c5c77 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from typing import Union @@ -12,7 +11,7 @@ CPUOffload, FullyShardedDataParallel as FSDP, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -102,12 +101,11 @@ def _test_ddp_parity( ) ddp_model = DDP(local_model, device_ids=[self.rank]) fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "cpu_offload": CPUOffload(offload_params=offload_params), "use_orig_params": use_orig_params, diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 79ed6da6240f..8c972f851563 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -15,7 +15,11 @@ FullyShardedDataParallel as FSDP, ShardingStrategy, ) -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + always_wrap_policy, + ModuleWrapPolicy, + transformer_auto_wrap_policy, +) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -211,10 +215,20 @@ def forward(self, x, y): def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + self.run_subtests( + {"use_callable": [False, True]}, + self._test_device_id_auto_wrap, ) + + def _test_device_id_auto_wrap(self, use_callable: bool): + module_classes = {TransformerEncoderLayer, TransformerDecoderLayer} + if use_callable: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=module_classes, + ) + else: + auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index ba51ae66ed1b..6fafc8e8fdf4 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -26,7 +26,7 @@ ) from torch.distributed.fsdp._shard_utils import _gather_state_dict from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM -from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap +from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD @@ -350,9 +350,8 @@ def test_state_dict_with_manual_ac_wrapper( @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_with_shared_parameters(self, state_dict_type): - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) model_creator = partial( TransformerWithSharedParams.init, @@ -377,9 +376,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 24829ff408d9..0f5ffa564c2d 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -15,7 +15,7 @@ ShardingStrategy, ) from torch.distributed.fsdp._common_utils import clean_tensor_name -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -117,12 +117,11 @@ def _get_fsdp_transformer_and_optim( # combination with the parameter group construction, ensures different # hyperparameter settings within one `FlatParameter` fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "use_orig_params": True, "sharding_strategy": sharding_strategy, diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index e797325ccbc9..37c52547e847 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import random import sys import unittest @@ -14,7 +13,7 @@ from torch import distributed as dist from torch.distributed.fsdp._utils import _apply_to_tensors from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.utils import _replace_by_prefix from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -173,9 +172,7 @@ def test_module_wrap_policy(self): # Compute the mapping from submodule to states according to a logical # module wrap policy module_classes = (nn.Sequential,) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=set(module_classes) - ) + auto_wrap_policy = ModuleWrapPolicy(set(module_classes)) submodule_to_states = _get_submodule_to_states( model, auto_wrap_policy, set(), set() ) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index cd0d11ba9b4b..e157f041ae1b 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -5,6 +5,7 @@ import tempfile import unittest from enum import auto, Enum +from typing import Callable, Union import torch import torch.nn as nn @@ -15,10 +16,12 @@ FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _wrap_batchnorm_individually, always_wrap_policy, enable_wrap, + ModuleWrapPolicy, size_based_auto_wrap_policy, transformer_auto_wrap_policy, wrap, @@ -373,6 +376,19 @@ def test_transformer_auto_wrap_policy(self): transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) + self._test_transformer_wrapping(auto_wrap_policy) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") + def test_module_wrap_policy(self): + """Tests the ``ModuleWrapPolicy``.""" + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} + ) + self._test_transformer_wrapping(auto_wrap_policy) + + def _test_transformer_wrapping( + self, auto_wrap_policy: Union[Callable, _FSDPPolicy] + ): fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 2d9e9329795b..174b2ca89a78 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -24,6 +24,7 @@ MixedPrecision, ShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy @contract @@ -32,7 +33,7 @@ def fully_shard( process_group: Optional[dist.ProcessGroup] = None, mixed_precision: Optional[MixedPrecision] = None, cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Callable] = None, + policy: Optional[_FSDPPolicy] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, device_id: Optional[Union[int, torch.device]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -41,6 +42,9 @@ def fully_shard( """ Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. """ + # Enforce the new auto wrap policy + if policy is not None and not isinstance(policy, _FSDPPolicy): + raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}") state = fully_shard.state(module) state = _init_ignored_module_states(state, module, ignored_modules) state = _init_process_group_state(state, process_group) @@ -64,7 +68,7 @@ def fully_shard( state = _init_param_handles_from_module( state, module, - auto_wrap_policy, + policy, device_id, param_init_fn, sync_module_states, diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index 324a3442dea9..b1bffdb25a0e 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -11,4 +11,3 @@ ShardingStrategy, StateDictType, ) -from .wrap import ParamExecOrderWrapPolicy diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 1265ee3578ed..7e128251fcc4 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -47,6 +47,7 @@ HandleConfig, HandleShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy from torch.distributed.utils import _sync_params_and_buffers from torch.utils.hooks import RemovableHandle @@ -262,7 +263,7 @@ def _init_param_handle_from_module( def _init_param_handles_from_module( state: _FSDPState, root_module: nn.Module, - auto_wrap_policy: Callable, + policy: _FSDPPolicy, device_id: Optional[Union[int, torch.device]], param_init_fn: Optional[Callable[[nn.Module], None]], sync_module_states: bool, @@ -273,7 +274,7 @@ def _init_param_handles_from_module( """ submodule_to_states = _get_submodule_to_states( root_module, - auto_wrap_policy, + policy, state._ignored_modules, state._ignored_params, ) diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 34d1c9c1ac24..cdda065df199 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -1,7 +1,7 @@ import collections import functools import warnings -from typing import Any, Callable, Deque, Dict, List, NamedTuple, Set, Tuple +from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple import torch import torch.nn as nn @@ -10,6 +10,7 @@ _override_batchnorm_mixed_precision, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _recursive_wrap, _wrap_batchnorm_individually, @@ -45,6 +46,9 @@ def _auto_wrap( ``fsdp_kwargs`` contains all FSDP arguments except ``module``. """ auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + # Support new way to pass an auto wrap policy + if isinstance(auto_wrap_policy, _FSDPPolicy): + auto_wrap_policy = auto_wrap_policy.policy root_module = auto_wrap_kwargs["module"] assert auto_wrap_policy is not None # For auto wrapping, submodules should not already be wrapped with FSDP @@ -68,13 +72,13 @@ def _auto_wrap( "instances with mixed precision disabled since some batch norm " "kernels do not support low precision." ) - auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) def _get_submodule_to_states( root_module: nn.Module, - auto_wrap_policy: Callable, + auto_wrap_policy: _FSDPPolicy, ignored_modules: Set[nn.Module], ignored_params: Set[nn.Parameter], ) -> Dict[nn.Module, SubmoduleState]: @@ -99,7 +103,7 @@ def _get_submodule_to_states( wrapper_cls = functools.partial(_record_module_wrapper_cls, wrapped_modules) _recursive_wrap( root_module, - auto_wrap_policy=auto_wrap_policy, + auto_wrap_policy=auto_wrap_policy.policy, wrapper_cls=wrapper_cls, ignored_modules=ignored_modules, ignored_params=ignored_params, @@ -158,8 +162,9 @@ def _record_module_wrapper_cls( **kwargs, ) -> nn.Module: """ - This defines a wrapper class to be passed to ``_recursive_wrap()`` that - records the wrapped module to the input ``wrapped_modules``. + This defines a pseudo-wrapper class to be passed to ``_recursive_wrap()`` + that records the wrapped module to the input ``wrapped_modules`` without + actually wrapping with a class. """ wrapped_modules.append(module) return module diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index b790590c7943..b5892bca683a 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -838,7 +838,8 @@ def needs_unshard(self) -> bool: return False unsharded_flat_param = self._get_padded_unsharded_flat_param() already_unsharded = ( - unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel() + unsharded_flat_param._typed_storage()._size() + == unsharded_flat_param.numel() ) return not already_unsharded diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 510f90de2023..69c8dd92ed8d 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -96,14 +96,6 @@ ) from ._utils import p_assert from .flat_param import FlatParameter, FlatParamHandle -from .wrap import ParamExecOrderWrapPolicy - - -_TORCH_FX_AVAIL = True -if not hasattr(torch, "fx"): - _TORCH_FX_AVAIL = False -if _TORCH_FX_AVAIL: - from ._symbolic_trace import _init_execution_info, _patch_tracer, TracingConfig __all__ = [ @@ -207,37 +199,36 @@ class FullyShardedDataParallel(nn.Module): This configures CPU offloading. If this is set to ``None``, then no CPU offloading happens. See :class:`CPUOffload` for details. (Default: ``None``) - auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): - A callable specifying a policy to recursively wrap layers with FSDP. - Note that this policy currently will only apply to child modules of - the passed in module. The remainder modules are always wrapped in - the returned FSDP root instance. - ``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is - an example of ``auto_wrap_policy`` callable, this policy wraps layers - with the number of parameters larger than 100M. ``transformer_auto_wrap_policy`` - written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` - callable for transformer-like model architectures. Users can supply the customized - ``auto_wrap_policy`` callable that should accept following arguments: - ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return - a ``bool`` specifying whether the passed in ``module``` should be wrapped - (if ``recurse=False``) or whether we should recurse down the subgraph of ``module`` - children (if ``recurse=True``). Extra customized arguments could be added to - the customized ``auto_wrap_policy`` callable as well. It is a good practice to - print out the sharded model and check whether the sharded model is what - the application wants and then adjust accordingly. + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]): + This is either ``None``, an ``_FSDPPolicy``, or a callable of + a fixed signature. If it is ``None``, then ``module`` is wrapped + with only a top-level FSDP instance without any nested wrapping. If + it is an ``_FSDPPolicy``, then the wrapping follows the given + policy. ``ModuleWrapPolicy`` in ``torch.distributed.fsdp.wrap.py`` + is an example. If it is a callable, then it should take in three + arguments ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should be wrapped if + ``recurse=False`` or if the traversal should continue down the + subtree if ``recurse=True``. Additional custom arguments may be + added to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + wraps a module if the parameters in its subtree exceed 100M numel. + A good practice is to print the model after wrapping and adjust as + needed. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, - >>> unwrapped_params: int, - >>> # These are customizable for this policy function. + >>> nonwrapped_numel: int, + >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: - >>> return unwrapped_params >= min_num_params - >>> # Configure a custom min_num_params - >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) backward_prefetch (Optional[BackwardPrefetch]): This configures explicit backward prefetching of all-gathers. See @@ -337,25 +328,6 @@ def __init__( limit_all_gathers: bool = False, use_orig_params: bool = False, ): - if isinstance(auto_wrap_policy, ParamExecOrderWrapPolicy): - self._init_param_exec_order_wrap_policy( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - device_id=device_id, - sync_module_states=sync_module_states, - forward_prefetch=forward_prefetch, - limit_all_gathers=limit_all_gathers, - use_orig_params=use_orig_params, - ) - return - torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() @@ -1815,89 +1787,6 @@ def register_comm_hook(self, state: object, hook: callable): submodule._communication_hook_state = state submodule._communication_hook = hook - def _init_param_exec_order_wrap_policy(self, *args, **kwargs) -> None: - auto_wrap_policy = kwargs["auto_wrap_policy"] - module = kwargs["module"] - assert hasattr(auto_wrap_policy, "tracing_config") - if not _TORCH_FX_AVAIL: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should be None when torch.fx is not enabled" - elif isinstance(auto_wrap_policy.tracing_config, TracingConfig): - tracer = auto_wrap_policy.tracing_config.tracer - execution_info = _init_execution_info(module) - - for m in module.modules(): - assert not isinstance( - m, FullyShardedDataParallel - ), "The input module of _patch_tracer should not contain FSDP modules" - - with _patch_tracer( - tracer=tracer, - root_module=module, - execution_info=execution_info, - ): - try: - tracer.trace(module, auto_wrap_policy.tracing_config.concrete_args) - except BaseException as e: - raise RuntimeError( - "tracer.trace failed inside _init_param_exec_order_wrap_policy" - f" with the error: {e}." - ) - else: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should either be an instance of TracingConfig or be None" - # The initial FSDP wrapping is done with auto_wrap_policy.init_policy - kwargs["auto_wrap_policy"] = auto_wrap_policy.init_policy - self.__init__(*args, **kwargs) - self._param_exec_order_policy: bool = True - # self._param_exec_order_prep_stage is set to True before we get the execution order - self._param_exec_order_prep_stage: bool = True - # A list that stores the flatten parameters and its name based on the parameter execution order - self._fsdp_params_exec_order: List[FlatParameter] = [] - if _TORCH_FX_AVAIL and isinstance( - auto_wrap_policy.tracing_config, TracingConfig - ): - # Initialize a dict that maps each module to its parent FSDP wrap - module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict() - for wrap in self.fsdp_modules(self): - module_to_fsdp[wrap.module] = wrap - # Set self._fsdp_params_exec_order based on execution_info.module_forward_order. - # TODO (linjianma): self._fsdp_params_exec_order will be set based on - # the parameter execution order rather than module_forward_order, - # once the non-recursive wrapping policy is fully implemented. - for m in execution_info.module_forward_order: - if m in module_to_fsdp: - for flat_param in module_to_fsdp[m].params: - self._fsdp_params_exec_order.append(flat_param) - self._param_exec_order_prep_stage = False - - for m in self.modules(): - if m is not self and isinstance(m, FullyShardedDataParallel): - # Assignment by reference, so each children FSDP wrap has access to - # the _fsdp_params_exec_order of the root module - m._fsdp_params_exec_order = self._fsdp_params_exec_order - m._param_exec_order_policy = self._param_exec_order_policy - m._param_exec_order_prep_stage = self._param_exec_order_prep_stage - - def _use_param_exec_order_policy(self) -> bool: - return ( - hasattr(self, "_param_exec_order_policy") and self._param_exec_order_policy - ) - - def _is_param_exec_order_prep_stage(self) -> bool: - is_prep_stage = ( - hasattr(self, "_param_exec_order_prep_stage") - and self._param_exec_order_prep_stage - ) - if not is_prep_stage: - for p in self.parameters(): - assert not hasattr( - p, "_params_exec_order_hook_handle" - ), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed." - return is_prep_stage - def _get_grad_norm( params: List[nn.Parameter], diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index c529bcde8c85..e20c07f18d13 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import contextlib -from dataclasses import dataclass +import functools +from abc import ABC, abstractmethod from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type import torch.nn as nn @@ -17,22 +18,84 @@ "size_based_auto_wrap_policy", "enable_wrap", "wrap", - "ParamExecOrderWrapPolicy", + "ModuleWrapPolicy", ] def always_wrap_policy(*args, **kwargs) -> bool: """ - A simple wrapper policy that always returns ``True``, - i.e. when passed as the `auto_wrap_policy` into FSDP, - this will result in all submodules being wrapped as - distinct FSDP instances. + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. """ return True +class _FSDPPolicy(ABC): + """ + This defines an abstract base class that represents an FSDP policy for + constructing ``FlatParameter`` s. + """ + + # The motivation for this abstract base class is to hide the interface + # expected by `_recursive_wrap()` from users (i.e. the `recurse` argument). + def __init__(self): + ... + + @property + @abstractmethod + def policy(self) -> Callable: + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: Set[Type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_FSDPPolicy): + """This is a wrapper around :func:`_module_wrap_policy`.""" + + def __init__(self, module_classes: Set[Type[nn.Module]]): + self._policy: Callable = functools.partial( + _module_wrap_policy, + module_classes=module_classes, + ) + + @property + def policy(self): + return self._policy + + def lambda_auto_wrap_policy( - module: nn.Module, recurse: bool, unwrapped_params: int, lambda_fn: Callable + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable ) -> bool: """ A convenient auto wrap policy to wrap submodules based on an arbitrary user @@ -44,70 +107,34 @@ def lambda_auto_wrap_policy( The first three parameters are required by :func:`_recursive_wrap`. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - lambda_fn (Callable[nn.Module] -> bool): - If this returns ``True``, this module will be wrapped by - wrapper_cls individually. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. """ if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return lambda_fn(module) + return True # always recurse + return lambda_fn(module) def transformer_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, transformer_layer_cls: Set[Type[nn.Module]], ) -> bool: """ - A convenient auto wrap policy for transformer models. If the submodule - is an instance of transformer_layer_cls, the submodule will be wrapped - as a FSDP unit. Otherwise, all the other remainder submodules are wrapped - by the outermost FSDP unit. Right now, FSDP requires submodules that share - weights to be wrapped in the same FSDP unit, this auto wrap policy can - conviniently wrap the shared embeddings into the same FSDP unit for transformer - models. In the near future, FSDP will support submodules that share weights - to be wrapped in the separated FSDP units. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are required by :func:`_recursive_wrap`. - - - Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - transformer_layer_cls (int): - Submodules with one of the `transformer_layer_cls` names - will be wrapped as separated FSDP units + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. """ - if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return isinstance(module, tuple(transformer_layer_cls)) + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) def _wrap_batchnorm_individually( @@ -117,7 +144,7 @@ def _wrap_batchnorm_individually( **kwargs, ) -> bool: """ - A policy that wraps ``BatchNorm`` instances in their own FSDP unit. + A policy that wraps ``BatchNorm`` instances in their own FSDP instance. """ if recurse: # always recurse @@ -131,52 +158,46 @@ def _wrap_batchnorm_individually( def _or_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, policies, ) -> bool: """ A policy that wraps ``module`` if any policy in the passed in iterable of ``policies`` returns ``True``. """ - return any(policy(module, recurse, unwrapped_params) for policy in policies) + return any(policy(module, recurse, nonwrapped_numel) for policy in policies) def size_based_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, - # These are customizable for this policy function. + nonwrapped_numel: int, + # Additional custom arguments min_num_params: int = int(1e8), force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, ) -> bool: - """A size based auto_wrap_policy function for FSDP API. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are used by :func:`_recursive_wrap`. If - you write a custom version of this policy function, your version - needs to at least accept the first three parameters and free - to do whatever you want in the function. + """ + A size-based auto wrap policy. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - min_num_params (int): - Customizable policy input. It controls the size threshold - on how big should a module be to be considered wrapped. - force_leaf_modules (Set[Type[nn.Module]]): set of module types to - keep as leaves, i.e., their children will never be wrapped. - exclude_wrap_modules (Set[Type[nn.Module]]): - Customizable set of module types to be excluded in wrapping. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. """ force_leaf_modules = ( size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] @@ -189,7 +210,10 @@ def size_based_auto_wrap_policy( else exclude_wrap_modules ) - is_large = unwrapped_params >= min_num_params + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel if recurse: # We should recurse if the module is big enough but not in force_leaf_modules list. return is_large and not isinstance(module, tuple(force_leaf_modules)) @@ -276,56 +300,6 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: return module -@dataclass -class ParamExecOrderWrapPolicy: - """ - This is the class used for the wrapping policy that wraps parameters and performs - the communication scheduling based on the parameter execution order in the forward pass - (also called non-recursive wrapping policy). - - The policy contains multiple wraps. Each wrap contains original parameters that will be executed together, - and the wrap transfers these parameters into one ``FlattenParameter``. In both forward and the backward passes, - the sharded parameters in each wrap will be gathered just before these parameters are used in the passes. - These parameters will then be reshaded once they have been used. - - TODO (linjianma): For now, the parameters contained in each wrap of ``ParamExecOrderWrapPolicy`` - are the parameters in each wrap of the ``init_policy`` (a recursive wrapping policy). - Later we will wrap parameters based on bucket size. - - Args: - init_policy (Callable): - The initial recursive wrapping policy used to guide the wrapping of - this policy. If tracing_config is none, in the first forward and - backward iteration, ``init_policy`` is used to record parameter - execution order. Otherwise, init_policy is only used in FSDP - constructor for module level wrapping. - - The default ``always_wrap_policy`` might not be the best choice for every model. For example, for - transformer based models, setting ``transformer_auto_wrap_policy`` as the ``init_policy`` will guarantee - wrapping each transformer layer into one FSDP unit, and can be easily combined with checkpointing - within each transformer layer. - - tracing_config (Optional[TracingConfig]): - The configuration used to perform symbolic tracing at FSDP - constructor to get the module and parameter execution order. The - type of ``tracing_config`` needs to be either ``None`` or - ``TracingConfig``. If set as ``None``, then symbolic tracing is not - enabled, and one forward as well as backward iteration are needed to - get the parameter execution order. - - ..warning :: Note that not all modules can be successfully traced when - ``tracing_config`` is not None and symbolic tracing is enabled. The two - cases below may be unable to trace: 1. when there is a data-dependent - branch, 2. when the forward pass contains operators that don't support - ``torch.fx.Proxy`` as the input type (e.g. ``arange``, ``zeros``, ``ones``, - ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases, - users can set ``tracing_config = None`` to disable symbolic tracing. - """ - - init_policy: Callable = always_wrap_policy - tracing_config: Any = None - - def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: assert wrapper_cls is not None if hasattr(module, "_wrap_overrides"): @@ -349,13 +323,13 @@ def _recursive_wrap( **kwargs: Any, ) -> Tuple[nn.Module, int]: """ - Automatically wrap child modules of *module* that meet the given - criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap. + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + Args: - module (nn.Module): - module to recursively wrap - auto_wrap_policy (Callable): - A callable specifying a policy to recursively wrap layers with FSDP. + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. ignored_modules (Set[torch.nn.Module]): Modules to ignore when wrapping. ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when @@ -363,7 +337,7 @@ def _recursive_wrap( in ``ignored_modules``. Returns: (nn.Module, int): - Wrapped module and the number parameters wrapped recursively. + ``module`` after wrapping and the numel recursively wrapped. """ assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." assert wrapper_cls is not None, "Must specify wrapper_cls" @@ -378,11 +352,13 @@ def _recursive_wrap( pass # We count all params, assuming none of them are already wrapped. - num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params) + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) assert auto_wrap_policy is not None - if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): - total_wrapped_params = 0 + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 # Iterate through the children, recursively wrap if necessary for name, child in module.named_children(): if child in ignored_modules: @@ -397,17 +373,17 @@ def _recursive_wrap( ) setattr(module, name, wrapped_child) # Keep track of how many parameters have been wrapped - total_wrapped_params += num_wrapped_params + total_wrapped_numel += num_wrapped_params # decide if we need to wrap the current module, # since the left over parameters exceed the number of params to wrap - remainder = num_params - total_wrapped_params + remainder = nonwrapped_numel - total_wrapped_numel if not only_wrap_children and auto_wrap_policy( - module=module, recurse=False, unwrapped_params=remainder + module=module, recurse=False, nonwrapped_numel=remainder ): # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **kwargs), num_params + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel else: - return module, total_wrapped_params + return module, total_wrapped_numel return module, 0 diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 0dca22f48092..b4650adff569 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from abc import ABC, abstractmethod @@ -21,11 +20,7 @@ ShardingStrategy, ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp.wrap import ( - always_wrap_policy, - transformer_auto_wrap_policy, - wrap, -) +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS @@ -285,8 +280,8 @@ def init( fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap any modules with FSDP. If ``RECURSIVE``, then wraps with top-level FSDP. By default, the top-level FSDP uses the - ``transformer_auto_wrap_policy()`` for encoder and decoder - layers, but a different auto wrap policy may be specified via + ``ModuleWrapPolicy`` for encoder and decoder layers, but a + different auto wrap policy may be specified via ``fsdp_kwargs``. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments @@ -302,14 +297,13 @@ def init( group, cuda_init_mode, add_bn, deterministic ) elif fsdp_init_mode == FSDPInitMode.RECURSIVE: - # Default to the `transformer_auto_wrap_policy()` + # Default to the `ModuleWrapPolicy` if "auto_wrap_policy" not in fsdp_kwargs: - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + auto_wrap_policy = ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ) else: auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") From c83348597b195f2da1cca0e8318c878b104bce5d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 12 Nov 2022 04:45:17 +0000 Subject: [PATCH 097/453] [dynamo][api] Better support of torch.nn.Module (#88629) This is an API change, so please review carefully. With this PR, torchdynamo returns an `OptimizedModule` class object, a subclass of `torch.nn.Module`, when asked to optimize a `nn.Module` object. Most of the methods are redirected to the original `nn.Module`, which is installed as `_mod` in the `OptimizedModule`. This is helpful for many cases ``` mod = MockModule() opt_mod = torch._dynamo.optimize()(mod) print(opt_mod) # Works opt_mod = opt_mod.to(device="cuda") print(opt_mod) # Works opt_mod(input) # Triggers recompile if necessary, earlier we were shedding the TorchDynamo wrapper opt_mod.parameters() # Refers to the original module ``` Topics unclear to me * I have overridden many methods to raise NotImplementedError. A careful review of those will be good. * hooks * For the optimized forward, should we call torchdynamo optimization on `__call__` or `forward` * What else to test Pull Request resolved: https://github.com/pytorch/pytorch/pull/88629 Approved by: https://github.com/Chillee, https://github.com/jansel, https://github.com/msaroufim --- test/dynamo/test_modules.py | 127 +++++++++++++++++++++++++++++++++++ torch/_dynamo/__init__.py | 2 + torch/_dynamo/debug_utils.py | 8 +++ torch/_dynamo/eval_frame.py | 74 ++++++++++++++------ torch/_dynamo/testing.py | 13 ++++ 5 files changed, 204 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 2fb83b3add6c..930035f99a30 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,6 +904,133 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.relu(self.linear(x) + self.buf0) + + +class OptimizedModuleTest(torch._dynamo.test_case.TestCase): + def test_nn_module(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_to(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 1) + + opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + x = torch.randn(10, 10).to(dtype=torch.float64) + opt_mod(x) + # Ensure that there is a recompilation + self.assertEqual(cnt.frame_count, 2) + + def test_attr(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.r(torch.sin(x)) + self.buf0 + + mod = MockModule() + opt_mod = torch._dynamo.optimize("eager")(mod) + + # Check parameteres and buffers + for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): + self.assertTrue(id(p1) == id(p2)) + + def test_recursion(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + + for _ in range(5): + opt_mod = torch._dynamo.optimize(cnt)(opt_mod) + opt_mod(torch.randn(10, 10)) + self.assertEqual(cnt.frame_count, 1) + + def test_composition(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + opt_inner_mod = InnerModule() + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_composition_with_opt_mod(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + inner_mod = InnerModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + # There will be a graph break for the inner mod being OptimizedModule + self.assertEqual(cnt.frame_count, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 80f927aeef2f..5eee609b0852 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,6 +7,7 @@ export, optimize, optimize_assert, + OptimizedModule, reset_code, run, skip, @@ -25,6 +26,7 @@ "reset", "list_backends", "skip", + "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 98a269fe8c9e..29d830167b10 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -515,8 +515,16 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ + from .eval_frame import OptimizedModule + from .testing import named_parameters_for_optimized_module from .utils import same + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8d9e3b7b6aa1..20e8c7de085e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,6 +5,7 @@ import logging import os import sys +import textwrap import threading import traceback import types @@ -44,6 +45,27 @@ most_recent_backend = None +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + def __init__(self, mod): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def forward(self, *args, **kwargs): + # This will be monkey patched later + raise RuntimeError("Should not be here") + + def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -118,31 +140,15 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - optimized_forward = self(mod.forward) - - class TorchDynamoNNModuleWrapper: - """ - A wrapper that redirects the forward call to the optimized - forward, while for rest it redirects the calls to the original - module. - """ - - def __getattr__(self, name): - return getattr(mod, name) - - def forward(self, *args, **kwargs): - return optimized_forward(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - new_mod = TorchDynamoNNModuleWrapper() + new_mod = OptimizedModule(mod) + new_mod.forward = self(mod.forward) # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod + new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) + callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -184,6 +190,34 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please optimize the individual methods that you are interested in optimizing. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function + + >> mod.__call__ = torch._dynamo.optimize(mod.__call__) + + """ + ) + ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d6082ce48acf..b37299ffd579 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,6 +32,17 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def remove_optimized_module_prefix(name): + prefix = "_orig_mod." + assert name.startswith(prefix) + return name[len(prefix) :] + + def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -44,6 +55,8 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 34641c4384328ad9a3d2dc928de5b60a239427ee Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 05:16:41 +0000 Subject: [PATCH 098/453] Revert "Add comprehensive minifier tests (#88022)" This reverts commit 5ff600aa6e40c6b4d426594bbb1f446f005b7fb3. Reverted https://github.com/pytorch/pytorch/pull/88022 on behalf of https://github.com/wconstab due to Seems to be causing CI failures relating to minifier test and some /tmp/ path not existing --- test/dynamo/test_minifier.py | 630 ++++------------------------------- torch/_dynamo/debug_utils.py | 78 +---- 2 files changed, 76 insertions(+), 632 deletions(-) diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 51b79a5e7511..0cec7d202a9d 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,138 +1,27 @@ # Owner(s): ["module: dynamo"] -import functools import os -import re import shutil -import subprocess -import textwrap import unittest +from unittest.mock import patch import torch import torch._dynamo import torch._dynamo.test_case import torch._dynamo.testing -import torch._inductor.utils -from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT +from torch._dynamo.optimizations.backends import create_backend -_HAS_TRITON = torch._inductor.utils.has_triton() -requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") -RELU_COMPILE_ERROR_BACKEND = """\ -from torch._dynamo.optimizations.backends import register_backend +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() -class DynamoCompileError(Exception): - pass - -@register_backend -def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): - for node in gm.graph.nodes: - if node.target == torch.relu: - raise DynamoCompileError("relu found") - return gm -""" - -RELU_RUNTIME_ERROR_BACKEND = """\ -import copy -from torch._dynamo.optimizations.backends import register_backend - -@register_backend -def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): - gm = copy.deepcopy(gm) - for node in gm.graph.nodes: - if node.target == torch.relu: - node.target = torch._assert - node.args = (False, "DynamoRuntimeError") - gm.recompile() - return gm -""" - -RELU_ACCURACY_ERROR_BACKEND = """\ -import copy -from torch._dynamo.optimizations.backends import register_backend - -@register_backend -def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): - gm = copy.deepcopy(gm) - for node in gm.graph.nodes: - if node.target == torch.relu: - node.target = torch.add - node.args = (node.args[0], 1) - gm.recompile() - - return gm -""" - -RELU_CUSTOM_ERROR_BACKEND = """\ -class CustomError(Exception): - pass - -def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): - for node in gm.graph.nodes: - if node.target == torch.relu: - raise CustomError("relu found") - return gm -""" - -CPP_COMPILE_ERROR = """\ -def cpp_compile_error(x): - return "compile error!" -""" - -CPP_RUNTIME_ERROR = """\ -def cpp_runtime_error(x): - return f"{x}; throw 1" -""" - -CPP_ACCURACY_ERROR = """\ -def cpp_accuracy_error(x): - return f"{x} + 1" -""" - -TRITON_COMPILE_ERROR = """\ -def triton_compile_error(x): - return "compile error!" -""" - -# NOTE: there is currently not an easy way to cause a triton runtime error. -TRITON_RUNTIME_ERROR = """\ -def triton_runtime_error(x): - return f"{x}; assert?" -""" - -TRITON_ACCURACY_ERROR = """\ -def triton_accuracy_error(x): - return f"{x} + 1" -""" - -DEBUG_DIR = "/tmp/_torchdynamo_debug_/" - -# Search for the name of the first function defined in a code string. -def get_fn_name(code): - fn_name_match = re.search(r"def (\w+)\(", code) - if fn_name_match is not None: - return fn_name_match.group(1) - return None - - -# Generates code that patches CppOverrides/TritonOverrides. -def gen_codegen_fn_patch_code(old_fn_name, new_fn_code, device): - new_fn_name = get_fn_name(new_fn_code) - if new_fn_name is not None: - patch_code = f"""\ -import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen -overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} -{new_fn_code} -overrides.{old_fn_name} = staticmethod({new_fn_name}) -""" - return f"""\ -{patch_code} -isolate_fails_code_str = \"\"\"\\ -{patch_code} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -\"\"\" -""" - - return None + def forward(self, x): + for _ in range(10): + x = torch.sin(x) + x = torch._foobar(x) + for _ in range(10): + x = torch.cos(x) + return x class MinfierTests(torch._dynamo.test_case.TestCase): @@ -143,10 +32,9 @@ def setUpClass(cls): unittest.mock.patch.object( torch._dynamo.config, "debug_dir_root", - DEBUG_DIR, + "/tmp/_torchdynamo_debug_/", ) ) - os.makedirs(DEBUG_DIR, exist_ok=True) @classmethod def tearDownClass(cls): @@ -159,455 +47,65 @@ def setUp(self): def tearDown(self): super().tearDown() - # Run `code` in a separate python process. - # Returns the completed process state and the directory containing the - # minifier launcher script, if `code` outputted it. - def _run_test_code(self, code): - proc = subprocess.run( - ["python3", "-c", code], capture_output=True, cwd=DEBUG_DIR - ) - - repro_dir_match = re.search( - r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") - ) - if repro_dir_match is not None: - # Print repro directory for debugging generated code. - # Make sure to comment out `shutil.rmtree...` above as well. - print("repro dir:", repro_dir_match.group(1)) - return proc, repro_dir_match.group(1) - return proc, None - - # Patch generated files with testing patches - def _inject_code(self, patch_code, filename): - patch_code = f"""\ -{patch_code} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -""" - with open(filename, "r") as f: - code = f.read() - code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) - with open(filename, "w") as f: - f.write(code) - return code - - # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. - def _run_minifier_launcher(self, patch_code, repro_dir): - self.assertIsNotNone(repro_dir) - launch_file = os.path.join(repro_dir, "minifier_launcher.py") - self.assertTrue(os.path.exists(launch_file)) - launch_code = self._inject_code(patch_code, launch_file) - - launch_proc = subprocess.run( - ["python3", launch_file], - capture_output=True, - cwd=repro_dir, - ) - - return launch_proc, launch_code - - # Runs the repro script in `repro_dir`, patched with `patch_code` - def _run_repro(self, patch_code, repro_dir): - self.assertIsNotNone(repro_dir) - repro_file = os.path.join(repro_dir, "repro.py") + def test_after_dynamo(self): + @create_backend + def bad_dynamo_backend(subgraph): + import sys + + def f(*args): + # Shifted the forced exception to runtime as this is more common + # in JIT compilers. + for node in subgraph.model.graph.nodes: + if node.op == "call_function" and node.target is torch._foobar: + sys.stdout.write("Dynamo compiled failed\n") + raise NotImplementedError("foobar is not implemented") + return subgraph.model(*args) + + return f + + mod = MockModule() + opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) + repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() + + @patch.object(torch._dynamo.config, "repro_after", "dynamo") + def inner(): + x = torch.randn(4) + try: + opt_mod(x) + except Exception: + pass + + inner() self.assertTrue(os.path.exists(repro_file)) - repro_code = self._inject_code(patch_code, repro_file) - - repro_proc = subprocess.run( - ["python3", repro_file], capture_output=True, cwd=repro_dir - ) - - return repro_proc, repro_code - - # Template for testing code. - # `run_code` is the code to run for the test case. - # `patch_code` is the code to be patched in every generated file. - def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): - return f"""\ -import torch -import torch._dynamo -{patch_code} -torch._dynamo.config.repro_after = "{repro_after}" -torch._dynamo.config.repro_level = {repro_level} -torch._dynamo.config.debug_dir_root = "{DEBUG_DIR}" -{run_code} -""" - - # Runs a full minifier test. - # Minifier tests generally consist of 3 stages: - # 1. Run the problematic code (in a separate process since it could segfault) - # 2. Run the generated minifier launcher script - # 3. Run the generated repro script - def _run_full_test(self, run_code, repro_after, repro_level, patch_code): - test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) - test_proc, repro_dir = self._run_test_code(test_code) - self.assertIsNotNone(repro_dir) - launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) - repro_proc, repro_code = self._run_repro(patch_code, repro_dir) - return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) - - # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) - def _test_after_dynamo(self, device, repro_level, backend_code, error_name): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{get_fn_name(backend_code)}") - def inner(x): - for _ in range(10): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(10): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "dynamo", repro_level, backend_code - ) - - self.assertIn(error_name, test_proc.stderr.decode("utf-8")) - self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) - - def test_after_dynamo_cpu_compile_error(self): - self._test_after_dynamo( - "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" - ) - - def test_after_dynamo_cpu_runtime_error(self): - self._test_after_dynamo( - "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" - ) - - def test_after_dynamo_cpu_accuracy_error(self): - self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") - - @requires_cuda() - def test_after_dynamo_cuda_compile_error(self): - self._test_after_dynamo( - "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" - ) - - @requires_cuda() - def test_after_dynamo_cuda_runtime_error(self): - self._test_after_dynamo( - "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" - ) - - @requires_cuda() - def test_after_dynamo_cuda_accuracy_error(self): - self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") - - # Ensure that the testing backends pass when relu is not present. - def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{get_fn_name(backend_code)}") - def inner(x): - for _ in range(10): - x = torch.sin(x) - for _ in range(10): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - - test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) - proc, repro_dir = self._run_test_code(test_code) - self.assertEqual(proc.returncode, 0) - self.assertIsNone(repro_dir) - - def test_after_dynamo_cpu_compile_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) - - def test_after_dynamo_cpu_runtime_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) - - def test_after_dynamo_cpu_accuracy_backend_passes(self): - self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) - @requires_cuda() - def test_after_dynamo_cuda_compile_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + # If error_at_aot is True, an error will be produced when AOTAutograd + # attempts to generate the backward graph. + # If error_after_aot is False, an error will be produced in inductor. + def _test_around_aot(self, error_at_aot): + mod = MockModule() + opt_mod = torch._dynamo.optimize("inductor")(mod) - @requires_cuda() - def test_after_dynamo_cuda_runtime_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() + repro_after = "dynamo" if error_at_aot else "aot" - @requires_cuda() - def test_after_dynamo_cuda_accuracy_backend_passes(self): - self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + @patch.object(torch._dynamo.config, "repro_after", repro_after) + def inner(): + x = torch.randn(4) + x.requires_grad = error_at_aot + try: + opt_mod(x) + except Exception: + pass - # Ensure that generated code with a custom backends generates a runnable minifier - # launcher script that results in a RuntimeError - def test_after_dynamo_custom_backend(self): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize({get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) - def inner(x): - for _ in range(10): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(10): - x = torch.cos(x) - return x + inner() - inner(torch.randn(20, 20)) - """ - ) - - test_code = self._gen_test_code( - run_code, "dynamo", 2, RELU_CUSTOM_ERROR_BACKEND - ) - _, repro_dir = self._run_test_code(test_code) - launch_proc, launch_code = self._run_minifier_launcher("", repro_dir) - self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) - - # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd - @requires_cuda() - def test_cpu_cuda_module_after_dynamo(self): - backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) - - run_code = textwrap.dedent( - f"""\ - class CpuCudaModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.m_x = torch.nn.Linear(20, 20).cuda() - self.m_y = torch.nn.Linear(20, 20) - self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) - self.p_y = torch.nn.Parameter(torch.randn(20, 20)) - self.register_buffer("b_x", torch.ones(20, 20).cuda()) - self.register_buffer("b_y", torch.ones(20, 20)) - - def forward(self, x, y): - return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y - - mod = CpuCudaModule() - - @torch._dynamo.optimize("{backend_name}") - def inner(x1, y1): - x2 = torch.randn(20, 20).cuda() - y2 = torch.randn(20, 20) - x3, y3 = mod(x1 + x2, y1 + y2) - return torch.relu(x3.cpu() + y3) - - inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) - """ - ) - - (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( - run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND - ) - - tb1 = test_proc.stderr.decode("utf-8") - tb2 = repro_proc.stderr.decode("utf-8") - - # Check if generated minifier code covers all cpu/cuda cases - self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) - self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) - # search for Linear(...).cuda() - self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) - # search for Linear(...) - self.assertIsNotNone( - re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) - self.assertIsNotNone( - re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) - self.assertIsNotNone( - re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) - ) - # search for - # = torch.randn(...) - # ... = .cuda() - self.assertIsNotNone( - re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) - ) - # search for - # = torch.randn(...) - # no followup call to .cuda() - self.assertIsNotNone( - re.search( - r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL - ) - ) - - self.assertIn(backend_name, tb1) - self.assertIn(backend_name, tb2) - - # Test if we can actually get a minified graph - def test_if_graph_minified(self): - backend_name = get_fn_name(RELU_COMPILE_ERROR_BACKEND) - - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("{backend_name}") - def inner(x): - for _ in range(20): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(20): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20)) - """ - ) - - (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( - run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND - ) - - tb1 = test_proc.stderr.decode("utf-8") - tb2 = repro_proc.stderr.decode("utf-8") - - self.assertIn(backend_name, tb1) - self.assertIn(backend_name, tb2) - - # compare the length of the forward functions - match = re.search(r"def forward.*return", launch_code, re.DOTALL) - self.assertIsNotNone(match) - self.assertGreater(match.group(0).count("\n"), 40) - - match = re.search(r"def forward.*return", repro_code, re.DOTALL) - self.assertIsNotNone(match) - self.assertLess(match.group(0).count("\n"), 5) - - # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) - def _test_after_aot(self, device, backend_code, repro_level): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "aot", repro_level, patch_code - ) - return ( - (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), - (test_proc.returncode, repro_proc.returncode), - ) - - def test_after_aot_cpu_compile_error(self): - (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) - self.assertIn("CppCompileError", tb1) - self.assertIn("CppCompileError", tb2) - - def test_after_aot_cpu_accuracy_error(self): - (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) - self.assertIn("AccuracyError", tb1) - self.assertIn("AccuracyError", tb2) - - @requires_cuda() - def test_after_aot_cuda_compile_error(self): - (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) - self.assertIn("SyntaxError", tb1) - self.assertIn("SyntaxError", tb2) - - @requires_cuda() - def test_after_aot_cuda_accuracy_error(self): - (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) - self.assertIn("AccuracyError", tb1) - self.assertIn("AccuracyError", tb2) - - # Test that runtime errors after aot can be repro'd (CPU only for now) - def _test_after_aot_runtime_error(self, device, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - x = torch.relu(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - - (test_proc, _, repro_proc), _ = self._run_full_test( - run_code, "aot", 3, patch_code - ) - - self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) - - self.assertEqual(test_proc.returncode, repro_proc.returncode) - self.assertNotEqual(test_proc.returncode, 0) - - def test_after_aot_cpu_runtime_error(self): - self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) - - # NOTE: there is currently not an easy way to cause a triton runtime error. - @unittest.skip - @requires_cuda() - def test_after_aot_cuda_runtime_error(self): - self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) - - # Ensure that inductor codegen patches pass when relu is not present. - def _test_after_aot_backend_passes(self, device, repro_level, backend_code): - run_code = textwrap.dedent( - f"""\ - @torch._dynamo.optimize("inductor") - def inner(x): - for _ in range(3): - x = torch.sin(x) - for _ in range(3): - x = torch.cos(x) - return x - - inner(torch.randn(20, 20).to("{device}")) - """ - ) - patch_code = gen_codegen_fn_patch_code("relu", backend_code, device) - self.assertIsNotNone(patch_code) - - test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) - proc, repro_dir = self._run_test_code(test_code) - self.assertEqual(proc.returncode, 0) - self.assertIsNone(repro_dir) - - def test_after_aot_cpu_compile_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) - - def test_after_aot_cpu_runtime_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) - - def test_after_aot_cpu_accuracy_backend_passes(self): - self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) - - @requires_cuda() - def test_after_aot_cuda_compile_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) + self.assertTrue(os.path.exists(repro_file)) - # NOTE: there is currently not an easy way to cause a triton runtime error. - @unittest.skip - @requires_cuda() - def test_after_aot_cuda_runtime_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) + def test_at_aot(self): + self._test_around_aot(True) - @requires_cuda() - def test_after_aot_cuda_accuracy_backend_passes(self): - self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) + def test_after_aot(self): + self._test_around_aot(False) if __name__ == "__main__": diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 29d830167b10..089ef172d625 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -84,11 +84,6 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" - # module should be a core torch.nn.Module, so all parameters - # should be on the same device. - example_param = next(module.parameters(), None) - if example_param is not None and example_param.is_cuda: - module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -100,16 +95,12 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) - if buffer.is_cuda: - tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" - if param.is_cuda: - tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -154,9 +145,6 @@ def _cuda_system_info_comment(): return model_str -TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" - - def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -167,8 +155,6 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx - {TEST_REPLACEABLE_COMMENT} - """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -184,7 +170,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += "mod = make_fx(Repro())(*args)\n" + model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' return model_str @@ -211,8 +197,7 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - curdir = os.getcwd() - repro_path = os.path.join(curdir, "repro.py") + repro_path = os.path.join(config.base_dir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -231,10 +216,7 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - class AccuracyError(Exception): - pass - if not same_two_models(mod, compiled, args, only_fwd=True): - raise AccuracyError("Bad accuracy detected") + assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" """ ) ) @@ -249,7 +231,7 @@ class AccuracyError(Exception): ) -def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None): if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -257,10 +239,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - repro_code = generate_compiler_repro_string(fx_g, args) - if patch_code is not None: - repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) - fd.write(repro_code) + fd.write(generate_compiler_repro_string(fx_g, args)) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -284,7 +263,6 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], - cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -351,8 +329,6 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" -isolate_fails_code_str = None - {generate_compiler_repro_string(gm, args)} from functools import partial @@ -367,7 +343,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -375,10 +351,6 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) -class AccuracyError(Exception): - pass - - def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -438,7 +410,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise AccuracyError("Bad accuracy detected") + raise ValueError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -463,8 +435,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - log.error("CompilerError") - raise + raise e if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -581,14 +552,9 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() - -class AccuracyError(Exception): - pass - with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - if not same_two_models(mod, opt_mod, args): - raise AccuracyError("Dynamo failed") + assert same_two_models(mod, opt_mod, args), "Dynamo failed" """ ) @@ -603,14 +569,12 @@ class AccuracyError(Exception): from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models -{TEST_REPLACEABLE_COMMENT} - args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro() +mod = Repro().cuda() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -749,21 +713,6 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" - custom_compiler_error = ( - textwrap.dedent( - """\ - raise RuntimeError( - 'Compiler name is None - this likely means that a custom compiler ' - 'was called by torchdynamo. Please remove this error, import your ' - 'custom compiler function, and replace the compiler_name="None" ' - 'line below to compiler_name=' - ) - """ - ) - if compiler_name is None - else "" - ) - contents = textwrap.dedent( f""" import os @@ -777,17 +726,14 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided -{TEST_REPLACEABLE_COMMENT} - args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro() +mod = Repro().cuda() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] -{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -831,7 +777,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc = AccuracyError("Bad accuracy detected.") + exc = ValueError("Bad accuracy detected.") exc.minifier_path = os.path.join( minifier_dir(), "minifier_launcher.py" ) From 6b775c42dd2d40992611fb5636e787560663902c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 12 Nov 2022 07:52:44 +0000 Subject: [PATCH 099/453] [quant][executorch] Support quant fusion for reshape in quant in executorch stack (#88858) Summary: This diff added support for fusing "dq - reshape - q" to a reshape op, the op is needed in wakeword model Test Plan: buck test executorch/exir/tests:quant_fusion_pass Reviewed By: qihqi, JacobSzwejbka Differential Revision: D41111069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88858 Approved by: https://github.com/JacobSzwejbka --- torch/_C/__init__.pyi.in | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d20da2a04f3..5833d7d7f2a4 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -972,11 +972,14 @@ class AggregationType(Enum): AVG = 1 class FileCheck(object): - # TODO (add more FileCheck signature) - def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... def run(self, test_string: str) -> None: ... def check(self, test_string: str) -> 'FileCheck': ... def check_not(self, test_string: str) -> 'FileCheck': ... + def check_same(self, test_string: str) -> 'FileCheck': ... + def check_next(self, test_string: str) -> 'FileCheck': ... + def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ... + def check_dag(self, test_string: str) -> 'FileCheck': ... + def check_source_highlighted(self, test_string: str) -> 'FileCheck': ... ... # Defined in torch/csrc/jit/python/init.cpp From ae2c668cc044d841853e2672d96bfe0afb38a89c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 12 Nov 2022 07:52:53 +0000 Subject: [PATCH 100/453] Revert "[dynamo][api] Better support of torch.nn.Module (#88629)" This reverts commit c83348597b195f2da1cca0e8318c878b104bce5d. Reverted https://github.com/pytorch/pytorch/pull/88629 on behalf of https://github.com/anijain2305 due to job failing on master https://github.com/pytorch/pytorch/actions/runs/3449914495/jobs/5758267231 --- test/dynamo/test_modules.py | 127 ----------------------------------- torch/_dynamo/__init__.py | 2 - torch/_dynamo/debug_utils.py | 8 --- torch/_dynamo/eval_frame.py | 74 ++++++-------------- torch/_dynamo/testing.py | 13 ---- 5 files changed, 20 insertions(+), 204 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 930035f99a30..2fb83b3add6c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,133 +904,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.relu(self.linear(x) + self.buf0) - - -class OptimizedModuleTest(torch._dynamo.test_case.TestCase): - def test_nn_module(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_to(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - # Ensure that there is no recompilation - opt_mod(x) - self.assertEqual(cnt.frame_count, 1) - - opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - x = torch.randn(10, 10).to(dtype=torch.float64) - opt_mod(x) - # Ensure that there is a recompilation - self.assertEqual(cnt.frame_count, 2) - - def test_attr(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.r(torch.sin(x)) + self.buf0 - - mod = MockModule() - opt_mod = torch._dynamo.optimize("eager")(mod) - - # Check parameteres and buffers - for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): - self.assertTrue(id(p1) == id(p2)) - - def test_recursion(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - - for _ in range(5): - opt_mod = torch._dynamo.optimize(cnt)(opt_mod) - opt_mod(torch.randn(10, 10)) - self.assertEqual(cnt.frame_count, 1) - - def test_composition(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - opt_inner_mod = InnerModule() - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_composition_with_opt_mod(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - inner_mod = InnerModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - # There will be a graph break for the inner mod being OptimizedModule - self.assertEqual(cnt.frame_count, 2) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 5eee609b0852..80f927aeef2f 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,7 +7,6 @@ export, optimize, optimize_assert, - OptimizedModule, reset_code, run, skip, @@ -26,7 +25,6 @@ "reset", "list_backends", "skip", - "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 089ef172d625..f09991f9bf34 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ - from .eval_frame import OptimizedModule - from .testing import named_parameters_for_optimized_module from .utils import same - if isinstance(gm, OptimizedModule): - gm.named_parameters = named_parameters_for_optimized_module(gm) - - if isinstance(opt_gm, OptimizedModule): - opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) - ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 20e8c7de085e..8d9e3b7b6aa1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,7 +5,6 @@ import logging import os import sys -import textwrap import threading import traceback import types @@ -45,27 +44,6 @@ most_recent_backend = None -class OptimizedModule(torch.nn.Module): - """ - Wraps the original nn.Module object and later patches its - forward method to optimized self.forward method. - """ - - def __init__(self, mod): - super().__init__() - # Installs the params/buffer - self._orig_mod = mod - - def __getattr__(self, name): - if name == "_orig_mod": - return self._modules["_orig_mod"] - return getattr(self._orig_mod, name) - - def forward(self, *args, **kwargs): - # This will be monkey patched later - raise RuntimeError("Should not be here") - - def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -140,15 +118,31 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - new_mod = OptimizedModule(mod) - new_mod.forward = self(mod.forward) + optimized_forward = self(mod.forward) + + class TorchDynamoNNModuleWrapper: + """ + A wrapper that redirects the forward call to the optimized + forward, while for rest it redirects the calls to the original + module. + """ + + def __getattr__(self, name): + return getattr(mod, name) + + def forward(self, *args, **kwargs): + return optimized_forward(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + new_mod = TorchDynamoNNModuleWrapper() # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod.forward + new_mod._torchdynamo_orig_callable = mod return new_mod assert callable(fn) - callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -190,34 +184,6 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): - if not hasattr(fn, "__code__"): - raise RuntimeError( - textwrap.dedent( - """ - - torch._dynamo.optimize is called on a non function object. - If this is a callable class, please optimize the individual methods that you are interested in optimizing. - - >> class CallableClass: - >> def __init__(self): - >> super().__init__() - >> self.relu = torch.nn.ReLU() - >> - >> def __call__(self, x): - >> return self.relu(torch.sin(x)) - >> - >> def print_hello(self): - >> print("Hello world") - >> - >> mod = CallableClass() - - If you want to optimize the __call__ function - - >> mod.__call__ = torch._dynamo.optimize(mod.__call__) - - """ - ) - ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index b37299ffd579..d6082ce48acf 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,17 +32,6 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) -def named_parameters_for_optimized_module(mod): - assert isinstance(mod, eval_frame.OptimizedModule) - return mod._orig_mod.named_parameters - - -def remove_optimized_module_prefix(name): - prefix = "_orig_mod." - assert name.startswith(prefix) - return name[len(prefix) :] - - def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -55,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): - if isinstance(model, eval_frame.OptimizedModule): - name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 6e5f736d86be09bd86a5da276ce2f5dcbe0bfc09 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 11 Nov 2022 08:21:48 -0800 Subject: [PATCH 101/453] [15/N] Add allreduce_coalesced custom op with CPU/CUDA implementations (#88846) Differential Revision: [D41227740](https://our.internmc.facebook.com/intern/diff/D41227740) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88846 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_common.py | 15 +++++++++++ test/distributed/test_c10d_gloo.py | 4 +++ test/distributed/test_c10d_nccl.py | 5 ++++ torch/csrc/distributed/c10d/Ops.cpp | 36 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 5 ++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 34 +++++++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 6 ++--- 7 files changed, 102 insertions(+), 3 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index cf46f89b353c..77ee7487a0af 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1503,6 +1503,21 @@ def _test_collectives(self, backend): with self.subTest(collective=collective, args=args): self._call_collective_with_varying_tensors(backend, collective, *args) + def _test_allreduce_coalesced(self, backend): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) + # TODO: this will be updated in the future to not be backend specific + device = "cuda" if backend == "nccl" else "cpu" + tensors = [torch.ones(10, 10, device=torch.device(device))] + dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM) + for tensor in tensors: + self.assertEqual(tensor, torch.ones(10, 10) * self.world_size) + class CompilerTest(MultiProcessTestCase): def setUp(self): super(CompilerTest, self).setUp() diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index e0c7c64f7b83..ba214a02696f 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2363,6 +2363,10 @@ class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="gloo") + @requires_gloo() + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="gloo") + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5d412dd3fb1b..b3790b082ed5 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2953,6 +2953,11 @@ class NcclProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="nccl") + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ea77bb337b4a..15e186fe3d22 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -40,6 +40,19 @@ std::tuple, c10::intrusive_ptr> allreduce_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + c10::intrusive_ptr reduce_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -177,6 +190,10 @@ TORCH_LIBRARY(c10d, m) { m.def( "allreduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_)); + m.def( + "allreduce_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_)); m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); @@ -249,6 +266,25 @@ c10::intrusive_ptr allreduce( opts.timeout.count())); } +c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allreduce_coalesced_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + + return op.call( + tensors, + process_group, + c10::make_intrusive(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index adc64066a885..8ef78126e5b9 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -21,6 +21,11 @@ TORCH_API c10::intrusive_ptr allreduce( at::TensorList tensors, const AllreduceOptions& opts = {}); +TORCH_API c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts = {}); + TORCH_API c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 03ec6892857e..94f5febec14d 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -149,6 +149,32 @@ std::tuple, c10::intrusive_ptr> allreduce_cuda_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_cpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + +c10::intrusive_ptr allreduce_coalesced_cuda_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + std::tuple>, c10::intrusive_ptr> allgather_cpu_( const std::vector>& output_tensors, @@ -367,6 +393,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allreduce_", allreduce_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("allgather_", allgather_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6515a3d9a87d..673f481d6025 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1134,10 +1134,10 @@ that adds a prefix to each key inserted to the store. .def( "allreduce_coalesced", - [](::c10d::ProcessGroup& self, - std::vector& xs, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& xs, ::c10d::AllreduceCoalescedOptions opts) { - return self.allreduce_coalesced(xs, opts); + return ::c10d::ops::allreduce_coalesced(self, xs, opts); }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), From 4270bb37dacf7e3b2b784fa4ff4002ee6bf87e56 Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Sat, 12 Nov 2022 00:41:57 +0100 Subject: [PATCH 102/453] [primTorch] Improve `narrow` and `narrow_copy`: refs, tests, docs (#87045) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87045 Approved by: https://github.com/mruberry --- aten/src/ATen/native/TensorShape.cpp | 13 +- test/test_meta.py | 1 - torch/_refs/__init__.py | 38 +++- torch/_tensor_docs.py | 13 +- torch/_torch_docs.py | 27 +-- .../_internal/common_methods_invocations.py | 163 ++++++++++++++---- 6 files changed, 188 insertions(+), 67 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index deb9b949aa5d..e8c87a2f1f5c 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1196,6 +1196,8 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ auto output = at::empty_like(self); return narrow_copy_dense_cpu_out(self, dim, start, length, output); @@ -1205,9 +1207,10 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ int64_t allDim = self.dim(); int64_t end = start+length; TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); TORCH_CHECK(dim >= 0 && dim < allDim, "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, "."); - TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim), + TORCH_CHECK(start >= 0 && end <= self.size(dim), "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").") Tensor indices = self._indices(); int64_t sparse_dim = self.sparse_dim(); @@ -1235,6 +1238,8 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor& narrow_copy_dense_cpu_out( const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output ) { @@ -1318,22 +1323,24 @@ Tensor& narrow_copy_dense_cpu_out( Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice(self, dim, start, start + length, 1); } Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.sym_size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice_symint(self, dim, start, start + length, 1); } diff --git a/test/test_meta.py b/test/test_meta.py index ef25d184c842..ae248a90cffb 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -745,7 +745,6 @@ def run_meta_crossref( } meta_function_device_skips['cpu'] = { - torch.narrow_copy: {b8, bf16, c128, c32, c64, f16, f32, f64, i16, i32, i64, i8, u8}, torch.native_batch_norm: {f32, f64}, } diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 43b0c74192de..70edbff2237f 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2750,19 +2750,39 @@ def flipud(a: TensorLikeType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp -def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + check(length >= 0, lambda: "narrow(): length must be non-negative.") dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + # Start being the end is usually invalid since it's out of bounds. So it's + # not allowed by canonicalize_dim. But for narrow it's valid as long as + # the length is 0, which is handled by the check below. + if start != dim_length: + # Negative start means indexing from the end of dim. + # Note: a dimension isn't being canonicalized here, this reuses + # canonicalize_dim because the semantics are similar. + start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] + check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) return prims.slice_in_dim(a, start, start + length, axis=dim) -@register_decomposition(torch.ops.aten.narrow_copy) -@out_wrapper() -def narrow_copy(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: - # TODO: This must return a sparse tensor if the input is sparse, but refs - # have no sparse support. See narrow_copy_sparse in core. - if a.is_sparse: - raise NotImplementedError("narrow_copy ref doesn't support sparse tensors") - return torch.clone(torch.narrow(a=a, dim=dim, start=start, length=length)) # type: ignore[call-overload] +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) def _normalize( diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 8c734a1f3774..726ae5137e6a 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3436,18 +3436,7 @@ def callable(a, b) -> number r""" narrow(dimension, start, length) -> Tensor -See :func:`torch.narrow` - -Example:: - - >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> x.narrow(0, 0, 2) - tensor([[ 1, 2, 3], - [ 4, 5, 6]]) - >>> x.narrow(1, 1, 2) - tensor([[ 2, 3], - [ 5, 6], - [ 8, 9]]) +See :func:`torch.narrow`. """, ) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 40375bae3e27..2ff2e9be315d 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7980,8 +7980,10 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (Tensor or int): the starting dimension - length (int): the distance to the ending dimension + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive Example:: @@ -7993,6 +7995,10 @@ def merge_dicts(*dicts): tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) """, ) @@ -8008,8 +8014,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (int): the starting offset - length (int): the distance to the ending dimension + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive Keyword args: {out} @@ -8027,13 +8034,13 @@ def merge_dicts(*dicts): >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) >>> torch.narrow_copy(s, 0, 0, 1) tensor(indices=tensor([[0, 0], - [0, 1]]), - values=tensor([[[0, 1], - [2, 3]], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], - [[4, 5], - [6, 7]]]), - size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) .. seealso:: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5178ec978bd1..8ab1ea8a047c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4391,29 +4391,127 @@ def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(shape), rep_dim) -def sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): shapes_and_args = ( - ((S, S, S), (1, 2, 2)), - ((S, S, S), (-1, 2, 2)), - ((S, S, S), (1, 0, 0)), - ((S, S, S), (-1, 0, 0)), - ((S, S, S), (2, 1, 2)), + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), ) - for shape, args in shapes_and_args: + for shape, dim, start, length in shapes_and_args: tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) - yield SampleInput(tensor, args=args) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) -def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs): - ''' - sample_inputs_narrow accepts the same inputs as narrow_copy, in addition - narrow also accepts `start` argument to be a Tensor. - ''' - for sample in sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): - yield sample - yield SampleInput(sample.input, args=(sample.args[0], torch.tensor(sample.args[1]), sample.args[2])) + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=RuntimeError, + error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): @@ -12407,7 +12505,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - sample_inputs_func=sample_inputs_narrow, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), skips=( # Use of .item() DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), @@ -12423,15 +12523,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=False, supports_autograd=False, # https://github.com/pytorch/pytorch/issues/86931 - sample_inputs_func=sample_inputs_narrow_copy, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), skips=( # https://github.com/pytorch/pytorch/issues/84577 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - # Not implemented - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta', device_type='cuda'), + # Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), )), UnaryUfuncInfo('neg', aliases=('negative', ), @@ -18061,22 +18162,20 @@ def reference_flatten(input, start_dim=0, end_dim=-1): "_refs.narrow", torch_opinfo_name="narrow", supports_nvfuser=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), - ) - ), - PythonRefInfo( - "_refs.nn.functional.group_norm", - torch_opinfo_name="nn.functional.group_norm", - supports_nvfuser=False, - validate_view_consistency=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), ), PythonRefInfo( "_refs.narrow_copy", torch_opinfo_name="narrow_copy", supports_out=True, supports_nvfuser=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + supports_nvfuser=False, + validate_view_consistency=False, ), PythonRefInfo( "_refs.native_layer_norm", From 27dc03e09b6b1948e416a9fd78e6ca2b0a0bb1c7 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Fri, 11 Nov 2022 11:51:22 -0500 Subject: [PATCH 103/453] Turn internal assert when saved tensor is detached inplace into torch check (#88860) Fixes https://github.com/pytorch/pytorch/issues/88809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88860 Approved by: https://github.com/albanD --- test/test_autograd.py | 14 ++++++++++++++ torch/csrc/autograd/saved_variable.cpp | 11 ++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index e08047860e42..33cf188af065 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6776,6 +6776,20 @@ def inplace_double(x): # not leaf, not output test(lambda: (1 + torch.randn(5, requires_grad=True)), False) + def test_saved_variable_saved_original_inplace_detach(self): + # Detaching a tensor that is saved input raises + a = torch.tensor(1., requires_grad=True).clone() + b = a.sin() + a.detach_() + with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"): + b.backward() + + # Detaching a tensor that is saved as output is OK + a = torch.tensor(1., requires_grad=True).clone() + b = a.exp() + a.detach_() + b.backward() + def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks # The saved_original / did_not_save_original distinction corresponds to the `save_original` diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a2e0f05b6394..d438205e8947 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { : grad_fn_; if (!is_leaf_ && !grad_fn) { - TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); + // This issue was introduced when we added logic to save the original + // because now we rely on data_.grad_fn(), but can be unreliable if the + // autograd_meta of that saved tensor is cleared with an in-place detach. + // As a simple fix, we choose to disallow that behavior here even though + // it makes behavior inconsistent depending on whether you are saving + // input or output. + TORCH_CHECK( + saved_for, + "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." + "This is not supported, please use out-of-place `.detach()` instead"); grad_fn = std::move(saved_for); } From 3765621356c645ead1d712c5b7e4d57d6803cc81 Mon Sep 17 00:00:00 2001 From: ydwu4 Date: Sat, 12 Nov 2022 20:00:51 +0000 Subject: [PATCH 104/453] torchdynamo support self.modules() for nn_module (#88695) This PR allows models to call self.modules() during dynamo tracing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88695 Approved by: https://github.com/voznesenskym --- test/dynamo/test_repros.py | 20 ++++++++++++++++++++ torch/_dynamo/guards.py | 2 +- torch/_dynamo/variables/nn_module.py | 2 ++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2103e075fffc..913d59322ac7 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1792,6 +1792,26 @@ def fn(x): res = opt_fn(a) self.assertTrue(same(ref, res)) + def test_modules(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 3) + + def forward(self, inp): + res = torch.zeros(3, 3) + for mod in self.modules(): + res += self.fc(inp) + return res + + mod = Foo() + args = (torch.ones(3, 4),) + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod) + self.assertTrue(same(mod(*args), opt_mod(*args))) + self.assertEqual(cnt.op_count, 5) + self.assertEqual(cnt.frame_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 382734412b2b..d4903964aac6 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -543,7 +543,7 @@ def __init__(self, expr_to_tensor_ref, id_to_name_map): self.id_to_name_map = id_to_name_map def _print_Symbol(self, expr) -> str: - assert isinstance(expr, sympy.core.symbol.Symbol) + assert isinstance(expr, sympy.Symbol) if expr == 0: return "0" if expr == 1: diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 6f7c2ff28737..1922980fc957 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -337,6 +337,8 @@ def named_embed(name, obj): ): result.append(named_embed(name, submod)) return ListIteratorVariable(result, mutable_local=MutableLocal(), **options) + elif name == "modules": + return wrap_values(module.named_modules()) elif name == "parameters": return wrap_values(module.named_parameters(**get_kwargs("recurse"))) elif name == "values": From df1df9d10a7a2f4d7b1327fa85d0bb5fb6e9b693 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 11 Nov 2022 11:44:00 -0800 Subject: [PATCH 105/453] [16/N] Add _allgather_base custom op with CPU/CUDA implementation (#88889) Differential Revision: [D41227739](https://our.internmc.facebook.com/intern/diff/D41227739) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88889 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 17 +++++++++++++++++ torch/csrc/distributed/c10d/Ops.cpp | 25 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 6 ++++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 22 ++++++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 8 +++++++- 5 files changed, 77 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index b3790b082ed5..c514ea4ab31f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2958,6 +2958,23 @@ def test_collectives(self): def test_allreduce_coalesced(self): self._test_allreduce_coalesced(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allgather_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 15e186fe3d22..f825afca2a1d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -88,6 +88,13 @@ allgather_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_( const std::vector& output_tensors, const std::vector>& input_tensors, @@ -197,6 +204,9 @@ TORCH_LIBRARY(c10d, m) { m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); + m.def( + "_allgather_base_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_)); m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); @@ -303,6 +313,21 @@ c10::intrusive_ptr allgather( output_tensors, input_tensors, process_group, opts.timeout.count())); } +c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_allgather_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_tensor, input_tensor, process_group); +} + c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index 8ef78126e5b9..72f09e341d7d 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -32,6 +32,12 @@ TORCH_API c10::intrusive_ptr allgather( const std::vector& input_tensors, const AllgatherOptions& opts = {}); +TORCH_API c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 94f5febec14d..78e26c9656d8 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -211,6 +211,20 @@ allgather_cuda_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr _allgather_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_cpu_( const std::vector& output_tensors, @@ -409,6 +423,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allgather_", allgather_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_allgather_base_", _allgather_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_allgather_base_", _allgather_base_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_scatter_", reduce_scatter_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 673f481d6025..2424506eef0f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1187,7 +1187,13 @@ that adds a prefix to each key inserted to the store. .def( "_allgather_base", - &::c10d::ProcessGroup::_allgather_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::_allgather_base( + self, output_tensor, input_tensor, opts); + }, py::arg("output"), py::arg("input"), py::arg("opts") = ::c10d::AllgatherOptions(), From 2aca97cc9ae7081f00ebc7d58367c443cd4528cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Sun, 13 Nov 2022 00:31:11 +0000 Subject: [PATCH 106/453] Vectorized CPU code implementing left shift operator. (#88607) This PR adds vectorized implementation for CPU version of left shift operator. All of the tests run by `pytest test/test_ops.py -vk left_shift` pass. Here are some additional details:
Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now ```python import torch from torch import Tensor from torch.utils.benchmark import Timer, Compare from itertools import product from functools import partial # These functions exist, because torch.jit.script does not support `torch.iinfo` def _num_value_bits(dtype): if dtype == torch.uint8: return 8 else: # torch.int32 return 31 def _max_value(dtype): if dtype == torch.uint8: return 255 else: # torch.int32 return 2147483647 def bitshift(image, dtype): num_value_bits_input = _num_value_bits(image.dtype) num_value_bits_output = _num_value_bits(dtype) return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) def mul(image, dtype): input_max = float(_max_value(image.dtype)) output_max = float(_max_value(dtype)) factor = int((output_max + 1) // (input_max + 1)) image = image.to(dtype) return image * factor size = 256 image = torch.randint(0, 256, (3, size, size), dtype=torch.uint8) dtype = torch.int32 def gen_inputs(): devices = ("cpu",) fns = (mul, bitshift) threads = (1,) for device, fn, threads in product(devices, fns, threads): yield f"Bitshift {device} {image.dtype}", str(tuple(image.shape)), threads, fn, image, dtype def benchmark(label, sub_label, threads, f, *args, **kwargs): return Timer("f(*args, **kwargs)", globals=locals(), label=label, description=f.__name__, sub_label=sub_label, num_threads=threads).blocked_autorange() results = [] for args in gen_inputs(): results.append(benchmark(*args)) compare = Compare(results) compare.trim_significant_figures() compare.print() ```
Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy) ```python import numpy as np import torch # Testing shifting of non-negative numbers only, but will test all # possible RHS shift values for given type. For int8 and int16, we'll # test shifting all of non-negative values represntable by type. For # the rest of data types, we'll test shifting some random numbers in # the corresponding range. def _create_inputs(dtype): info = torch.iinfo(dtype) if dtype == torch.int8 or dtype == torch.int16: ntests = info.max + 1 x = torch.arange(info.max + 1, dtype=dtype, device="cpu", requires_grad=False) else: ntests = 100000 x = torch.randint(info.max + 1 if dtype != torch.int64 else info.max, (ntests,), dtype=dtype, device="cpu", requires_grad=False) y = torch.tensor(range(info.bits), dtype=dtype, device="cpu", requires_grad=False) xy = torch.cartesian_prod(x, y) return (xy[:, 0], xy[:, 1]) torch.manual_seed(0) # Perform testing for each datatype supported, and compare results # with ones generated by numpy. for dtype in (torch.int8, torch.int16, torch.int32, torch.int64): (x, y) = _create_inputs(dtype) z = x << y xnp = x.numpy() ynp = y.numpy() znp = z.numpy() assert((znp == (xnp << ynp)).all()) ```
Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift) ```python import torch import pickle import itertools from torch.utils.benchmark import Timer, Compare torch.manual_seed(0) # Edit this part if needed. lengths = [1024, 4096, 16384, 65536] rhss = [1, 2, 7, 8, 15, 16, 31, 32, 63, 64] benchmark_name = "lshift" label = "" dtypes = [torch.int8, torch.int16, torch.int32, torch.int64] results = [] # Create an argument pair for testing. Argument are tensors of given # datatype and length, LHS for each shift operation is a random # number, and RHS is given value that is same for all of them. def _make_args(dtype, length, rhs): info = torch.iinfo(dtype) imax = info.max return (torch.randint(info.max, (length,), dtype=dtype, device="cpu", requires_grad=False), rhs * torch.ones((length,), dtype=dtype, device="cpu", requires_grad=False)) # Run shift operation for vectors of given lenghts and for given # number of bits to be shifted, and remember timings. for dtype, length, rhs in itertools.product(dtypes, lengths, rhss): x, y = _make_args(dtype, length, rhs) timer = Timer("x << y", globals=globals(), label=benchmark_name, description=label, sub_label=f"dtype={dtype},length={length}", num_threads=1) results.append(timer.blocked_autorange()) # Gather results. compare = Compare(results) compare.trim_significant_figures() compare.print() # Print results. with open("{}.pickle".format(label), "wb") as f: pickle.dump(results, f) ```
Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below) ``` [------------------- lshift -------------------------------] | master | mybranch 1 threads: ------------------------------------------------ dtype=torch.int8,length=1024 | 3 | 3 dtype=torch.int8,length=4096 | 5 | 3 dtype=torch.int8,length=16384 | 14 | 5 dtype=torch.int8,length=65536 | 51 | 15 dtype=torch.int16,length=1024 | 3 | 3 dtype=torch.int16,length=4096 | 4 | 3 dtype=torch.int16,length=16384 | 11 | 5 dtype=torch.int16,length=65536 | 39 | 13 dtype=torch.int32,length=1024 | 3 | 2 dtype=torch.int32,length=4096 | 4 | 3 dtype=torch.int32,length=16384 | 10 | 4 dtype=torch.int32,length=65536 | 35 | 12 dtype=torch.int64,length=1024 | 3 | 3 dtype=torch.int64,length=4096 | 4 | 3 dtype=torch.int64,length=16384 | 11 | 6 dtype=torch.int64,length=65536 | 36 | 20 Times are in microseconds (us). ```
All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only. For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), [Compiler Explorer](https://godbolt.org/) is used, with GCC trunk and `-mavx512f -mavx512bw` flags added. Here are further details:
C program used for basic validation of AVX-512 vectorized version for 8-bit operands ``` #include #include #include #include static void print_m512i_int8(const __m512i* x) { int8_t val[64]; memcpy(val, x, sizeof(val)); for (int i = 0; i < 64; ++i) { if (i > 0) printf(", "); printf("%d", (int)val[i]); } printf("\n"); } int main() { __m512i a = _mm512_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1); __m512i b = _mm512_set_epi8(7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0); // ------- Copied code from vec512_int.h // Mask used to set upper 8 bits of each 16-bit value to 0, and keep // lower 8 bits. __m512i mask = _mm512_set_epi16(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff); // Convert 8-bit operands from lower lanes to 16-bit values, and // perform vectorized shift. Make sure that upper 8 bits of 16-bit // results are all 0. __m256i a_lo_8 = _mm512_extracti64x4_epi64(a, 0); __m256i b_lo_8 = _mm512_extracti64x4_epi64(b, 0); __m512i a_lo_16 = _mm512_cvtepi8_epi16(a_lo_8); __m512i b_lo_16 = _mm512_cvtepi8_epi16(b_lo_8); __m512i c_lo_16 = _mm512_and_si512(_mm512_sllv_epi16(a_lo_16, b_lo_16), mask); // Convert 8-bit operands from upper lanes to 16-bit values, and // perform vectorized shift. Make sure that upper 8 bits of 16-bit // results are all 0. __m256i a_hi_8 = _mm512_extracti64x4_epi64(a, 1); __m256i b_hi_8 = _mm512_extracti64x4_epi64(b, 1); __m512i a_hi_16 = _mm512_cvtepi8_epi16(a_hi_8); __m512i b_hi_16 = _mm512_cvtepi8_epi16(b_hi_8); __m512i c_hi_16 = _mm512_and_si512(_mm512_sllv_epi16(a_hi_16, b_hi_16), mask); // Cast 16-bit results back into 8-bit values and merge them // together (using unsigned saturation with higher 8 bits set to 0 // above ensures that results are correct). Values are merged per // lanes, so this is not yet the final result. __m512i c_perm = _mm512_packus_epi16(c_lo_16, c_hi_16); // Permute values so that final result is produced. __m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0); __m512i c = _mm512_permutexvar_epi64(idx, c_perm); // ------- End copied print_m512i_int8(&c); // Expected output: 1(x8), 2(x8), 4(x8), 8(x8), 16(x8), 32(x8), 64(x8), 128(x8), -128(x8) return 0; } ```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88607 Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10 --- aten/src/ATen/cpu/vec/vec256/vec256_int.h | 195 +++++++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_int.h | 93 +++++++++ aten/src/ATen/cpu/vec/vec_base.h | 13 ++ aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 11 +- 4 files changed, 308 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 0cc36d590019..7737f4a0037c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -1133,6 +1133,201 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +template +Vectorized inline shift_256_16(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80, + 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80, + 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80, + 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26, + 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18, + 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10, + 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template +Vectorized inline shift_256_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80, + 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80, + 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80, + 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25, + 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17, + 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9, + 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1); + __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80, + 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80, + 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80, + 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80); + __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26, + 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18, + 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10, + 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2); + __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80, + 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80, + 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80, + 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80); + __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, + 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, + 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, + 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3); + __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, + 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, + 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, + 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80); + __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80, + 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80, + 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80, + 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Peform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Peform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Peform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index c2cbc0b1d7f9..590c3254e379 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -1163,6 +1163,99 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +template +Vectorized inline shift_512_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t, so emulating it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8(62, 0x80, 60, 0x80, 58, 0x80, 56, 0x80, + 54, 0x80, 52, 0x80, 50, 0x80, 48, 0x80, + 46, 0x80, 44, 0x80, 42, 0x80, 40, 0x80, + 38, 0x80, 36, 0x80, 34, 0x80, 32, 0x80, + 30, 0x80, 28, 0x80, 26, 0x80, 24, 0x80, + 22, 0x80, 20, 0x80, 18, 0x80, 16, 0x80, + 14, 0x80, 12, 0x80, 10, 0x80, 8, 0x80, + 6, 0x80, 4, 0x80, 2, 0x80, 0, 0x80); + __m512i ctl_1_0 = _mm512_set_epi8(0x80, 63, 0x80, 61, 0x80, 59, 0x80, 57, + 0x80, 55, 0x80, 53, 0x80, 51, 0x80, 49, + 0x80, 47, 0x80, 45, 0x80, 43, 0x80, 41, + 0x80, 39, 0x80, 37, 0x80, 35, 0x80, 33, + 0x80, 31, 0x80, 29, 0x80, 27, 0x80, 25, + 0x80, 23, 0x80, 21, 0x80, 19, 0x80, 17, + 0x80, 15, 0x80, 13, 0x80, 11, 0x80, 9, + 0x80, 7, 0x80, 5, 0x80, 3, 0x80, 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index b9b3745e99d5..f045437ac368 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -799,6 +799,13 @@ inline Vectorized operator~(const Vectorized& a) { return a ^ ones; } +template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] << b[i]; + } + return c; +} template inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { @@ -826,6 +833,12 @@ inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { return a; } +template +inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + template inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index a5dde16024ab..c2497a6949f1 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -314,10 +314,13 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { void lshift_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return static_cast>(a) << b; - }); + cpu_kernel_vec(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return static_cast>(a) << b; + }, + [](Vectorized a, Vectorized b) { + return a << b; + }); }); } From 46796fe5e9b74602d45927304773fdcda1c3215a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 12 Nov 2022 06:19:02 -0800 Subject: [PATCH 107/453] Fix XLA symbolic shapes binding (#88928) Obsoletes https://github.com/pytorch/pytorch/pull/88772 Mostly revolves around NOT assuming that the inside is a SymNode, but instead duck-typed to be a SymNode. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88928 Approved by: https://github.com/SherlockNoMad --- c10/core/SymNodeImpl.h | 3 - test/test_dynamic_shapes.py | 6 +- torch/__init__.py | 2 - torch/csrc/jit/python/init.cpp | 77 ++++++++++++++++-------- torch/csrc/utils/pybind.cpp | 14 ++++- torch/csrc/utils/python_symnode.h | 4 -- torch/fx/experimental/symbolic_shapes.py | 39 ++++++------ 7 files changed, 85 insertions(+), 60 deletions(-) diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index d2f3aafaad8b..fcec452821d7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -85,9 +85,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode clone() { TORCH_CHECK(false, "NYI"); }; - virtual SymNode sym_int() { - TORCH_CHECK(false, "NYI"); - } virtual SymNode sym_float() { TORCH_CHECK(false, "NYI"); } diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0f1f49d2e6ea..3a8e31151bf3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -20,7 +20,7 @@ from torch.utils._pytree import tree_map from torch.fx.experimental import symbolic_shapes from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int +from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node from torch.utils._python_dispatch import TorchDispatchMode from torch import SymInt @@ -478,9 +478,9 @@ def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): def get_sym_inp(inp): if isinstance(inp, int): - return torch.SymInt(seed_node.to_node(inp)) + return torch.SymInt(to_node(seed_node, inp)) else: - return torch.SymFloat(seed_node.to_node(inp)) + return torch.SymFloat(to_node(seed_node, inp)) def maybe_xfail(inp1, inp2): key = (fn, type(inp1).__name__, type(inp2).__name__) diff --git a/torch/__init__.py b/torch/__init__.py index 19be59282cca..6def80d1dc59 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -204,8 +204,6 @@ class SymInt: """ def __init__(self, node): - from torch.fx.experimental.symbolic_shapes import SymNode - assert isinstance(node, SymNode) # This field MUST be named node; C++ binding code assumes that this # class has a field named node that stores SymNode self.node = node diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a72a8a2c1150..7ee48635cdff 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1148,38 +1148,65 @@ void initJITBindings(PyObject* module) { // NB: This isn't actually used for regular PyTorch symbolic tracing; // XLA is what needs this #define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); }) -#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); }) #define SYMNODE_BINARY(n) \ .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); }) auto symnode_class = py::class_(m, "_SymNode") + // clang-format off // These DO NOT install magic methods; the SymInt/SymFloat wrapper in // Python is responsible for this SYMNODE_UNARY(clone) - // Named these for consistency with inner python class, but maybe - // should change the python side - SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_) - SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2( - __sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub) - SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow) - SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY( - eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt) - SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min) - SYMNODE_BINARY(max) SYMNODE_UNARY(ceil) - SYMNODE_UNARY(floor) SYMNODE_UNARY(neg) - // Intentionally don't set file line, as the - // Python backtrace matters more here - .def( - "guard_int", - [](c10::SymNode a) { - return a->guard_int(nullptr, 0); - }) - .def( - "__str__", - [](c10::SymNode a) { return a->str(); }) - .def("__repr__", [](c10::SymNode a) { - return a->str(); - }); + SYMNODE_UNARY(is_int) + SYMNODE_UNARY(is_float) + SYMNODE_UNARY(bool_) + SYMNODE_UNARY(int_) + SYMNODE_UNARY(sym_float) + SYMNODE_BINARY(add) + SYMNODE_BINARY(sub) + SYMNODE_BINARY(mul) + SYMNODE_BINARY(truediv) + SYMNODE_BINARY(pow) + SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(mod) + SYMNODE_BINARY(eq) + SYMNODE_BINARY(gt) + SYMNODE_BINARY(lt) + SYMNODE_BINARY(le) + SYMNODE_BINARY(ge) + SYMNODE_BINARY(min) + SYMNODE_BINARY(max) + SYMNODE_UNARY(ceil) + SYMNODE_UNARY(floor) + SYMNODE_UNARY(neg) + // Intentionally don't set file line, as the + // Python backtrace matters more here + .def( + "guard_int", + [](c10::SymNode a) { + return a->guard_int(nullptr, 0); + }) + .def( + "guard_float", + [](c10::SymNode a) { + return a->guard_float(nullptr, 0); + }) + .def( + "wrap_int", + [](c10::SymNode a, int64_t b) { + return a->wrap_int(b); + }) + .def( + "wrap_float", + [](c10::SymNode a, double b) { + return a->wrap_float(b); + }) + .def( + "__str__", + [](c10::SymNode a) { return a->str(); }) + .def("__repr__", [](c10::SymNode a) { + return a->str(); + }); + // clang-format on // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") diff --git a/torch/csrc/utils/pybind.cpp b/torch/csrc/utils/pybind.cpp index 37e37a873774..4cd148fdfa91 100644 --- a/torch/csrc/utils/pybind.cpp +++ b/torch/csrc/utils/pybind.cpp @@ -25,11 +25,19 @@ py::handle type_caster::cast( return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { - // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); - TORCH_INTERNAL_ASSERT(py_node); - return torch::get_symint_class()(py_node->getPyObj()).release(); + if (py_node) { + // Return the Python directly (unwrap) + return torch::get_symint_class()(py_node->getPyObj()).release(); + } else { + // Wrap the C++ into Python + auto inner = py::cast(si.toSymNodeImpl()); + if (!inner) { + throw python_error(); + } + return torch::get_symint_class()(inner).release(); + } } else { return py::cast(si.as_int_unchecked()).release(); } diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index be402e4d5439..3a9fa79d37d6 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -164,10 +164,6 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__FUNCTION__); } - c10::SymNode sym_int() override { - return dispatch_common_(__FUNCTION__); - } - c10::SymNode sym_float() override { return dispatch_common_(__FUNCTION__); } diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index d9b0a8fc2019..9b55af3c555c 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -126,6 +126,18 @@ def sym_int(a): return sym_floor(a) if a > 0 else sym_ceil(a) return int(a) +def to_node(self, num): + if isinstance(num, (SymInt, SymFloat)): + return num.node + elif isinstance(num, int): + return self.wrap_int(num) + elif isinstance(num, float): + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + # TODO: An incomplete list # 1. Set variables to be equal when we do equality # 2. Specialize on 0/1 when we do subtraction @@ -148,18 +160,6 @@ def expr(self): def _update_expr(self): self._expr = self.shape_env.replace(self._expr) - def to_node(self, num): - if isinstance(num, (SymInt, SymFloat)): - return num.node - elif isinstance(num, int): - return self.wrap_int(num) - elif isinstance(num, float): - return self.wrap_float(num) - else: - # NotImplemented is important so that Python tries the - # other magic method - return NotImplemented - def is_int(self): return self.pytype is int @@ -297,16 +297,15 @@ def _nyi(): always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"} def wrap_node(x): - if not isinstance(x, SymNode): - return x - if x.constant is not None: + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: return x.constant - if x.pytype is int: + if x.is_int(): return SymInt(x) - elif x.pytype is float: + elif x.is_float(): return SymFloat(x) else: - raise AssertionError(f"unrecognized return type {x.pytype}") + raise AssertionError(f"unrecognized return type {x}") def _make_node_magic(method, func): func = lru_cache(256)(func) @@ -378,13 +377,13 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method)()) def binary_magic_impl(self, other): - other_node = self.node.to_node(other) + other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented return wrap_node(getattr(self.node, method)(other_node)) def rbinary_magic_impl(self, other): - other_node = self.node.to_node(other) + other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented return wrap_node(getattr(other_node, method)(self.node)) From 4f2639e56ad5b26d2f5383dcc14e0f91c250d355 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Sat, 12 Nov 2022 20:27:00 +0000 Subject: [PATCH 108/453] [FSDP] Fix `FSDP.clip_grad_norm_()` for `NO_SHARD` (#88955) This PR fixes `FSDP.clip_grad_norm_()` for `NO_SHARD`, which previously "double-counted" each gradient `world_size`-many times. This does not address any discrepancies between `FULL_SHARD` and DDP. (Note that the unit tests do show parity between `FULL_SHARD` and DDP when using `FSDP.clip_grad_norm_()` and `nn.utils.clip_grad_norm_()` respectively on one iteration.) The added unit test code path tests mixing nested FSDP instances with both `FULL_SHARD` and `NO_SHARD` to ensure that the `local_sharded_norm` and `local_nonsharded_norm` computations are interoperating correctly. I want to test non-FSDP root instance in the future, but this is BC breaking since we need to make `clip_grad_norm_()` a static method, which would require a different method call syntax (`FSDP.clip_grad_norm_(root_module, ...)` vs. `root_module.clip_grad_norm_(...)`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/88955 Approved by: https://github.com/zhaojuanmao --- .../fsdp/test_fsdp_clip_grad_norm.py | 74 ++++++++++++++----- .../fsdp/fully_sharded_data_parallel.py | 36 +++++++-- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index e587065c5c77..1a742da889ac 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from torch import distributed as dist +from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, @@ -42,10 +43,6 @@ class TestClipGradNorm(FSDPTest): """Tests :meth:`FullyShardedDataParallel.clip_grad_norm_`.""" - @property - def world_size(self) -> int: - return 2 - @skip_if_lt_x_gpu(2) def test_non_root(self): """ @@ -80,6 +77,11 @@ def test_ddp_parity(self): { "max_norm": [1, 2.5], "norm_type": [1, 2, float("inf")], + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + ShardingStrategy.NO_SHARD, + "mixed_strategy", + ], "use_orig_params": [False, True], "offload_params": [False, True], }, @@ -90,8 +92,9 @@ def _test_ddp_parity( self, max_norm: Union[float, int], norm_type: Union[float, int], - offload_params: bool, + sharding_strategy: Union[ShardingStrategy, str], use_orig_params: bool, + offload_params: bool, ): local_model = TransformerWithSharedParams.init( self.process_group, @@ -101,22 +104,52 @@ def _test_ddp_parity( ) ddp_model = DDP(local_model, device_ids=[self.rank]) fsdp_kwargs = { - "auto_wrap_policy": ModuleWrapPolicy( - { - TransformerEncoderLayer, - TransformerDecoderLayer, - } - ), "cpu_offload": CPUOffload(offload_params=offload_params), "use_orig_params": use_orig_params, } - fsdp_model = TransformerWithSharedParams.init( - self.process_group, - FSDPInitMode.RECURSIVE, - CUDAInitMode.CUDA_BEFORE, - deterministic=True, - fsdp_kwargs=fsdp_kwargs, - ) + if sharding_strategy == "mixed_strategy": + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) + # Apply `NO_SHARD` to the encoder + fsdp_model.transformer.encoder = FSDP( + fsdp_model.transformer.encoder, + sharding_strategy=ShardingStrategy.NO_SHARD, + **fsdp_kwargs, + ) + # Apply `FULL_SHARD` to the decoder + fsdp_model.transformer.decoder = FSDP( + fsdp_model.transformer.decoder, + sharding_strategy=ShardingStrategy.FULL_SHARD, + **fsdp_kwargs, + ) + # TODO: FSDP's `clip_grad_norm_()` is not a static method, so we + # must make the root module an FSDP instance + fsdp_model = FSDP( + fsdp_model, sharding_strategy=ShardingStrategy.FULL_SHARD, **fsdp_kwargs + ) + else: + fsdp_kwargs.update( + { + "sharding_strategy": sharding_strategy, + "auto_wrap_policy": ModuleWrapPolicy( + { + TransformerEncoderLayer, + TransformerDecoderLayer, + } + ), + } + ) + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + fsdp_kwargs=fsdp_kwargs, + ) LR = 1e-2 ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR) @@ -125,7 +158,10 @@ def _test_ddp_parity( inp = ddp_model.module.get_input(device) for model in (ddp_model, fsdp_model): out = model(*inp) - loss = model.module.get_loss(inp, out) + if isinstance(model, (DDP, FSDP)): + loss = model.module.get_loss(inp, out) + else: + loss = model.get_loss(inp, out) loss.backward() # Multiply gradients by a large factor to ensure that gradients will diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 69c8dd92ed8d..3e84315a4e11 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1161,23 +1161,45 @@ def clip_grad_norm_( self._streams["unshard"], self._streams["pre_unshard"], ) - max_norm = float(max_norm) norm_type = float(norm_type) - # Compute the local gradient norm (only including this rank's shard - # of the gradients) - local_norm = _get_grad_norm(self.parameters(), norm_type).to( + # Perform local gradient norm computation, where sharded and + # non-sharded parameters must be handled separately + sharded_params = set() + nonsharded_params = set() # `NO_SHARD` or not FSDP-managed + for handle in FullyShardedDataParallel._fsdp_handles(self): + target_set = ( + sharded_params if handle.uses_sharded_strategy else nonsharded_params + ) + if handle._use_orig_params: + for param in handle.flat_param._params: + target_set.add(param) + else: + target_set.add(handle.flat_param) + for param in self.parameters(): + not_fsdp_managed = ( + param not in sharded_params and param not in nonsharded_params + ) + if not_fsdp_managed: + nonsharded_params.add(param) + local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to( + self.compute_device + ) + local_nonsharded_norm = _get_grad_norm(nonsharded_params, norm_type).to( self.compute_device ) # Reconstruct the total gradient norm depending on the norm type if norm_type == math.inf: - total_norm = local_norm + total_norm = torch.maximum(local_sharded_norm, local_nonsharded_norm) dist.all_reduce( total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group ) else: - total_norm = local_norm**norm_type + total_norm = local_sharded_norm**norm_type dist.all_reduce(total_norm, group=self.process_group) + # All-reducing the local non-sharded norm would count it an extra + # world-size-many times + total_norm += local_nonsharded_norm**norm_type total_norm = total_norm ** (1.0 / norm_type) if self.cpu_offload.offload_params: total_norm = total_norm.cpu() @@ -1789,7 +1811,7 @@ def register_comm_hook(self, state: object, hook: callable): def _get_grad_norm( - params: List[nn.Parameter], + params: Iterable[nn.Parameter], norm_type: float, ) -> torch.Tensor: """ From 06ce1338bced2d2cb933a383157b335f65a35e71 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Sun, 13 Nov 2022 04:50:21 +0000 Subject: [PATCH 109/453] [dynamo] Port all pytorch/dynamo and test/dynamo pieces over from symbolic-shapes branch (#88768) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88768 Approved by: https://github.com/jansel, https://github.com/ezyang --- functorch/_src/compilers.py | 30 ++ test/distributed/test_dynamo_distributed.py | 2 + test/dynamo/test_dynamic_shapes.py | 109 ++---- test/dynamo/test_export.py | 26 ++ test/dynamo/test_misc.py | 2 + test/dynamo/test_no_fake_tensors.py | 5 - test/dynamo/test_repros.py | 38 +- test/dynamo/test_unspec.py | 2 + test/inductor/test_torchinductor_opinfo.py | 1 + torch/_dynamo/codegen.py | 2 + torch/_dynamo/guards.py | 49 ++- torch/_dynamo/optimizations/analysis.py | 25 +- torch/_dynamo/optimizations/training.py | 6 +- torch/_dynamo/output_graph.py | 56 ++- torch/_dynamo/symbolic_convert.py | 53 ++- torch/_dynamo/utils.py | 135 ++++++- torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/builder.py | 238 +++++++++++- torch/_dynamo/variables/builtin.py | 85 ++++- torch/_dynamo/variables/constant.py | 32 +- torch/_dynamo/variables/lists.py | 75 +++- torch/_dynamo/variables/misc.py | 7 +- torch/_dynamo/variables/nn_module.py | 7 +- torch/_dynamo/variables/tensor.py | 387 ++++---------------- torch/_dynamo/variables/torch.py | 43 ++- torch/_subclasses/fake_tensor.py | 6 +- torch/fx/experimental/symbolic_shapes.py | 1 - 27 files changed, 921 insertions(+), 502 deletions(-) diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py index 3f52fede57eb..55de63e5c344 100644 --- a/functorch/_src/compilers.py +++ b/functorch/_src/compilers.py @@ -19,6 +19,8 @@ draw_graph, min_cut_rematerialization_partition, ) +import torch.utils._pytree as pytree + # These canonicalizations are needed here (and not decompositions), as the ops @@ -113,6 +115,34 @@ def nop(fx_g: fx.GraphModule, _) -> Callable: """ return fx_g +class DebugInterpreter(fx.Interpreter): + def run_node(self, n): + # TODO: This will fail once we start caching in AOTAutograd + # again, because we need to remap SymInts to their new values + # in the presence of dynamism + r = super().run_node(n) + if 'val' in n.meta: + n_vals, n_spec = pytree.tree_flatten(n.meta['val']) + r_vals, r_spec = pytree.tree_flatten(r) + assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + assert nv.size() == rv.size(), f"output {i}: {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"output {i}: {nv.dtype} != {rv.dtype}" + assert torch._prims_common.check_significant_strides(nv, rv), f"output {i}: {nv.stride()} != {rv.stride()}" + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run @make_boxed_compiler def simple_ts_compile(fx_g, _): diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index b6bc16edb941..21550a0120e4 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -258,6 +258,8 @@ def test_fsdp_inductor(self): # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) + # TODO(voz): Flaky on CI failure, consistent failure on local master. + @unittest.skipIf(True, "Flaky on CI failure, consistent failure on local master") def test_hf_bert_fsdp(self): from transformers.models.bert.modeling_bert import BertLayer diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index d82cc6925fe9..294ea9e54952 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -51,22 +51,6 @@ def make_dynamic_cls(cls): ) -# DynamicShapesReproTests -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_eval_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_train_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_issue175_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - unittest.expectedFailure( DynamicShapesReproTests.test_do_paste_mask_dynamic_shapes # aten.min.dim - couldn't find symbolic meta function/decomposition @@ -77,97 +61,66 @@ def make_dynamic_cls(cls): # Could not infer dtype of torch._C.SymIntNode ) -unittest.expectedFailure( - DynamicShapesReproTests.test_ellipsis_dynamic_shapes - # Cannot call sizes() on tensor with symbolic sizes/strides -) - unittest.expectedFailure( DynamicShapesReproTests.test_hf_t5_forward_dynamic_shapes # Cannot call sizes() on tensor with symbolic sizes/strides ) +# DynamicShapesExportTests unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes - # Unable to cast Python instance to C++ type -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_guard_fail_tensor_bool_dynamic_shapes - # RuntimeError: aten.allclose.default - couldn't find symbolic meta function/decomposition + DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes ) - -# DynamicShapesMiscTests unittest.expectedFailure( - DynamicShapesMiscTests.test_unsupported_fake_tensor_dynamic_shapes - # aten.quantize_per_tensor.default - couldn't find symbolic meta function/decomposition + DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes ) unittest.expectedFailure( - DynamicShapesMiscTests.test_module_deepcopy_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decompositio + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes ) - -# DynamicShapesUnspecTests unittest.expectedFailure( - DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes - # float() argument must be a string or a real number, not 'torch._C.SymIntNode' + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes ) -# DynamicShapesNNModuleTests -unittest.expectedFailure( - DynamicShapesNNModuleTests.test_unsupportedmethod_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition -) - +# DynamicShapesSubGraphTests unittest.expectedFailure( - DynamicShapesNNModuleTests.test_unsupportedmodule_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes ) +unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes) +# DynamicShapesUnspecTests +# Missing decomp +# RuntimeError: Failed running call_function +# (*(FakeTensor(FakeTensor(..., device='meta', size=(5, 1, 28, 28)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), False, 0.1, +# FakeTensor(FakeTensor(..., device='meta', size=()), cpu)), **{}): +# aten._local_scalar_dense.default +unittest.expectedFailure(test_unspec.UnspecReproTests.test_batch_norm_act_unspec) + +# SymIntArrayRef expected to contain only concrete integers unittest.expectedFailure( - DynamicShapesNNModuleTests.test_self_mutating1_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes ) +# DynamicShapesReproTests unittest.expectedFailure( - DynamicShapesNNModuleTests.test_call_fn_with_non_const_inputs_safe_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesReproTests.test_reformer_eval_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer ) - -# DynamicShapesExportTests -unittest.expectedFailure( - DynamicShapesExportTests.test_export_compare_optimize_with_make_fx_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes -) unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_stack_trace_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dynamic_shapes + DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes + # Unable to cast Python instance to C++ type ) - -# DynamicShapesSubGraphTests unittest.expectedFailure( - DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes + DynamicShapesReproTests.test_reformer_train_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer ) -unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes) if __name__ == "__main__": diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index a157926422c8..21c0d2004bb9 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -71,6 +71,32 @@ def func(x): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + def test_export_shape_control_flow_1(self): + def func(x): + if x.shape[0] > 10: + return x.cos() + return x.sin() + + opt_func = torch._dynamo.optimize("eager")(func) + real_result = opt_func(torch.ones(6, 4)) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, torch.ones(6, 4)) + out_graph, out_guards = exported + + dynamo_result = out_graph(torch.ones(6, 4)) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + hit = False + for guard in out_guards: + if guard.name == "symbolic_shape_expression": + hit = True + self.assertTrue("x.size()[0] <= 10" in guard.code_list) + + self.assertTrue(hit) + def test_export_graph_bypass(self): inp = [ torch.tensor([0.1, 0.1]), diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a8bf86e46411..e270852fc526 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1144,6 +1144,7 @@ def fn(x): torch._dynamo.run()(fn2)(torch.randn(4)) self.assertEqual(cnts2.frame_count, 0) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_nested_disable_decorator(self): cnts = torch._dynamo.testing.CompileCounter() @@ -1616,6 +1617,7 @@ def fn(x, func): self.assertEqual(cnts.op_count, 1) @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_unsupported_fake_tensor(self): def f(x): return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8) diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index df511f1affd5..f7943c1d7ab9 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -1,6 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest - from torch._dynamo.testing import make_test_cls_with_patches try: @@ -25,9 +23,6 @@ def make_no_fake_cls(cls): NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) -unittest.expectedFailure( - NoFakeTensorsReproTests.test_guard_fail_tensor_bool_no_fake_tensors -) NoFakeTensorsReproTests.test_numpy_list_no_fake_tensors.__unittest_expecting_failure__ = ( False ) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 913d59322ac7..6a1c654a4873 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -11,6 +11,8 @@ from typing import List from unittest.mock import patch +import functorch._src.config + import numpy as np import torch @@ -803,7 +805,6 @@ def test_do_paste_mask(self): ) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) - # Graph break because of dynamic slicing self.assertEqual( torch._dynamo.utils.counters["frames"]["total"], torch._dynamo.utils.counters["frames"]["ok"] + 1, @@ -961,7 +962,7 @@ def test_maml_item_capture(self): self.assertEqual(cnt.frame_count, ifdyn(3, 2)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (36, 35, 29, 28)) + self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27)) # see: https://github.com/pytorch/pytorch/issues/80067 @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) @@ -980,7 +981,7 @@ def test_maml_no_item_capture(self): self.assertEqual(cnt.frame_count, ifdyn(5, 4)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (31, 36, 35, 29, 28)) + self.assertIn(cnt.op_count, (31, 36, 35, 34, 29, 28)) def test_hf_model_output(self): ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) @@ -1316,6 +1317,7 @@ def blah(self, x): self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 3) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_guard_fail_tensor_bool(self): @torch._dynamo.skip def fn(): @@ -1402,8 +1404,17 @@ def fn(x): self.assertTrue(same(ref1, res1)) @unittest.skipIf(not HAS_REFS, "requires recent PT version") - @unittest.expectedFailure def test_primtorch(self): + @torch._dynamo.optimize("eager") + def fn(x): + torch._refs.abs(x) + + fn(torch.randn(3)) + + @unittest.skipIf(not HAS_REFS, "requires recent PT version") + @unittest.expectedFailure + # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)] + def test_primtorch_no_graph_break(self): @torch._dynamo.optimize("eager", nopython=True) def fn(x): torch._refs.abs(x) @@ -1456,14 +1467,14 @@ def fn(x): fn(torch.randn(3)) - # AssertionError: ABCMeta + # Bug with storage meta - torch.BoolStorage is becoming torch.storage._LegacyStorageMeta @unittest.expectedFailure def test_isinstance_storage(self): @torch._dynamo.optimize("eager") def fn(x): f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) bools = torch.BoolStorage.from_buffer(f, "big") - self.assertTrue(isinstance(bools, torch.BoolStorage)) + assert isinstance(bools, torch.BoolStorage) return x fn(torch.randn(3)) @@ -1662,6 +1673,21 @@ def fn(x): opt_fn(x) self.assertEqual(cnt.frame_count, 1) + @patch.object(functorch._src.config, "use_dynamic_shapes", True) + def test_bigbird_unsqueeze_inplace(self): + def fn(reshape_2): + view_2 = reshape_2.clone() + view_2.unsqueeze_(2) + cat_11 = torch.cat([view_2], dim=2) + view_13 = cat_11.view((2, 12, 64, -1)) + return (view_13,) + + x = torch.randn(2, 12, 64, 64, requires_grad=True) + ref = fn(x) + opt_fn = torch._dynamo.optimize("aot_eager")(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + # This doesn't work without fake tensors but I don't care @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) def test_issue1466_size_aot_autograd(self): diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index fd5396981b74..e46d79208de0 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -50,6 +50,8 @@ class UnspecTest(cls): UnspecReproTests = make_unspec_cls(test_repros.ReproTests) UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) +unittest.expectedFailure(UnspecReproTests.test_batch_norm_act_unspec) + @patch.object(torch._dynamo.config, "specialize_int_float", False) class UnspecTests(torch._dynamo.test_case.TestCase): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 5cee29920b77..3d384efea0ae 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -279,6 +279,7 @@ def process(device_type): "baddbmm": {f16}, "bernoulli": {f16, f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 2ba29981c366..e469ce02ebd6 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -14,6 +14,7 @@ from .variables.base import VariableTracker from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -95,6 +96,7 @@ def __call__(self, value, allow_cache=True): value, ( TensorVariable, + DynamicShapeVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index d4903964aac6..9cbcb93fcc5c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -92,7 +92,7 @@ def __hash__(self): def sort_key(self): return ( - self.source.value, + self.source.value if self.source else -1, len(self.name), self.name, self.create_fn.__code__.co_firstlineno, @@ -128,7 +128,7 @@ def __getattr__(self, x): def __str__(self): s = f""" - {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__} + {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.create_fn.__name__} {{ 'guard_types': {self.guard_types}, 'code': {self.code_list}, @@ -438,6 +438,13 @@ def GRAD_MODE(self, guard: Guard): code = "not ___is_grad_enabled()" self._produce_guard_code(guard, [code]) + # This is a bit of a crutch for export case for symbolic shape guards. + # SYMBOL_MATCH is only ever, and must only ever, be used for setting this value on + # the create_fn field for tracking guards in export. + @staticmethod + def SYMBOL_MATCH(): + pass + def TENSOR_MATCH(self, guard: Guard): if guard.is_nn_module(): self.ID_MATCH(guard) @@ -537,10 +544,14 @@ def tensor_ref_as_str(tensor_ref, id_to_name_map): return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()[{tensor_ref.idx}]" return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()" - def __init__(self, expr_to_tensor_ref, id_to_name_map): + def __init__( + self, expr_to_tensor_ref, id_to_name_map, shape_env, intermediary_symbols + ): super().__init__() self.expr_to_tensor_ref = expr_to_tensor_ref self.id_to_name_map = id_to_name_map + self.shape_env = shape_env + self.intermediary_symbols = intermediary_symbols def _print_Symbol(self, expr) -> str: assert isinstance(expr, sympy.Symbol) @@ -548,7 +559,7 @@ def _print_Symbol(self, expr) -> str: return "0" if expr == 1: return "1" - assert expr in self.expr_to_tensor_ref, f"Unknown expression {expr}" + assert expr in (self.expr_to_tensor_ref) or (expr in self.intermediary_symbols) refs = self.expr_to_tensor_ref[expr] if len(refs) == 0: return super()._print_Symbol(expr) @@ -599,7 +610,7 @@ def combine_scopes(left, right): if not config.guard_nn_modules and guard.is_nn_module(): continue guard.create(local_builder, global_builder) - self.check_fn = self.compile_check_fn(local_builder, global_builder) + self.check_fn = self.compile_check_fn(local_builder, global_builder, guards) self._seen_ids.clear() """ @@ -632,7 +643,12 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids return None expr_to_tensor_ref = {} - guard_printer = DynamoGuardPrinter(expr_to_tensor_ref, id_to_name_map) + guard_printer = DynamoGuardPrinter( + expr_to_tensor_ref, + id_to_name_map, + self.output_graph.shape_env, + self.output_graph.intermediary_symbols, + ) # tensor_check_names is the primary tensor association mechanism in dynamo. # All other guards installations are driven off of it, so these ones will too. @@ -649,7 +665,6 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids if obj_expr not in expr_to_tensor_ref: expr_to_tensor_ref[obj_expr] = {} expr_to_tensor_ref[obj_expr][tensor_ref] = "" - finished_expressions.append(f"isinstance({name}, torch.Tensor)") guard_expression = self.output_graph.shape_env.get_guard_expr() expr_as_str = guard_printer.doprint(guard_expression) @@ -668,7 +683,6 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids if len(equality_candidates) > 1: equality_expr = " == ".join(equality_candidates) - # breakpoint() finished_expressions.append(equality_expr) # Redundant with code_parts, but allows us to wrap it with parens nicely. @@ -678,7 +692,7 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids expression = " and ".join(finished_expressions) return f"({expression})" - def compile_check_fn(self, local_builder, global_builder): + def compile_check_fn(self, local_builder, global_builder, guards_out): assert not (set(local_builder.argnames) & set(global_builder.argnames)) # see parallel handling of ".0" / "___implicit0" in _eval_frame.c args = [a for a in local_builder.scope.keys() if a == "___implicit0"] @@ -707,10 +721,6 @@ def compile_check_fn(self, local_builder, global_builder): symbolic_shape_expression = self._parse_symbolic_shape_expressions( tensor_check_names, tensor_check_ids ) - if symbolic_shape_expression: - code_parts.append(symbolic_shape_expression) - verbose_code_parts.append(symbolic_shape_expression) - tensor_check_examples = ( local_builder.tensor_check_examples + global_builder.tensor_check_examples @@ -725,6 +735,17 @@ def compile_check_fn(self, local_builder, global_builder): tensor_check_names + ["tensor_check_names=tensor_check_names"] ) verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})") + if symbolic_shape_expression: + code_parts.append(symbolic_shape_expression) + verbose_code_parts.append(symbolic_shape_expression) + guards_out.add( + Guard( + name="symbolic_shape_expression", + source=None, + create_fn=GuardBuilder.SYMBOL_MATCH, + code_list=symbolic_shape_expression, + ) + ) def direct_equality(a, b): return a == b @@ -739,6 +760,8 @@ def direct_negation(a, b): ("___check_tensors", check_tensors_fn), ("___check_tensors_verbose", check_tensors_verbose_fn), ("tensor_check_names", tensor_check_names), + ("floor", math.floor), + ("ceiling", math.ceil), ("Eq", direct_equality), ("Ne", direct_negation), ("Mod", sympy.Mod), diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py index b3f6ed79eb06..c4ed04ca8c39 100644 --- a/torch/_dynamo/optimizations/analysis.py +++ b/torch/_dynamo/optimizations/analysis.py @@ -15,7 +15,7 @@ if fake_tensors_available: from torch._subclasses import FakeTensorMode # noqa: F401 - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor + from ..utils import deepcopy_to_fake_tensor class ShapeAliasingAndMutationProp(ShapeProp): @@ -122,9 +122,26 @@ def has_mutation(gm, example_inputs, inputs_only=False): # TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way. if fake_tensors_available and config.fake_tensor_propagation: - with FakeTensorMode() as fake_mode: - pass - fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=fake_mode) + + def _wrap_to_fake_tensor(t, *, f_mode): + if type(t) in (torch.Tensor, torch.nn.Parameter): + static_shapes_ = config.dynamic_shapes is False + return fake_mode.from_tensor( + t, static_shapes=config.dynamic_shapes is not False + ) + else: + return t + + # Our analysis pass should use dynamic shape tensor inputs + # when dynamic shapes are enabled. + # We don't actually care about the guards that are created + # on those shapes though, so just create a fresh ShapeEnv here. + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + fake_mode = FakeTensorMode( + shape_env=ShapeEnv() if config.dynamic_shapes else None + ) + fake_wrapper = functools.partial(_wrap_to_fake_tensor, f_mode=fake_mode) example_inputs = tree_map(fake_wrapper, example_inputs) new_gm = deepcopy_to_fake_tensor(gm, fake_mode) with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode: diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py index 49f9a4397dd9..a56a74ad5aea 100644 --- a/torch/_dynamo/optimizations/training.py +++ b/torch/_dynamo/optimizations/training.py @@ -140,9 +140,13 @@ class AotNop(AotAutogradStrategy): """Useful for debugging purpose""" def candidate(self): + from functorch._src.compilers import debug_nop from functorch.compile import nop - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, fw_compiler=nop) + DEBUG = False + return BACKENDS["aot_autograd"]( + self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop + ) aot_eager = AotNop.compile_fn diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 9dd9a713a25c..ee5079581be7 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -6,7 +6,7 @@ import re import traceback from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch.nn from torch import fx @@ -15,7 +15,7 @@ from . import config, logging as torchdynamo_logging, variables from .bytecode_transformation import create_instruction, Instruction, unique_id from .codegen import PyCodegen -from .exc import BackendCompilerFailed, unimplemented +from .exc import BackendCompilerFailed from .guards import GuardBuilder from .mutation_guard import is_dynamic_nn_module from .side_effects import SideEffects @@ -27,9 +27,10 @@ fake_tensors_available, format_graph_tabular, ) -from .variables.builder import VariableBuilder +from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, @@ -93,7 +94,7 @@ def __init__( self.side_effects = SideEffects() self.code_options = dict(code_options) self.output_instructions = [] - # Node => computed real value (see TensorVariable.get_real_value) + # Node => computed real value (see utils.get_real_value) self.real_value_cache = {} # Not checkpointed @@ -107,6 +108,7 @@ def __init__( self.unspec_variable_map = {} self.shape_env = ShapeEnv() if config.dynamic_shapes else None self.tensor_id_to_sym_shape_ref = {} + self.intermediary_symbols = {} @property def output(self): @@ -194,43 +196,63 @@ def update_co_names(self, name): name, ) - def register_attr_or_module(self, mod: torch.nn.Module, *names, **options): - if is_dynamic_nn_module(mod): - return variables.UnspecializedNNModuleVariable(mod, **options) + def register_attr_or_module( + self, target: Union[torch.nn.Module, torch.Tensor, Any], *names, **options + ): + if is_dynamic_nn_module(target): + return variables.UnspecializedNNModuleVariable(target, **options) options = dict(options) options["guards"] = set(options.get("guards", [])) source: Source = options.get("source", None) - if isinstance(mod, torch.Tensor): + if isinstance(target, torch.Tensor): if source: options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH)) def wrap_name(module_key): - return TensorVariable.create( + return wrap_fx_proxy( self, self.create_proxy("get_attr", module_key, tuple(), {}), - example_value=mod, + example_value=target, **options, ) - elif isinstance(mod, torch.nn.Module): - assert isinstance(mod, torch.nn.Module) + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE)) def wrap_name(module_key): - return NNModuleVariable(type(mod), module_key, **options) + return NNModuleVariable(type(target), module_key, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + self.intermediary_symbols.update({target.get_pyobj().expr: None}) + + def wrap_name(module_key): + return DynamicShapeVariable.create( + self, + self.create_proxy("get_attr", module_key, tuple(), {}), + dyn_shape=target, + **options, + ) + # HACKY CODE REGION END else: def wrap_name(module_key): self.output.update_co_names(module_key) - self.root_globals[module_key] = mod + self.root_globals[module_key] = target return VariableBuilder(self, ConstantSource(source_name=module_key))( - mod + target ) for k, v in self.nn_modules.items(): - if v is mod: + if v is target: # it already exists return wrap_name(k) @@ -246,7 +268,7 @@ def wrap_name(module_key): base = name for i in itertools.count(): if name not in self.nn_modules: - self.nn_modules[name] = mod + self.nn_modules[name] = target return wrap_name(name) name = f"{base}_{i}" diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e06f62a6bf62..88e0df5470bc 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -55,7 +55,7 @@ istype, ) from .variables.base import MutableLocal, typestr, VariableTracker -from .variables.builder import VariableBuilder +from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builtin import BuiltinVariable from .variables.constant import ConstantVariable from .variables.dicts import ConstDictVariable @@ -81,7 +81,7 @@ WithExitFunctionVariable, ) from .variables.nn_module import NNModuleVariable -from .variables.tensor import TensorVariable +from .variables.tensor import DynamicShapeVariable, TensorVariable from .variables.torch import TorchVariable from .variables.user_defined import UserDefinedVariable @@ -129,7 +129,9 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if truth_fn(value.as_python_constant()): push and self.push(value) self.jump(inst) - elif isinstance(value, TensorVariable) and self.should_compile_partial_graph(): + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): # compile a partial subgraph prefix then jump into user code self.push(value) self.output.compile_subgraph( @@ -155,6 +157,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if truth_fn(len(value.unpack_var_sequence(self))): push and self.push(value) self.jump(inst) + elif isinstance(value, DynamicShapeVariable): + eval_result = value.evaluate_expr(self.output) + if truth_fn(eval_result): + push and self.push(value) + self.jump(inst) else: unimplemented(f"generic_jump {typestr(value)}") @@ -700,6 +707,7 @@ def COMPARE_OP(self, inst): left, ( TensorVariable, + DynamicShapeVariable, NNModuleVariable, BaseListVariable, UserDefinedVariable, @@ -717,16 +725,6 @@ def COMPARE_OP(self, inst): supported_is_const[op](object(), right.value), **options ) ) - elif ( - isinstance(left, TensorVariable) or isinstance(right, TensorVariable) - ) and op in supported_tensors: - self.push( - TensorVariable.create( - self, - supported_tensors[op](left.as_proxy(), right.as_proxy()), - **options, - ) - ) elif ( left.is_python_constant() and right.is_python_constant() @@ -741,6 +739,28 @@ def COMPARE_OP(self, inst): **options, ) ) + elif ( + isinstance(left, TensorVariable) or isinstance(right, TensorVariable) + ) and op in supported_tensors: + self.push( + wrap_fx_proxy( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + **options, + ) + ) + elif ( + isinstance(left, DynamicShapeVariable) + or isinstance(right, DynamicShapeVariable) + ) and op in supported_tensors: + self.push( + DynamicShapeVariable.create( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + dyn_shape=None, + **options, + ) + ) elif op in ("in", "not in"): self.push(right.call_method(self, "__contains__", [left], {})) if op == "not in": @@ -1029,12 +1049,12 @@ def UNPACK_SEQUENCE(self, inst): elif isinstance(seq, TensorVariable): proxy = seq.as_proxy() for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + self.push(wrap_fx_proxy(self, proxy[i], **options)) elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + self.push(wrap_fx_proxy(self, proxy[i], **options)) else: unimplemented(f"UNPACK_SEQUENCE {seq}") @@ -1109,7 +1129,8 @@ def FORMAT_VALUE(self, inst): fmt_spec = ConstantVariable("") value = self.pop() - + if isinstance(value, DynamicShapeVariable): + value = ConstantVariable(str(value.dyn_shape)) if (flags & 0x03) == 0x01: value = BuiltinVariable(str).call_function(self, [value], {}) elif (flags & 0x03) == 0x02: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 067a80807374..0b87be7393b5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -29,7 +29,9 @@ import torch from torch import fx +from torch._dispatch.python import enable_python_dispatcher from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._pytree import tree_map from . import config, logging as torchdynamo_logging @@ -679,10 +681,8 @@ def rename_implicit(v): UnsupportedFakeTensorException, ) - def make_fake_tensor(e, fake_mode, tx=None): - fake_tensor = fake_mode.from_tensor( - e, static_shapes=config.dynamic_shapes is False - ) + def make_fake_tensor(e, fake_mode, static_shapes=False, tx=None): + fake_tensor = fake_mode.from_tensor(e, static_shapes=static_shapes) if tx is not None: from torch._dynamo.guards import TensorReference @@ -728,13 +728,23 @@ def wrap_fake_exception(fn): def wrap_to_fake_tensor(e, fake_mode): if type(e) in (torch.Tensor, torch.nn.Parameter): - return wrap_fake_exception(lambda: make_fake_tensor(e, fake_mode)) + return wrap_fake_exception( + lambda: make_fake_tensor( + e, fake_mode, static_shapes=config.dynamic_shapes is False + ) + ) else: return e def wrap_to_fake_tensor_and_record(e, tx): if type(e) in (torch.Tensor, torch.nn.Parameter): - return wrap_fake_exception(lambda: make_fake_tensor(e, tx.fake_mode, tx)) + static_shapes = config.dynamic_shapes is False + if type(e) is torch.nn.Parameter: + # Always static for params + static_shapes = True + return wrap_fake_exception( + lambda: make_fake_tensor(e, tx.fake_mode, static_shapes, tx) + ) else: return e @@ -997,3 +1007,116 @@ def _get_debug_dir(root_dir): def get_debug_dir(): debug_root = config.debug_dir_root return _get_debug_dir(debug_root) + + +def get_fake_value(node, tx): + """ + Run the computation represented by `node` using fake tensors and return the result. + """ + from .exc import TorchRuntimeError, unimplemented, Unsupported + + op = node.op + fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) + + def visit(n: torch.fx.Node): + return n.meta["example_value"] + + args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) + args = tree_map(fake_wrapper, args) + kwargs = tree_map(fake_wrapper, kwargs) + + nnmodule = None + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if not is_lazy_module(nnmodule): + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + if op == "call_module" and is_lazy_module(nnmodule): + assert nnmodule is not None + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nnmodule(*args, **kwargs) + try: + with tx.fake_mode, enable_python_dispatcher(): + return wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + if isinstance(e, torch._subclasses.fake_tensor.DataDependentOutputException): + if config.capture_scalar_outputs and node.target == "item": + return torch.zeros(size=(), dtype=args[0].dtype).item() + else: + unimplemented(f"data dependent operator: {e.func}") + elif isinstance(e, torch._subclasses.fake_tensor.DynamicOutputShapeException): + unimplemented(f"dynamic shape operator: {e.func}") + raise TorchRuntimeError() from e + + +def run_node(output_graph, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dicatated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The output_graph arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return output_graph.get_submodule(node.target) + except Exception as e: + raise RuntimeError( + f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)" + ) from e + raise AssertionError(op) + + +def get_real_value(node, output_graph): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + cache = output_graph.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( + (node.args, node.kwargs), + lambda n: get_real_value(n, output_graph), + ) + + if op == "call_module": + nn_module = output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(output_graph, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError() from e + return real_value diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 8c80557e3fd0..2305afc226ac 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -35,6 +35,7 @@ ) from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, FakeItemVariable, TensorVariable, UnspecializedNumpyVariable, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d3c5140fa4a9..67e506b5b435 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3,15 +3,19 @@ import enum import functools import inspect +import math +import numbers +import operator import re import types from abc import ABCMeta -from typing import Any, List +from typing import Any, List, Union import numpy as np from functorch.experimental.ops import PyOperator import torch +from torch.fx.immutable_collections import immutable_list from .. import config, mutation_guard, replay_record, skipfiles from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy @@ -31,6 +35,10 @@ TupleIteratorGetItemSource, ) from ..utils import ( + clone_input, + fake_tensors_available, + get_fake_value, + get_real_value, getfile, global_key_name, is_namedtuple, @@ -38,11 +46,14 @@ istensor, istype, odict_values, + preserve_rng_state, tuple_iterator, tuple_iterator_getitem, tuple_iterator_len, + wrap_to_fake_tensor_and_record, ) -from .base import MutableLocal + +from .base import MutableLocal, typestr from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .dicts import ( @@ -57,6 +68,7 @@ ListVariable, NamedTupleVariable, RangeVariable, + SizeVariable, SliceVariable, TupleVariable, ) @@ -72,6 +84,7 @@ ) from .nn_module import UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -86,6 +99,10 @@ from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable +class _missing: + pass + + @dataclasses.dataclass class GraphArg: source: Source @@ -187,6 +204,8 @@ def make_guards(self, *guards): def _wrap(self, value): make_guards = self.make_guards + if istype(value, (torch.SymInt, torch.SymFloat)): + return self.wrap_sym(value) if istensor(value): return self.wrap_tensor(value) elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value): @@ -490,6 +509,26 @@ def tensor_should_specialize(self): ) ) + def wrap_sym(self, value: Union[torch.SymInt, torch.SymFloat]): + if not is_constant_source(self.get_source()): + self.tx.output.graphargs.append(GraphArg(self.get_source(), value, False)) + elif is_constant_source(self.get_source()): + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=None, + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + return DynamicShapeVariable.create( + tx=self.tx, + proxy=self.tx.output.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) + ), + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + def wrap_tensor(self, value: torch.Tensor): if self.get_source().guard_source().is_nn_module(): return self.tx.output.register_attr_or_module( @@ -514,7 +553,7 @@ def wrap_tensor(self, value: torch.Tensor): source=None, # Guards are added inside register_attr_or_module ) - tensor_variable = TensorVariable.create( + tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) @@ -556,14 +595,16 @@ def wrap_unspecialized_primitive(self, value): ) if isinstance(value, np.number): - unspec_var = UnspecializedNumpyVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedNumpyVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) else: - unspec_var = UnspecializedPythonVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, @@ -589,3 +630,190 @@ def _dataclasses_fields_lambda(obj): ) items.append(UserDefinedObjectVariable(field, source=source).add_options(obj)) return TupleVariable(items).add_options(obj) + + +def wrap_fx_proxy(tx, proxy, example_value=None, **options): + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +def wrap_fx_proxy_cls(target_cls, tx, proxy, example_value=None, **options): + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta + if not config.dynamic_propagation: + if isinstance(example_value, torch.Tensor): + options.update(target_cls.specialize(example_value)) + return target_cls(proxy, **options) + + use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation + + initial_example_value = example_value + + def _clone_input(value): + if isinstance(value, torch.Tensor): + use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not use_fake_tensors or not isinstance( + value, torch._subclasses.fake_tensor.FakeTensor + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + with preserve_rng_state(): + if example_value is None: + if use_fake_tensors: + example_value = get_fake_value(proxy.node, tx) + else: + example_value = get_real_value(proxy.node, tx.output) + + else: + proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) + if use_fake_tensors: + fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) + example_value = fake_wrapper(example_value) + + if isinstance(example_value, torch.Tensor): + is_parameter = isinstance(example_value, torch.nn.Parameter) + should_specialize = options.pop("should_specialize", False) + if is_parameter or should_specialize: + specialized_value = initial_example_value + else: + specialized_value = None + + example_value = _clone_input(example_value) + proxy.node.meta["example_value"] = example_value + specialized_props = target_cls.specialize(example_value) + if use_fake_tensors and isinstance( + example_value, torch._subclasses.fake_tensor.FakeTensor + ): + specialized_props["class_type"] = ( + torch.nn.Parameter if is_parameter else torch.Tensor + ) + + specialized_props["specialized_value"] = specialized_value + + options.update(specialized_props) + return target_cls(proxy, **options) + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target == torch.random.set_rng_state + ): + from . import TorchVariable + + return TorchVariable(proxy.node.target) + elif ( + proxy.node.target == torch._C._DisableFuncTorch + or proxy.node.target == torch.cuda._is_in_bad_fork + ): + from . import UserDefinedObjectVariable + + return UserDefinedObjectVariable(example_value) + elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + elif istype(example_value, torch.Size) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + sizes = [] + for i, v in enumerate(example_value): + proxy_i = proxy[i] + sizes.append(DynamicShapeVariable.create(tx, proxy_i, v, **options)) + return SizeVariable(sizes, proxy, **options) + elif istype(example_value, int) and proxy.node.target in ( + torch.seed, + operator.mod, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + ): + if config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + else: + return ConstantVariable(example_value, **options) + elif istype(example_value, torch.Size) and all( + [isinstance(x, int) for x in example_value] + ): + sizes = [ConstantVariable(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable(None, **options), + ) + else: + unpacked.append( + wrap_fx_proxy( + tx, + proxy.tracer.create_proxy( + "call_function", operator.getitem, (proxy, i), {} + ), + example_value=val, + **options, + ) + ) + if istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, mutable_local=MutableLocal(), **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ("namedtuple?") + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable(None, **options) + elif ( + isinstance(example_value, int) + and proxy.node.target is torch._utils._element_size + ): + proxy.node.meta["example_value"] = example_value + return ConstantVariable(example_value, **options) + elif ( + isinstance(example_value, numbers.Number) + and (proxy.node.target == "item" or proxy.node.target in {math.sqrt, math.pow}) + and config.capture_scalar_outputs + ): + if use_fake_tensors: + # item raw value should not be accessed + return wrap_fx_proxy_cls( + FakeItemVariable, + tx=tx, + proxy=proxy, + example_value=torch.tensor(example_value), + **options, + ) + else: + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=tx, + proxy=proxy, + example_value=torch.tensor(example_value), + raw_value=None if use_fake_tensors else example_value, + need_unwrap=False, + **options, + ) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat)): + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable(proxy, example_value, **options) + else: + raise AssertionError( + "torch.* op returned non-Tensor " + + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" + ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5a88f375c9c2..904ed8a49f81 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -10,6 +10,7 @@ import numpy as np import torch +from torch.fx.experimental.symbolic_shapes import sym_float, sym_int from .. import config, variables from ..allowed_functions import is_allowed @@ -26,7 +27,7 @@ ) from .base import MutableLocal, VariableTracker from .dicts import ConstDictVariable -from .tensor import DynamicShapeVariable, FakeItemVariable +from .tensor import DynamicShapeVariable, FakeItemVariable, UnspecializedPythonVariable log = logging.getLogger(__name__) @@ -226,6 +227,7 @@ def unwrap_unspec_args_kwargs(args, kwargs): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls constant_args = check_constant_args(args, kwargs) tensor_args = self.tensor_args(*args, **kwargs) @@ -234,7 +236,7 @@ def call_function( has_constant_handler = self.can_constant_fold_through() and ( constant_args or unspec_python_args ) - assert isinstance(args, list) + assert isinstance(args, (list, tuple)) assert isinstance(kwargs, dict) if ( @@ -274,7 +276,8 @@ def call_function( "call_function", fn, *proxy_args_kwargs(args, kwargs), current_tx=tx ) if any([isinstance(arg, FakeItemVariable) for arg in args]): - return variables.FakeItemVariable.create( + return wrap_fx_proxy_cls( + FakeItemVariable, tx, proxy, **options, @@ -282,7 +285,8 @@ def call_function( elif self.unspec_numpy_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) - return variables.UnspecializedNumpyVariable.create( + return wrap_fx_proxy_cls( + variables.UnspecializedNumpyVariable, tx, proxy, raw_value=raw_value, @@ -298,7 +302,8 @@ def call_function( if isinstance(x, variables.UnspecializedPythonVariable) ) - return variables.UnspecializedPythonVariable.create( + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx, proxy, raw_value=raw_value, @@ -312,14 +317,27 @@ def call_function( args[0], variables.UnspecializedPythonVariable ): args[0] = args[0].convert_to_constant(tx) - return variables.TensorVariable.create(tx, proxy, **options) + return wrap_fx_proxy(tx, proxy, **options) except NotImplementedError: unimplemented(f"partial tensor op: {self} {args} {kwargs}") # Handle cases like int(torch.seed()) - if self.fn is int and isinstance(args[0], DynamicShapeVariable): - return args[0] + # Also handle sym_float to sym_int cases + if self.fn in (int, float) and isinstance(args[0], DynamicShapeVariable): + fn_ = sym_int if self.fn is int else sym_float + out = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (args[0].as_proxy(),), + {}, + current_tx=tx, + ), + **options, + ) + return out handler = getattr(self, f"call_{self.fn.__name__}", None) if handler: @@ -353,7 +371,6 @@ def call_function( ), **options, ) - return super().call_function(tx, args, kwargs) def _call_min_max(self, tx, a, b): @@ -368,7 +385,9 @@ def _call_min_max(self, tx, a, b): # Dynamic input does not get resolved, rather, gets stored as call_function if isinstance(a, DynamicShapeVariable): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -437,7 +456,13 @@ def _call_min_max(self, tx, a, b): return variables.ConstantVariable(max(a.value, b.value)) else: return variables.ConstantVariable(min(a.value, b.value)) + elif isinstance(a, DynamicShapeVariable) or isinstance(b, DynamicShapeVariable): + proxy = tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs([a, b], {}) + ) + return DynamicShapeVariable.create(tx, proxy, None) else: + unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") call_min = _call_min_max @@ -454,11 +479,48 @@ def call_range(self, tx, *args, **kwargs): **{k: v.value for k, v in kwargs.items()}, ), ) + elif self._dynamic_args(*args, **kwargs): + assert len(kwargs) == 0 + + def guard_if_dyn(arg): + if isinstance(arg, DynamicShapeVariable): + return arg.evaluate_expr(tx.output) + return arg + + args = [guard_if_dyn(arg) for arg in args] + value = self.fn(*args) + return variables.RangeVariable(value=value) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args, **kwargs): + return any([isinstance(x, DynamicShapeVariable) for x in args]) or any( + [isinstance(x, DynamicShapeVariable) for x in kwargs.values()] + ) def call_slice(self, tx, *args): return variables.SliceVariable(args) - def _call_iter_tuple_list(self, tx, obj=None): + def _dyn_proxy(self, tx, *args, **kwargs): + assert self._dynamic_args(*args, **kwargs) + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + **options, + ) + + def call_mod(self, tx, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -551,6 +613,7 @@ def call_getitem(self, tx, *args, **kwargs): def call_isinstance(self, tx, arg, isinstance_type): arg_type = arg.python_type() + isinstance_type = isinstance_type.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index d3366448e379..63eed37ccbec 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -13,6 +13,8 @@ class ConstantVariable(VariableTracker): def __init__(self, value, **kwargs): super(ConstantVariable, self).__init__(**kwargs) assert not isinstance(value, torch.Tensor) + assert not isinstance(value, torch.SymInt) + assert not isinstance(value, torch.SymFloat) self.value = value def as_proxy(self): @@ -70,6 +72,8 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": + from .tensor import DynamicShapeVariable + options = VariableTracker.propagate(self, args, kwargs.values()) if istype(self.value, tuple): @@ -78,6 +82,20 @@ def call_method( items=self.unpack_var_sequence(tx), source=self.source, **options ).call_method(tx, name, args, kwargs) + if any([isinstance(x, DynamicShapeVariable) for x in args]): + # NOTE! DANGER! THIS ONLY WORKS FOR COMMUTATIVE OPS + # we are relying on add to have arg[0] be a DynamicShapeVariable + # because we are in ConstantVariable land + # This transforms + # constant + dynamic + # into + # dynamic + constant + # Which already has infra built for writing to the graph + if name == "__add__": + assert len(args) == 1 + return args[0].call_method(tx, name, [self], {}) + # Unfortunate constant + return super(ConstantVariable, self).call_method(tx, name, args, kwargs) try: const_args = [a.as_python_constant() for a in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} @@ -98,7 +116,19 @@ def has_arith_binop(num_ty): return ConstantVariable(method(*const_args, **const_kwargs), **options) elif has_arith_binop(int) or has_arith_binop(float): op = getattr(operator, name) - return ConstantVariable(op(self.value, const_args[0]), **options) + add_target = const_args[0] + if isinstance(add_target, (torch.SymInt, torch.SymFloat)): + from .tensor import DynamicShapeVariable + + # Addition between a non sym and sym makes a sym + # dyn_shape = tx.output.register_attr_or_module( + # add_target, f"sym_shape_{add_target}", source=None + # ) + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return DynamicShapeVariable.create(tx, proxy, add_target, **options) + return ConstantVariable(op(self.value, add_target), **options) elif name == "__len__" and not (args or kwargs): return ConstantVariable(len(self.value), **options) elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index f63283819f35..151619d0e4ab 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -7,7 +7,7 @@ from ..bytecode_transformation import create_instruction from ..exc import unimplemented from ..source import GetItemSource -from ..utils import namedtuple_fields +from ..utils import namedtuple_fields, proxy_args_kwargs from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -308,6 +308,58 @@ def reconstruct(self, codegen): ] return build_torch_size + def unpack_var_sequence(self, tx): + return [x.add_options(self) for x in self.items] + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + options = VariableTracker.propagate(self, args, kwargs.values()) + if name == "__getitem__": + assert not kwargs and len(args) == 1 + if config.dynamic_shapes: + out = self.get_item_dyn(tx, args[0]) + else: + out = self.getitem_const(args[0]) + return out + return super(SizeVariable, self).call_method(tx, name, args, kwargs) + + def get_item_dyn(self, tx, arg: VariableTracker): + from .tensor import DynamicShapeVariable + + index = arg.as_python_constant() + if isinstance(index, slice): + + def _dynamo_get_item_lambda(target, index): + return torch.Size.__getitem__(target, index) + + parent_proxy = self.as_proxy() + proxy = tx.output.create_proxy( + "call_function", + _dynamo_get_item_lambda, + *proxy_args_kwargs([self, arg], {}), + current_tx=tx, + ) + items = self.items[index] + + def _unpack_into_example(item): + if isinstance(item, DynamicShapeVariable): + return item.dyn_shape + return item.as_python_constant() + + # Mirror the indexing into example_value for downstream correctness + proxy.node.meta["example_value"] = parent_proxy.node.meta["example_value"][ + index + ] + return SizeVariable(items, proxy=proxy).add_options(arg, self) + else: + assert isinstance(index, int) + return self.items[index].add_options(arg, self) + class ShapeVariable(TupleVariable): """ @@ -349,13 +401,20 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": class SliceVariable(BaseListVariable): def __init__(self, items, **kwargs): + from .tensor import DynamicShapeVariable + + if any([isinstance(x, DynamicShapeVariable) for x in items]): + unimplemented("Dynamic slicing not supported") + + items_to_map = items start, stop, step = [variables.ConstantVariable(None)] * 3 - if len(items) == 1: - (stop,) = items - elif len(items) == 2: - start, stop = items - elif len(items) == 3: - start, stop, step = items + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map else: raise AssertionError() @@ -366,7 +425,7 @@ def __init__(self, items, **kwargs): # more complete support for breaking on data dependent operators. if not config.capture_scalar_outputs: for limit in (start, stop, step): - if isinstance(limit, variables.TensorVariable): + if isinstance(limit, (variables.TensorVariable, DynamicShapeVariable)): unimplemented("Dynamic slicing not supported") super().__init__([start, stop, step], **kwargs) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index da327122a6a7..5d7336cefeae 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -513,6 +513,7 @@ def reconstruct(self, codegen): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy # This variable is True when it corresponds to user code such as # @@ -530,7 +531,7 @@ def call_function( if is_original_tensor_torch_function: # Instead of tracing inside torch.Tensor.__torch_function__, # record the `call_function` or `call_method` call into the graph. - from . import TensorVariable, TorchVariable + from . import TorchVariable original_torch_or_getattr_variable = args[0] new_args = args[2].items @@ -540,7 +541,7 @@ def call_function( # example tensor from going into the override. with torch._C.DisableTorchFunction(): if isinstance(args[0], TorchVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -551,7 +552,7 @@ def call_function( **options, ) elif isinstance(args[0], GetAttrVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 1922980fc957..848f022525d9 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -197,8 +197,9 @@ def record_nn_module_stack(): # The module type will change after it is called if is_lazy: self.module_type = mod.cls_to_become + from .builder import wrap_fx_proxy - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_module", @@ -454,7 +455,9 @@ def make_attr(name): proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index e87b1d87bac9..8867f7e6cc93 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1,161 +1,28 @@ -import copy -import functools import itertools -import math -import numbers import operator from typing import Dict, List import torch.fx import torch.random -from ..utils import fake_tensors_available - -if fake_tensors_available: - from torch._subclasses import FakeTensor - from torch._subclasses.fake_tensor import ( - DataDependentOutputException, - DynamicOutputShapeException, - ) - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor_and_record - -import torch.utils._python_dispatch as py_dispatch -from torch.fx.immutable_collections import immutable_list -from torch.utils._pytree import tree_map - from .. import config, variables -from ..exc import TorchRuntimeError, unimplemented, Unsupported +from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource + from ..utils import ( - clone_input, - is_lazy_module, - istype, - preserve_rng_state, + fake_tensors_available, + get_fake_value, + get_real_value, product, proxy_args_kwargs, tensortype_to_dtype, ) -from .base import MutableLocal, typestr, VariableTracker +from .base import VariableTracker from .constant import ConstantVariable from .lists import ShapeVariable, SizeVariable -class _missing: - pass - - -def _run_node(output_graph, node, args, kwargs, nnmodule): - op = node.op - if op == "call_function": - return node.target(*args, **kwargs) - elif op == "call_method": - return getattr(args[0], node.target)(*args[1:], **kwargs) - elif op == "call_module": - assert nnmodule is not None - return nnmodule(*args, **kwargs) - elif op == "get_attr": - return output_graph.get_submodule(node.target) - raise AssertionError(op) - - -def _get_real_value(node, output_graph): - """ - Run the actual computation represented by `node` and return the result. - This will execute any dependent nodes in the graph as well. - """ - cache = output_graph.real_value_cache - if node in cache: - return cache[node] - - op = node.op - args, kwargs = torch.fx.node.map_arg( - (node.args, node.kwargs), - lambda n: _get_real_value(n, output_graph), - ) - - if op == "call_module": - nn_module = output_graph.nn_modules[node.target] - if not is_lazy_module(nn_module): - nn_module = copy.deepcopy(nn_module) - else: - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nn_module(*args, **kwargs) - else: - nn_module = None - - try: - real_value = _run_node(output_graph, node, args, kwargs, nn_module) - cache[node] = real_value - except RuntimeError as e: - raise TorchRuntimeError() from e - return real_value - - -def _get_fake_value(node, tx): - """ - Run the computation represented by `node` using fake tensors and return the result. - """ - op = node.op - fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) - from ..utils import wrap_fake_exception - - def visit(n: torch.fx.Node): - return n.meta["example_value"] - - args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) - args = tree_map(fake_wrapper, args) - kwargs = tree_map(fake_wrapper, kwargs) - - nnmodule = None - if op == "call_module": - nnmodule = tx.output.nn_modules[node.target] - - if not is_lazy_module(nnmodule): - nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) - - def context(): - if hasattr(py_dispatch, "enable_torch_dispatch_mode"): - return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode) - else: - return tx.fake_mode - - if op == "call_module" and is_lazy_module(nnmodule): - assert nnmodule is not None - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nnmodule(*args, **kwargs) - try: - with context(): - return wrap_fake_exception( - lambda: _run_node(tx.output, node, args, kwargs, nnmodule) - ) - except Unsupported: - raise - except RuntimeError as e: - if isinstance(e, DataDependentOutputException): - if config.capture_scalar_outputs and node.target == "item": - return torch.zeros(size=(), dtype=args[0].dtype).item() - else: - unimplemented(f"data dependent operator: {e.func}") - elif isinstance(e, DynamicOutputShapeException): - unimplemented(f"dynamic shape operator: {e.func}") - else: - raise TorchRuntimeError() from e - - -def _clone_input(value): - if isinstance(value, torch.Tensor): - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - # tensor subclasses will not be converted to FakeTensors and need to be cloned - if not use_fake_tensors or not isinstance(value, FakeTensor): - # NB: ensure strides are preserved - value = clone_input(value) - - return value - - class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -178,173 +45,7 @@ def get_real_value(self): NOTE: this runs actual tensor computation and may be slow and memory-intensive. """ - return _get_real_value(self.proxy.node, self.proxy.tracer) - - @classmethod - def create(cls, tx, proxy, example_value=None, **options): - if "guards" in options and options["guards"] is not None: - tx.output.guards.update(options["guards"]) - - assert "example_value" not in proxy.node.meta - if not config.dynamic_propagation: - if isinstance(example_value, torch.Tensor): - options.update(cls.specialize(example_value)) - return cls(proxy, **options) - - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - - initial_example_value = example_value - - with preserve_rng_state(): - if example_value is None: - if use_fake_tensors: - example_value = _get_fake_value(proxy.node, tx) - else: - example_value = _get_real_value(proxy.node, tx.output) - - else: - proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) - if use_fake_tensors: - fake_wrapper = functools.partial( - wrap_to_fake_tensor_and_record, tx=tx - ) - example_value = fake_wrapper(example_value) - - if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - should_specialize = options.pop("should_specialize", False) - if is_parameter or should_specialize: - specialized_value = initial_example_value - else: - specialized_value = None - - example_value = _clone_input(example_value) - proxy.node.meta["example_value"] = example_value - specialized_props = cls.specialize(example_value) - if use_fake_tensors and isinstance(example_value, FakeTensor): - specialized_props["class_type"] = ( - torch.nn.Parameter if is_parameter else torch.Tensor - ) - - specialized_props["specialized_value"] = specialized_value - - options.update(specialized_props) - return cls(proxy, **options) - elif ( - hasattr(proxy.node.target, "__name__") - and proxy.node.target.__name__ == "set_state" - and isinstance(proxy.node.target.__self__, torch._C.Generator) - or proxy.node.target == torch.random.set_rng_state - ): - from . import TorchVariable - - return TorchVariable(proxy.node.target) - elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, example_value, **options) - elif istype(example_value, torch.Size) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - sizes = [] - for i, v in enumerate(example_value): - proxy_i = proxy[i] - proxy_i.node.meta["example_value"] = v - sizes.append(DynamicShapeVariable(proxy_i, v)) - return SizeVariable(sizes, proxy, **options) - elif istype(example_value, int) and proxy.node.target in ( - torch.seed, - operator.mod, - # some mac builds are missing torch.distributed.get_rank() - getattr(torch.distributed, "get_rank", _missing), - getattr(torch.distributed, "get_world_size", _missing), - ): - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, example_value, **options) - elif istype(example_value, torch.Size) and all( - [isinstance(x, int) for x in example_value] - ): - sizes = [variables.ConstantVariable(x) for x in example_value] - return SizeVariable(sizes, **options) - elif isinstance(example_value, (tuple, list)): - unpacked = [] - for i, val in enumerate(example_value): - if val is None: - # nn.MultiheadAttention() can return None, see issue #175 - unpacked.append( - variables.ConstantVariable(None, **options), - ) - else: - unpacked.append( - cls.create( - tx, - proxy.tracer.create_proxy( - "call_function", operator.getitem, (proxy, i), {} - ), - example_value=val, - **options, - ) - ) - if istype(example_value, tuple): - return variables.TupleVariable(unpacked, **options) - elif istype(example_value, (list, immutable_list)): - return variables.ListVariable( - unpacked, mutable_local=MutableLocal(), **options - ) - else: - assert ( - example_value.__class__.__module__ == "torch.return_types" - or hasattr(example_value, "_fields") - ), "namedtuple?" - return variables.NamedTupleVariable( - unpacked, example_value.__class__, **options - ) - elif example_value is None or proxy.node.target is torch.manual_seed: - return variables.ConstantVariable(None, **options) - elif ( - isinstance(example_value, int) - and proxy.node.target is torch._utils._element_size - ): - proxy.node.meta["example_value"] = example_value - return variables.ConstantVariable(example_value, **options) - elif ( - isinstance(example_value, numbers.Number) - and ( - proxy.node.target == "item" - or proxy.node.target in {math.sqrt, math.pow} - ) - and config.capture_scalar_outputs - ): - if use_fake_tensors: - # item raw value should not be accessed - return FakeItemVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - **options, - ) - else: - return UnspecializedPythonVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - raw_value=None if use_fake_tensors else example_value, - need_unwrap=False, - **options, - ) - elif ( - proxy.node.target == torch._C._DisableFuncTorch - or proxy.node.target == torch.cuda._is_in_bad_fork - ): - from . import UserDefinedObjectVariable - - return UserDefinedObjectVariable(example_value) - elif isinstance(example_value, torch.SymInt): - proxy.node.meta["example_value"] = example_value - return cls(proxy, **options) - else: - raise AssertionError( - "torch.* op returned non-Tensor " - + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" - ) + return get_real_value(self.proxy.node, self.proxy.tracer) def __init__( self, @@ -482,15 +183,26 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from . import ConstantVariable, TupleVariable + from .builder import wrap_fx_proxy kwargs = dict(kwargs) - options = VariableTracker.propagate(self, args, kwargs.values()) if name == "stride" and self.stride is not None: constant_result = ConstantVariable(self.stride, **options) elif name == "size" and self.size is not None: sizes = [variables.ConstantVariable(x) for x in self.size] constant_result = SizeVariable(sizes, **options) + elif name == "size" and self.size is None and config.dynamic_shapes: + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + args, kwargs), + current_tx=tx, + ), + **options, + ) elif name == "numel" and self.size is not None: constant_result = ConstantVariable(product(self.size), **options) elif name in ("ndimension", "dim") and self.ndim is not None: @@ -531,11 +243,19 @@ def call_method( unimplemented(f"Tensor.{name}") elif name == "item": if config.capture_scalar_outputs: - return self.__class__.create( + use_fake_tensors = ( + fake_tensors_available and config.fake_tensor_propagation + ) + if use_fake_tensors: + example_value = get_fake_value(self.proxy.node, tx) + else: + example_value = get_real_value(self.proxy.node, tx.output).item() + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", "item", (self.as_proxy(),), {}, current_tx=tx ), + example_value=example_value, **options, ) else: @@ -545,7 +265,7 @@ def call_method( assert not config.dynamic_shapes return ConstantVariable(self.size[0], **options) else: - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", len, (self.as_proxy(),), {}, current_tx=tx @@ -584,7 +304,7 @@ def call_method( self.ndim = args[0].ndim self.is_contiguous = (memory_format,) - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", @@ -604,8 +324,7 @@ def call_method( and not config.dynamic_shapes ): name = "new_empty" - - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", @@ -617,13 +336,23 @@ def call_method( ) -class DynamicShapeVariable(TensorVariable): +class DynamicShapeVariable(VariableTracker): """ Represents a symbolic size, e.g., as returned by tensor.size(0) """ + @classmethod + def create(cls, tx, proxy, dyn_shape, **options): + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == dyn_shape + if dyn_shape is None: + dyn_shape = get_fake_value(proxy.node, tx) + proxy.node.meta["example_value"] = dyn_shape + return DynamicShapeVariable(proxy, dyn_shape, **options) + def __init__(self, proxy, dyn_shape, **kwargs): - super(DynamicShapeVariable, self).__init__(proxy, **kwargs) + super(DynamicShapeVariable, self).__init__(**kwargs) + self.proxy = proxy self.dyn_shape = dyn_shape def python_type(self): @@ -632,6 +361,36 @@ def python_type(self): def unpack_var_sequence(self, tx): super(DynamicShapeVariable, self).unpack_var_sequence(tx) + def as_proxy(self): + return self.proxy + + def evaluate_expr(self, output_graph): + if not isinstance(self.dyn_shape, torch.SymInt): + return self.dyn_shape + return output_graph.shape_env.evaluate_expr(self.dyn_shape.get_pyobj().expr) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + list(args), kwargs), + current_tx=tx, + ), + **options, + ) + class TensorWithTFOverrideVariable(VariableTracker): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c55a64cff50c..0debfe9e9f3c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,4 +1,6 @@ import logging + +import math import re import types from typing import Dict, List @@ -170,7 +172,15 @@ def can_constant_fold_through(self): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": - from . import ConstantVariable, GradModeVariable, TensorVariable + from . import ( + ConstantVariable, + DynamicShapeVariable, + GradModeVariable, + TensorVariable, + ) + + # print("CALLING ON TORCH", self.value) + from .builder import wrap_fx_proxy constant_args = check_constant_args(args, kwargs) unspec_python_args = check_unspec_python_args(args, kwargs) @@ -302,7 +312,7 @@ def call_function( def get_state_from_generator(): return self.value() - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -338,7 +348,7 @@ def get_state_from_generator(): example_value = args[0].proxy.node.meta["example_value"] self.value.__module__ = self.__module__ - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -357,7 +367,7 @@ def get_state_from_generator(): ): # TODO(voz): This is rewritten as a call_method because # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", @@ -380,11 +390,21 @@ def get_state_from_generator(): if isinstance(x.value, numpy.generic): x.value = x.value.item() - tensor_variable = TensorVariable.create( + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any([isinstance(x, DynamicShapeVariable) for x in args]): + if self.value == math.sqrt: + from torch.fx.experimental.symbolic_shapes import sym_sqrt + + fn_ = sym_sqrt + + tensor_variable = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", - self.value, + fn_, *proxy_args_kwargs(args, kwargs), current_tx=tx, ), @@ -450,7 +470,9 @@ def _call_softmax(self, tx, args, kwargs, options): dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None)) def fake_softmax(input): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -502,7 +524,9 @@ def normalize_args( ) = normalize_args(*args, **kwargs) def fake_cross_entropy_loss(input, target): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -577,6 +601,7 @@ def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from . import ListVariable, TensorVariable, UserFunctionVariable + from .builder import wrap_fx_proxy assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet" @@ -688,7 +713,7 @@ def register_as_subgraph(fn, name, args): p_args[2] = false_node # Store the invocation as a call - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 14f5cd2de0a7..65f571f93ec0 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1081,7 +1081,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # clone will get called in Parameter deepcopy if func == torch._C._TensorBase.clone: - return func(self.fake_mode.from_tensor(args[0]), **kwargs) + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) elif func == torch.Tensor.__deepcopy__: assert len(args) == 2 and len(kwargs) == 0 tensor, memo = args @@ -1089,7 +1091,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if id(tensor) in memo: return memo[id(tensor)] - out = self.fake_mode.from_tensor(tensor) + out = self.fake_mode.from_tensor(tensor, static_shapes=True) memo[id(tensor)] = out return out else: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 9b55af3c555c..ae4427e2320e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -456,7 +456,6 @@ def create_symbolic_sizes_strides(self, ex: torch.Tensor): We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables. """ - size = [self.create_symbol(i) for i in ex.size()] stride: List[Optional[sympy.Expr]] = [None] * len(size) for i, val in enumerate(ex.stride()): From e950afc3958c9bae5d61cbc99bc088309141df6d Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 13 Nov 2022 08:19:45 +0000 Subject: [PATCH 110/453] [reland][dynamo] Better support for nn.Module (#88959) Relanding https://github.com/pytorch/pytorch/pull/88629 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88959 Approved by: https://github.com/msaroufim --- test/dynamo/test_modules.py | 127 +++++++++++++++++++++++++++++++++++ torch/_dynamo/__init__.py | 2 + torch/_dynamo/debug_utils.py | 8 +++ torch/_dynamo/eval_frame.py | 74 ++++++++++++++------ torch/_dynamo/testing.py | 14 ++++ 5 files changed, 205 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 2fb83b3add6c..930035f99a30 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,6 +904,133 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.relu(self.linear(x) + self.buf0) + + +class OptimizedModuleTest(torch._dynamo.test_case.TestCase): + def test_nn_module(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_to(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 1) + + opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + x = torch.randn(10, 10).to(dtype=torch.float64) + opt_mod(x) + # Ensure that there is a recompilation + self.assertEqual(cnt.frame_count, 2) + + def test_attr(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.r(torch.sin(x)) + self.buf0 + + mod = MockModule() + opt_mod = torch._dynamo.optimize("eager")(mod) + + # Check parameteres and buffers + for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): + self.assertTrue(id(p1) == id(p2)) + + def test_recursion(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + + for _ in range(5): + opt_mod = torch._dynamo.optimize(cnt)(opt_mod) + opt_mod(torch.randn(10, 10)) + self.assertEqual(cnt.frame_count, 1) + + def test_composition(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + opt_inner_mod = InnerModule() + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_composition_with_opt_mod(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + inner_mod = InnerModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + # There will be a graph break for the inner mod being OptimizedModule + self.assertEqual(cnt.frame_count, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 80f927aeef2f..5eee609b0852 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,6 +7,7 @@ export, optimize, optimize_assert, + OptimizedModule, reset_code, run, skip, @@ -25,6 +26,7 @@ "reset", "list_backends", "skip", + "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index f09991f9bf34..089ef172d625 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -486,8 +486,16 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ + from .eval_frame import OptimizedModule + from .testing import named_parameters_for_optimized_module from .utils import same + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8d9e3b7b6aa1..20e8c7de085e 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,6 +5,7 @@ import logging import os import sys +import textwrap import threading import traceback import types @@ -44,6 +45,27 @@ most_recent_backend = None +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + def __init__(self, mod): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def forward(self, *args, **kwargs): + # This will be monkey patched later + raise RuntimeError("Should not be here") + + def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -118,31 +140,15 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - optimized_forward = self(mod.forward) - - class TorchDynamoNNModuleWrapper: - """ - A wrapper that redirects the forward call to the optimized - forward, while for rest it redirects the calls to the original - module. - """ - - def __getattr__(self, name): - return getattr(mod, name) - - def forward(self, *args, **kwargs): - return optimized_forward(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - new_mod = TorchDynamoNNModuleWrapper() + new_mod = OptimizedModule(mod) + new_mod.forward = self(mod.forward) # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod + new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) + callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -184,6 +190,34 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please optimize the individual methods that you are interested in optimizing. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function + + >> mod.__call__ = torch._dynamo.optimize(mod.__call__) + + """ + ) + ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d6082ce48acf..6e0d32d21f97 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,6 +32,18 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def remove_optimized_module_prefix(name): + prefix = "_orig_mod." + assert name.startswith(prefix) + name = name[len(prefix) :] + return torch.distributed.fsdp._common_utils.clean_tensor_name(name) + + def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -44,6 +56,8 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 2b12bfce8800cfcc54222e913955914994bb4daf Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 13 Nov 2022 09:53:38 +0000 Subject: [PATCH 111/453] [dynamo] Skip frame when graph break in a loop (#88857) This fixes excessing recompilation issue in tacotron2 but has few caveats - https://github.com/pytorch/torchdynamo/issues/330 For tacotron2, the repro is something like this ~~~ def inner(x): return torch.sin(x) def fn(x): for _ in range(100): inner(x) torch._dynamo.graph_break() return x ~~~ The problem here is that Dynamo has guards on the TUPLE_ITERATOR_LEN whenever a graph break happens. Therefore, we keep on recompiling. This PR checks if there is a backedge (helps with while loop) in presence of a graph break. If there is, Dynamo skips processing this frame. Therefore, Dynamo gets called when inner is called, and we compile only once. Note that, if there was no graph break, we will unroll the original loop, and see one graph with 100 sin operations (just as before, so no changes there). The caveat is - We are skipping the frame, so if we have something like this ~~~ def fn(x): for _ in range(100): # 1000s of lines of PyTorch code torch._dynamo.graph_break() return x ~~~ Dynamo will skip processing this frame, and might miss on the optimization. Completely open for suggestions. Happy to re-implement if there is a better way to handle this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88857 Approved by: https://github.com/jansel, https://github.com/yanboliang --- test/dynamo/test_optimizers.py | 3 +- test/dynamo/test_repros.py | 55 +++++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 37 +++++++++++++++++---- 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 92b163b76d6d..2f204a7a1199 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import inspect -import sys import unittest import torch @@ -126,7 +125,7 @@ def training_iter_fn(batch, model, optimizer): batch = {"x": input1, "y": input2} for _ in range(2): opt_training_iter_fn(batch, net, optimizer) - self.assertEqual(cnts.frame_count, (2 if sys.version_info < (3, 8) else 6)) + self.assertEqual(cnts.frame_count, 2) if __name__ == "__main__": diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 6a1c654a4873..aa30affd5144 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1838,6 +1838,61 @@ def forward(self, inp): self.assertEqual(cnt.op_count, 5) self.assertEqual(cnt.frame_count, 1) + def test_for_loop_graph_break(self): + def inner(x): + return torch.sin(x) + + def fn(x): + for _ in range(100): + inner(x) + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + + def test_for_loop_graph_break_before(self): + # Checks that the backedge is calculated correctly + def inner(x): + return torch.sin(x) + + def fn(x): + torch._dynamo.graph_break() + for _ in range(100): + inner(x) + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 100) + + def test_while_loop_graph_break(self): + # Repro of tacotron2 cache_size_recompilation + def inner(x): + return torch.sin(x) + + def fn(x): + i = 20 + while i > 10: + x = inner(x) + i -= 1 + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 88e0df5470bc..d707bee930ee 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -133,6 +133,13 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() ): # compile a partial subgraph prefix then jump into user code + if self.has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop" + ) + log.debug(msg) + raise exc.SkipFrame(msg) + self.push(value) self.output.compile_subgraph( self, @@ -179,10 +186,15 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): reason = None try: return inner_fn(self, inst) - except Unsupported as exc: + except Unsupported as excp: + if self.has_backedge(): + msg = "Skipping frame because there is a graph break in a for/while loop" + log.debug(msg) + raise exc.SkipFrame(msg) + if not self.should_compile_partial_graph(): raise - user_stack = [self.frame_summary()] + list(reversed(exc.real_stack)) + user_stack = [self.frame_summary()] + list(reversed(excp.real_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack)) frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) # torch._dynamo.explain() formats this a little nicer, and presents a slightly @@ -193,12 +205,12 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): and graph_break_dup_warning_checker.add(frame_loc) ): log.warning( - f"Graph break: {exc} from user code at {user_stack_formatted}" + f"Graph break: {excp} from user code at {user_stack_formatted}" ) - exc.remove_from_stats() - exc.add_to_stats("graph_break") - reason = GraphCompileReason(exc.msg, user_stack) + excp.remove_from_stats() + excp.add_to_stats("graph_break") + reason = GraphCompileReason(excp.msg, user_stack) self.restore_graphstate(state) self.output.compile_subgraph(self, reason=reason) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) @@ -237,6 +249,19 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): class InstructionTranslatorBase(object): + def has_backedge(self): + cur_offset = self.current_instruction.offset + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ( + "JUMP_ABSOLUTE", + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + ): + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + def cell_and_freevars(self): if not hasattr(self, "_cell_and_freevars"): self._cell_and_freevars = tuple( From bca75fd2d36de72c2682b47d62eab01f6f897b75 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Sat, 12 Nov 2022 21:41:31 -0800 Subject: [PATCH 112/453] Move xnnpack taget to fb code base (#88909) 1. Move the source file list to the `build_variables.bzl`, as it's the source of truth for both internal buck build and oss build 2. Move target definitions to `fb` internal folder 3. Some changes are triggered from auto format. Differential Revision: [D40906961](https://our.internmc.facebook.com/intern/diff/D40906961/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D40906961/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/88909 Approved by: https://github.com/mcr229 --- build_variables.bzl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/build_variables.bzl b/build_variables.bzl index e476341b9ac0..473ed1c1de1b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1489,3 +1489,33 @@ aten_cuda_with_sort_by_key_source_list = [ aten_cuda_cu_with_sort_by_key_source_list = [ "aten/src/ATen/native/cuda/Unique.cu", ] + +# Followings are source code for xnnpack delegate + +xnnpack_delegate_serializer_header = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.h", +] + +xnnpack_delegate_serializer_source_list = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp", +] + +xnnpack_delegate_core_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", +] + +xnnpack_delegate_core_header = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h", + "torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h", +] + +xnnpack_backend_header = [ + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h", +] + xnnpack_delegate_core_header + +xnnpack_backend_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp", +] + xnnpack_delegate_core_source_list From 4284862db6e7c14494f27ef681036d909a5e8b67 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Sat, 12 Nov 2022 19:26:28 +0000 Subject: [PATCH 113/453] [Dynamo][FSDP] Migrate to `ModuleWrapPolicy` (#88453) Hello @wconstab! As you saw, `transformer_auto_wrap_policy()` is a misnomer and actually works for any module classes. The PR before this one tries to add a class `ModuleWrapPolicy` that takes in the `module_classes` in its constructor and works just like `transformer_auto_wrap_policy()` without requiring the `functools.partial()`. I hope you do not mind if we update the dynamo benchmarks util file with this migration. The PR before this one might require some back and forth within FSDP devs, so I apologize for any consequent updates to this PR, which in itself is an easy change. I will request review once we know the previous PR is good for land. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88453 Approved by: https://github.com/wconstab --- benchmarks/dynamo/dist_util.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index d30b5a63cfe5..d0267cbca307 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -13,7 +13,7 @@ CheckpointImpl, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy try: from .torchbench import setup_torchbench_cwd @@ -138,10 +138,7 @@ def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): "toy_model" if model.__class__ is ToyModel else args.torchbench_model ] if use_wrap_policy: - # transformer policy is really a generic policy that wraps modules of specified classes - wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=blocks - ) + wrap_policy = ModuleWrapPolicy(blocks) model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) if use_checkpointing: From 897d029a738c831448c0984bc0ab91544ca04545 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 13 Nov 2022 16:20:45 +0000 Subject: [PATCH 114/453] [reland][dynamo] fixes dict changed during runtime error (#88877) Reland https://github.com/pytorch/pytorch/pull/87526 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88877 Approved by: https://github.com/ezyang --- test/dynamo/test_aot_cudagraphs.py | 3 --- test/dynamo/test_repros.py | 30 ++++++++++++++++++++++++++++++ torch/_dynamo/convert_frame.py | 10 ++++++---- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index cb1d2a0e601f..fdb7c88762b8 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -71,7 +71,6 @@ def fn(x, y): y = torch.randn(3, device="cuda") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_dtoh(self): def model(x, y): @@ -105,7 +104,6 @@ def fn(x, y): y = torch.randn((), device="cpu") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch("functorch._src.config.use_functionalize", True) @patch_all(ok=False) # input mutation not supported yet def test_mutate_input(self): @@ -145,7 +143,6 @@ def fn(x, y): y = torch.randn(1, device="cuda") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_factory(self): def model(y): diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index aa30affd5144..fd0fcf9e08bc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1818,6 +1818,36 @@ def fn(x): res = opt_fn(a) self.assertTrue(same(ref, res)) + def test_tokenization(self): + from collections import UserDict + + class BatchEncoding(UserDict): + """ + Copied from tokenization + """ + + def __init__( + self, + data, + ): + super().__init__(data) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def tokenization(x): + encoding = BatchEncoding({"key": x}) + return encoding["key"] + + opt_fn = torch._dynamo.optimize("eager")(tokenization) + x = torch.rand((1, 4)) + ref = tokenization(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + def test_modules(self): class Foo(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index f1ce83727a19..c612fe3c167d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -156,7 +156,11 @@ def has_tensor(obj): seen_ids[obj_id] = any([has_tensor(v) for v in obj]) return seen_ids[obj_id] elif istype(obj, dict): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any([has_tensor(v) for v in values]) return seen_ids[obj_id] elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False @@ -164,9 +168,6 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) - return seen_ids[obj_id] else: # if config.debug: # print( @@ -302,6 +303,7 @@ def _convert_frame_assert(frame: types.FrameType, cache_size: int): # setattr could be tricky to handle generally, # but also not likely useful to compile- skip the whole frame return None + # Check if the frame is generated by an exec builtin call # TODO - Running exec generated frame seems propagates f_globals to the # next frames. From 98bcb4acb651378d7eaae7532d52f08939464c06 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 13 Nov 2022 16:21:12 +0000 Subject: [PATCH 115/453] Revert "[reland][dynamo] Better support for nn.Module (#88959)" This reverts commit e950afc3958c9bae5d61cbc99bc088309141df6d. Reverted https://github.com/pytorch/pytorch/pull/88959 on behalf of https://github.com/malfet due to Broke `test_accuracy_issue1` --- test/dynamo/test_modules.py | 127 ----------------------------------- torch/_dynamo/__init__.py | 2 - torch/_dynamo/debug_utils.py | 8 --- torch/_dynamo/eval_frame.py | 74 ++++++-------------- torch/_dynamo/testing.py | 14 ---- 5 files changed, 20 insertions(+), 205 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 930035f99a30..2fb83b3add6c 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,133 +904,6 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.relu(self.linear(x) + self.buf0) - - -class OptimizedModuleTest(torch._dynamo.test_case.TestCase): - def test_nn_module(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_to(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - x = torch.randn(10, 10) - self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - # Ensure that there is no recompilation - opt_mod(x) - self.assertEqual(cnt.frame_count, 1) - - opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) - self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) - x = torch.randn(10, 10).to(dtype=torch.float64) - opt_mod(x) - # Ensure that there is a recompilation - self.assertEqual(cnt.frame_count, 2) - - def test_attr(self): - class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(10, 10) - self.register_buffer("buf0", torch.randn(10, 10)) - - def forward(self, x): - return self.r(torch.sin(x)) + self.buf0 - - mod = MockModule() - opt_mod = torch._dynamo.optimize("eager")(mod) - - # Check parameteres and buffers - for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): - self.assertTrue(id(p1) == id(p2)) - - def test_recursion(self): - mod = MockModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_mod = torch._dynamo.optimize(cnt)(mod) - - for _ in range(5): - opt_mod = torch._dynamo.optimize(cnt)(opt_mod) - opt_mod(torch.randn(10, 10)) - self.assertEqual(cnt.frame_count, 1) - - def test_composition(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - opt_inner_mod = InnerModule() - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - self.assertEqual(cnt.frame_count, 1) - - def test_composition_with_opt_mod(self): - class InnerModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(torch.sin(x)) - - inner_mod = InnerModule() - cnt = torch._dynamo.testing.CompileCounter() - opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mod = opt_inner_mod - - def forward(self, x): - return self.mod(torch.cos(x)) - - outer_mod = OuterModule() - opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) - - x = torch.randn(4) - self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) - self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) - # There will be a graph break for the inner mod being OptimizedModule - self.assertEqual(cnt.frame_count, 2) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 5eee609b0852..80f927aeef2f 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,7 +7,6 @@ export, optimize, optimize_assert, - OptimizedModule, reset_code, run, skip, @@ -26,7 +25,6 @@ "reset", "list_backends", "skip", - "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 089ef172d625..f09991f9bf34 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ - from .eval_frame import OptimizedModule - from .testing import named_parameters_for_optimized_module from .utils import same - if isinstance(gm, OptimizedModule): - gm.named_parameters = named_parameters_for_optimized_module(gm) - - if isinstance(opt_gm, OptimizedModule): - opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) - ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 20e8c7de085e..8d9e3b7b6aa1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,7 +5,6 @@ import logging import os import sys -import textwrap import threading import traceback import types @@ -45,27 +44,6 @@ most_recent_backend = None -class OptimizedModule(torch.nn.Module): - """ - Wraps the original nn.Module object and later patches its - forward method to optimized self.forward method. - """ - - def __init__(self, mod): - super().__init__() - # Installs the params/buffer - self._orig_mod = mod - - def __getattr__(self, name): - if name == "_orig_mod": - return self._modules["_orig_mod"] - return getattr(self._orig_mod, name) - - def forward(self, *args, **kwargs): - # This will be monkey patched later - raise RuntimeError("Should not be here") - - def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -140,15 +118,31 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - new_mod = OptimizedModule(mod) - new_mod.forward = self(mod.forward) + optimized_forward = self(mod.forward) + + class TorchDynamoNNModuleWrapper: + """ + A wrapper that redirects the forward call to the optimized + forward, while for rest it redirects the calls to the original + module. + """ + + def __getattr__(self, name): + return getattr(mod, name) + + def forward(self, *args, **kwargs): + return optimized_forward(*args, **kwargs) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + new_mod = TorchDynamoNNModuleWrapper() # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod.forward + new_mod._torchdynamo_orig_callable = mod return new_mod assert callable(fn) - callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -190,34 +184,6 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): - if not hasattr(fn, "__code__"): - raise RuntimeError( - textwrap.dedent( - """ - - torch._dynamo.optimize is called on a non function object. - If this is a callable class, please optimize the individual methods that you are interested in optimizing. - - >> class CallableClass: - >> def __init__(self): - >> super().__init__() - >> self.relu = torch.nn.ReLU() - >> - >> def __call__(self, x): - >> return self.relu(torch.sin(x)) - >> - >> def print_hello(self): - >> print("Hello world") - >> - >> mod = CallableClass() - - If you want to optimize the __call__ function - - >> mod.__call__ = torch._dynamo.optimize(mod.__call__) - - """ - ) - ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 6e0d32d21f97..d6082ce48acf 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,18 +32,6 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) -def named_parameters_for_optimized_module(mod): - assert isinstance(mod, eval_frame.OptimizedModule) - return mod._orig_mod.named_parameters - - -def remove_optimized_module_prefix(name): - prefix = "_orig_mod." - assert name.startswith(prefix) - name = name[len(prefix) :] - return torch.distributed.fsdp._common_utils.clean_tensor_name(name) - - def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -56,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): - if isinstance(model, eval_frame.OptimizedModule): - name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From 52be0c42abfcf566e730d927b6a3e90e4380017b Mon Sep 17 00:00:00 2001 From: anjali411 Date: Sun, 13 Nov 2022 15:56:16 +0000 Subject: [PATCH 116/453] meta function for max_pool2d_with_indices_backward (#88743) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88743 Approved by: https://github.com/lezcano, https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 3 + test/inductor/test_torchinductor.py | 3 + test/test_proxy_tensor.py | 2 +- torch/_meta_registrations.py | 56 ++++++++++++++++--- .../_internal/common_methods_invocations.py | 35 ++++++++++++ 5 files changed, 90 insertions(+), 9 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index f4782b8a595d..ea00842a4e00 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -973,6 +973,9 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cholesky'), xfail('linalg.cholesky'), + # Given input size: (s0xs1x2). Calculated output size: ... + skip('max_pool2d_with_indices_backward'), + # Misc xfail('to_sparse'), xfail('corrcoef'), diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index aea8013bdfac..d331559a3a8b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4787,6 +4787,9 @@ def forward(self, x): for param in model_opt.parameters(): param.add_(1.0) + # Probably fails due to the symint math issue caught while adding + # max_pool2d_with_indices_backward + @unittest.skip("Accuracy failure, needs debugging") def test_accuracy_issue1(self): class Repro(torch.nn.Module): def __init__(self): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 86beb651cb2d..42ecc3d376ab 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -721,7 +721,6 @@ def deco(cls): @xfail_inherited_tests([ "test_mode_tracing_factory_function", "test_make_fx_overloads", - "test_resnet18_backward_trace", "test_trace_subclasses", ]) class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): @@ -1229,6 +1228,7 @@ def f(a, b, c, d, e): xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('max_pool2d_with_indices_backward', ''), # (symint math failure) Given input size: (s0xs1x2). Calculated ... xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5d583de67d19..be7370e344f0 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1349,9 +1349,8 @@ def pool2d_shape_check( ) -@register_meta(aten.max_pool2d_with_indices.default) -def meta_max_pool2d_with_indices( - input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +def max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): @@ -1376,6 +1375,9 @@ def unpack(name, val): padH, padW = unpack("padding", padding) dilationH, dilationW = unpack("dilation", dilation) + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: @@ -1394,11 +1396,6 @@ def unpack(name, val): lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) - nbatch = input.size(-4) if input.dim() == 4 else 1 - nInputPlane = input.size(-3) - inputHeight = input.size(-2) - inputWidth = input.size(-1) - outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) @@ -1420,6 +1417,49 @@ def unpack(name, val): memory_format, ) + return nInputPlane, outputHeight, outputWidth + + +@register_meta(aten.max_pool2d_with_indices_backward.default) +def meta_max_pool2d_with_indices_backward( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + check( + self.dtype == grad_output.dtype, + lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", + ) + + nOutputPlane = nInputPlane + ndim = self.ndim + + def _check_dim_size(t): + check_dim_size(t, ndim, ndim - 3, nOutputPlane) + check_dim_size(t, ndim, ndim - 2, outputHeight) + check_dim_size(t, ndim, ndim - 1, outputWidth) + + _check_dim_size(grad_output) + _check_dim_size(indices) + + memory_format = utils.suggest_memory_format(self) + return torch.empty( + self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format + ) + + +@register_meta(aten.max_pool2d_with_indices.default) +def meta_max_pool2d_with_indices( + input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = utils.suggest_memory_format(input) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8ab1ea8a047c..441bc7adcf83 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2925,6 +2925,7 @@ def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): 'nn.functional.max_pool1d': _TestParamsMaxPool1d, 'nn.functional.max_pool2d': _TestParamsMaxPool2d, 'nn.functional.max_pool3d': _TestParamsMaxPool3d, + 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, } params_generator = params_generator_type_dict[op_info.name]() @@ -2932,6 +2933,15 @@ def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) yield SampleInput(arg, kwargs=kwargs) +def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): + out, indices = torch.nn.functional.max_pool2d_with_indices( + *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) + grad_out = torch.ones_like(out) + if stride is None: + stride = kernel_size + out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( + grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) + return out_b def error_inputs_max_pool1d(op_info, device, **kwargs): # Toggle requires_grad because `max_pool1d` has different path @@ -11567,6 +11577,31 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), error_inputs_func=error_inputs_max_pool2d, sample_inputs_func=sample_inputs_max_pool), + OpInfo('max_pool2d_with_indices_backward', + op=max_pool2d_backward, + # We've defined a custom op, so there's no corresponding aten op + aten_name=None, + method_variant=None, + inplace_variant=None, + operator_variant=None, + inplace_operator_variant=None, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_pool, + skips=( + # We've defined a custom op here, and we don't handle the case where we receive an out kwarg + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') + )), OpInfo('nn.functional.max_pool3d', aten_name='max_pool3d', # Runs very slowly on slow gradcheck - alternatively reduce input sizes From 8f7e519f12d165c06ea3e20b994c2d3c5c44af2c Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sun, 13 Nov 2022 19:42:42 +0000 Subject: [PATCH 117/453] Skip dynamo benchmark tests under TSAN (#88895) Summary: Fixes T137546804 Test Plan: ``` buck2 test mode/opt-tsan //caffe2/benchmarks/dynamo:test buck2 test mode/opt //caffe2/benchmarks/dynamo:test ``` Differential Revision: D41226384 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88895 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/benchmarks/dynamo/test.py b/benchmarks/dynamo/test.py index 317e8e4ea50e..438218462030 100644 --- a/benchmarks/dynamo/test.py +++ b/benchmarks/dynamo/test.py @@ -5,8 +5,17 @@ from .torchbench import setup_torchbench_cwd, TorchBenchmarkRunner +try: + # fbcode only + from aiplatform.utils.sanitizer_status import is_asan_or_tsan +except ImportError: + + def is_asan_or_tsan(): + return False + class TestDynamoBenchmark(unittest.TestCase): + @unittest.skipIf(is_asan_or_tsan(), "ASAN/TSAN not supported") def test_benchmark_infra_runs(self) -> None: """ Basic smoke test that TorchBench runs. From 76af71444a43962ee3e1cef987ac2028f2b8f44d Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Sat, 12 Nov 2022 20:06:12 +0100 Subject: [PATCH 118/453] [primTorch] Add ref for `complex` (#88562) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88562 Approved by: https://github.com/ezyang --- torch/_prims/context.py | 3 +- torch/_refs/__init__.py | 1 - torch/_refs/_conversions.py | 45 ++++++++++++++++++- .../_internal/common_methods_invocations.py | 34 ++++++++++++++ 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 203d73fd948e..b9f6e634bb49 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -68,7 +68,8 @@ def torch_to_refs_map(): # Support conversions for s in torch._refs._conversions.__all__: - r[getattr(torch.Tensor, s)] = torch._refs._conversions.__dict__.get(s) + tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) + r[tensor_attr] = torch._refs._conversions.__dict__.get(s) return r diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 70edbff2237f..a1de9a438d77 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -122,7 +122,6 @@ "bitwise_right_shift", "bitwise_xor", "clamp_min", - # "complex", "copysign", "div", "eq", diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index 11657f7058bd..abcd5729818d 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -1,6 +1,12 @@ import torch +import torch._prims_common as utils -from torch._prims_common import TensorLikeType +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +from torch._prims_common import check, TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes # Data conversion references. # @@ -10,6 +16,7 @@ # (like int). __all__ = [ + # dtypes "bfloat16", "bool", "byte", @@ -23,6 +30,8 @@ "int", "long", "short", + # misc + "complex", ] @@ -61,3 +70,37 @@ def fn( long = _make_conversion_method("long", torch.long) short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 441bc7adcf83..62c9b4750ae9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5239,6 +5239,28 @@ def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): sizes = ((S, S), ()) return (SampleInput(make_arg(size)) for size in sizes) +def error_inputs_complex(op_info, device, is_ref=False, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + + if is_ref: + error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" + error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" + error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" + else: + error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" + error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument" + error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'" + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), + error_type=RuntimeError, error_regex=error_float) + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64), + out=make_arg(M, S, dtype=torch.complex64)), + error_type=RuntimeError, error_regex=error_out) + def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): def make_arg(shape): # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck @@ -9097,6 +9119,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_rhs_python_scalar=False, + error_inputs_func=error_inputs_complex, skips=( # Test doesn't account for complex's type promotion semantics DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), @@ -17933,6 +17956,17 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), ) ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.complex", + torch_opinfo_name="complex", + error_inputs_func=partial(error_inputs_complex, is_ref=True), + # prims.empty_strided.default does not support nvfuser + supports_nvfuser=False, + skips=( + # Test doesn't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ) + ), ElementwiseUnaryPythonRefInfo( "_refs._conversions.double", torch_opinfo_name="double", From 9eabcc370f4c3a04be85cb1f878038f10716bdc3 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Sun, 13 Nov 2022 06:06:24 +0000 Subject: [PATCH 119/453] Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88761 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 4 - test/functorch/test_ops.py | 4 + test/functorch/test_vmap.py | 3 + test/inductor/test_torchinductor_opinfo.py | 1 + test/test_decomp.py | 3 + test/test_proxy_tensor.py | 22 +++-- torch/_decomp/decompositions.py | 98 +++++++++++++++---- .../_internal/common_methods_invocations.py | 94 ++++++++++++++---- 8 files changed, 177 insertions(+), 52 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ea00842a4e00..e0ffcbe7d97d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1005,7 +1005,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cholesky_inverse', ''), # could not find kernel xfail('cholesky_solve', ''), # could not find kernel - xfail('chunk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition @@ -1139,7 +1138,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta... xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'bicubic'), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.interpolate', 'bilinear'), # Cannot call sizes() on tensor with symbolic sizes/str... xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'nearest'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... @@ -1166,7 +1164,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1199,7 +1196,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ... - xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 74085941c6c8..85ac70d74825 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1052,6 +1052,7 @@ def test(): xfail('segment_reduce', 'lengths'), xfail('sparse.sampled_addmm', ''), xfail("native_batch_norm"), + xfail("native_dropout_backward"), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: @@ -1216,6 +1217,8 @@ def get_vjp(cotangents, *primals): xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce xfail('index_reduce', ''), # NYI: forward-AD for index_reduce xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce + xfail('native_dropout_backward'), # NYI + })) @opsToleranceOverride('TestOperators', 'test_jvpvjp', ( tol1('masked.prod', @@ -1372,6 +1375,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): # input while the running_mean or running_var, which will be updated in # place, were not batched. xfail("native_batch_norm"), + xfail('native_dropout_backward',) })) @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6d95077b627e..9726b7feedb7 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3238,6 +3238,7 @@ def test(): xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work + skip('_softmax_backward_data'), skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format # ---------------------------------------------------------------------- @@ -3379,6 +3380,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('bernoulli', ''), xfail('linalg.lu_factor', ''), xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('native_dropout_backward'), xfail('nn.functional.kl_div', ''), xfail('multinomial', ''), xfail('column_stack', ''), @@ -3452,6 +3454,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('equal', ''), xfail('linalg.lu', ''), skip('linalg.ldl_solve', ''), + skip('_softmax_backward_data'), })) def test_op_has_batch_rule(self, device, dtype, op): # needs to be fixed diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 3d384efea0ae..89ea42c9fea7 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -425,6 +425,7 @@ def wrapper_set_seed(op, *args, **kwargs): "randn": {"assert_equal": False}, ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, + ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, "gradient": {"check_gradient": False}, # segfault on check_gradient # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error "linalg.solve_triangular": {"check_gradient": False}, diff --git a/test/test_decomp.py b/test/test_decomp.py index 67e99d5eb829..a3658792c5e7 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -294,6 +294,9 @@ def normalize_op_input_output(f, sample, requires_grad=True): (None, None, "meshgrid"), # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) (None, None, "diag"), + + # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 + ("cpu", torch.bfloat16, "_softmax_backward_data"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 42ecc3d376ab..894b35693430 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1124,7 +1124,6 @@ def f(a, b, c, d, e): xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... - xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel xfail('combinations', ''), xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... @@ -1247,7 +1246,6 @@ def f(a, b, c, d, e): xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... - xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function... xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... @@ -1267,7 +1265,6 @@ def f(a, b, c, d, e): xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de... xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1313,7 +1310,6 @@ def f(a, b, c, d, e): xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... - xfail('split', ''), # 'torch._C.SymIntNode' and 'int' xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1439,10 +1435,13 @@ def _fn(t, *args, **kwargs): return _fn def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): - def f(args, kwargs, extra_args): + def f(args, kwargs, extra_args, extra_kwargs): if extra_args: for i, t in extra_args: args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op return fn(*args, **kwargs) @@ -1463,23 +1462,26 @@ def f(args, kwargs, extra_args): # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in # symbolic mode, a no-op otherwise) extra_args = [] + extra_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, torch.Size): - extra_args.append((i, torch.empty((), device="cpu").expand(arg))) - # TODO: support kwargs + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") try: - new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args) + new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) except DynamicOutputShapeException as e: self.skipTest("Dynamic output shape operation in trace") for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: - old_out = f(args, kwargs, extra_args) + old_out = f(args, kwargs, extra_args, extra_kwargs) except Exception: continue - new_out = wrapper_set_seed(new_f, args, kwargs, extra_args) + new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) self.assertEqual(new_out, old_out) class TestProxyTensorOpInfo(TestCase): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 1a2d332e99fd..7c84cb7e2ca8 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4,7 +4,7 @@ from enum import Enum from functools import partial, reduce from itertools import product -from typing import Callable, cast, Iterable, List, Optional, Tuple +from typing import Callable, cast, Iterable, List, Optional, Tuple, Union import torch import torch._prims_common as utils @@ -13,6 +13,7 @@ from torch._decomp import register_decomposition from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper +from torch.fx.experimental.symbolic_shapes import guard_int, sym_float, sym_int from torch.utils._pytree import tree_flatten, tree_map DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] @@ -696,7 +697,12 @@ def _softmax_backward_data( grad_input = new_grad_output - output * torch.sum( new_grad_output, dim=dim, keepdim=True ) - return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() @register_decomposition(aten._log_softmax_backward_data) @@ -912,9 +918,17 @@ def check_positive(param, param_name, strict=True): @register_decomposition(aten.native_dropout_backward) -@pw_cast_for_opmath def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): - return grad_output * (mask.type_as(grad_output) * scale) + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r @register_decomposition(aten.unfold_backward) @@ -1095,8 +1109,9 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: assert dim_size == 0 return [self] chunks = (dim_size + split_size - 1) // split_size + chunks = guard_int(chunks) split_sizes = [split_size for i in range(chunks)] - split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + split_sizes[-1] = split_size - (split_size * chunks - dim_size) return torch.split(self, split_sizes, dim) @@ -1786,29 +1801,74 @@ def norm( return torch.linalg.vector_norm(self, p, dim, keepdim, dtype=dtype) +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + utils.check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + utils.check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(scale_factors) == spatial_dimensions, lambda: "") + return [ + # Returning output_size as float. We cannot convert it to int directly, + # as latter computation of scale_factor is relying output size being float + sym_float(input_size[i + 2] * scale_factors[i]) + for i in range(spatial_dimensions) + ] + utils.check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + @register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) -@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec, type="pre_autograd") +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + + # NB: osize could be a list of float when scale_factors is float + # so we cannot redispatch to aten.upsample_bilinear2d.default here + return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w) + + +@register_decomposition(torch.ops.aten.upsample_bilinear2d.default) +@torch.ops.aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) @pw_cast_for_opmath -def upsample_bilinear2d_vec( +def upsample_bilinear2d( input: Tensor, - output_size: Optional[List[int]], + output_size: List[Union[int, float]], align_corners: bool, - scale_factors: Optional[List[float]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, ) -> Tensor: # get dimensions of original image n_batch, n_channels, in_h, in_w = input.shape - if output_size is not None: - out_h = float(output_size[0]) - out_w = float(output_size[1]) - elif scale_factors is not None: - out_h = in_h * scale_factors[0] - out_w = in_w * scale_factors[1] + out_h = sym_float(output_size[0]) + out_w = sym_float(output_size[1]) # Calculate horizontal and vertical scaling factor + # TODO: Figure out if scales_h/scales_w matters here if out_h > 1: if align_corners: - h_scale_factor = (in_h - 1) / (int(out_h) - 1) + h_scale_factor = (in_h - 1) / (sym_int(out_h) - 1) else: h_scale_factor = in_h / out_h else: @@ -1816,14 +1876,14 @@ def upsample_bilinear2d_vec( if out_w > 1: if align_corners: - w_scale_factor = (in_w - 1) / (int(out_w) - 1) + w_scale_factor = (in_w - 1) / (sym_int(out_w) - 1) else: w_scale_factor = in_w / out_w else: w_scale_factor = 0.0 - i = torch.arange(int(out_h), dtype=input.dtype, device=input.device) - j = torch.arange(int(out_w), dtype=input.dtype, device=input.device) + i = torch.arange(sym_int(out_h), dtype=input.dtype, device=input.device) + j = torch.arange(sym_int(out_w), dtype=input.dtype, device=input.device) if align_corners: x = h_scale_factor * i diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 62c9b4750ae9..8a7968cf57d2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -406,6 +406,21 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): # running_mean and running_var are required in evaluation mode (training: False) but not in training mode yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) +def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), 0), + ((S, S), 0), + ((S, M, S), -1), + ] + input_dtypes = [dtype] + if dtype == torch.float and device == 'cuda': + input_dtypes += [torch.float16] + + for (shape, dim), input_dtype in product(cases, input_dtypes): + yield SampleInput(make_arg(shape), make_arg(shape), dim, input_dtype) def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) @@ -1173,7 +1188,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): cases = ((), (S, S, S), (S,)) for shape in cases: - yield(SampleInput(make_arg(shape))) + yield SampleInput(make_arg(shape)) # TODO: add reduction kwargs def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): @@ -3745,8 +3760,8 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): def shape(size, rank, with_batch_channel=True): if with_batch_channel: - return tuple([N, C] + ([size] * rank)) - return tuple([size] * rank) + return torch.Size([N, C] + ([size] * rank)) + return torch.Size([size] * rank) make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1) @@ -5794,9 +5809,9 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals if list_args: cases = ( - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), - ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),), - ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], -2),) + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) ) else: cases = ( # type: ignore[assignment] @@ -5811,10 +5826,10 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - cases = (((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), - ((S, S, S), ([int(S / 3), S - int(S / 3), 0],)), - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], 2)), - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], -2)), + cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), ) for shape, args in cases: @@ -6190,7 +6205,7 @@ def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): else: raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") - yield(SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)) + yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6446,7 +6461,7 @@ def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield(SampleInput(make_arg(shape), args=(args, ))) + yield SampleInput(make_arg(shape), args=(args,)) def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6469,8 +6484,8 @@ def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): ) for shape, shape_other in cases: - yield(SampleInput(make_arg(shape, requires_grad=requires_grad), - args=(make_arg(shape_other, requires_grad=False), ))) + yield SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False),)) def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): @@ -6588,8 +6603,8 @@ def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): inputs.append(mixed) for input_t, as_tuple in product(inputs, [False, True]): - yield(SampleInput(input_t.clone().requires_grad_(requires_grad), - kwargs=dict(as_tuple=as_tuple))) + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple)) def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6600,7 +6615,7 @@ def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield(SampleInput(make_arg(shape), args=args)) + yield SampleInput(make_arg(shape), args=args) def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) @@ -6678,6 +6693,15 @@ def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, yield SampleInput(make_arg(case), p=p, training=training) yield SampleInput(make_arg(case)) +def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) + + cases = ((S, S, S, S), (S,), ()) + scale_vals = [0.0, 1.0, 2.0] + + for case, scale in product(cases, scale_vals): + yield SampleInput(make_arg(case), make_mask(case), scale) def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): def make_input(shape): @@ -8095,7 +8119,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): in_shape = input.shape in_rank = len(in_shape) for d in start_dim, end_dim: - if not((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank-1}], but got {d}") end_dim = end_dim if end_dim >= 0 else in_rank + end_dim start_dim = start_dim if start_dim >= 0 else in_rank + start_dim @@ -8424,7 +8448,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): variant_test_name='decomposed', dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, - *[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []), + *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10554,6 +10578,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=True), + OpInfo( + '_softmax_backward_data', + op=torch.ops.aten._softmax_backward_data, + aten_name='_softmax_backward_data', + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_softmax_backward_data, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + ), # `softmin` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. # https://github.com/pytorch/pytorch/issues/68752 @@ -15927,6 +15967,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_dropout, inplace_variant=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "native_dropout_backward", + op=torch.ops.aten.native_dropout_backward.default, + aten_name="native_dropout_backward", + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_dropout_backward, + skips=( + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # Lazy tensor failures + DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + ), + ), OpInfo( "nn.functional.dropout2d", op=lambda input, *args, **kwargs: From 48dc24ddceb5d048ceb38f00f6d4ec0cfc3e71d0 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sun, 13 Nov 2022 22:05:41 +0000 Subject: [PATCH 120/453] Fix: [ATen] Add some missing moves (#88514) Related to #88512 , but for ATen. This should reduce a number of copies and inefficient atomic smart pointer increments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88514 Approved by: https://github.com/jgong5, https://github.com/ezyang --- aten/src/ATen/InferSize.h | 2 +- aten/src/ATen/core/Formatting.cpp | 4 ++-- aten/src/ATen/core/Formatting.h | 4 ++-- aten/src/ATen/native/TensorShape.cpp | 5 +++-- aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp | 3 ++- aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp | 3 ++- c10/core/Storage.h | 2 +- c10/core/StorageImpl.h | 2 +- c10/core/WrapDimMinimal.cpp | 3 ++- c10/core/WrapDimMinimal.h | 2 +- 10 files changed, 17 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 594b87373a20..111c7eb8f5fc 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -80,7 +80,7 @@ inline at::SymDimVector infer_size_dv( c10::SymInt numel) { auto res = at::SymDimVector(shape); infer_size_impl( - shape, numel, res); + shape, std::move(numel), res); return res; } diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index 875b9ef3d042..4537adff5aa4 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream & out, Backend b) { return out << toString(b); } -std::ostream& operator<<(std::ostream & out, Scalar s) { +std::ostream& operator<<(std::ostream & out, const Scalar& s) { if (s.isFloatingPoint()) { return out << s.toDouble(); } @@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream & out, Scalar s) { throw std::logic_error("Unknown type in Scalar"); } -std::string toString(Scalar s) { +std::string toString(const Scalar& s) { std::stringstream out; out << s; return out.str(); diff --git a/aten/src/ATen/core/Formatting.h b/aten/src/ATen/core/Formatting.h index 6dcfc6c7b3cd..9dcd14e1902e 100644 --- a/aten/src/ATen/core/Formatting.h +++ b/aten/src/ATen/core/Formatting.h @@ -8,8 +8,8 @@ namespace c10 { TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); -TORCH_API std::ostream& operator<<(std::ostream & out, Scalar s); -TORCH_API std::string toString(Scalar s); +TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s); +TORCH_API std::string toString(const Scalar& s); } namespace at { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index e8c87a2f1f5c..ccaf4b464252 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -204,10 +204,11 @@ #include #endif +#include #include #include +#include #include -#include namespace at { namespace meta { @@ -416,7 +417,7 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st const auto itemsize = result.dtype().itemsize(); c10::SymInt size_bytes = at::detail::computeStorageNbytes( size, stride, itemsize, storage_offset); - storage.set_nbytes(size_bytes); + storage.set_nbytes(std::move(size_bytes)); } return result; } diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 2250e84ad7a6..9d2f1a96c31b 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -1,4 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include @@ -444,7 +445,7 @@ c10::intrusive_ptr> PackedConvWeightsOnednn< exp_wgt.init(w_desc); exp_wgt.set_scale(wgt_scales); // Also for feed_from() exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index dda600e9b41c..36523bbd1b9b 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -23,6 +23,7 @@ #include #include +#include #include int register_linear_params(); @@ -249,7 +250,7 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( dnnl::memory::data_type::u8); ideep::tensor exp_wgt(w_desc); exp_wgt.feed_from(wgt); - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); diff --git a/c10/core/Storage.h b/c10/core/Storage.h index a89a0039fdfe..09c5920b5649 100644 --- a/c10/core/Storage.h +++ b/c10/core/Storage.h @@ -76,7 +76,7 @@ struct C10_API Storage { } void set_nbytes(c10::SymInt size_bytes) const { - storage_impl_.get()->set_nbytes(size_bytes); + storage_impl_.get()->set_nbytes(std::move(size_bytes)); } bool resizable() const { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index bbf080384253..1d80daed871a 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -112,7 +112,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } void set_nbytes(c10::SymInt size_bytes) { - size_bytes_ = size_bytes; + size_bytes_ = std::move(size_bytes); } bool resizable() const { diff --git a/c10/core/WrapDimMinimal.cpp b/c10/core/WrapDimMinimal.cpp index 6703f0638901..2375dc3ac5cf 100644 --- a/c10/core/WrapDimMinimal.cpp +++ b/c10/core/WrapDimMinimal.cpp @@ -14,7 +14,8 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { "Dimension specified as ", dim, " but tensor has no dimensions"); - return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false); + return c10::maybe_wrap_dim( + std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); } T min = dim_post_expr * -1; diff --git a/c10/core/WrapDimMinimal.h b/c10/core/WrapDimMinimal.h index 0f5949f65082..dda01fbe18f0 100644 --- a/c10/core/WrapDimMinimal.h +++ b/c10/core/WrapDimMinimal.h @@ -38,7 +38,7 @@ inline c10::SymInt maybe_wrap_dim( c10::SymInt dim, c10::SymInt dim_post_expr, bool wrap_scalar = true) { - return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); + return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); } } // namespace c10 From eea506aee12371a1fbde271c99fb30a8537d1db7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Nov 2022 01:58:47 +0000 Subject: [PATCH 121/453] Revert "Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761)" This reverts commit 9eabcc370f4c3a04be85cb1f878038f10716bdc3. Reverted https://github.com/pytorch/pytorch/pull/88761 on behalf of https://github.com/suo due to much broken https://hud.pytorch.org/pytorch/pytorch/commit/9eabcc370f4c3a04be85cb1f878038f10716bdc3 --- test/functorch/test_aotdispatch.py | 4 + test/functorch/test_ops.py | 4 - test/functorch/test_vmap.py | 3 - test/inductor/test_torchinductor_opinfo.py | 1 - test/test_decomp.py | 3 - test/test_proxy_tensor.py | 22 ++--- torch/_decomp/decompositions.py | 98 ++++--------------- .../_internal/common_methods_invocations.py | 94 ++++-------------- 8 files changed, 52 insertions(+), 177 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e0ffcbe7d97d..ea00842a4e00 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1005,6 +1005,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cholesky_inverse', ''), # could not find kernel xfail('cholesky_solve', ''), # could not find kernel + xfail('chunk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition @@ -1138,6 +1139,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta... xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'bicubic'), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('nn.functional.interpolate', 'bilinear'), # Cannot call sizes() on tensor with symbolic sizes/str... xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'nearest'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... @@ -1164,6 +1166,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1196,6 +1199,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ... + xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 85ac70d74825..74085941c6c8 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1052,7 +1052,6 @@ def test(): xfail('segment_reduce', 'lengths'), xfail('sparse.sampled_addmm', ''), xfail("native_batch_norm"), - xfail("native_dropout_backward"), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: @@ -1217,8 +1216,6 @@ def get_vjp(cotangents, *primals): xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce xfail('index_reduce', ''), # NYI: forward-AD for index_reduce xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce - xfail('native_dropout_backward'), # NYI - })) @opsToleranceOverride('TestOperators', 'test_jvpvjp', ( tol1('masked.prod', @@ -1375,7 +1372,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): # input while the running_mean or running_var, which will be updated in # place, were not batched. xfail("native_batch_norm"), - xfail('native_dropout_backward',) })) @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 9726b7feedb7..6d95077b627e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3238,7 +3238,6 @@ def test(): xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work - skip('_softmax_backward_data'), skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format # ---------------------------------------------------------------------- @@ -3380,7 +3379,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('bernoulli', ''), xfail('linalg.lu_factor', ''), xfail('nn.functional.feature_alpha_dropout', 'with_train'), - xfail('native_dropout_backward'), xfail('nn.functional.kl_div', ''), xfail('multinomial', ''), xfail('column_stack', ''), @@ -3454,7 +3452,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('equal', ''), xfail('linalg.lu', ''), skip('linalg.ldl_solve', ''), - skip('_softmax_backward_data'), })) def test_op_has_batch_rule(self, device, dtype, op): # needs to be fixed diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 89ea42c9fea7..3d384efea0ae 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -425,7 +425,6 @@ def wrapper_set_seed(op, *args, **kwargs): "randn": {"assert_equal": False}, ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, - ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, "gradient": {"check_gradient": False}, # segfault on check_gradient # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error "linalg.solve_triangular": {"check_gradient": False}, diff --git a/test/test_decomp.py b/test/test_decomp.py index a3658792c5e7..67e99d5eb829 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -294,9 +294,6 @@ def normalize_op_input_output(f, sample, requires_grad=True): (None, None, "meshgrid"), # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) (None, None, "diag"), - - # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 - ("cpu", torch.bfloat16, "_softmax_backward_data"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 894b35693430..42ecc3d376ab 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1124,6 +1124,7 @@ def f(a, b, c, d, e): xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... + xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel xfail('combinations', ''), xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... @@ -1246,6 +1247,7 @@ def f(a, b, c, d, e): xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... + xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function... xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... @@ -1265,6 +1267,7 @@ def f(a, b, c, d, e): xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de... xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1310,6 +1313,7 @@ def f(a, b, c, d, e): xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... + xfail('split', ''), # 'torch._C.SymIntNode' and 'int' xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1435,13 +1439,10 @@ def _fn(t, *args, **kwargs): return _fn def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): - def f(args, kwargs, extra_args, extra_kwargs): + def f(args, kwargs, extra_args): if extra_args: for i, t in extra_args: args[i] = t.size() - if extra_kwargs: - for k, t in extra_kwargs.items(): - kwargs[k] = t.size() fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op return fn(*args, **kwargs) @@ -1462,26 +1463,23 @@ def f(args, kwargs, extra_args, extra_kwargs): # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in # symbolic mode, a no-op otherwise) extra_args = [] - extra_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, torch.Size): - extra_args.append((i, torch.empty(arg, device="cpu"))) - for key, value in kwargs.items(): - if isinstance(value, torch.Size): - extra_kwargs[key] = torch.empty(value, device="cpu") + extra_args.append((i, torch.empty((), device="cpu").expand(arg))) + # TODO: support kwargs try: - new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) + new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args) except DynamicOutputShapeException as e: self.skipTest("Dynamic output shape operation in trace") for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: - old_out = f(args, kwargs, extra_args, extra_kwargs) + old_out = f(args, kwargs, extra_args) except Exception: continue - new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) + new_out = wrapper_set_seed(new_f, args, kwargs, extra_args) self.assertEqual(new_out, old_out) class TestProxyTensorOpInfo(TestCase): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7c84cb7e2ca8..1a2d332e99fd 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4,7 +4,7 @@ from enum import Enum from functools import partial, reduce from itertools import product -from typing import Callable, cast, Iterable, List, Optional, Tuple, Union +from typing import Callable, cast, Iterable, List, Optional, Tuple import torch import torch._prims_common as utils @@ -13,7 +13,6 @@ from torch._decomp import register_decomposition from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper -from torch.fx.experimental.symbolic_shapes import guard_int, sym_float, sym_int from torch.utils._pytree import tree_flatten, tree_map DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] @@ -697,12 +696,7 @@ def _softmax_backward_data( grad_input = new_grad_output - output * torch.sum( new_grad_output, dim=dim, keepdim=True ) - - # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor - # if grad_output.device == torch.device("cpu"): - # return grad_input.contiguous() - - return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) @register_decomposition(aten._log_softmax_backward_data) @@ -918,17 +912,9 @@ def check_positive(param, param_name, strict=True): @register_decomposition(aten.native_dropout_backward) +@pw_cast_for_opmath def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): - # According to the CUDA kernel implementation we should have this test; - # but it seems to fail tests! - # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") - - # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format - # This different from TensorIterator's behavior - r = (grad_output * (mask.type_as(grad_output) * scale)).clone( - memory_format=utils.suggest_memory_format(grad_output) - ) - return r + return grad_output * (mask.type_as(grad_output) * scale) @register_decomposition(aten.unfold_backward) @@ -1109,9 +1095,8 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: assert dim_size == 0 return [self] chunks = (dim_size + split_size - 1) // split_size - chunks = guard_int(chunks) split_sizes = [split_size for i in range(chunks)] - split_sizes[-1] = split_size - (split_size * chunks - dim_size) + split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) return torch.split(self, split_sizes, dim) @@ -1801,74 +1786,29 @@ def norm( return torch.linalg.vector_norm(self, p, dim, keepdim, dtype=dtype) -# aten/src/ATen/native/UpSample.cpp compute_output_size -def upsample_compute_output_size(input_size, output_size, scale_factors): - spatial_dimensions = len(input_size) - 2 - if output_size is not None: - utils.check( - scale_factors is None, - lambda: "Must specify exactly one of output_size and scale_factors", - ) - utils.check(len(output_size) == spatial_dimensions, lambda: "") - return output_size - if scale_factors is not None: - # NB: this isn't necessary lol - utils.check( - output_size is None, - lambda: "Must specify exactly one of output_size and scale_factors", - ) - utils.check(len(scale_factors) == spatial_dimensions, lambda: "") - return [ - # Returning output_size as float. We cannot convert it to int directly, - # as latter computation of scale_factor is relying output size being float - sym_float(input_size[i + 2] * scale_factors[i]) - for i in range(spatial_dimensions) - ] - utils.check( - False, lambda: "Must specify exactly one of output_size and scale_factors" - ) - - -def get_scale_value(scales, idx): - if scales is None: - return None - return scales[idx] - - @register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) -@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) -@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) -def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors): - osize = upsample_compute_output_size(input.size(), output_size, scale_factors) - scale_h = get_scale_value(scale_factors, 0) - scale_w = get_scale_value(scale_factors, 1) - - # NB: osize could be a list of float when scale_factors is float - # so we cannot redispatch to aten.upsample_bilinear2d.default here - return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w) - - -@register_decomposition(torch.ops.aten.upsample_bilinear2d.default) -@torch.ops.aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) +@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec, type="pre_autograd") @pw_cast_for_opmath -def upsample_bilinear2d( +def upsample_bilinear2d_vec( input: Tensor, - output_size: List[Union[int, float]], + output_size: Optional[List[int]], align_corners: bool, - scales_h: Optional[float] = None, - scales_w: Optional[float] = None, + scale_factors: Optional[List[float]], ) -> Tensor: # get dimensions of original image n_batch, n_channels, in_h, in_w = input.shape - out_h = sym_float(output_size[0]) - out_w = sym_float(output_size[1]) + if output_size is not None: + out_h = float(output_size[0]) + out_w = float(output_size[1]) + elif scale_factors is not None: + out_h = in_h * scale_factors[0] + out_w = in_w * scale_factors[1] # Calculate horizontal and vertical scaling factor - # TODO: Figure out if scales_h/scales_w matters here if out_h > 1: if align_corners: - h_scale_factor = (in_h - 1) / (sym_int(out_h) - 1) + h_scale_factor = (in_h - 1) / (int(out_h) - 1) else: h_scale_factor = in_h / out_h else: @@ -1876,14 +1816,14 @@ def upsample_bilinear2d( if out_w > 1: if align_corners: - w_scale_factor = (in_w - 1) / (sym_int(out_w) - 1) + w_scale_factor = (in_w - 1) / (int(out_w) - 1) else: w_scale_factor = in_w / out_w else: w_scale_factor = 0.0 - i = torch.arange(sym_int(out_h), dtype=input.dtype, device=input.device) - j = torch.arange(sym_int(out_w), dtype=input.dtype, device=input.device) + i = torch.arange(int(out_h), dtype=input.dtype, device=input.device) + j = torch.arange(int(out_w), dtype=input.dtype, device=input.device) if align_corners: x = h_scale_factor * i diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8a7968cf57d2..62c9b4750ae9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -406,21 +406,6 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): # running_mean and running_var are required in evaluation mode (training: False) but not in training mode yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) -def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial( - make_tensor, device=device, dtype=dtype, requires_grad=requires_grad - ) - cases = [ - ((S,), 0), - ((S, S), 0), - ((S, M, S), -1), - ] - input_dtypes = [dtype] - if dtype == torch.float and device == 'cuda': - input_dtypes += [torch.float16] - - for (shape, dim), input_dtype in product(cases, input_dtypes): - yield SampleInput(make_arg(shape), make_arg(shape), dim, input_dtype) def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) @@ -1188,7 +1173,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): cases = ((), (S, S, S), (S,)) for shape in cases: - yield SampleInput(make_arg(shape)) + yield(SampleInput(make_arg(shape))) # TODO: add reduction kwargs def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): @@ -3760,8 +3745,8 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): def shape(size, rank, with_batch_channel=True): if with_batch_channel: - return torch.Size([N, C] + ([size] * rank)) - return torch.Size([size] * rank) + return tuple([N, C] + ([size] * rank)) + return tuple([size] * rank) make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1) @@ -5809,9 +5794,9 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals if list_args: cases = ( - ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), - ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), - ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), + ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),), + ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], -2),) ) else: cases = ( # type: ignore[assignment] @@ -5826,10 +5811,10 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), - ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), - ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), - ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), + cases = (((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), + ((S, S, S), ([int(S / 3), S - int(S / 3), 0],)), + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], 2)), + ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], -2)), ) for shape, args in cases: @@ -6205,7 +6190,7 @@ def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): else: raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") - yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) + yield(SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)) def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6461,7 +6446,7 @@ def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield SampleInput(make_arg(shape), args=(args,)) + yield(SampleInput(make_arg(shape), args=(args, ))) def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6484,8 +6469,8 @@ def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): ) for shape, shape_other in cases: - yield SampleInput(make_arg(shape, requires_grad=requires_grad), - args=(make_arg(shape_other, requires_grad=False),)) + yield(SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False), ))) def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): @@ -6603,8 +6588,8 @@ def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): inputs.append(mixed) for input_t, as_tuple in product(inputs, [False, True]): - yield SampleInput(input_t.clone().requires_grad_(requires_grad), - kwargs=dict(as_tuple=as_tuple)) + yield(SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple))) def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6615,7 +6600,7 @@ def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield SampleInput(make_arg(shape), args=args) + yield(SampleInput(make_arg(shape), args=args)) def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) @@ -6693,15 +6678,6 @@ def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, yield SampleInput(make_arg(case), p=p, training=training) yield SampleInput(make_arg(case)) -def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) - - cases = ((S, S, S, S), (S,), ()) - scale_vals = [0.0, 1.0, 2.0] - - for case, scale in product(cases, scale_vals): - yield SampleInput(make_arg(case), make_mask(case), scale) def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): def make_input(shape): @@ -8119,7 +8095,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): in_shape = input.shape in_rank = len(in_shape) for d in start_dim, end_dim: - if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + if not((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank-1}], but got {d}") end_dim = end_dim if end_dim >= 0 else in_rank + end_dim start_dim = start_dim if start_dim >= 0 else in_rank + start_dim @@ -8448,7 +8424,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): variant_test_name='decomposed', dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, - *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), + *[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10578,22 +10554,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=True), - OpInfo( - '_softmax_backward_data', - op=torch.ops.aten._softmax_backward_data, - aten_name='_softmax_backward_data', - dtypes=floating_types_and(torch.bfloat16), - dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), - sample_inputs_func=sample_inputs_softmax_backward_data, - assert_autodiffed=True, - supports_forward_ad=True, - supports_fwgrad_bwgrad=True, - supports_out=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), - ), - ), # `softmin` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. # https://github.com/pytorch/pytorch/issues/68752 @@ -15967,22 +15927,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_dropout, inplace_variant=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), - OpInfo( - "native_dropout_backward", - op=torch.ops.aten.native_dropout_backward.default, - aten_name="native_dropout_backward", - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), - dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), - supports_out=False, - sample_inputs_func=sample_inputs_dropout_backward, - skips=( - DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), - # Lazy tensor failures - DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), - DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), - DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), - ), - ), OpInfo( "nn.functional.dropout2d", op=lambda input, *args, **kwargs: From 06486cd0087200e08ebb8a9518e064251c7c5309 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Mon, 14 Nov 2022 03:39:43 +0000 Subject: [PATCH 122/453] fix typo: AT_MKLDNN_EBABLED => AT_MKLDNN_ENABLED (#88952) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88952 Approved by: https://github.com/XiaobingSuper --- aten/src/ATen/native/mkldnn/Prelu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/Prelu.cpp b/aten/src/ATen/native/mkldnn/Prelu.cpp index acc78211d83c..dc7d239da7b6 100644 --- a/aten/src/ATen/native/mkldnn/Prelu.cpp +++ b/aten/src/ATen/native/mkldnn/Prelu.cpp @@ -17,7 +17,7 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons }} -#else // AT_MKLDNN_EBABLED +#else // AT_MKLDNN_ENABLED #include #include @@ -76,4 +76,4 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons } }} -#endif // AT_MKLDNN_EBABLED +#endif // AT_MKLDNN_ENABLED From 4ad7b17fabd2a2b6873bc369bd223223ff1e628b Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 13 Nov 2022 22:09:53 -0500 Subject: [PATCH 123/453] TorchDynamo: Add convolution binary(inplace) fusion for cpu in inference mode (#88403) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88403 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/ir.py | 112 +++++++++++++++++++++- torch/_inductor/lowering.py | 34 +++++++ torch/_inductor/overrides.py | 174 ++++++++++++++++++++++++++++++++--- 3 files changed, 303 insertions(+), 17 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 240c196a73b6..ffb935ae440d 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1552,11 +1552,23 @@ def loader(index): @dataclasses.dataclass class Layout(IRNode): - device: torch.device - dtype: torch.dtype - size: List[Expr] - stride: List[Expr] - offset: Expr = Integer(0) + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: List[Expr], + offset: Expr = Integer(0), + ): + self.device = device + self.dtype = dtype + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride def __str__(self): offset = "" @@ -1772,6 +1784,15 @@ def __init__(self, target: IRNode): ) self.target = target + @Layout.stride.getter + def stride(self): + return self.real_layout().stride + + def real_layout(self): + if isinstance(self.target, MutationLayout): + return self.target.real_layout() + return self.target.data.layout + @classmethod def realize_into(cls, src, dst): dst.realize() @@ -2467,6 +2488,16 @@ def require_stride_order(cls, x, order): x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x + elif isinstance(x.get_layout(), MutationLayout): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayout's real layout shouldn't be FlexibleLayout" + ) + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): + return x + # TODO - Storage to InputBuffer if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): return x @@ -3513,6 +3544,77 @@ def apply_constraint(self): self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) +class ConvolutionBinaryInplace(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + + def __init__( + self, + kernel_layout, + inputs_layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._convolution_pointwise_.binary", + ): + super().__init__(kernel_layout, inputs, constant_args) + self.kernel = kernel + self.inputs_layout = inputs_layout + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + def get_mutation_names(self): + assert isinstance(self.layout, MutationLayout) + return (self.layout.target.get_name(),) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List], + unary_algorithm: Optional[str], + ): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + (inputs, constant_args, inputs_layout,) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.realize_input(other) + V.graph.realize_users_of(other.get_name()) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ] + return ConvolutionBinaryInplace( + kernel_layout=MutationLayout(inputs[1]), + inputs_layout=inputs_layout, + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + def apply_constraint(self): + x = self.inputs[0] + # FixedLayout of input + x = self.require_stride_order(x, self.inputs_layout.preferred_stride_order) + self.inputs[0] = x + self.freeze_layout_with_stride_order(self.inputs_layout.preferred_stride_order) + + class LinearUnary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._linear_pointwise" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index dedd39cd91c4..9924396075f6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -960,6 +960,40 @@ def convolution_binary( ) ) + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index d89ee82674dd..a4a29fb02382 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -196,6 +196,79 @@ def forward(self, input, other): return self._conv_forward(input, other, self.weight, self.bias) +class ConvBinaryInplace2d(nn.Conv2d): + def __init__( + self, + conv: nn.Module, + binary_op_name: str, + ): + super(ConvBinaryInplace2d, self).__init__( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + conv.weight.device, + conv.weight.dtype, + ) + self._update_module_params(conv, binary_op_name) + + def _update_module_params(self, conv, binary_op_name): + self.__dict__ = copy.deepcopy(conv.__dict__) + self.binary_attr = binary_op_name + self.binary_alpha = None + self.unary_attr = None + self.unary_scalars = [] + self.unary_algorithm = None + + def _update_unary_params(self, unary): + self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( + unary + ) + + def _conv_forward(self, input, other, weight, bias): + if self.padding_mode != "zeros": + return torch.ops.mkldnn._convolution_pointwise_( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + other, + weight, + bias, + _pair(0), + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + return torch.ops.mkldnn._convolution_pointwise_( + input, + other, + weight, + bias, + self.padding, + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + + def forward(self, input, other): + return self._conv_forward(input, other, self.weight, self.bias) + + class LinearUnary(nn.Linear): def __init__( self, @@ -263,6 +336,14 @@ def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str): ) +def fused_conv_binary_inplace_eval(conv: nn.Module, binary_op_name: str): + assert not (conv.training), "Fusion only for eval!" + return ConvBinaryInplace2d( + conv, + binary_op_name, + ) + + def is_bfloat16_module(m): weight_is_bf16 = m.weight.dtype == torch.bfloat16 bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 @@ -312,6 +393,25 @@ def check_node_is_binary(node): ) +def check_binary_op_kwargs_is_default(node): + # For binary op, we hope the kwargs values are the default value: + # torch.sub(add)(input, other, *, alpha=1, out=None). + if len(node.args) > 2: + return False + if len(node.kwargs) > 0: + if "out" in node.kwargs and node.kwargs["out"] is not None: + return False + if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: + return False + return True + + +def check_node_is_add_inplace(node): + return (node.op == "call_function" and node.target in [operator.iadd]) or ( + node.op == "call_method" and node.target in ["add_"] + ) + + def fuse_fx(gm: torch.fx.GraphModule, example_inputs): # make sure the autograd is disabled. if torch.is_grad_enabled(): @@ -328,6 +428,7 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs): # the binary inputs have same tensor info(device, dtype, and layout). ShapeProp(gm).propagate(*example_inputs) gm = fuse_unary(gm) + gm = fuse_binary_inplace(gm) gm = fuse_binary(gm) return gm @@ -419,26 +520,31 @@ def replace_and_fuse_for_binary( node.replace_all_uses_with(node.args[index_node]) +def binary_inputs_meta_is_same(binary_node): + tensor0_meta = binary_node.args[0].meta.get("tensor_meta") + tensor1_meta = binary_node.args[1].meta.get("tensor_meta") + if not tensor0_meta or not tensor1_meta: + return False + if ( + tensor0_meta.shape != tensor1_meta.shape + or tensor0_meta.stride != tensor1_meta.stride + or tensor0_meta.dtype != tensor1_meta.dtype + ): + return False + + return True + + def fuse_binary(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for node in gm.graph.nodes: - if check_node_is_binary(node) and ( - len(node.kwargs) != 2 or node.kwargs["alpha"] == 1.0 - ): + if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node.args[1], torch.fx.Node ): continue - tensor0_meta = node.args[0].meta.get("tensor_meta") - tensor1_meta = node.args[1].meta.get("tensor_meta") - if not tensor0_meta or not tensor1_meta: - continue - if ( - tensor0_meta.shape != tensor1_meta.shape - or tensor0_meta.stride != tensor1_meta.stride - or tensor0_meta.dtype != tensor1_meta.dtype - ): + if not binary_inputs_meta_is_same(node): continue attr = binary_attr[node.target] index_list = supported_index_list[attr] @@ -473,6 +579,46 @@ def fuse_binary(gm: torch.fx.GraphModule): return gm +def fuse_binary_inplace(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node): + for ( + node_kind, + fuse_func, + ) in computation_op_binary_op_fusion_inplace_map.items(): + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node.args[1], torch.fx.Node + ): + continue + if not binary_inputs_meta_is_same(node): + continue + if check_node_kind(node.args[1], modules, node_kind): + if len(node.args[1].users) > 1: + continue + # make sure the output and input are not same tensor. + if node.args[1].args[0] == node.args[0]: + continue + computation_node = modules[node.args[1].target] + replace_and_fuse_for_binary( + computation_node, + node, + fuse_func, + "add", + modules, + 1, # conv module index + 0, # binary op index + ) + # Make sure the fused node is post node of node's inputs nodes. + node.append(node.args[1]) + gm.graph.erase_node(node) + gm.graph.lint() + break + + gm.recompile() + return gm + + philox_rand_like = _prims._make_prim( schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, @@ -629,6 +775,10 @@ def rand_like(x, **kwargs): } +computation_op_binary_op_fusion_inplace_map = { + nn.Conv2d: fused_conv_binary_inplace_eval, +} + # For add: we support conv/linear + other and other + conv # For sub/add_/sub_, we only support conv/linear - other # or conv/linear +(-)= other From 03296844aa0cb560401584545ba1412e52c87b37 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2022 09:50:50 +0000 Subject: [PATCH 124/453] Fix typos in messages under aten (#88964) This PR fixes typos of messages and parms in c++ source files under `aten` directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88964 Approved by: https://github.com/lezcano --- aten/src/ATen/core/List_test.cpp | 4 ++-- aten/src/ATen/core/class_type.cpp | 4 ++-- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 2 +- aten/src/ATen/cudnn/Descriptors.cpp | 2 +- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- aten/src/ATen/native/SpectralOps.cpp | 4 ++-- aten/src/ATen/native/TensorShape.cpp | 2 +- .../ao_sparse/quantized/cpu/qlinear_deserialize.cpp | 2 +- .../native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp | 2 +- aten/src/ATen/native/quantized/cpu/BinaryOps.cpp | 4 ++-- aten/src/ATen/native/quantized/cpu/qconv.cpp | 2 +- aten/src/ATen/native/quantized/cpu/qmatmul.cpp | 4 ++-- aten/src/ATen/native/quantized/cpu/qmul.cpp | 4 ++-- aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp | 4 ++-- aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp | 2 +- aten/src/ATen/native/sparse/SparseCsrTensor.cpp | 10 +++++----- aten/src/ATen/native/vulkan/api/Adapter.cpp | 2 +- aten/src/ATen/native/vulkan/ops/Clone.cpp | 2 +- test/test_sparse_csr.py | 8 ++++---- 19 files changed, 33 insertions(+), 33 deletions(-) diff --git a/aten/src/ATen/core/List_test.cpp b/aten/src/ATen/core/List_test.cpp index e16e26b6042e..f37f3c008493 100644 --- a/aten/src/ATen/core/List_test.cpp +++ b/aten/src/ATen/core/List_test.cpp @@ -1118,7 +1118,7 @@ TEST(ListTest, canAccessStringByReference) { List list({"one", "two"}); const auto& listRef = list; static_assert(std::is_same::value, - "const List acccess should be by const reference"); + "const List access should be by const reference"); std::string str = list[1]; const std::string& strRef = listRef[1]; EXPECT_EQ("two", str); @@ -1130,7 +1130,7 @@ TEST(ListTest, canAccessOptionalStringByReference) { const auto& listRef = list; static_assert( std::is_same>>::value, - "List> acccess should be by const reference"); + "List> access should be by const reference"); c10::optional str1 = list[1]; c10::optional str2 = list[2]; decltype(auto) strRef1 = listRef[1]; diff --git a/aten/src/ATen/core/class_type.cpp b/aten/src/ATen/core/class_type.cpp index 9d7b38d4d67b..2478bde034bc 100644 --- a/aten/src/ATen/core/class_type.cpp +++ b/aten/src/ATen/core/class_type.cpp @@ -86,7 +86,7 @@ std::string ClassType::getForwardPreHookErrorMessage(int pre_hook_idx) const { std::string pre_hook_schema = pre_hook_name + "(self, input: Tuple[" + input_types + "])"; std::string return_string = - "This error occured while scripting the forward pre-hook '" + + "This error occurred while scripting the forward pre-hook '" + pre_hook_name + "' on module '" + name()->name() + "'. If you did not want to script this pre-hook remove it from the " "original NN module before scripting. Pre-hooks for module '" + @@ -111,7 +111,7 @@ std::string ClassType::getForwardHookErrorMessage(int hook_idx) const { std::string hook_schema = hook_name + "(self, input: Tuple[" + input_types + "], output: " + output_types + ")"; std::string return_string = - "This error occured while scripting the forward hook '" + "This error occurred while scripting the forward hook '" + hook_name + "' on module " + name()->name() + ". If you did not want to script this hook remove it from" + " the original NN module before scripting. This hook was" + diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index b5e685dac65f..25e4c2b44fa9 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -82,7 +82,7 @@ void CUDAHooks::initCUDA() const { at::cuda::detail::init_p2p_access_cache(num_devices); #if AT_MAGMA_ENABLED() - TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initilaize magma, init routine not set"); + TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initialize magma, init routine not set"); magma_init_fn(); #endif } diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index f954bbf5623a..0e739a49bb33 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -164,7 +164,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo filter_format = CUDNN_TENSOR_NHWC; break; default: - TORCH_INTERNAL_ASSERT(false, "unsurpported memory_format for cuDNN filters"); + TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); } set(getDataType(t), (int) dim, size, filter_format); } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 8c5a6fc8f195..c21bc4b47531 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -877,7 +877,7 @@ std::vector> matrix_chain_order(TensorList tensors) { /** * @brief Recursively multiplies the tensors i...j using the given order * - * @param tensors matrices to multiply togther + * @param tensors matrices to multiply together * @param order optimal chain multiplication order from #matrix_chain_order * @param i index of first tensor to be multiplied * @param j index of last tensor to be multiplied diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 0acc3506cf51..e08e17af4d08 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1053,13 +1053,13 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho if (onesided) { if (n_fft / 2 + 1 != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onesided=True, but got " << fft_size; AT_ERROR(ss.str()); } } else { if (n_fft != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got " << fft_size; AT_ERROR(ss.str()); } } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index ccaf4b464252..ba6ff27661ba 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1589,7 +1589,7 @@ Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_sha c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel()); if (self.is_mkldnn()) { - TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tesnors"); + TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tensors"); } if (self.is_contiguous()) { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp index c5fa0210cd58..d367dbe01103 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -209,7 +209,7 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( std::get(serialized); TORCH_CHECK( serialization_version <= SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION, - "Attemped to deserialize sparse qlinear packed params with an ", + "Attempted to deserialize sparse qlinear packed params with an ", "incompatible serialization version (", serialization_version, " > ", diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index a430e8185451..64cab80790a9 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -45,7 +45,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( const auto cols_input = static_cast(input.size(input.dim() - 1)); TORCH_CHECK( cols_input == input_channels_, - "quantized_sparse_lienar: Input tensor's last and weight tensor's" + "quantized_sparse_linear: Input tensor's last and weight tensor's" " second dimension must match."); // On empty input, no output data will be generated, diff --git a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp index 8444f9ca615b..58a7036bdd7e 100644 --- a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp @@ -36,10 +36,10 @@ namespace { inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK( qa.qscheme() == kPerTensorAffine, - "Only per tensor quantization is suported in Add."); + "Only per tensor quantization is supported in Add."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), - "Both inputs to Add must have the same quantization shceme."); + "Both inputs to Add must have the same quantization scheme."); TORCH_CHECK( qa.scalar_type() == qb.scalar_type(), "Add operands should have same data type."); diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 2cd7cd81b903..b6fa57b9e3ed 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -130,7 +130,7 @@ at::SmallVector MakeDeConvOutputShape( ", output padding: ", output_padding[idx], ", dilation: ", dilation[idx]) TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim, - "Output dimension is beyound reasonable maximum for ", idx, + "Output dimension is beyond reasonable maximum for ", idx, " axis;" " kernel: ", kernel[idx], ", stride: ", stride[idx], diff --git a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp index c1e5041a5734..4da714e0bcf0 100644 --- a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp @@ -21,7 +21,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { "MatMul operands should have same data type."); TORCH_CHECK( qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per-tensor quantization is suported in Matmul."); + "Only per-tensor quantization is supported in Matmul."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), "Both inputs to Matmul must have the same quantization scheme."); @@ -45,7 +45,7 @@ Tensor qmatmul( " and ", b_num_dims, " provided)"); TORCH_CHECK( num_dims >= 2, - "Quantized Matmul currently only suports operands which are at least 2-dimensional. (", + "Quantized Matmul currently only supports operands which are at least 2-dimensional. (", num_dims, " provided)"); const int64_t m = qa.size(num_dims - 2); diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index 35d2139c6c14..aa6ad0e724f5 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -40,7 +40,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK(qa.scalar_type() == qb.scalar_type(), "Mul operands should have same data type."); TORCH_CHECK(qa.qscheme() == qb.qscheme(), - "Both inputs to Mul must have the same quantization shceme."); + "Both inputs to Mul must have the same quantization scheme."); } // Note: out is assumed to be the same size as self and other. @@ -314,7 +314,7 @@ class QMulScalarTensor final { static Tensor run(Tensor qa, Tensor b) { TORCH_CHECK(qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per tensor quantization is suported in Mul."); + "Only per tensor quantization is supported in Mul."); auto qc = at::empty_like(qa, qa.suggest_memory_format()); return _mul_scalar_out(qc, qa, b.item()); } diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index d9abd8bcfc79..fbb46b4b0174 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -71,10 +71,10 @@ std::unordered_map> PackedConvWeightCudnn< int64_t groups, bool transpose) { // TODO: need to check out to implement groups for conv operator in Conv.cpp - TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currenty limited to groups = 1; received groups =", groups); + TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currently limited to groups = 1; received groups =", groups); TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme())); TORCH_CHECK( kSpatialDim == 2, // 1D is packed as 2d, hence we don't need other checks diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 2bcbe00a8720..ef205c5673ae 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -129,7 +129,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind // 3.1 TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); // For CSR/CSC formats, we define blocksize=(1, 1) so that checking @@ -380,7 +380,7 @@ DimVector _estimate_sparse_compressed_tensor_size( } TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); return size; } @@ -559,13 +559,13 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki "torch.copy_: expected shapes of self and src to match along dimension ", self_compressed_dim, " for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } else { TORCH_CHECK(self_compressed_dims == src_compressed_dims, "torch.copy_: expected shapes of self and src to match along dimensions ", self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{}, @@ -576,7 +576,7 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2)); TORCH_CHECK(self_blocksize == src_blocksize, "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.", - " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectivly."); + " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively."); }); AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{ diff --git a/aten/src/ATen/native/vulkan/api/Adapter.cpp b/aten/src/ATen/native/vulkan/api/Adapter.cpp index 311648b6894e..176236611c1d 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.cpp +++ b/aten/src/ATen/native/vulkan/api/Adapter.cpp @@ -195,7 +195,7 @@ std::string get_device_type_str(const VkPhysicalDeviceType type) { case VK_PHYSICAL_DEVICE_TYPE_CPU: return "CPU"; default: - return "UNKOWN"; + return "UNKNOWN"; } } diff --git a/aten/src/ATen/native/vulkan/ops/Clone.cpp b/aten/src/ATen/native/vulkan/ops/Clone.cpp index de353a10cb93..2601d785ddb5 100644 --- a/aten/src/ATen/native/vulkan/ops/Clone.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clone.cpp @@ -21,7 +21,7 @@ Tensor clone( TORCH_CHECK( (c10::MemoryFormat::Preserve == memory_format) || (c10::MemoryFormat::Contiguous == memory_format), - "Vulkan supports Preserve and Contiguous memory foramts"); + "Vulkan supports Preserve and Contiguous memory formats"); Tensor self; if (memory_format == MemoryFormat::Preserve) { diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index cc5044da0bd5..d2e3c5fc3851 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -710,7 +710,7 @@ def _generate_invalid_input(self, layout, device): shape((2, 3)), 'compressed_indices must have dimensionality >= 1 but got 0') - yield ('compressed/plain_indices mismatch of dimensionalites', + yield ('compressed/plain_indices mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), @@ -718,14 +718,14 @@ def _generate_invalid_input(self, layout, device): 'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively') if layout in {torch.sparse_csr, torch.sparse_csc}: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), shape((2, 3)), r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1') else: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), @@ -737,7 +737,7 @@ def _generate_invalid_input(self, layout, device): tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), (2,), - r'tensor dimensionality must be sum of batch, base, and dense dimensionalites \(=0 \+ 2 \+ 0\) but got 1') + r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1') yield ('invalid batchsize', tensor([[0, 2, 4]]), From cb4842c9495a68d2a1d4a3ee3ffc9eab30dce28c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Nov 2022 10:29:24 +0000 Subject: [PATCH 125/453] [xla hash update] update the pinned xla hash (#88982) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned xla hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88982 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/xla.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 957272e8578b..6e29f8ee3c31 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -08121e41079319cd369f82f523f5a714a0563f9d +dd9b67ff0d6ba4da6a46ca1b22e35c98dbed0d77 From 072920c281bb4d9ca899c6c781a8374ab42a9a3f Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 13 Nov 2022 22:09:54 -0500 Subject: [PATCH 126/453] TorchDynamo: Add convolution binary+unary fusion for cpu in inference mode (#88412) This PR is about enabling the fusion of **conv+binary+relu**, which will improve the vision model's performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88412 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 15 +++++++++++++-- torch/_inductor/overrides.py | 23 ++++++++++++++++++++--- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d331559a3a8b..bf1b0a9e4b37 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1449,6 +1449,7 @@ def __init__( dilation, groups, bias, + has_relu, **kwargs, ): super(M, self).__init__() @@ -1471,16 +1472,18 @@ def __init__( ) ) self.binary_fn = binary_fn + self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity() def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x) - return self.binary_fn(x1, x2) + return self.relu(self.binary_fn(x1, x2)) test_memory_format = [torch.contiguous_format, torch.channels_last] options = itertools.product( binary_list, [True, False], + [True, False], [1, 3], [1, 2], [1, 4], @@ -1489,6 +1492,7 @@ def forward(self, x): for ( binary_fn, + has_relu, bias, kernel_size, dilation, @@ -1499,7 +1503,14 @@ def forward(self, x): iC = 3 * groups x_shape = (1, iC, 112, 112) mod = M( - binary_fn, iC, oC, dilation, groups, bias, kernel_size=kernel_size + binary_fn, + iC, + oC, + dilation, + groups, + bias, + has_relu, + kernel_size=kernel_size, ).eval() mod = mod.to(memory_format=memory_format) # TODO: add bf16 test diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index a4a29fb02382..8d99107d17c3 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -157,6 +157,11 @@ def _update_module_params(self, conv, binary_op_name): self.unary_scalars = [] self.unary_algorithm = None + def _update_unary_params(self, unary): + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) + def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": return torch.ops.mkldnn._convolution_pointwise( @@ -226,9 +231,9 @@ def _update_module_params(self, conv, binary_op_name): self.unary_algorithm = None def _update_unary_params(self, unary): - self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__]( - unary - ) + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": @@ -344,6 +349,13 @@ def fused_conv_binary_inplace_eval(conv: nn.Module, binary_op_name: str): ) +def fused_binary_unary_eval(conv_binary: nn.Module, unary: nn.Module): + assert not (conv_binary.training), "Fusion only for eval!" + # reuse origin conv module, and just update its' unary attr. + conv_binary._update_unary_params(unary) + return conv_binary + + def is_bfloat16_module(m): weight_is_bf16 = m.weight.dtype == torch.bfloat16 bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 @@ -430,6 +442,9 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs): gm = fuse_unary(gm) gm = fuse_binary_inplace(gm) gm = fuse_binary(gm) + # why re-run fuse_unary? we want to enable conv+binary+unary fusion, + # such as conv+add+relu for vision model. + gm = fuse_unary(gm) return gm @@ -741,6 +756,8 @@ def rand_like(x, **kwargs): computation_op_unary_op_fusion_map = { nn.Conv2d: fused_conv_unary_eval, nn.Linear: fused_linear_unary_eval, + ConvBinary2d: fused_binary_unary_eval, + ConvBinaryInplace2d: fused_binary_unary_eval, } From 8371bb8a3dddbead709bc1e9d26715818a34fa8a Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 13 Nov 2022 22:33:13 +0000 Subject: [PATCH 127/453] Run test_torchinductor_opinfo CPU tests if triton not installed (#88934) These test are not run currently because normal CI workers don't have triton installed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88934 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 20 ++-------------- test/inductor/test_torchinductor_opinfo.py | 28 ++++++++++++++-------- torch/testing/_internal/inductor_utils.py | 23 ++++++++++++++++++ 3 files changed, 43 insertions(+), 28 deletions(-) create mode 100644 torch/testing/_internal/inductor_utils.py diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bf1b0a9e4b37..dfce58397c5c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -20,7 +20,6 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F from torch.testing._internal.common_utils import ( - IS_FBCODE, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase as TorchTestCase, @@ -41,7 +40,7 @@ from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing from torch._inductor.sizevars import SizeVarAllocator - from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed + from torch._inductor.utils import has_torchvision_roi_align, timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -53,25 +52,10 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") -HAS_CPU = False -try: - from subprocess import CalledProcessError - - from torch._inductor.codecache import CppCodeCache - - CppCodeCache.load("") - HAS_CPU = not IS_FBCODE -except ( - CalledProcessError, - OSError, - torch._inductor.exc.InvalidCxxCompiler, - torch._inductor.exc.CppCompileError, -): - pass +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten -HAS_CUDA = has_triton() requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 3d384efea0ae..36c5aaacd1dd 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -16,20 +16,22 @@ onlyNativeDeviceTypes, OpDTypes, ops, + skipCPUIf, + skipCUDAIf, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( dtype_abbrs, run_tests, skipCUDAMemoryLeakCheckIf, + skipIfCrossRef, + skipIfTorchDynamo, suppress_warnings, - TEST_WITH_ROCM, TestCase, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA try: - from torch._inductor.utils import has_triton - try: from .test_torchinductor import check_model, check_model_cuda except ImportError: @@ -120,6 +122,7 @@ def process(device_type): inductor_skips["cpu"] = { "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault + "linalg.ldl_factor": {f32, f64}, # flaky "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky } @@ -169,6 +172,8 @@ def process(device_type): "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, "bincount": {i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -209,11 +214,10 @@ def process(device_type): "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.lu_solve": {f32, f64}, - "lu_solve": {f32, f64}, - "lu_unpack": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "logdet": {f32, f64}, "masked.norm": {f16}, + "masked.normalize": {f16}, "masked_fill": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, @@ -225,8 +229,8 @@ def process(device_type): "nan_to_num": {f16}, "nanquantile": {f32, f64}, "nn.functional.avg_pool1d": {i64}, - "nn.functional.avg_pool2d": {i64}, - "nn.functional.adaptive_avg_pool2d": {f16}, + "nn.functional.avg_pool2d": {i64, f64}, + "nn.functional.adaptive_avg_pool2d": {f16, f64}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, "nn.functional.gelu": {f64}, @@ -243,6 +247,7 @@ def process(device_type): "quantile": {f32, f64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "scatter_add": {f16}, @@ -455,6 +460,10 @@ class TestInductorOpInfo(TestCase): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently + @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") + @skipIfTorchDynamo("Test uses dynamo already") + @skipIfCrossRef @_ops(op_db[START:END]) @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True) def test_comprehensive(self, device, dtype, op): @@ -599,5 +608,4 @@ def fn(*args, **kwargs): instantiate_device_type_tests(TestInductorOpInfo, globals()) if __name__ == "__main__": - if has_triton() and not TEST_WITH_ROCM: - run_tests() + run_tests() diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py new file mode 100644 index 000000000000..84750a2de3ee --- /dev/null +++ b/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,23 @@ +from subprocess import CalledProcessError + +from torch._inductor.codecache import CppCodeCache +from torch._inductor.utils import has_triton +from torch.testing._internal.common_utils import ( + IS_FBCODE, + TEST_WITH_ROCM, +) +import torch + +HAS_CPU = False +try: + CppCodeCache.load("") + HAS_CPU = not IS_FBCODE +except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, +): + pass + +HAS_CUDA = has_triton() and not TEST_WITH_ROCM From 5e6cefd258dfdb4ddf2956c0b5631d84e97027e5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Nov 2022 12:02:43 +0000 Subject: [PATCH 128/453] Revert "Run test_torchinductor_opinfo CPU tests if triton not installed (#88934)" This reverts commit 8371bb8a3dddbead709bc1e9d26715818a34fa8a. Reverted https://github.com/pytorch/pytorch/pull/88934 on behalf of https://github.com/peterbell10 due to Inductor tests failing on master --- test/inductor/test_torchinductor.py | 20 ++++++++++++++-- test/inductor/test_torchinductor_opinfo.py | 28 ++++++++-------------- torch/testing/_internal/inductor_utils.py | 23 ------------------ 3 files changed, 28 insertions(+), 43 deletions(-) delete mode 100644 torch/testing/_internal/inductor_utils.py diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index dfce58397c5c..bf1b0a9e4b37 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -20,6 +20,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F from torch.testing._internal.common_utils import ( + IS_FBCODE, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase as TorchTestCase, @@ -40,7 +41,7 @@ from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing from torch._inductor.sizevars import SizeVarAllocator - from torch._inductor.utils import has_torchvision_roi_align, timed + from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -52,10 +53,25 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +HAS_CPU = False +try: + from subprocess import CalledProcessError + + from torch._inductor.codecache import CppCodeCache + + CppCodeCache.load("") + HAS_CPU = not IS_FBCODE +except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, +): + pass aten = torch.ops.aten +HAS_CUDA = has_triton() requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 36c5aaacd1dd..3d384efea0ae 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -16,22 +16,20 @@ onlyNativeDeviceTypes, OpDTypes, ops, - skipCPUIf, - skipCUDAIf, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( dtype_abbrs, run_tests, skipCUDAMemoryLeakCheckIf, - skipIfCrossRef, - skipIfTorchDynamo, suppress_warnings, + TEST_WITH_ROCM, TestCase, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA try: + from torch._inductor.utils import has_triton + try: from .test_torchinductor import check_model, check_model_cuda except ImportError: @@ -122,7 +120,6 @@ def process(device_type): inductor_skips["cpu"] = { "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault - "linalg.ldl_factor": {f32, f64}, # flaky "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky } @@ -172,8 +169,6 @@ def process(device_type): "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, "bincount": {i32, i64}, - "cdouble": {b8, f16, f32, f64, i32, i64}, - "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -214,10 +209,11 @@ def process(device_type): "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.pinv.singular": {f32, f64}, + "linalg.lu_solve": {f32, f64}, + "lu_solve": {f32, f64}, + "lu_unpack": {f32, f64}, "logdet": {f32, f64}, "masked.norm": {f16}, - "masked.normalize": {f16}, "masked_fill": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, @@ -229,8 +225,8 @@ def process(device_type): "nan_to_num": {f16}, "nanquantile": {f32, f64}, "nn.functional.avg_pool1d": {i64}, - "nn.functional.avg_pool2d": {i64, f64}, - "nn.functional.adaptive_avg_pool2d": {f16, f64}, + "nn.functional.avg_pool2d": {i64}, + "nn.functional.adaptive_avg_pool2d": {f16}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, "nn.functional.gelu": {f64}, @@ -247,7 +243,6 @@ def process(device_type): "quantile": {f32, f64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, - "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "scatter_add": {f16}, @@ -460,10 +455,6 @@ class TestInductorOpInfo(TestCase): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently - @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") - @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") - @skipIfTorchDynamo("Test uses dynamo already") - @skipIfCrossRef @_ops(op_db[START:END]) @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True) def test_comprehensive(self, device, dtype, op): @@ -608,4 +599,5 @@ def fn(*args, **kwargs): instantiate_device_type_tests(TestInductorOpInfo, globals()) if __name__ == "__main__": - run_tests() + if has_triton() and not TEST_WITH_ROCM: + run_tests() diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py deleted file mode 100644 index 84750a2de3ee..000000000000 --- a/torch/testing/_internal/inductor_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from subprocess import CalledProcessError - -from torch._inductor.codecache import CppCodeCache -from torch._inductor.utils import has_triton -from torch.testing._internal.common_utils import ( - IS_FBCODE, - TEST_WITH_ROCM, -) -import torch - -HAS_CPU = False -try: - CppCodeCache.load("") - HAS_CPU = not IS_FBCODE -except ( - CalledProcessError, - OSError, - torch._inductor.exc.InvalidCxxCompiler, - torch._inductor.exc.CppCompileError, -): - pass - -HAS_CUDA = has_triton() and not TEST_WITH_ROCM From 15ef0660c553ebb50ad639f563062cab01e5e6dc Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 13 Nov 2022 22:09:56 -0500 Subject: [PATCH 129/453] Fake Tensor For (ConvFusion) Propagation (#88414) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88414 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/ir.py | 62 +++++++++++++++---------------------------- 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index ffb935ae440d..8a2e26ee9b94 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3130,7 +3130,7 @@ def create( sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() ] - _, _, *kernel_size = weight.get_size() + _, _, *kernel_size = weight_shape # choose runtime kernel config_conv = config.triton.convolution @@ -3355,50 +3355,28 @@ def _prepare_convolution_fusion_create( padding = tuple(padding_) dilation = tuple(dilation_) assert isinstance(groups, int) - + with FakeTensorMode(): + output, *_ = cls.process_kernel( + torch.ops.aten.convolution, + x, + weight, + bias, + stride, + padding, + dilation, + False, + [0, 0], + groups, + ) + + output_size = output.shape weight_shape = [ sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() ] - - out_channels, in_channels1, *kernel_size = weight_shape - in_channels1 = in_channels1 * groups - assert len(x.get_size()) == 2 + len(kernel_size) - batch, in_channels2, *input_size = x.get_size() - output_size = [batch] - V.graph.sizevars.guard_equals(in_channels1, in_channels2) - - output_size.append(out_channels) - assert ( - len(stride) - == len(padding) - == len(dilation) - == len(kernel_size) - == len(input_size) + _, _, *kernel_size = weight_shape + output_layout_str = ( + "torch.contiguous_format" if output.is_contiguous() else "torch.channels_last" ) - for i in range(len(stride)): - output_size.append( - IndexingDiv( - input_size[i] - + 2 * padding[i] - - dilation[i] * (kernel_size[i] - 1) - - 1 - + stride[i], - stride[i], - ) - ) - output_size[-1] = sympy.Integer( - V.graph.sizevars.guard_static_shape(output_size[-1]) - ) - - output_layout_str = "torch.contiguous_format" - # If x or weight have one channels_last(2d or 3d) format, it will call channels_last path, - # which align with aten.convolutuion path(cpu only support 2d case now). - # TODO: after cpu 3d convolution support channels_last path, the size check can be removed. - if len(x.get_size()) == 4 and ( - x.get_layout().is_channels_last_stride_ordered() - or weight.get_layout().is_channels_last_stride_ordered() - ): - output_layout_str = "torch.channels_last" if output_layout_str == "torch.channels_last": stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1))) @@ -3440,6 +3418,8 @@ def codegen(self, wrapper): wrapper.writeline( f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) @classmethod def create( From 9943d46aab4465b887039aa1a9b5d9ebc0a01a35 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 13 Nov 2022 22:09:58 -0500 Subject: [PATCH 130/453] TorchDynamo: skip convolution fusion when convolution's padding is string (#88794) Currently, the fusion convolution doesn't support the case when padding is a string, we will support it at the next step. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88794 Approved by: https://github.com/jansel, https://github.com/jgong5 --- test/inductor/test_torchinductor.py | 9 +++++++++ torch/_inductor/overrides.py | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bf1b0a9e4b37..8c74b1090a23 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1400,6 +1400,7 @@ def test_conv2d_unary(self): [1, 3], [1, 2], [1, 4], + ["same", 0], test_memory_format, ) @@ -1409,6 +1410,7 @@ def test_conv2d_unary(self): kernel_size, dilation, groups, + padding, memory_format, ) in options: oC = 32 * groups @@ -1419,6 +1421,7 @@ def test_conv2d_unary(self): iC, oC, kernel_size=kernel_size, + padding=padding, dilation=dilation, groups=groups, bias=bias, @@ -1448,6 +1451,7 @@ def __init__( out_channels, dilation, groups, + padding, bias, has_relu, **kwargs, @@ -1458,6 +1462,7 @@ def __init__( out_channels, dilation=dilation, groups=groups, + padding=padding, bias=bias, **kwargs, ) @@ -1467,6 +1472,7 @@ def __init__( out_channels, dilation=dilation, groups=groups, + padding=padding, bias=bias, **kwargs, ) @@ -1487,6 +1493,7 @@ def forward(self, x): [1, 3], [1, 2], [1, 4], + ["same", 0], test_memory_format, ) @@ -1497,6 +1504,7 @@ def forward(self, x): kernel_size, dilation, groups, + padding, memory_format, ) in options: oC = 32 * groups @@ -1508,6 +1516,7 @@ def forward(self, x): oC, dilation, groups, + padding, bias, has_relu, kernel_size=kernel_size, diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 8d99107d17c3..3a95aa7ce880 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -499,7 +499,11 @@ def fuse_unary(gm: torch.fx.GraphModule): eval_mode = all(not n.training for n in [computation_node, unary_node]) if not eval_mode: continue - + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [nn.Linear] and not is_bfloat16_module( computation_node @@ -570,6 +574,11 @@ def fuse_binary(gm: torch.fx.GraphModule): if len(node.args[index_node].users) > 1: continue computation_node = modules[node.args[index_node].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [ nn.Linear @@ -615,6 +624,11 @@ def fuse_binary_inplace(gm: torch.fx.GraphModule): if node.args[1].args[0] == node.args[0]: continue computation_node = modules[node.args[1].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue replace_and_fuse_for_binary( computation_node, node, From ec4eadac5baebcf094836108a25ef3af63d39f5d Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Fri, 11 Nov 2022 14:13:01 -0800 Subject: [PATCH 131/453] reland "Do not use unsafe restriding for subclasses (#87610)" (#88343) This reverts commit 5b75b19f51837e162cc0e5e5757dfd9bef437c67. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88343 Approved by: https://github.com/ezyang --- .../ATen/functorch/BatchRulesScatterOps.cpp | 5 ++ aten/src/ATen/native/TensorShape.cpp | 3 +- test/functorch/test_aotdispatch.py | 2 - test/functorch/test_eager_transforms.py | 10 ++- test/test_functionalization.py | 76 +++++++++---------- 5 files changed, 52 insertions(+), 44 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 5eecbedd93e7..fc51e9d74409 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -928,6 +928,11 @@ Tensor index_copy_decomp( return at::scatter(self, dim, index_, source); ; } +// Note [Fix vmap slice_scatter] +// registers a decomposition for `slice_scatter` that calls into `slice.src` +// *_scatter operators have some special semantics though, that we can't easily +// through a decomposition: slice_scatter's output needs to have the same +// size, size, strides and storage_offset as the input. Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, int64_t dim, c10::optional start, c10::optional end, int64_t step) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index ba6ff27661ba..c44f3a921afc 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -1573,7 +1574,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { // // We need to do the checks here instead of in `native_functions.yaml` // to preserve backwards compatibility. - if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) { + if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) { return self._reshape_alias_symint(shape, stride.value()); } else { return self.view_symint(shape); diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ea00842a4e00..e31ac58039ec 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1098,8 +1098,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked_fill', ''), # could not find kernel xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - # Seems flaky: https://github.com/pytorch/pytorch/issues/88883 - skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 26b64c5e70cc..ff69ed9df6e6 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3130,13 +3130,16 @@ def normalize_devices(fx_g): return fx_g class TestFunctionalize(TestCase): - def _check_functionalize_correctness(self, f, inpt): + def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False): inpt1 = inpt.clone() inpt2 = inpt.clone() inpt3 = inpt.clone() expected_outputs = f(inpt1) - actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() + if skip_vmap: + actual_outputs = functionalize(f)(inpt2) + else: + actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() # Right now the flavor of functionalize that also removes view ops # isn't being used with vmap # That's because {view}_copy ops don't have batching rules yet @@ -3206,7 +3209,8 @@ def f(x: torch.Tensor) -> torch.Tensor: z2, z3 = z1.split(2) z2.add_(tmp) return x - self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + # See Note [Fix vmap slice_scatter] + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True) # Ensure functionalize works with List[Optional[Tensor]] arguments. # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 diff --git a/test/test_functionalization.py b/test/test_functionalization.py index c6c3d991771b..c5330664d1e8 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -147,17 +147,17 @@ def forward(self, a_1): sum_1 = torch.ops.aten.sum.default(relu) ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None - new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) - view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) - view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) - clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None + view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None + new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) + view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) + view_copy_5 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) + clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format); view_copy_5 = None threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None - detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None - view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None - _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None - detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]); view_copy_3 = None + detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None + view_copy_7 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None + detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None return detach_copy_1 """) # noqa: B950 @@ -710,40 +710,40 @@ def forward(self, a_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None view_copy = torch.ops.aten.view_copy.default(add, [8]) - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None - transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0) + view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None + transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0) unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None getitem = split_copy[0] getitem_1 = split_copy[1]; split_copy = None add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None - select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1]) - view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None - _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None - transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None + select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]) + view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None + view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None - _reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None - view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None - view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]) - _reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None - select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None - view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None - _reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None - transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None + view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None + view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8]) + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None + select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None + view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None + view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None + transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None getitem_2 = split_copy_1[0] getitem_3 = split_copy_1[1]; split_copy_1 = None - _reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None - add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None + view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None + add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None return add_1 """) # noqa: B950 @@ -756,30 +756,30 @@ def forward(self, a_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None view = torch.ops.aten.view.default(add, [8]) - _reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None - transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0) + view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None + transpose = torch.ops.aten.transpose.int(view_1, 1, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None getitem = split[0] getitem_1 = split[1]; split = None add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None - select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None + select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format) _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None - view_1 = torch.ops.aten.view.default(add, [8]); add = None - _reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None - transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None + view_2 = torch.ops.aten.view.default(add, [8]); add = None + view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None + transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None - _reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None - view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None - view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None - _reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None - select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None + view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None + view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None + view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None + view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None + select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None return getitem """) From c8f3d1c13460bbaa85b7f423bfb7f414e825c757 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 14 Nov 2022 12:36:44 +0000 Subject: [PATCH 132/453] Run test_torchinductor_opinfo CPU tests if triton not installed (#88934) These test are not run currently because normal CI workers don't have triton installed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88934 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 20 ++------------- test/inductor/test_torchinductor_opinfo.py | 30 ++++++++++++++-------- torch/testing/_internal/inductor_utils.py | 23 +++++++++++++++++ 3 files changed, 44 insertions(+), 29 deletions(-) create mode 100644 torch/testing/_internal/inductor_utils.py diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 8c74b1090a23..ba1f9032d97f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -20,7 +20,6 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F from torch.testing._internal.common_utils import ( - IS_FBCODE, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase as TorchTestCase, @@ -41,7 +40,7 @@ from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing from torch._inductor.sizevars import SizeVarAllocator - from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed + from torch._inductor.utils import has_torchvision_roi_align, timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -53,25 +52,10 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") -HAS_CPU = False -try: - from subprocess import CalledProcessError - - from torch._inductor.codecache import CppCodeCache - - CppCodeCache.load("") - HAS_CPU = not IS_FBCODE -except ( - CalledProcessError, - OSError, - torch._inductor.exc.InvalidCxxCompiler, - torch._inductor.exc.CppCompileError, -): - pass +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten -HAS_CUDA = has_triton() requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 3d384efea0ae..3880b87c082c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -16,20 +16,22 @@ onlyNativeDeviceTypes, OpDTypes, ops, + skipCPUIf, + skipCUDAIf, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( dtype_abbrs, run_tests, skipCUDAMemoryLeakCheckIf, + skipIfCrossRef, + skipIfTorchDynamo, suppress_warnings, - TEST_WITH_ROCM, TestCase, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA try: - from torch._inductor.utils import has_triton - try: from .test_torchinductor import check_model, check_model_cuda except ImportError: @@ -120,6 +122,7 @@ def process(device_type): inductor_skips["cpu"] = { "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault + "linalg.ldl_factor": {f32, f64}, # flaky "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky } @@ -169,6 +172,9 @@ def process(device_type): "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -209,11 +215,10 @@ def process(device_type): "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.lu_solve": {f32, f64}, - "lu_solve": {f32, f64}, - "lu_unpack": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "logdet": {f32, f64}, "masked.norm": {f16}, + "masked.normalize": {f16}, "masked_fill": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, @@ -225,8 +230,8 @@ def process(device_type): "nan_to_num": {f16}, "nanquantile": {f32, f64}, "nn.functional.avg_pool1d": {i64}, - "nn.functional.avg_pool2d": {i64}, - "nn.functional.adaptive_avg_pool2d": {f16}, + "nn.functional.avg_pool2d": {i64, f64}, + "nn.functional.adaptive_avg_pool2d": {f16, f64}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, "nn.functional.gelu": {f64}, @@ -243,6 +248,7 @@ def process(device_type): "quantile": {f32, f64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "scatter_add": {f16}, @@ -366,7 +372,6 @@ def process(device_type): "asin": {f16}, "cumprod": {f16}, "linalg.vector_norm": {f64, f64}, - "linalg.householder_product": {f32}, "kron": {f16}, "nanquantile": {f32, f64}, "native_batch_norm": {f16, f32, f64}, @@ -455,6 +460,10 @@ class TestInductorOpInfo(TestCase): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently + @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") + @skipIfTorchDynamo("Test uses dynamo already") + @skipIfCrossRef @_ops(op_db[START:END]) @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True) def test_comprehensive(self, device, dtype, op): @@ -599,5 +608,4 @@ def fn(*args, **kwargs): instantiate_device_type_tests(TestInductorOpInfo, globals()) if __name__ == "__main__": - if has_triton() and not TEST_WITH_ROCM: - run_tests() + run_tests() diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py new file mode 100644 index 000000000000..84750a2de3ee --- /dev/null +++ b/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,23 @@ +from subprocess import CalledProcessError + +from torch._inductor.codecache import CppCodeCache +from torch._inductor.utils import has_triton +from torch.testing._internal.common_utils import ( + IS_FBCODE, + TEST_WITH_ROCM, +) +import torch + +HAS_CPU = False +try: + CppCodeCache.load("") + HAS_CPU = not IS_FBCODE +except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, +): + pass + +HAS_CUDA = has_triton() and not TEST_WITH_ROCM From 06f1b52705ee360e5ac89e0f1f32f69ffde72b9a Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 14 Nov 2022 17:37:24 +0000 Subject: [PATCH 133/453] don't use prims.unsqueeze in group_norm (#88927) inductor doesn't have prims.squeeze lowering, so this breaks it. Longer term, `squeeze` with multiple dimensions is not a prim, nvfuser implements it with a loop, inductor uses `_squeeze_multiple` helper which turns it into a loop. Prim should accept only a single dimension. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88927 Approved by: https://github.com/eellison --- torch/_refs/__init__.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index a1de9a438d77..f2817f0331ac 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2820,6 +2820,12 @@ def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeT return x +def _squeeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType: + for dim in reversed(sorted(dimensions)): + x = torch.squeeze(x, dim) + return x + + @register_decomposition(torch.ops.aten.native_group_norm.default) def native_group_norm( input: Tensor, @@ -2868,8 +2874,8 @@ def native_group_norm( rstd = _maybe_convert_to_dtype(rstd, input.dtype) # type: ignore[assignment] # remove broadcast dimensions from mean and rstd - mean = prims.squeeze(mean, reduction_dims) - rstd = prims.squeeze(rstd, reduction_dims) + mean = _squeeze_multiple(mean, reduction_dims) + rstd = _squeeze_multiple(rstd, reduction_dims) return (out, mean, rstd) From b0c86caa1d46a16195682e2afe5456f97265aa53 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 14 Nov 2022 17:49:30 +0000 Subject: [PATCH 134/453] Remove cpu path from lobpcg's basis helper (#88984) Fixes https://github.com/pytorch/pytorch/issues/88650 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88984 Approved by: https://github.com/lezcano --- torch/_linalg_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index 76b8ab532fcd..bdd22f395d2d 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -76,12 +76,7 @@ def qform(A: Optional[Tensor], S: Tensor): def basis(A): """Return orthogonal basis of A columns.""" - if A.is_cuda: - # torch.orgqr is not available in CUDA - Q = torch.linalg.qr(A).Q - else: - Q = torch.orgqr(*torch.geqrf(A)) - return Q + return torch.linalg.qr(A).Q def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: From f1a5044de0639180f667d212800aa43f34026b3c Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Mon, 14 Nov 2022 18:18:45 +0000 Subject: [PATCH 135/453] [primTorch] _refs & opinfo alpha_dropout (#87989) Add _refs and OpInfo for `nn.functional.alpha_dropout` Pull Request resolved: https://github.com/pytorch/pytorch/pull/87989 Approved by: https://github.com/mruberry --- test/functorch/test_ops.py | 5 ++ test/functorch/test_vmap.py | 1 + torch/_refs/nn/functional/__init__.py | 81 ++++++++++++++++--- .../_internal/common_methods_invocations.py | 47 +++++++++++ 4 files changed, 123 insertions(+), 11 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 74085941c6c8..2e303922dfa1 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -612,6 +612,7 @@ def fn(inp, *args, **kwargs): skip("nn.functional.dropout"), # calls random op skip("nn.functional.dropout2d"), # calls random op skip("nn.functional.dropout3d"), # calls random op + skip("nn.functional.alpha_dropout"), # calls random op skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op @@ -719,6 +720,7 @@ def vjp_of_vjp(*args_and_cotangents): skip('nn.functional.dropout'), # randomness skip('nn.functional.dropout2d'), # randomness skip('nn.functional.dropout3d', ''), # randomness + skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional._scaled_dot_product_attention'), # randomness xfail('as_strided'), # as_strided is too wild for us to support, wontfix xfail('index_put', ''), # not possible due to dynamic shapes; we support a subset @@ -808,6 +810,7 @@ def test_vmapvjp(self, device, dtype, op): skip('nn.functional.dropout2d', ''), skip('nn.functional.dropout3d', ''), skip('nn.functional._scaled_dot_product_attention'), # randomness + skip('nn.functional.alpha_dropout'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), skip('nn.functional.feature_alpha_dropout', 'with_train'), xfail('nn.functional.fractional_max_pool2d'), # Cannot access data pointer of Tensor that doesn't have storage @@ -1089,6 +1092,7 @@ def test(): skip('nn.functional.rrelu'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to_sparse', ''), # non-dense output skip('ormqr', ''), # takes too long @@ -1330,6 +1334,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.dropout'), # calls random op skip('nn.functional._scaled_dot_product_attention'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition + xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op xfail('nn.functional.fractional_max_pool2d'), # calls random op xfail('nn.functional.fractional_max_pool3d'), # calls random op diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6d95077b627e..fb8722b8405b 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3219,6 +3219,7 @@ def test(): xfail('nn.functional.rrelu'), # randomness xfail('nn.functional.dropout2d', ''), # randomness xfail('nn.functional.dropout3d', ''), # randomness + xfail('nn.functional.alpha_dropout', ''), # randomness xfail('nn.functional.feature_alpha_dropout', 'with_train'), # randomness xfail('as_strided'), # Our test runner can't handle this; manual test exists skip('new_empty_strided'), # empty tensor data is garbage so it's hard to make comparisons with it diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index dcd86d8952d2..3848a738d534 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,3 +1,4 @@ +import math from typing import Callable, Optional, Union import torch @@ -27,6 +28,7 @@ from torch._subclasses.fake_tensor import FakeTensor __all__ = [ + "alpha_dropout", "celu", "dropout", "elu", @@ -59,6 +61,65 @@ Tensor = torch.Tensor + +def _dropout_helper( + self: TensorLikeType, + val: float, +) -> TensorLikeType: + """ + Helper function for all dropout-type operators. During training, + some of the elements of the input tensor are randomly masked. + + Returns the masked tensor of the boolean values. + + """ + + return ( + refs.uniform( + self.shape, low=0.0, high=1.0, dtype=torch.float32, device=self.device + ) + < val + ) + + +@register_decomposition(torch.ops.aten.alpha_dropout) +def alpha_dropout( + self: TensorLikeType, p: float = 0.5, training: bool = False, inplace: bool = False +) -> TensorLikeType: + + if inplace: + raise NotImplementedError + + if not training: + return self + + utils.check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) + + if p == 1: + return torch.zeros_like(self) + + if p == 0: + return self + + dropout_mask = _dropout_helper(self, 1 - p) + + # From paper: Self-Normalizing Neural Networks (https://arxiv.org/pdf/1706.02515.pdf) + # alpha = - SELU.alpha * SELU.scale, here + # SELU.alpha = 1.6732632423543772848170429916717 and + # SELU.scale = 1.0507009873554804934193349852946 + alpha = -1.7580993408473766 + + a = 1.0 / math.sqrt((alpha * alpha * p + 1) * (1 - p)) + b = torch.logical_not(dropout_mask) + b = b * (alpha * a) + alpha * a * p + dropout_mask = a * dropout_mask + + return self * dropout_mask + b + + # celu is implemented specially because it has an alpha argument # celu is very similar to elu @register_decomposition(torch.ops.aten.celu) @@ -93,7 +154,6 @@ def celu( return torch.where(a > 0, a, rhs) -# TODO: should we allow the user to set a different dtype for the mask generation? @register_decomposition(torch.ops.aten.dropout) def dropout( a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False @@ -105,22 +165,21 @@ def dropout( if not training: return a - assert p <= 1 - assert p >= 0 + utils.check( + p <= 1 and p >= 0, + lambda: f"dropout probability has to be between 0 and 1, but got, {p}", + ) if p == 1: - return refs.zeros_like(a) + return torch.zeros_like(a) if p == 0: return a - p1m = 1 - p - scale = 1 / p1m - mask = refs.lt( - refs.uniform(a.shape, low=0.0, high=1.0, dtype=torch.float32, device=a.device), - p1m, - ) - return refs.mul(refs.mul(a, mask), scale) + scale = 1 / (1 - p) + dropout_mask = _dropout_helper(a, 1 - p) + + return a * dropout_mask * scale # elu is implemented specially because it has an alpha argument diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 62c9b4750ae9..001fd455e82e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15965,6 +15965,28 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=partial(sample_inputs_dropout, valid_input_dim=(4, 5)), inplace_variant=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.dropout3d, input, *args, **kwargs, inplace=True)), + OpInfo( + "nn.functional.alpha_dropout", + op=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs), + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + gradcheck_wrapper=wrapper_set_seed, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + sample_inputs_func=sample_inputs_dropout, + check_batched_forward_grad=False, + inplace_variant=lambda input, *args, **kwargs: + wrapper_set_seed(torch.nn.functional.alpha_dropout, input, *args, **kwargs, inplace=True), + skips=( + # lambda impl + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # AssertionError: Tensor-likes are not close! + # Fails in cuda11.7 + # Error Log: https://github.com/pytorch/pytorch/actions/runs/3440108478/jobs/5738475757 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu', device_type='cuda'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),),), # In training mode, feature_alpha_dropout currently doesn't support inputs of complex dtype # unlike when `train=False`, it supports complex inputs, hence 2 OpInfos to cover all cases OpInfo( @@ -17287,6 +17309,31 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # # Elementwise Unary nn.functional OpInfos # + PythonRefInfo( + "_refs.nn.functional.alpha_dropout", + torch_opinfo_name="nn.functional.alpha_dropout", + supports_nvfuser=False, + decorators=( + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_torch_fallback'), + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_python_ref_executor', device_type='cuda'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestMathBits', + 'test_neg_view'), + # AssertionError: Tensor-likes are not close! + DecorateInfo(unittest.skip("Expected: dropout is not comparable"), + 'TestCommon', + 'test_compare_cpu'), + ) + ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.celu", torch_opinfo_name="nn.functional.celu", From cdb798faefa2520b37938311bcef1c175581a0ff Mon Sep 17 00:00:00 2001 From: Sean Ross-Ross Date: Mon, 14 Nov 2022 18:39:45 +0000 Subject: [PATCH 136/453] _get_nested_attr should return a value in the general case (#88822) Fixes https://github.com/pytorch/functorch/issues/1053 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88822 Approved by: https://github.com/zou3519 --- functorch/_src/make_functional.py | 2 +- test/functorch/test_eager_transforms.py | 31 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/functorch/_src/make_functional.py b/functorch/_src/make_functional.py index 7b8c15196e23..abb3f07ca597 100644 --- a/functorch/_src/make_functional.py +++ b/functorch/_src/make_functional.py @@ -44,7 +44,7 @@ def _get_nested_attr(obj: nn.Module, names: List[str]) -> None: if len(names) == 1: return getattr(obj, names[0]) else: - _get_nested_attr(getattr(obj, names[0]), names[1:]) + return _get_nested_attr(getattr(obj, names[0]), names[1:]) def raise_parameter_tying_error(): diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index ff69ed9df6e6..e88e8007e77e 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -2669,6 +2669,37 @@ def test_combine_state_for_ensemble_smoke(self): models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] _ = combine_state_for_ensemble(models) + def test_state_correctly_returned_after_forward(self): + class Net(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(3, 3) + + def forward(self, x): + x = self.linear(x) + return x + + mod = Net() + func, params = make_functional(mod) + + # state in func.names_map + old_state_linear_weight = func.stateless_model.linear.weight + old_state_linear_bias = func.stateless_model.linear.bias + + self.assertIsNotNone(old_state_linear_weight) + self.assertIsNotNone(old_state_linear_bias) + + x = torch.randn(4, 3) + func(params, x) + + new_state_linear_weight = func.stateless_model.linear.weight + new_state_linear_bias = func.stateless_model.linear.bias + + self.assertIsNotNone(new_state_linear_weight) + self.assertIsNotNone(new_state_linear_bias) + + self.assertEqual(old_state_linear_weight, new_state_linear_weight) + self.assertEqual(old_state_linear_bias, new_state_linear_bias) class TestExamplesCorrectness(TestCase): def test_maml_regression(self, device): From 36d87465fb9b34914e6db50638c0f5bf04e3d7d9 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 14 Nov 2022 18:43:50 +0000 Subject: [PATCH 137/453] Fix long comment error on dashboard (#89002) Fix dashboard comment failure due to the following trace: ``` Traceback (most recent call last): File "/scratch/anijain/dashboard/work/pytorch/benchmarks/dynamo/runner.py", line 1180, in DashboardUpdater(args).update() File "/scratch/anijain/dashboard/work/pytorch/benchmarks/dynamo/runner.py", line 1119, in update self.comment_on_gh(comment) File "/scratch/anijain/dashboard/work/pytorch/benchmarks/dynamo/runner.py", line 1096, in comment_on_gh subprocess.check_call( File "/scratch/anijain/dashboard/env/lib/python3.9/subprocess.py", line 368, in check_call retcode = call(*popenargs, **kwargs) File "/scratch/anijain/dashboard/env/lib/python3.9/subprocess.py", line 349, in call with Popen(*popenargs, **kwargs) as p: File "/scratch/anijain/dashboard/env/lib/python3.9/subprocess.py", line 951, in __init__ self._execute_child(args, executable, preexec_fn, close_fds, File "/scratch/anijain/dashboard/env/lib/python3.9/subprocess.py", line 1821, in _execute_child raise child_exception_type(errno_num, err_msg, err_filename) OSError: [Errno 7] Argument list too long: '/data/home/anijain/miniconda/bin/gh' srun: error: a100-st-p4d24xlarge-27: task 0: Exited with exit code 1 ``` That is, we were trying to execute a gh command in the OS that was too long. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89002 Approved by: https://github.com/davidberard98 --- benchmarks/dynamo/runner.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 99c70426cd36..d27763c41b0b 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -36,6 +36,7 @@ import re import shutil import subprocess +import tempfile from collections import defaultdict from datetime import datetime from os.path import abspath, exists @@ -1093,6 +1094,10 @@ def comment_on_gh(self, comment): """ Send a commment to dashboard """ + with tempfile.NamedTemporaryFile(mode="w", delete=False) as f: + f.write(comment) + filename = f.name + subprocess.check_call( [ self.args.dashboard_gh_cli_path, @@ -1100,11 +1105,13 @@ def comment_on_gh(self, comment): "comment", "--repo=https://github.com/pytorch/torchdynamo.git", "681", - "-b", - comment, + "-F", + filename, ] ) + os.remove(filename) + def update(self): self.upload_graphs() AccuracyRegressionTracker(self.args).generate_comment() From 3d79ced8cfb2ddd250f9a31dad9b990c120e6dab Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sat, 12 Nov 2022 14:20:41 +0000 Subject: [PATCH 138/453] wrap_pybind_function: support member function pointers (#88932) This updates `wrap_pybind_function` to use `invoke` and adds the `invoke_traits` object which is analogous to `function_traits` but for member functions it includes the class as an explicit argument. To test this is working properly, I've also applied it to the `CUDAGraph` binding code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88932 Approved by: https://github.com/albanD --- aten/src/ATen/detail/FunctionTraits.h | 24 ++++++++++++++++++++++++ torch/csrc/Exceptions.h | 11 ++++++----- torch/csrc/cuda/Graph.cpp | 10 +++++----- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/detail/FunctionTraits.h b/aten/src/ATen/detail/FunctionTraits.h index aab7300b585f..f49a55e1326d 100644 --- a/aten/src/ATen/detail/FunctionTraits.h +++ b/aten/src/ATen/detail/FunctionTraits.h @@ -76,3 +76,27 @@ struct binary_function_traits { using arg1_t = typename traits::template arg<0>::type; using arg2_t = typename traits::template arg<1>::type; }; + + +// Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType +template +struct invoke_traits : public function_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : public invoke_traits{ +}; + +template +struct invoke_traits : + public function_traits { +}; + +template +struct invoke_traits : + public function_traits { +}; diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index c9069a4a7c5b..01caa6a702c0 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -375,17 +376,17 @@ struct PyWarningHandler { namespace detail { template -using Arg = typename function_traits::template arg::type; +using Arg = typename invoke_traits::template arg::type; template auto wrap_pybind_function_impl_(Func&& f, std::index_sequence) { - using traits = function_traits; + using result_type = typename invoke_traits::result_type; namespace py = pybind11; // f=f is needed to handle function references on older compilers - return [f = f](Arg... args) -> typename traits::result_type { + return [f = std::forward(f)](Arg... args) -> result_type { HANDLE_TH_ERRORS - return f(std::forward>(args)...); + return c10::guts::invoke(f, std::forward>(args)...); END_HANDLE_TH_ERRORS_PYBIND }; } @@ -395,7 +396,7 @@ auto wrap_pybind_function_impl_(Func&& f, std::index_sequence) { // Returns a function object suitable for registering with pybind11. template auto wrap_pybind_function(Func&& f) { - using traits = function_traits; + using traits = invoke_traits; return torch::detail::wrap_pybind_function_impl_( std::forward(f), std::make_index_sequence{}); } diff --git a/torch/csrc/cuda/Graph.cpp b/torch/csrc/cuda/Graph.cpp index 0866b82f659d..6d3a77c365e1 100644 --- a/torch/csrc/cuda/Graph.cpp +++ b/torch/csrc/cuda/Graph.cpp @@ -30,23 +30,23 @@ void THCPGraph_init(PyObject* module) { // docs aren't clear. But it works. .def( "capture_begin", - &::at::cuda::CUDAGraph::capture_begin, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_begin), py::call_guard(), py::arg("pool") = c10::cuda::MempoolId_t{0, 0}) .def( "capture_end", - &::at::cuda::CUDAGraph::capture_end, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::capture_end), py::call_guard()) .def( "replay", - &::at::cuda::CUDAGraph::replay, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::replay), py::call_guard()) .def( "reset", - &::at::cuda::CUDAGraph::reset, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::reset), py::call_guard()) .def( "pool", - &::at::cuda::CUDAGraph::pool, + torch::wrap_pybind_function(&at::cuda::CUDAGraph::pool), py::call_guard()); } From e0c194f10b20a5ab2ad8d2075bec81ca57320268 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2022 19:06:38 +0000 Subject: [PATCH 139/453] Fix typos in messages under torch (#88961) This PR fixes typos of messages and parms in c++ source and head files under `torch` directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88961 Approved by: https://github.com/albanD --- torch/csrc/api/src/nn/modules/transformer.cpp | 4 ++-- torch/csrc/api/src/optim/optimizer.cpp | 4 ++-- torch/csrc/autograd/FunctionsManual.cpp | 2 +- torch/csrc/autograd/custom_function.h | 2 +- torch/csrc/autograd/python_variable.cpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 4 ++-- torch/csrc/distributed/c10d/UCCUtils.cpp | 2 +- torch/csrc/distributed/c10d/init.cpp | 2 +- torch/csrc/distributed/rpc/utils.cpp | 2 +- .../xnnpack/compiler/xnn_compiler.cpp | 2 +- torch/csrc/jit/codegen/cuda/arith.cpp | 10 +++++----- .../csrc/jit/codegen/cuda/compute_at_map.cpp | 2 +- torch/csrc/jit/codegen/cuda/disjoint_set.h | 2 +- .../csrc/jit/codegen/cuda/executor_utils.cpp | 6 +++--- torch/csrc/jit/codegen/cuda/ir_nodes.cpp | 8 ++++---- torch/csrc/jit/codegen/cuda/kernel_cache.cpp | 4 ++-- .../csrc/jit/codegen/cuda/lower_expr_sort.cpp | 6 +++--- .../cuda/lower_misaligned_vectorization.cpp | 2 +- .../cuda/lower_predicate_elimination.cpp | 2 +- .../jit/codegen/cuda/lower_validation.cpp | 2 +- .../csrc/jit/codegen/cuda/root_domain_map.cpp | 2 +- .../jit/codegen/cuda/scheduler/mma_utils.cpp | 2 +- .../csrc/jit/codegen/cuda/scheduler/utils.cpp | 2 +- torch/csrc/jit/codegen/cuda/tensor_view.cpp | 2 +- .../csrc/jit/codegen/cuda/test/test_gpu3.cpp | 2 +- .../jit/codegen/cuda/transform_rfactor.cpp | 2 +- .../csrc/jit/codegen/cuda/transform_view.cpp | 2 +- torch/csrc/jit/frontend/ir_emitter.cpp | 2 +- torch/csrc/jit/ir/irparser.cpp | 2 +- torch/csrc/jit/mobile/flatbuffer_loader.cpp | 2 +- .../jit/passes/onnx/shape_type_inference.cpp | 2 +- torch/csrc/jit/passes/peephole_non_tensor.cpp | 2 +- .../quantization/insert_quant_dequant.cpp | 2 +- torch/csrc/jit/python/script_init.cpp | 2 +- torch/csrc/jit/runtime/graph_executor.cpp | 7 +++---- torch/csrc/jit/runtime/static/ops.cpp | 2 +- .../jit/serialization/export_bytecode.cpp | 2 +- .../csrc/jit/serialization/export_module.cpp | 2 +- torch/csrc/jit/serialization/import.cpp | 2 +- torch/csrc/jit/tensorexpr/kernel.cpp | 2 +- torch/csrc/jit/tensorexpr/llvm_codegen.cpp | 20 +++++++++---------- torch/csrc/lazy/core/config.cpp | 2 +- torch/csrc/profiler/util.cpp | 4 ++-- 43 files changed, 70 insertions(+), 71 deletions(-) diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index 6d643fc7354f..df08c629da56 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -466,7 +466,7 @@ Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { // Treat 0 dim valid here TORCH_CHECK( sz >= 0, - "Input size must be non-negative to genearte a valid square subsequent mask, but got ", + "Input size must be non-negative to generate a valid square subsequent mask, but got ", sz); // check IEEE754 support here since -inf is not guaranteed to be valid on non @@ -479,7 +479,7 @@ Tensor TransformerImpl::generate_square_subsequent_mask(int64_t sz) { // platform else { TORCH_WARN_ONCE( - "IEEE754 is not supporetd on this platform, generate_square_subsequent_mask will fill " + "IEEE754 is not supported on this platform, generate_square_subsequent_mask will fill " "the mask with smallest float number on this platform instead of -inf"); return torch::triu( torch::full({sz, sz}, std::numeric_limits::lowest()), 1); diff --git a/torch/csrc/api/src/optim/optimizer.cpp b/torch/csrc/api/src/optim/optimizer.cpp index 95165d850cf6..f73e54d2835f 100644 --- a/torch/csrc/api/src/optim/optimizer.cpp +++ b/torch/csrc/api/src/optim/optimizer.cpp @@ -64,13 +64,13 @@ void OptimizerParamState::serialize( double OptimizerOptions::get_lr() const { TORCH_CHECK( false, - "double get_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + "double get_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } void OptimizerOptions::set_lr(const double lr) { TORCH_CHECK( false, - "double set_lr() has not been overidden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); + "double set_lr() has not been overridden and implemented in subclass of torch::optim::OptimizerOptions, you must override it in your subclass."); } std::unique_ptr OptimizerOptions::clone() const { diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 5a3f96d47e30..c0fbf5f6c0aa 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -4846,7 +4846,7 @@ Tensor log1p_backward(const Tensor& grad, const Tensor& self) { // materialized so if self is strided and grad is sparse nothing unepected // happens memory wise TORCH_WARN( - "log1p_backward: recived self with sparse layout, but backward requires materialization of a dense tensor with this shape"); + "log1p_backward: received self with sparse layout, but backward requires materialization of a dense tensor with this shape"); self_p1_conj = (self.to_dense() + 1).conj(); } else { // Although calling self.to_dense() would just return self when it has diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index bc7489292c23..d7670d924b1f 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -300,7 +300,7 @@ auto Function::apply(Args&&... args) TORCH_CHECK( false, "jvp is not implemented for the c++ API of custom Function yet.", - "Please open a feature request on Github if you need this."); + "Please open a feature request on GitHub if you need this."); }; auto wrapped_outputs = _wrap_outputs( diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 002b904d4072..e3ab10c7499c 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -857,7 +857,7 @@ static PyObject* THPVariable_make_wrapper_subclass( if (sizes_strides_policy.has_value()) { TORCH_CHECK( false, - "Setting sizes_strides_policy isn't suppored for this overload") + "Setting sizes_strides_policy isn't supported for this overload") } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index c92d24af21c8..1d788a2c2e0c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1120,7 +1120,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( "[", rank_, "] is setting up NCCL communicator and " - "retreiving ncclUniqueId from [0] via c10d key-value store by key '", + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", storeKey, "', but store->get('", storeKey, @@ -1133,7 +1133,7 @@ void ProcessGroupNCCL::broadcastUniqueNCCLID( "Unknown exception while [", rank_, "] is setting up NCCL communicator and " - "retreiving ncclUniqueId from [0] via c10d key-value store by key '", + "retrieving ncclUniqueId from [0] via c10d key-value store by key '", storeKey, "'")); } diff --git a/torch/csrc/distributed/c10d/UCCUtils.cpp b/torch/csrc/distributed/c10d/UCCUtils.cpp index ef934d1597f9..590a931f2f11 100644 --- a/torch/csrc/distributed/c10d/UCCUtils.cpp +++ b/torch/csrc/distributed/c10d/UCCUtils.cpp @@ -186,7 +186,7 @@ void CommUCC::free_request(ucc_coll_req_h request) { CommUCC::~CommUCC() { if (context != nullptr) { TORCH_UCC_CHECK( - ucc_context_destroy(context), "failed to destory UCC context"); + ucc_context_destroy(context), "failed to destroy UCC context"); } if (lib != nullptr) { TORCH_UCC_CHECK(ucc_finalize(lib), "failed to finalize UCC library"); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2424506eef0f..313aabee7cd9 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1906,7 +1906,7 @@ Example:: Returns: A ``Work`` object which is associated with the completion of the ``torch.futures.Future``. - This is the prefered way of constructing Work objects when writing a custom ProcessGroup + This is the preferred way of constructing Work objects when writing a custom ProcessGroup in python. Example:: >>> class SingleRankProcessGroup(torch.distributed.ProcessGroup): diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp index 12e3f2edf755..c20145e82d03 100644 --- a/torch/csrc/distributed/rpc/utils.cpp +++ b/torch/csrc/distributed/rpc/utils.cpp @@ -38,7 +38,7 @@ void processRemoteProfiledEvents( TORCH_CHECK( enabled, "Profiler was expected to be enabled. This can happen in callback " - " continutations that run in different threads, and the TLS of the " + " continuations that run in different threads, and the TLS of the " " profiler was not propagated."); std::vector events = rpcWithProfilingResp.getProfiledEvents(); const auto& profilingId = rpcWithProfilingResp.getProfilingId(); diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index 395d59a1cf21..3bbff2309904 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -55,7 +55,7 @@ XNNExecutor XNNCompiler::compileModel(std::string ser_model) { auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx != 0) { // TODO: @maxren implement data handling - TORCH_CHECK(false, "Cosntant data handling not yet implemented") + TORCH_CHECK(false, "Constant data handling not yet implemented") } std::vector dims_data; for (auto dim : *tensor_value->dims()) { diff --git a/torch/csrc/jit/codegen/cuda/arith.cpp b/torch/csrc/jit/codegen/cuda/arith.cpp index 8e8d82128512..d4e1348ee693 100644 --- a/torch/csrc/jit/codegen/cuda/arith.cpp +++ b/torch/csrc/jit/codegen/cuda/arith.cpp @@ -1094,7 +1094,7 @@ static TensorView* newForReduction( TORCH_INTERNAL_ASSERT( !axes_set.empty(), - "Asked for ouput of reduction, but no reduction axis provided."); + "Asked for output of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain.size(), @@ -1183,7 +1183,7 @@ TensorView* reductionOp( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, @@ -1518,7 +1518,7 @@ WelfordResult Welford( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, @@ -2228,7 +2228,7 @@ static TensorView* newForMma( TORCH_INTERNAL_ASSERT( !axes_set.empty(), - "Asked for ouput of reduction, but no reduction axis provided."); + "Asked for output of reduction, but no reduction axis provided."); TORCH_INTERNAL_ASSERT( (*(axes_set.rbegin())) < orig_domain_a.size(), @@ -2319,7 +2319,7 @@ TensorView* fusedMultiplySum( TORCH_CHECK( axis >= 0 && axis < ndims, - "Reduction on invalid axis, recieved: ", + "Reduction on invalid axis, received: ", axis, " however tensor view only has ", ndims, diff --git a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp index 7f3de6687eb3..1c2ac627b575 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at_map.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at_map.cpp @@ -331,7 +331,7 @@ void IterDomainGraph::build(Fusion* fusion) { c_tv->getRootDomain().size() == first_output_tv->getRootDomain().size(), "Multiple outputs with mismatched dimensions is not supported. ", - "Only supported case is welford op where all outputs tvs have idential domains."); + "Only supported case is welford op where all outputs tvs have identical domains."); // p->f, c->c std::unordered_map c2f_root_map; for (const auto i : diff --git a/torch/csrc/jit/codegen/cuda/disjoint_set.h b/torch/csrc/jit/codegen/cuda/disjoint_set.h index 09cf6e8de950..8fd60dab5bd2 100644 --- a/torch/csrc/jit/codegen/cuda/disjoint_set.h +++ b/torch/csrc/jit/codegen/cuda/disjoint_set.h @@ -260,7 +260,7 @@ class DisjointSets { entry_it != disjointSetMap().end(), "Strict mapping failed on element: ", abstractToString(entry0), - " either an error occured, or non strict mapping should have been used."); + " either an error occurred, or non strict mapping should have been used."); return entry_it->second->has(entry1); } diff --git a/torch/csrc/jit/codegen/cuda/executor_utils.cpp b/torch/csrc/jit/codegen/cuda/executor_utils.cpp index 6da05cbf4dcb..217480a974ed 100644 --- a/torch/csrc/jit/codegen/cuda/executor_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/executor_utils.cpp @@ -155,7 +155,7 @@ bool validateKernelArgTensor( } if (!is_cpu_scalar(arg) && !arg.is_cuda()) { - msg << "Argumnet is a CPU tensor which is not supported in fusions.\n"; + msg << "Argument is a CPU tensor which is not supported in fusions.\n"; return false; } @@ -824,7 +824,7 @@ void bindInputForExprEvaluation( if (root_domain[dim]->hasExpandedExtent()) { TORCH_INTERNAL_ASSERT( tensor_arg_stride == 0, - "Execting an expanded dimension on dimension ", + "Expecting an expanded dimension on dimension ", dim, " but found stride ", tensor_arg_stride); @@ -838,7 +838,7 @@ void bindInputForExprEvaluation( *maybe_expanded_size == tensor_arg_size, "Expecting expanded extent of ", *maybe_expanded_size, - " but recieved value of ", + " but received value of ", tensor_arg_size); } } diff --git a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp index 3b51b807a727..c4d994f272be 100644 --- a/torch/csrc/jit/codegen/cuda/ir_nodes.cpp +++ b/torch/csrc/jit/codegen/cuda/ir_nodes.cpp @@ -600,7 +600,7 @@ BroadcastOp::BroadcastOp( id->isReduction() || id->isStride(), "Invalid broadcast op: ", id, - ". Non-reduction input dim does't match to output."); + ". Non-reduction input dim doesn't match to output."); } } @@ -2060,7 +2060,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", root_domain_.size()); @@ -2084,7 +2084,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", root_domain_.size()); @@ -2124,7 +2124,7 @@ TensorDomain::TensorDomain( : std::move(contiguity)) { TORCH_CHECK( contiguity_.size() == getMaybeRFactorDomain().size(), - "Invalid contiguity information provided, incorrect size. Recieved vector of size ", + "Invalid contiguity information provided, incorrect size. Received vector of size ", contiguity_.size(), " but needed one of size ", getMaybeRFactorDomain().size()); diff --git a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp index efcc51f231b2..c4604042bfae 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_cache.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_cache.cpp @@ -493,7 +493,7 @@ void FusionKernelRuntime::startAsyncCompile(KernelArgumentHolder& args_old) { TORCH_INTERNAL_ASSERT( args.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", + "Inputs were not set up correctly, received ", args.size(), " inputs but expecting ", segmented_fusion_->inputs().size()); @@ -610,7 +610,7 @@ std::vector FusionKernelRuntime::runWithInput( TORCH_INTERNAL_ASSERT( args.size() == segmented_fusion_->inputs().size(), - "Inputs were not set up correctly, recieved ", + "Inputs were not set up correctly, received ", args.size(), " inputs but expecting ", segmented_fusion_->inputs().size()); diff --git a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp index 7f06aea2f542..1e2806b11fd4 100644 --- a/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_expr_sort.cpp @@ -927,8 +927,8 @@ bool ExprSegmentationSorter::interIterUpdate() { // If we didn't finish and we tried the fallback, throw. TORCH_INTERNAL_ASSERT( !fallback_mode_enabled_, - "Couldn't succcessfully sort out the fusion expressions. ", - "There are remaining connections of the heirarchical segmentation which should have been ", + "Couldn't successfully sort out the fusion expressions. ", + "There are remaining connections of the hierarchical segmentation which should have been ", "flattened to a single ordered group, or disjoint ordered groups."); // We didn't finish, but we haven't tried the fallback, try again with that. fallback_mode_enabled_ = true; @@ -1066,7 +1066,7 @@ void ExprSegmentationSorter::initializeForLoopDependencies() { } } - std::cerr << "Depdencies: " << std::endl; + std::cerr << "Dependencies: " << std::endl; for (const auto& dep_entry : concrete_id_dependencies) { std::cerr << " Deps of " << dep_entry.first->toString() << std::endl << " "; diff --git a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp index bd3c9baf66e1..9e713f4cf3a2 100644 --- a/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_misaligned_vectorization.cpp @@ -462,7 +462,7 @@ class MisalignedVectorizationModifier : public kir::ExprMutator { TORCH_INTERNAL_ASSERT( !gpu_lower->trivialReductionInfo().isDerived(producer_root_id), - "No trivial reduciton axis should exist: ", + "No trivial reduction axis should exist: ", producer_root_id); // If the producer ID is reduction or broadcast, it should be safe diff --git a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp index 38df8229bb77..294a2327bbba 100644 --- a/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_predicate_elimination.cpp @@ -925,7 +925,7 @@ bool PredicateElimination::setReductionInitValue( } else { TORCH_INTERNAL_ASSERT( false, - "Incosistent setting of initialization value for t", + "Inconsistent setting of initialization value for t", tv->name(), ". Prev: ", existing_val->toString(), diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index da1def37cad8..f6f71c2ec123 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -86,7 +86,7 @@ class ValidateSiblings : public IterVisitor { auto sibling_id = it->second; TORCH_INTERNAL_ASSERT( sibling->axis(i) == sibling_id, - "Invalid matching sinbling ID detected. Expr: ", + "Invalid matching sibling ID detected. Expr: ", expr->toString(), "Sibling ID: ", sibling_id->toString()); diff --git a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp index 09a740d01097..235d257e2351 100644 --- a/torch/csrc/jit/codegen/cuda/root_domain_map.cpp +++ b/torch/csrc/jit/codegen/cuda/root_domain_map.cpp @@ -486,7 +486,7 @@ bool ComputeAtRootDomainMap::canMap( const IterDomain* id_b) const { TORCH_INTERNAL_ASSERT( id_b->definition() == nullptr || id_b->isRFactorProduct(), - "Non-root domain is not supproted: ", + "Non-root domain is not supported: ", id_b); if (!id_b->isBroadcast()) { diff --git a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp index 1991cada00dd..ddf1061591ed 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/mma_utils.cpp @@ -208,7 +208,7 @@ std::vector getMmaDomains(MmaOp* mma, MmaDimension dimension) { TORCH_CHECK( a_domain.size() == b_domain.size() && a_domain.size() == accumulator_domain.size(), - "Inconsisitent dimensions in mma op", + "Inconsistent dimensions in mma op", a_domain.size(), " ", b_domain.size(), diff --git a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp index d985da926354..4ba6b241e455 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/utils.cpp @@ -2366,7 +2366,7 @@ std::unordered_map domainReorderAsRfactorMap(TensorView* tv) { // Should be impossible. TORCH_INTERNAL_ASSERT( pos0 != pos1, - "Didn't expect merge inputs to be the same iteratrion domain:\n", + "Didn't expect merge inputs to be the same iteration domain:\n", merge->toString()); reordered_ids.erase(reordered_ids.begin() + pos0); diff --git a/torch/csrc/jit/codegen/cuda/tensor_view.cpp b/torch/csrc/jit/codegen/cuda/tensor_view.cpp index 633c98102e2e..85f320fef2e4 100644 --- a/torch/csrc/jit/codegen/cuda/tensor_view.cpp +++ b/torch/csrc/jit/codegen/cuda/tensor_view.cpp @@ -757,7 +757,7 @@ TensorView* TensorView::rFactor(const std::vector& axes) { TORCH_CHECK( !definition()->isA(), - "For GroupedReducitonOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); + "For GroupedReductionOp, use TensorView::rFactor(const std::vector& axes, const std::vector& tvs)"); // Split tensor view into 2 parts auto domain_pair = domain()->rFactor(axes); diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp index a8fb439af14f..8d24cc380374 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp @@ -4096,7 +4096,7 @@ TEST_F(NVFuserTest, FusionUnsqueeze1_CUDA) { fusion.addOutput(tv2); TORCH_CHECK( - tv2->nDims() == 2, "Unpected unsqueeze result: ", tv2->toString()); + tv2->nDims() == 2, "Unexpected unsqueeze result: ", tv2->toString()); TORCH_CHECK( tv2->axis(1)->isBroadcast(), "Unexpected unsqueeze result: ", diff --git a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp index dc5973c0ecd6..8d5151074563 100644 --- a/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_rfactor.cpp @@ -262,7 +262,7 @@ std::pair TransformRFactor::runReplay( std::transform(axes.begin(), axes.end(), axes.begin(), [ndims](int i) { TORCH_CHECK( i >= -ndims && i < ndims, - "Rfactor replay recieved an axis outside the number of dims in the tensor, acceptable inclusive range is ", + "Rfactor replay received an axis outside the number of dims in the tensor, acceptable inclusive range is ", -ndims, " to ", ndims - 1); diff --git a/torch/csrc/jit/codegen/cuda/transform_view.cpp b/torch/csrc/jit/codegen/cuda/transform_view.cpp index e5f9c068f16c..a543c6d0f79c 100644 --- a/torch/csrc/jit/codegen/cuda/transform_view.cpp +++ b/torch/csrc/jit/codegen/cuda/transform_view.cpp @@ -732,7 +732,7 @@ AnalyzeViewResult analyzeView( FUSER_PERF_SCOPE("analyzeView"); TORCH_INTERNAL_ASSERT( original_sizes.size() > 0, - "Empty original size not supported for view operatioon."); + "Empty original size not supported for view operation."); TORCH_INTERNAL_ASSERT( TensorDomain::noReductions(original_view_tv->getMaybeRFactorDomain()) diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index d60dd77bc8da..7c53dbd0b339 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -5640,7 +5640,7 @@ void CompilationUnit::define_interface( for (const Stmt& stmt : classDef.body()) { if (stmt.kind() != TK_DEF) { throw ErrorReport(stmt) - << "interface declartions can only contain method definitions"; + << "interface declarations can only contain method definitions"; } auto method_def = Def(stmt); if (!method_def.decl().return_type().present()) { diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index 1f790de92cb1..0673645731da 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -237,7 +237,7 @@ ParsedLiteral IRParser::parseScalarLiteral(Node* n) { auto text = L.expect(TK_NUMBER); if (!parse_tensor_constants_) { throw ErrorReport(token.range) - << "Single-element tensor constant encoutered but " + << "Single-element tensor constant encountered but " << "`parse_tensor_constants` is set to false " << token.text(); } L.expect('}'); diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 45e31fb5e174..29c29925ef09 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -718,7 +718,7 @@ void FlatbufferLoader::extractJitSourceAndConstants( std::vector* constants) { AT_ASSERT( module_parsed_, - "Need to first parse a flatbuffer file before extracing jit_sources"); + "Need to first parse a flatbuffer file before extracting jit_sources"); const auto* ivalues = module_->ivalues(); for (uint32_t i = mobile_ivalue_size_; i < ivalues->size(); i++) { diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index f646fe77e07a..8baa439bdb58 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -2228,7 +2228,7 @@ size_t ONNXAssignOutputShape( auto& new_var = THPVariable_Unpack(list_elem); TORCH_CHECK( var.scalar_type() == new_var.scalar_type(), - "Unsupported sequence with mixed elment types in model outputs. " + "Unsupported sequence with mixed element types in model outputs. " "ONNX supports only sequences of elements of the same data type."); } auto elem_type = graph->outputs() diff --git a/torch/csrc/jit/passes/peephole_non_tensor.cpp b/torch/csrc/jit/passes/peephole_non_tensor.cpp index c114ea759e52..10ff3db0586a 100644 --- a/torch/csrc/jit/passes/peephole_non_tensor.cpp +++ b/torch/csrc/jit/passes/peephole_non_tensor.cpp @@ -15,7 +15,7 @@ namespace { * constant int value if there exists one. * * @pre node is integer arithmetic. - * @post if there's one constant in two oprands, then the second operand is + * @post if there's one constant in two operands, then the second operand is * constant. */ c10::optional checkArithNode(Node& node) { diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp index 54bd6679980e..3270ef4ced82 100644 --- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp +++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp @@ -1554,7 +1554,7 @@ Node* insertQuantDequantNodes( void checkCalculateQParamsResultTypes(const Node* out) { TORCH_CHECK( out->outputs().size() == 2, - "cacluate_qparams should produce output of size 2 (scale, zero_point)."); + "calculate_qparams should produce output of size 2 (scale, zero_point)."); Value* scale = out->output(0); Value* zp = out->output(1); TORCH_CHECK( diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 030e2525a163..2c6f8b1daca8 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -1774,7 +1774,7 @@ void initJitScriptBindings(PyObject* module) { if (def.kind() != TK_DEF) { throw ErrorReport(def.range()) << "Currently class bodies can only contain method " - "definitions. File an issue on Github if you want " + "definitions. File an issue on GitHub if you want " "something else!"; } methodDefs.emplace_back(Def(def)); diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index c2c84eb9e4e4..88a092c39fe0 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -922,13 +922,12 @@ void runNondiffOptimization( std::shared_ptr& graph, bool strict_fuser_check) { GRAPH_DEBUG( - "Before customPrePassses (beginning of runNondiffOptimization)\n", - *graph); + "Before customPrePasses (beginning of runNondiffOptimization)\n", *graph); // Run custom passes that different backends can register. for (const auto& passPair : getCustomPrePasses()) { passPair.first(graph); } - GRAPH_DEBUG("After customPrePassses\n", *graph); + GRAPH_DEBUG("After customPrePasses\n", *graph); // decomposition pass, decompose certain ops that will be used in the // following passes (like batchmm and jit fusion) @@ -960,7 +959,7 @@ void runNondiffOptimization( passPair.first(graph); } GRAPH_DEBUG( - "After customPostPassses (end of runNondiffOptimization)\n", *graph); + "After customPostPasses (end of runNondiffOptimization)\n", *graph); } void runOptimization( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 92044ca565a9..834a71b08161 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -45,7 +45,7 @@ C10_DEFINE_bool( static_runtime_enable_fast_math, true, - "If on, static runtime may use use optimizations that cause accurary loss " + "If on, static runtime may use use optimizations that cause accuracy loss " "vs the jit interpreter"); namespace at { diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp index b56c4980211a..6f30f82899ed 100644 --- a/torch/csrc/jit/serialization/export_bytecode.cpp +++ b/torch/csrc/jit/serialization/export_bytecode.cpp @@ -212,7 +212,7 @@ mobile::Code compileGraphToMobileCode( for (const TypePtr& element_type : input_type->containedTypes()) { TORCH_CHECK( element_type->kind() != TypeKind::ClassType, - "Returining a list or dictionary with pytorch class type ", + "Returning a list or dictionary with pytorch class type ", "is not supported in mobile module " "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). " "Workaround: instead of using pytorch class as their element type, ", diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index b29f1e2914c0..90f9f9411b38 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -95,7 +95,7 @@ ExportModuleExtraFilesHook& GetExtraFilesHook() { * ] * ]" * - * @param compilation_unit Jit compilcation unit to look up function schema. + * @param compilation_unit Jit compilation unit to look up function schema. * @param type_ptr A type pointer and it can be possibly any type. * @param default_type_str The default string representation. The string can * either from type_ptr->str(), type_ptr->annotation_str(), or diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index a72abeaede8e..b79d29726bef 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -444,7 +444,7 @@ Module _load_jit_module_from_bytes( std::shared_ptr cu, c10::optional device, ExtraFilesMap& extra_files) { - TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecorgnized data format"); + TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format"); auto format = getFileFormat(data.get()); switch (format) { case FileFormat::FlatbufferFileFormat: { diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index c30ed316e48b..eb108abfb029 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -29,7 +29,7 @@ namespace tensorexpr { std::string buildErrorMessage(const std::string& s) { static const std::string generic_error_message = - "This error occured in the fuser. You can turn off the fuser with " + "This error occurred in the fuser. You can turn off the fuser with " "torch.jit.enable_fusion(False)."; if (s.empty()) { return generic_error_message; diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 1ca5665b4432..f6801973dd6b 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -694,7 +694,7 @@ void LLVMCodeGenImpl::visit(AddPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateAdd(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Add", v); + throw malformed_input("llvm_codegen: bad type in Add", v); } } @@ -712,7 +712,7 @@ void LLVMCodeGenImpl::visit(SubPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateSub(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Sub", v); + throw malformed_input("llvm_codegen: bad type in Sub", v); } } @@ -730,7 +730,7 @@ void LLVMCodeGenImpl::visit(MulPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateMul(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Mul", v); + throw malformed_input("llvm_codegen: bad type in Mul", v); } } @@ -748,7 +748,7 @@ void LLVMCodeGenImpl::visit(DivPtr v) { } else if (!lfp && !rfp) { value_ = irb_.CreateSDiv(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Div", v); + throw malformed_input("llvm_codegen: bad type in Div", v); } } @@ -763,7 +763,7 @@ void LLVMCodeGenImpl::visit(AndPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateAnd(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in And", v); + throw malformed_input("llvm_codegen: bad type in And", v); } } @@ -778,7 +778,7 @@ void LLVMCodeGenImpl::visit(OrPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateOr(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Or", v); + throw malformed_input("llvm_codegen: bad type in Or", v); } } @@ -793,7 +793,7 @@ void LLVMCodeGenImpl::visit(XorPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateXor(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Xor", v); + throw malformed_input("llvm_codegen: bad type in Xor", v); } } @@ -808,7 +808,7 @@ void LLVMCodeGenImpl::visit(LshiftPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateShl(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Lshift", v); + throw malformed_input("llvm_codegen: bad type in Lshift", v); } } @@ -827,7 +827,7 @@ void LLVMCodeGenImpl::visit(RshiftPtr v) { value_ = irb_.CreateLShr(lhs, rhs); } } else { - throw malformed_input("llvm_codgen: bad type in Rshift", v); + throw malformed_input("llvm_codegen: bad type in Rshift", v); } } @@ -842,7 +842,7 @@ void LLVMCodeGenImpl::visit(ModPtr v) { if (!lfp && !rfp) { value_ = irb_.CreateSRem(lhs, rhs); } else { - throw malformed_input("llvm_codgen: bad type in Mod", v); + throw malformed_input("llvm_codegen: bad type in Mod", v); } } diff --git a/torch/csrc/lazy/core/config.cpp b/torch/csrc/lazy/core/config.cpp index d87036767be5..c39fd8fef75a 100644 --- a/torch/csrc/lazy/core/config.cpp +++ b/torch/csrc/lazy/core/config.cpp @@ -10,7 +10,7 @@ C10_DEFINE_bool( C10_DEFINE_bool( torch_lazy_handle_special_scalars, false, - "Handle special scalars 0 and 1 diffrently"); + "Handle special scalars 0 and 1 differently"); C10_DEFINE_bool( torch_lazy_all_numbers_special_scalars, diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 08a20c84805e..f4fb4dd1eee1 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -347,7 +347,7 @@ static bool validateInput( const c10::ArrayRef& should_be_tensor) { std::stringstream ss; if (inputs.size() < min_size) { - ss << "Failed to save extra arguments for flops compuation of op " + ss << "Failed to save extra arguments for flops computation of op " << op_name << ", min size: " << min_size << ", actual size: " << inputs.size(); TORCH_WARN(ss.str()); @@ -355,7 +355,7 @@ static bool validateInput( } for (auto index : should_be_tensor) { if (!inputs[index].isTensor()) { - ss << "Failed to save extra arguments for flops compuation of op " + ss << "Failed to save extra arguments for flops computation of op " << op_name << ", input[" << index << "] must be a tensor."; TORCH_WARN(ss.str()); return false; From 540b42a1a883bb56235cdbf0bbbf103041c4dd8c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 14 Nov 2022 19:27:46 +0000 Subject: [PATCH 140/453] [quant][executorch] Support quant fusion for cat in quant in executorch stack (#88960) Summary: * added cat in executorch backend config * added quant fusion for "dq - cat - q" pattern Test Plan: buck run executorch/exir/tests:quant_fusion_pass -- "executorch.exir.tests.test_quant_fusion_pass.TestQuantFusionPass.test_cat" Reviewed By: qihqi Differential Revision: D41111054 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88960 Approved by: https://github.com/JacobSzwejbka --- torch/ao/quantization/backend_config/executorch.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py index 3c729327de76..627143c00099 100644 --- a/torch/ao/quantization/backend_config/executorch.py +++ b/torch/ao/quantization/backend_config/executorch.py @@ -200,6 +200,14 @@ def _get_bn_configs() -> List[BackendPatternConfig]: .set_dtype_configs(dtype_configs)) return bn_configs +def _get_cat_configs() -> List[BackendPatternConfig]: + dtype_configs = [executorch_default_op_quint8_dtype_config] + cat_configs = [] + cat_configs.append( + BackendPatternConfig(torch.cat) + .set_observation_type(ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) + .set_dtype_configs(dtype_configs)) + return cat_configs # ===================== # | BACKEND CONFIGS | @@ -214,4 +222,5 @@ def get_executorch_backend_config() -> BackendConfig: .set_backend_pattern_configs(_get_conv_configs()) \ .set_backend_pattern_configs(_get_binary_ops_configs()) \ .set_backend_pattern_configs(_get_share_qparams_ops_configs()) \ - .set_backend_pattern_configs(_get_bn_configs()) + .set_backend_pattern_configs(_get_bn_configs()) \ + .set_backend_pattern_configs(_get_cat_configs()) From f80992217dd2ae5ca0af5e280388cba6078ef57b Mon Sep 17 00:00:00 2001 From: anjali411 Date: Mon, 14 Nov 2022 14:43:15 +0000 Subject: [PATCH 141/453] Remove skip (#88979) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88979 Approved by: https://github.com/voznesenskym --- test/inductor/test_torchinductor.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ba1f9032d97f..23fb2f7712e0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4791,9 +4791,6 @@ def forward(self, x): for param in model_opt.parameters(): param.add_(1.0) - # Probably fails due to the symint math issue caught while adding - # max_pool2d_with_indices_backward - @unittest.skip("Accuracy failure, needs debugging") def test_accuracy_issue1(self): class Repro(torch.nn.Module): def __init__(self): From 4570bd6030c97577d2fa994857d0a022ef7563a4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 14 Nov 2022 14:34:01 -0500 Subject: [PATCH 142/453] woof (#89010) Signed-off-by: Edward Z. Yang Differential Revision: [D41276175](https://our.internmc.facebook.com/intern/diff/D41276175) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89010 Approved by: https://github.com/bigfootjon --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index bcce2997b25b..49bd2dfed706 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ ![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png) +woof + -------------------------------------------------------------------------------- PyTorch is a Python package that provides two high-level features: From b2082833c6082cbb25caf48bdb8f58c490b2c8a7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Nov 2022 21:21:09 +0000 Subject: [PATCH 143/453] Revert "woof (#89010)" This reverts commit 4570bd6030c97577d2fa994857d0a022ef7563a4. Reverted https://github.com/pytorch/pytorch/pull/89010 on behalf of https://github.com/ezyang due to whoops this actually landed --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 49bd2dfed706..bcce2997b25b 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ ![PyTorch Logo](https://github.com/pytorch/pytorch/blob/master/docs/source/_static/img/pytorch-logo-dark.png) -woof - -------------------------------------------------------------------------------- PyTorch is a Python package that provides two high-level features: From 074278f393e1a31b7ee058479cd5906ae830f5ed Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 14 Nov 2022 21:54:46 +0000 Subject: [PATCH 144/453] [CI] Push `latest` and hash+CUDAver tags (#88971) For nightly docker build to simulate the behavior of `push_nightly_docker_ghcr.yml` Tested in https://github.com/pytorch/pytorch/actions/runs/3465221336/jobs/5787694933 Fixes https://github.com/pytorch/pytorch/issues/88833 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88971 Approved by: https://github.com/seemethere --- .github/workflows/docker-release.yml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index d1b9209c4076..0f9638e210ad 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -91,6 +91,20 @@ jobs: # WITH_PUSH is used here to determine whether or not to add the --push flag run: | make -f docker.Makefile "${BUILD_IMAGE_TYPE}-image" + - name: Push nightly tags + if: ${{ github.event.ref == 'refs/heads/nightly' && matrix.image_type == 'runtime' }} + run: | + PYTORCH_DOCKER_TAG="${PYTORCH_VERSION}-runtime" + CUDA_VERSION=$(python3 -c "import re;print(re.search('CUDA_VERSION\s+=\s+([0-9\.]+)',open('docker.Makefile').read())[1],end='')") + PYTORCH_NIGHTLY_COMMIT=$(docker run ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \ + python -c 'import torch; print(torch.version.git_version[:7],end="")') + docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_DOCKER_TAG}" \ + ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" + docker push ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" + + docker tag ghcr.io/pytorch/pytorch-nightly:"${PYTORCH_NIGHTLY_COMMIT}-cu${CUDA_VERSION}" \ + ghcr.io/pytorch/pytorch-nightly:latest + docker push ghcr.io/pytorch/pytorch-nightly:latest - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() From 3b33a2794e07b5216aa473da67755af3aa6e6433 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Mon, 14 Nov 2022 22:11:29 +0000 Subject: [PATCH 145/453] support running test_mobile_profiler with buck1/buck2 and OSS (#89001) Summary: Internally we are switching to a new version of buck, but we also must keep this working in OSS. Test Plan: Rely on CI. Differential Revision: D41270673 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89001 Approved by: https://github.com/r-barnes, https://github.com/osalpekar, https://github.com/malfet --- .../lite_interpreter_runtime/CMakeLists.txt | 1 + test/cpp/lite_interpreter_runtime/resources.h | 19 +++++++++++ .../test_mobile_profiler.cpp | 34 ++++++++----------- 3 files changed, 34 insertions(+), 20 deletions(-) create mode 100644 test/cpp/lite_interpreter_runtime/resources.h diff --git a/test/cpp/lite_interpreter_runtime/CMakeLists.txt b/test/cpp/lite_interpreter_runtime/CMakeLists.txt index 6a2e6db6eaa9..b75ba4ed984e 100644 --- a/test/cpp/lite_interpreter_runtime/CMakeLists.txt +++ b/test/cpp/lite_interpreter_runtime/CMakeLists.txt @@ -25,6 +25,7 @@ target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_ if(LINUX) target_link_libraries(test_lite_interpreter_runtime PRIVATE "-Wl,--no-as-needed,$,--as-needed") + target_link_libraries(test_lite_interpreter_runtime PRIVATE stdc++fs) endif() if(INSTALL_TEST) diff --git a/test/cpp/lite_interpreter_runtime/resources.h b/test/cpp/lite_interpreter_runtime/resources.h new file mode 100644 index 000000000000..07f13ca8b86a --- /dev/null +++ b/test/cpp/lite_interpreter_runtime/resources.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace torch { +namespace testing { + +/// Gets the path to the resource identified by name. +/// +/// @param name identifies a resource, relative path starting from the +/// repo root +inline auto getResourcePath(std::string name) + -> std::experimental::filesystem::path { + return std::move(name); +} + +} // namespace testing +} // namespace torch diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index 08cb81ae7876..df9cb9cea28c 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -11,6 +11,8 @@ #include +#include "test/cpp/lite_interpreter_runtime/resources.h" + #ifdef EDGE_PROFILER_USE_KINETO namespace torch { namespace jit { @@ -42,16 +44,15 @@ bool checkMetaData( } // namespace TEST(MobileProfiler, ModuleHierarchy) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("to_be_profiled_module.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/to_be_profiled_module.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -95,16 +96,15 @@ TEST(MobileProfiler, ModuleHierarchy) { } TEST(MobileProfiler, Backend) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -130,16 +130,15 @@ TEST(MobileProfiler, Backend) { } TEST(MobileProfiler, BackendMemoryEvents) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { mobile::KinetoEdgeCPUProfiler profiler( bc, @@ -163,13 +162,8 @@ TEST(MobileProfiler, BackendMemoryEvents) { } TEST(MobileProfiler, ProfilerEvent) { - /* - * TODO: Using __FILE__ is unreliable e.g. it fails to resolve correctly when - * using buck2, works ok with buck1 - */ - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); @@ -180,7 +174,7 @@ TEST(MobileProfiler, ProfilerEvent) { torch::profiler::ProfilerPerfEvents.begin(), torch::profiler::ProfilerPerfEvents.end()); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { // Bail if something goes wrong here try { From 911a1349dd5d93b9de62d82f439b09eae9aedb92 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 14 Nov 2022 22:45:50 +0000 Subject: [PATCH 146/453] [Dynamo] Fix torch.is_tensor and torch.overrides.is_tensor_like (#88704) Fixes error from 7k github models: https://github.com/jansel/pytorch-jit-paritybench/blob/master/generated/test_arashwan_matrixnet.py Error: ``` AssertionError: torch.* op returned non-Tensor bool call_function from user code: File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 749, in scatter return scatter_map(inputs) File "/scratch/ybliang/work/repos/pytorch-jit-paritybench/generated/test_arashwan_matrixnet.py", line 741, in scatter_map assert not torch.is_tensor(obj), 'Tensors not supported in scatter.' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88704 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 43 ++++++++++++++++++++++++++++++++ torch/_dynamo/variables/torch.py | 21 +++++++++------- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e270852fc526..e27f7bc5198d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -400,6 +400,23 @@ def fn(a, b): return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3) + def test_is_tensor2(self): + def fn(x): + if torch.is_tensor(x): + return x + 1 + else: + return torch.ones([2, 3]) + + x1 = {"input": torch.rand(2, 3)} + x2 = torch.rand(2, 3) + ref1 = fn(x1) + ref2 = fn(x2) + opt_fn = torch._dynamo.optimize("eager")(fn) + res1 = opt_fn(x1) + res2 = opt_fn(x2) + self.assertEqual(ref1, res1) + self.assertEqual(ref2, res2) + def test_numel(self): def fn(a): return a + a.numel() + torch.numel(a) @@ -1244,6 +1261,32 @@ def f(x): self.assertTrue(same(ref0, res0)) self.assertTrue(same(ref1, res1)) + def test_is_tensor_like2(self): + class MyTensor(object): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.max: + return torch.tensor(123) + return func(*args, **kwargs) + + def fn(x): + if torch.overrides.is_tensor_like(x): + return torch.max(x) + else: + return torch.zeros(1) + + x = MyTensor() + ref0 = fn(x) + ref1 = fn(4) + opt_fn = torch._dynamo.optimize("eager")(fn) + res0 = opt_fn(x) + res1 = opt_fn(4) + self.assertTrue(same(ref0, res0)) + self.assertTrue(same(ref1, res1)) + def test_version_ci(self): # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 0debfe9e9f3c..3b9b552542ac 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -163,8 +163,6 @@ def can_constant_fold_through(self): torch.finfo, torch.iinfo, torch.is_floating_point, - torch.is_tensor, - torch.overrides.is_tensor_like, ): return True return getattr(self.value, "__module__", None) == "math" @@ -177,9 +175,9 @@ def call_function( DynamicShapeVariable, GradModeVariable, TensorVariable, + UserDefinedObjectVariable, ) - # print("CALLING ON TORCH", self.value) from .builder import wrap_fx_proxy constant_args = check_constant_args(args, kwargs) @@ -206,21 +204,26 @@ def call_function( return self._call_cross_entropy_loss(tx, args, kwargs, options) else: unimplemented(f"construct nn.Module: {self.value.__name__}") + elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like): + assert len(args) == 1 + if isinstance(args[0], TensorVariable) or ( + self.value is torch.overrides.is_tensor_like + and isinstance(args[0], UserDefinedObjectVariable) + and hasattr(args[0].value, "__torch_function__") + ): + return ConstantVariable(True, **options) + else: + return ConstantVariable(False, **options) elif ( self.value in ( - torch.is_tensor, torch.is_floating_point, torch.is_complex, - torch.overrides.is_tensor_like, - torch.is_complex, ) and isinstance(args[0], TensorVariable) and args[0].dtype is not None ): - if self.value in (torch.is_tensor, torch.overrides.is_tensor_like): - return ConstantVariable(True, **options) - elif self.value is torch.is_floating_point: + if self.value is torch.is_floating_point: return ConstantVariable(args[0].dtype.is_floating_point, **options) elif self.value is torch.is_complex: return ConstantVariable(args[0].dtype.is_complex, **options) From 3c3bd55bea3424cbfc0c319dcead9c1e5c55646d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 14 Nov 2022 23:24:31 +0000 Subject: [PATCH 147/453] [testing] fix a key in parse_namespace() (#88969) This PR fixes an incorrect key name of `mappings` dict in `parse_namespace()` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88969 Approved by: https://github.com/kit1980 --- test/functorch/xfail_suggester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index 4ae552a44bd3..cdf2cca13671 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -69,7 +69,7 @@ def parse_namespace(base): 'linalg_': 'linalg', '_masked_': '_masked', 'sparse_': 'sparse', - 'speical_': 'special', + 'special_': 'special', } for heading in mappings.keys(): if base.startswith(heading): From c53a5ac6cca7e2e7d7c47b1a816c7eaa2e7a7704 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 14 Nov 2022 23:36:17 +0000 Subject: [PATCH 148/453] Revert "support running test_mobile_profiler with buck1/buck2 and OSS (#89001)" This reverts commit 3b33a2794e07b5216aa473da67755af3aa6e6433. Reverted https://github.com/pytorch/pytorch/pull/89001 on behalf of https://github.com/kit1980 due to Broke trunk / macos-12-py3-x86-64-lite-interpreter / build --- .../lite_interpreter_runtime/CMakeLists.txt | 1 - test/cpp/lite_interpreter_runtime/resources.h | 19 ----------- .../test_mobile_profiler.cpp | 34 +++++++++++-------- 3 files changed, 20 insertions(+), 34 deletions(-) delete mode 100644 test/cpp/lite_interpreter_runtime/resources.h diff --git a/test/cpp/lite_interpreter_runtime/CMakeLists.txt b/test/cpp/lite_interpreter_runtime/CMakeLists.txt index b75ba4ed984e..6a2e6db6eaa9 100644 --- a/test/cpp/lite_interpreter_runtime/CMakeLists.txt +++ b/test/cpp/lite_interpreter_runtime/CMakeLists.txt @@ -25,7 +25,6 @@ target_link_libraries(test_lite_interpreter_runtime PRIVATE torch gtest backend_ if(LINUX) target_link_libraries(test_lite_interpreter_runtime PRIVATE "-Wl,--no-as-needed,$,--as-needed") - target_link_libraries(test_lite_interpreter_runtime PRIVATE stdc++fs) endif() if(INSTALL_TEST) diff --git a/test/cpp/lite_interpreter_runtime/resources.h b/test/cpp/lite_interpreter_runtime/resources.h deleted file mode 100644 index 07f13ca8b86a..000000000000 --- a/test/cpp/lite_interpreter_runtime/resources.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include -#include - -namespace torch { -namespace testing { - -/// Gets the path to the resource identified by name. -/// -/// @param name identifies a resource, relative path starting from the -/// repo root -inline auto getResourcePath(std::string name) - -> std::experimental::filesystem::path { - return std::move(name); -} - -} // namespace testing -} // namespace torch diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index df9cb9cea28c..08cb81ae7876 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -11,8 +11,6 @@ #include -#include "test/cpp/lite_interpreter_runtime/resources.h" - #ifdef EDGE_PROFILER_USE_KINETO namespace torch { namespace jit { @@ -44,15 +42,16 @@ bool checkMetaData( } // namespace TEST(MobileProfiler, ModuleHierarchy) { - auto testModelFile = torch::testing::getResourcePath( - "test/cpp/lite_interpreter_runtime/to_be_profiled_module.ptl"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("to_be_profiled_module.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace.trace"); - mobile::Module bc = _load_for_mobile(testModelFile.string()); + mobile::Module bc = _load_for_mobile(testModelFile); { KinetoEdgeCPUProfiler profiler( bc, @@ -96,15 +95,16 @@ TEST(MobileProfiler, ModuleHierarchy) { } TEST(MobileProfiler, Backend) { - auto testModelFile = torch::testing::getResourcePath( - "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend.trace"); - mobile::Module bc = _load_for_mobile(testModelFile.string()); + mobile::Module bc = _load_for_mobile(testModelFile); { KinetoEdgeCPUProfiler profiler( bc, @@ -130,15 +130,16 @@ TEST(MobileProfiler, Backend) { } TEST(MobileProfiler, BackendMemoryEvents) { - auto testModelFile = torch::testing::getResourcePath( - "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); - mobile::Module bc = _load_for_mobile(testModelFile.string()); + mobile::Module bc = _load_for_mobile(testModelFile); { mobile::KinetoEdgeCPUProfiler profiler( bc, @@ -162,8 +163,13 @@ TEST(MobileProfiler, BackendMemoryEvents) { } TEST(MobileProfiler, ProfilerEvent) { - auto testModelFile = torch::testing::getResourcePath( - "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); + /* + * TODO: Using __FILE__ is unreliable e.g. it fails to resolve correctly when + * using buck2, works ok with buck1 + */ + std::string filePath(__FILE__); + auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); + testModelFile.append("test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); @@ -174,7 +180,7 @@ TEST(MobileProfiler, ProfilerEvent) { torch::profiler::ProfilerPerfEvents.begin(), torch::profiler::ProfilerPerfEvents.end()); - mobile::Module bc = _load_for_mobile(testModelFile.string()); + mobile::Module bc = _load_for_mobile(testModelFile); { // Bail if something goes wrong here try { From 8df64abc6d8cd1de7017096159a93bb9c7c02bc1 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 14 Nov 2022 10:49:20 -0500 Subject: [PATCH 149/453] Fix some naughty uses of reshape/flatten (#88999) Mutating after reshape/flatten is bad! And it turns out the corresponding view operations are guaranteed to work too. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88999 Approved by: https://github.com/albanD --- torch/autograd/gradcheck.py | 2 +- torch/testing/_internal/opinfo/core.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 2f43423a2bd6..46d4f370a99a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -1164,7 +1164,7 @@ def _vec_from_tensor(x, generator, downcast_complex=False): dtype = _to_real_dtype(x.dtype) if downcast_complex else x.dtype values = torch.rand(x_values.numel(), generator=generator) \ .to(dtype=dtype, device=x.device) \ - .reshape(x_values.shape) + .view(x_values.shape) values /= values.norm() vec = torch.sparse_coo_tensor(x._indices(), values, x.size()) else: diff --git a/torch/testing/_internal/opinfo/core.py b/torch/testing/_internal/opinfo/core.py index 1114d6851832..4f4ab79c2256 100644 --- a/torch/testing/_internal/opinfo/core.py +++ b/torch/testing/_internal/opinfo/core.py @@ -1732,11 +1732,11 @@ def generate_elementwise_binary_extremal_value_tensors( lhs = make_tensor( (128, 128), device=device, dtype=dtype, requires_grad=requires_grad ) - lhs.flatten()[::3] = nan + lhs.view(-1)[::3] = nan rhs = make_tensor( (128, 128), device=device, dtype=dtype, requires_grad=requires_grad ) - rhs.flatten()[::3] = nan + rhs.view(-1)[::3] = nan yield SampleInput(lhs, args=(rhs,)) From 92c78f37afca6c1ff6c40be7c7ed44b162b287b4 Mon Sep 17 00:00:00 2001 From: wswartworth Date: Mon, 14 Nov 2022 23:58:46 +0000 Subject: [PATCH 150/453] improving torch.linalg.lstsq documentation formatting (#89013) Fixes #80441 The highlighting in the documentation for torch.linalg.lstsq was incorrect due to a newline that sphinx doesn't parse correctly. Instead of writing the tensors directly, I used randn to generate the tensors. This seems to be more consistent with how other documentation is written. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89013 Approved by: https://github.com/lezcano --- torch/linalg/__init__.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index e78cbbb3be35..3ec9a383546b 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -1084,16 +1084,26 @@ Examples:: - >>> A = torch.tensor([[[10, 2, 3], [3, 10, 5], [5, 6, 12]]], dtype=torch.float) # shape (1, 3, 3) - >>> B = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]], - [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) # shape (2, 3, 3) + >>> A = torch.randn(1,3,3) + >>> A + tensor([[[-1.0838, 0.0225, 0.2275], + [ 0.2438, 0.3844, 0.5499], + [ 0.1175, -0.9102, 2.0870]]]) + >>> B = torch.randn(2,3,3) + >>> B + tensor([[[-0.6772, 0.7758, 0.5109], + [-1.4382, 1.3769, 1.1818], + [-0.3450, 0.0806, 0.3967]], + [[-1.3994, -0.1521, -0.1473], + [ 1.9194, 1.0458, 0.6705], + [-1.1802, -0.9796, 1.4086]]]) >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) >>> torch.dist(X, torch.linalg.pinv(A) @ B) - tensor(2.0862e-07) + tensor(1.5152e-06) >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values >>> torch.dist(S, torch.linalg.svdvals(A)) - tensor(5.7220e-06) + tensor(2.3842e-07) >>> A[:, 0].zero_() # Decrease the rank of A >>> rank = torch.linalg.lstsq(A, B).rank From 0544a32ba35acd8648692a662197e3497654858e Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Tue, 15 Nov 2022 00:48:49 +0000 Subject: [PATCH 151/453] [inductor] fix could not find as_strided with config.triton.mm=triton (#88946) Summary: ReinterpretView doesn't seem to be handled properly with matrix multiply Triton kernels Reviewed By: bertmaher Differential Revision: D40836677 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88946 Approved by: https://github.com/jansel --- torch/_inductor/codegen/common.py | 12 ++++++++++++ torch/_inductor/codegen/triton_template.py | 2 +- torch/_inductor/graph.py | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index a949effb2679..932e8c91bc7d 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -283,6 +283,8 @@ def input(self, name): assert name not in V.graph.removed_buffers, name if name in self.output_buffers: return self.output_buffers[name] + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name if name.startswith("seed"): return self._lookup("seed", self.input_buffers, name) return self._lookup("in_ptr", self.input_buffers, name) @@ -290,6 +292,8 @@ def input(self, name): def output(self, name): name = V.graph.scheduler.mutation_real_name.get(name, name) assert name not in V.graph.removed_buffers, name + if name in self.inplace_buffers: + return self.inplace_buffers[name].inner_name return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name, output_name): @@ -392,6 +396,14 @@ def aliases(self): if other in self.output_buffers: yield self.output_buffers[other], inplaced.inner_name + def is_removed(self, name): + def _is_removed(name, buffers): + return name not in buffers or buffers[name] == "REMOVED" + + return _is_removed(name, self.output_buffers) and _is_removed( + name, self.inplace_buffers + ) + class CSE: """Common subexpression elimination""" diff --git a/torch/_inductor/codegen/triton_template.py b/torch/_inductor/codegen/triton_template.py index 0de771ff6574..cd1c2bed6bb7 100644 --- a/torch/_inductor/codegen/triton_template.py +++ b/torch/_inductor/codegen/triton_template.py @@ -330,7 +330,7 @@ def template_codegen(scheduler, scheduler_node, epilogue): kernel_buf_replace_name = None if could_remove_kernel_buf: for node in epilogue: - if kernel.args.output_buffers[node.get_name()] != "REMOVED": + if not kernel.args.is_removed(node.get_name()): kernel_buf_replace_name = node.get_name() break assert kernel_buf_replace_name is not None diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index f69a891fca7b..e0e41fd8afa5 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,6 +1,7 @@ import logging import operator import os +import re import time import sympy @@ -90,6 +91,9 @@ def get_dtype(self, buffer_name): return self.name_to_buffer[buffer_name].get_dtype() if buffer_name in self.graph_inputs: return self.graph_inputs[buffer_name].get_dtype() + m = re.match(r"as_strided\(([a-zA-Z0-9_]+),", buffer_name) + if m: + return self.get_dtype(m.group(1)) raise KeyError(f"could not find {buffer_name}") def random_seed_buffer(self, device: torch.device): From 7a37bbed15321fa121f628053ee3c93d516700f5 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Mon, 14 Nov 2022 07:40:32 -0500 Subject: [PATCH 152/453] Take input striding for conv fusion op based on eager output (#88864) As https://github.com/pytorch/pytorch/pull/88706, we also change the input stride check using eager output. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88864 Approved by: https://github.com/jgong5, https://github.com/jansel --- torch/_inductor/ir.py | 95 ++++++++++++++----------------------------- 1 file changed, 30 insertions(+), 65 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8a2e26ee9b94..fdc10c9ca16a 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3346,21 +3346,20 @@ def _prepare_convolution_fusion_create( function only supports the CPU device since conv post-op fusion kernel is only supported on CPU right now. """ - - x = cls.require_stride1(cls.realize_input(x)) - weight = cls.require_stride1(cls.realize_input(weight)) - assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] stride = tuple(stride_) padding = tuple(padding_) dilation = tuple(dilation_) assert isinstance(groups, int) - with FakeTensorMode(): - output, *_ = cls.process_kernel( - torch.ops.aten.convolution, - x, - weight, - bias, + with torch._subclasses.FakeTensorMode(): + x_fake = ir_node_to_tensor(x, guard_shape=True) + weight_fake = ir_node_to_tensor(weight, guard_shape=True) + bias_fake = ( + ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias + ) + output = torch.ops.aten.convolution( + x_fake, + weight_fake, + bias_fake, stride, padding, dilation, @@ -3368,29 +3367,18 @@ def _prepare_convolution_fusion_create( [0, 0], groups, ) + req_stride_order = get_stride_order(output.stride()) - output_size = output.shape - weight_shape = [ - sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() - ] - _, _, *kernel_size = weight_shape - output_layout_str = ( - "torch.contiguous_format" if output.is_contiguous() else "torch.channels_last" - ) - - if output_layout_str == "torch.channels_last": - stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1))) - if len(stride_order) < len(output_size): - # add batch dim if it exists - stride_order = [len(stride_order)] + stride_order - else: - stride_order = list(reversed(range(len(output_size)))) + x = cls.require_stride_order(x, req_stride_order) + weight = cls.require_stride1(cls.realize_input(weight)) + assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" + inputs = [x, weight] - kernel_layout = FlexibleLayout( - device=inputs[0].get_device(), - dtype=inputs[0].get_dtype(), - size=output_size, - stride_order=stride_order, + kernel_layout = FixedLayout( + x.get_device(), + x.get_dtype(), + output.size(), + output.stride(), ) constant_args = [padding, stride, dilation, groups] @@ -3398,7 +3386,7 @@ def _prepare_convolution_fusion_create( inputs.append(bias) else: constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout + return inputs, constant_args, kernel_layout, req_stride_order class ConvolutionUnary(ExternKernelAlloc): @@ -3436,7 +3424,7 @@ def create( algorithm, ): kernel = "torch.ops.mkldnn._convolution_pointwise" - (inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create( + (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) constant_args = constant_args + [attr, scalars, algorithm] @@ -3447,13 +3435,6 @@ def create( kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - class ConvolutionBinary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._convolution_pointwise.binary" @@ -3493,10 +3474,15 @@ def create( unary_algorithm: Optional[str], ): kernel = "torch.ops.mkldnn._convolution_pointwise.binary" - (inputs, constant_args, kernel_layout,) = _prepare_convolution_fusion_create( + ( + inputs, + constant_args, + kernel_layout, + req_stride_order, + ) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) - other = cls.require_stride1(cls.realize_input(other)) + other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) constant_args = constant_args + [ binary_attr, @@ -3512,17 +3498,6 @@ def create( kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.layout.preferred_stride_order) - self.inputs[0] = x - other = self.inputs[1] - # FixedLayout of other - other = self.require_stride_order(other, self.layout.preferred_stride_order) - self.inputs[1] = other - self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) - class ConvolutionBinaryInplace(ExternKernelAlloc): kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" @@ -3530,14 +3505,12 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): def __init__( self, kernel_layout, - inputs_layout, inputs, constant_args=(), kernel="torch.ops.mkldnn._convolution_pointwise_.binary", ): super().__init__(kernel_layout, inputs, constant_args) self.kernel = kernel - self.inputs_layout = inputs_layout def codegen(self, wrapper): wrapper.writeline( @@ -3566,7 +3539,7 @@ def create( unary_algorithm: Optional[str], ): kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" - (inputs, constant_args, inputs_layout,) = _prepare_convolution_fusion_create( + (inputs, constant_args, _, _) = _prepare_convolution_fusion_create( cls, x, weight, bias, padding_, stride_, dilation_, groups ) other = cls.realize_input(other) @@ -3581,19 +3554,11 @@ def create( ] return ConvolutionBinaryInplace( kernel_layout=MutationLayout(inputs[1]), - inputs_layout=inputs_layout, inputs=inputs, constant_args=constant_args, kernel=kernel, ) - def apply_constraint(self): - x = self.inputs[0] - # FixedLayout of input - x = self.require_stride_order(x, self.inputs_layout.preferred_stride_order) - self.inputs[0] = x - self.freeze_layout_with_stride_order(self.inputs_layout.preferred_stride_order) - class LinearUnary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._linear_pointwise" From f3462833bdd1324d32ad9a78b5f142fb4d75f57c Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 15 Nov 2022 01:01:37 +0000 Subject: [PATCH 153/453] Use same retry logic as macos binary builds (#89014) Occasionally the command to download sccache via curl fails with network errors (example below). The default curl retry option only retries errors that are considered "transient", but but the set of actual transient commands is greater than what curl considers to be transient. This PR modifies the retry logic for downloading sccache to match what's in https://github.com/pytorch/pytorch/blob/master/.github/templates/macos_binary_build_workflow.yml.j2#L79-L89, using the retry action to ensure we both retry all transient errors, and including a longer retry delay to give the transient issue time to resolve itself. Example failure from [this run](https://github.com/pytorch/pytorch/actions/runs/3422664884/jobs/5700595220): ``` Run sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:01 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:02 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:03 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:04 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:05 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:06 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:07 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:08 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:10 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:11 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:12 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:13 --:--:-- 0 0 0 0 0 0 0 0 0 --:--:-- 0:00:14 --:--:-- 0 curl: (35) OpenSSL SSL_connect: Connection reset by peer in connection to s3.amazonaws.com:443 Error: Process completed with exit code 35. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89014 Approved by: https://github.com/huydhn --- .github/workflows/_mac-build.yml | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index 9f0c988f4a31..faf069e7a7c3 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -109,12 +109,17 @@ jobs: brew link --force libomp - name: Install sccache (only for non-forked PRs, and pushes to trunk) + uses: nick-fields/retry@v2.8.2 if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} - run: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache - sudo chmod +x /usr/local/bin/sccache - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" + with: + timeout_minutes: 5 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo chmod +x /usr/local/bin/sccache + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" + echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" - name: Get workflow job id id: get-job-id From 35e668b5ced25e735b6e523d557ed7fd60267914 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Tue, 15 Nov 2022 01:10:35 +0000 Subject: [PATCH 154/453] Add mem efficient backward (#88856) # Registers the derivative for mem efficient backward - Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32 - I also made updates based off of Xformer main branch and flash-attention cutlass branch. - This will enable the fused backward to be called for scaled dot product attention Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 5 + .../native/transformers/cuda/attention.cu | 16 +- .../transformers/cuda/attention_backward.cu | 261 ++++++++++++++++++ .../transformers/cuda/flash_attn/fmha_api.cpp | 4 + .../attention_backward_generic.cu | 166 ----------- .../attention_forward_generic.cu | 232 ---------------- .../cuda/mem_eff_attention/find_default_mma.h | 7 +- .../cuda/mem_eff_attention/kernel_backward.h | 250 +++++++++++------ .../ATen/native/transformers/cuda/sdp_utils.h | 12 +- test/test_transformers.py | 44 ++- tools/autograd/derivatives.yaml | 7 +- .../_internal/common_methods_invocations.py | 4 +- 12 files changed, 501 insertions(+), 507 deletions(-) create mode 100644 aten/src/ATen/native/transformers/cuda/attention_backward.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index de087c0b8a89..9572ccc56653 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13287,6 +13287,11 @@ dispatch: CUDA: _efficient_attention_forward +- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index f65fedd6d795..46543d4663fa 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -746,7 +746,9 @@ std::tuple flash_attention_helper_dense_unpacked( std::tuple mem_eff_helper( const Tensor& query, const Tensor& key, - const Tensor& value){ + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) @@ -754,16 +756,18 @@ std::tuple mem_eff_helper( Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - Tensor attention = std::get<0>(at::_efficient_attention_forward( + Tensor attention, log_sumexp; + std::tie(attention, log_sumexp) = at::_efficient_attention_forward( q_t, k_t, v_t, c10::nullopt, c10::nullopt, c10::nullopt, - false, - false)).transpose(1,2); - return std::make_tuple(attention, Tensor()); + compute_log_sumexp, + is_causal); + attention = attention.transpose(1,2); + return std::make_tuple(std::move(attention), Tensor()); } std::tuple _scaled_dot_product_attention_forward_cuda( @@ -776,7 +780,7 @@ std::tuple _scaled_dot_product_attention_forward_cuda( case sdp::SDPBackend::flash_attention: return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value); + return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); case sdp::SDPBackend::math: return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); default: diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu new file mode 100644 index 000000000000..af005b2669b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -0,0 +1,261 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef USE_FLASH_ATTENTION +#include +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ + } + +#define DISPATCH_MAXK(func) \ + { \ + const auto maxK = std::max(query.size(3), value.size(3)); \ + if (maxK <= 64) { \ + constexpr int kMaxK = 64; \ + func(); \ + } else if (maxK <= 128) { \ + constexpr int kMaxK = 128; \ + func(); \ + } else { \ + constexpr int kMaxK = std::numeric_limits::max(); \ + func(); \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = \ + AttentionBackwardKernel; \ + bool isAligned = \ + (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })); \ + } + +namespace at { + +namespace native { + +std::tuple _efficient_attention_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + const at::Tensor& out, + bool causal) { + #if defined(USE_FLASH_ATTENTION) + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + // ndim + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + TORCH_CHECK(query.dim() == 4); + + // batch size + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // seqlen + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + TORCH_CHECK(query.size(2) == grad_out_.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + TORCH_CHECK(value.size(3) == grad_out_.size(3)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::cuda::CUDAGuard device_guard(query.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t nH = query.size(2); + int64_t K = query.size(3); + + // It does not make sense to use that in practice, + // but let's still make sure we are correct + // As we iterate through keys first, we skip + // keys with no query associated, so they are not + // initialized + bool grad_kv_needs_init = causal && N > M; + at::Tensor grad_q, grad_k, grad_v; + if (!grad_kv_needs_init && query.size(1) == key.size(1) && + query.size(3) == value.size(3) && + query.storage().is_alias_of(key.storage()) && + query.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty_like(query); + grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); + grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + } + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)) + .sum(-1) + .transpose(-2, -1) + .contiguous(); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == nH); + TORCH_INTERNAL_ASSERT(delta.size(2) == M); + + typename Kernel::Params p; + p.query_ptr = (scalar_t*)query.data_ptr(); + p.key_ptr = (scalar_t*)key.data_ptr(); + p.value_ptr = (scalar_t*)value.data_ptr(); + p.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); + p.output_ptr = (scalar_t*)out.data_ptr(); + p.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); + p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + p.delta_ptr = (float*)delta.data_ptr(); + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = query.size(1); + p.num_keys = key.size(1); + p.num_batches = B; + p.num_heads = nH; + p.causal = causal; + + ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); + p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; + TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); + TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); + TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + + Kernel::check_supported(p); + + constexpr auto kernel_fn = attention_kernel_backward_batched; + + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + // second syntax resulted in the error below on windows + // error C3495: 'kernel_fn': a simple capture must be a variable + // with automatic storage duration declared + // in the reaching scope of the lambda +#ifdef _WIN32 + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + TORCH_INTERNAL_ASSERT( + attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, + "Something went wrong in the build process"); +#else + auto checkBinaryArchMatches = [&]() { + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); +#endif + + kernel_fn<<>>(p); + }; + + DISPATCH_KERNEL( + query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); + #endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index a8d6110e951d..6c86e1ff63b0 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -29,6 +29,7 @@ #ifdef USE_FLASH_ATTENTION #include #include +#include #include #include @@ -185,6 +186,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.get_device()}; + auto opts = q.options(); auto o = at::empty({ total_q, num_heads, head_size }, opts); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu deleted file mode 100644 index 07c14ad8195d..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu +++ /dev/null @@ -1,166 +0,0 @@ -#include - -#define DISPATCH_MAXK(func) \ - { \ - const auto maxK = std::max(query.size(2), value.size(2)); \ - if (maxK <= 64) { \ - constexpr int kMaxK = 64; \ - func(); \ - } else if (maxK <= 128) { \ - constexpr int kMaxK = 128; \ - func(); \ - } else { \ - constexpr int kMaxK = std::numeric_limits::max(); \ - func(); \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(1) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ - } - -namespace { -std::tuple -mem_efficient_attention_backward_cutlass( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& logsumexp, - const at::Tensor& out, - bool causal) { - TORCH_CHECK(query.dim() == grad_out_.dim()); - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == 3); - - TORCH_CHECK(query.size(0) == grad_out_.size(0)); - TORCH_CHECK(query.size(1) == grad_out_.size(1)); - TORCH_CHECK(value.size(2) == grad_out_.size(2)); - - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - - // handle potentially non-contiguous grad_out through a copy - auto grad_out = grad_out_.contiguous(); - - CHECK_NOSPARSE_CONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(value); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); - - at::cuda::CUDAGuard device_guard(query.device()); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t K = query.size(2); - - // It does not make sense to use that in practice, - // but let's still make sure we are correct - // As we iterate through keys first, we skip - // keys with no query associated, so they are not - // initialized - bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q = at::empty_like(query); - at::Tensor grad_k = - grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - at::Tensor grad_v = - grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - - // TODO: Fuse this into a kernel? - // This is a bottleneck for smaller sequences (M <= 128) - auto delta = Kernel::kKernelComputesDelta - ? at::empty({B, M}, query.options().dtype(at::ScalarType::Float)) - : (grad_out.to(at::kFloat) * out.to(at::kFloat)).sum(-1); - TORCH_INTERNAL_ASSERT(delta.size(0) == B); - TORCH_INTERNAL_ASSERT(delta.size(1) == M); - - typename Kernel::Params params; - params.query_ptr = (scalar_t*)query.data_ptr(); - params.key_ptr = (scalar_t*)key.data_ptr(); - params.value_ptr = (scalar_t*)value.data_ptr(); - params.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); - params.output_ptr = (scalar_t*)out.data_ptr(); - params.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); - params.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); - params.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); - params.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); - params.delta_ptr = (float*)delta.data_ptr(); - params.head_dim = query.size(2); - params.head_dim_value = value.size(2); - params.num_queries = query.size(1); - params.num_keys = key.size(1); - params.num_batches = B; - params.causal = causal; - Kernel::check_supported(params); - - constexpr auto kernel_fn = attention_kernel_backward_batched; - - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - } - - auto checkBinaryArchMatches = [&]() { - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; - }; - TORCH_INTERNAL_ASSERT( - checkBinaryArchMatches(), "Something went wrong in the build process"); - - kernel_fn<<>>( - params); - }; - - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); -} // namespace - -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_cutlass"), -// TORCH_FN(mem_efficient_attention_backward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu deleted file mode 100644 index 59b3637c8a43..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - - -#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ - { \ - if (VALUE_HEAD_DIM <= 64) { \ - constexpr bool kIs64x64 = true; \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kIs64x64 = false; \ - if (VALUE_HEAD_DIM <= 128) { \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kSingleValueIteration = false; \ - FN(); \ - } \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_BLOCKSIZE( \ - VALUE.size(-1), ([&]() { \ - static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \ - static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - true, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - /* Run a more efficient kernel (with `isAligned=True`) \ - if memory is correctly aligned*/ \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \ - KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \ - VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \ - /* TODO: Should we warn or log somewhere when we use a \ - less efficient kernel due to wrong alignment? */ \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - kIsAligned, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - FUNC(); \ - })) \ - })) \ - })); \ - })); \ - } - -namespace { -/* - There are 2 modes for using this function. - (Mode BMHK) With all the heads having the same seqlen - (Mode 1MHK) `batch=1` with all tokens across batches concatenated -*/ -std::tuple efficient_attention_forward_cutlass( - const at::Tensor& query, // [b, seqlen, num_heads, K] - const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] - // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the - // position of the first query token for batch $b - const c10::optional& cu_seqlens_q, - // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the - // position of the first key token for batch $b - const c10::optional& cu_seqlens_k, - // (Mode 1MHK only) Maximum sequence length across batches - const c10::optional max_seqlen_q_, - bool compute_logsumexp, - bool causal) { - TORCH_CHECK(query.dim() == 4); - TORCH_CHECK(key.dim() == 4); - TORCH_CHECK(value.dim() == 4); - - // Batch sizes - TORCH_CHECK(query.size(0) == key.size(0)); - TORCH_CHECK(query.size(0) == value.size(0)); - - // Sequence length - TORCH_CHECK(key.size(1) == value.size(1)); - - // Num heads - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(2) == value.size(2)); - - // Embedding per head - TORCH_CHECK(query.size(3) == key.size(3)); - - int64_t max_seqlen_q, max_seqlen_k; - TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); - if (cu_seqlens_q.has_value()) { - TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); - TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); - TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); - TORCH_CHECK(max_seqlen_q_.has_value()); - max_seqlen_q = *max_seqlen_q_; - max_seqlen_k = 0; // Will be set inside the kernel - } else { - max_seqlen_q = query.size(1); - max_seqlen_k = key.size(1); - } - - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - - at::cuda::CUDAGuard device_guard(query.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t num_heads = query.size(-2); - int64_t K = query.size(-1); - int64_t Kv = value.size(-1); - - at::Tensor res; - at::Tensor logsumexp; - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - res = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - - // NOTE: Should be aligned (by padding) in case M is - // not a good number for loading during backward - constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; - logsumexp = at::empty( - {B, - num_heads, - compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, - query.options().dtype(at::ScalarType::Float)); - - typename Kernel::Params p; - p.query_ptr = (scalar_t*)query.data_ptr(); - p.key_ptr = (scalar_t*)key.data_ptr(); - p.value_ptr = (scalar_t*)value.data_ptr(); - p.logsumexp_ptr = compute_logsumexp - ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() - : nullptr; - at::Tensor output_accum; - if (Kernel::kNeedsOutputAccumulatorBuffer) { - output_accum = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - p.output_accum_ptr = - (typename Kernel::output_accum_t*)output_accum.data_ptr(); - } else { - p.output_accum_ptr = nullptr; - } - p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); - - if (cu_seqlens_q.has_value()) { - p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); - p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); - } - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ - } - - p.num_heads = num_heads; - p.head_dim = query.size(3); - p.head_dim_value = value.size(3); - p.num_queries = max_seqlen_q; - p.num_keys = max_seqlen_k; - p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; - p.causal = causal; - - ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); - - constexpr auto kernel_fn = attention_kernel_batched; - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - AT_CUDA_CHECK(cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - } - Kernel::check_supported(p); - kernel_fn<<>>(p); - }; - // Dispatch to the right kernel - DISPATCH_KERNEL(query, key, value, ([&]() { - launchKernel(Kernel{}, computeCapability); - })); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); -} -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_cutlass"), -// TORCH_FN(efficient_attention_forward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h index 399593fd0957..b0e7106f3cfc 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h @@ -1,15 +1,16 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instanciate to run a GEMM with various parameters (see + datastructures to instantiate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instanciation priority rules, it will only create an MmaMultiStage with + instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and - `default_mma_core.h` files and wrapped this template to allow our usecase. + `default_mma_core.h` files and wrapped this template to allow our use case. This is really only for the FastF32 case - aka using TensorCores with fp32. */ +#pragma once #include #include diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index e25701a7588a..c9652c40d38e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1,7 +1,5 @@ #pragma once - #include -#include #include #include @@ -75,46 +73,113 @@ struct AttentionBackwardKernel { struct Params { // Input tensors - scalar_t* query_ptr; // [num_queries, head_dim] - scalar_t* key_ptr; // [num_keys, head_dim] - scalar_t* value_ptr; // [num_keys, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_queries] - scalar_t* output_ptr; // [num_queries, head_dim_value] - scalar_t* grad_output_ptr; // [num_queries, head_dim_value] - accum_t* delta_ptr; // [num_queries] + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [Mq, nH] // Output tensors - scalar_t* grad_query_ptr; // [num_queries, head_dim] - scalar_t* grad_key_ptr; // [num_keys, head_dim] - scalar_t* grad_value_ptr; // [num_keys, head_dim_value] + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] // Dimensions/strides int32_t head_dim; int32_t head_dim_value; int32_t num_queries; int32_t num_keys; - int32_t num_batches; + int32_t num_heads; bool causal; - __device__ void advance_batches(int32_t batch_id) { + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int32_t num_batches; + + int64_t gO_strideB; + int64_t gQ_strideB; + int64_t gK_strideB; + int64_t gV_strideB; + int64_t gO_strideH; + int64_t gQ_strideH; + int64_t gK_strideH; + int64_t gV_strideH; + + CUTLASS_DEVICE void advance_to_block() { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - query_ptr += batch_id * head_dim * num_queries; - key_ptr += batch_id * head_dim * num_keys; - value_ptr += batch_id * head_dim_value * num_keys; - logsumexp_ptr += batch_id * lse_dim; - output_ptr += batch_id * head_dim_value * num_queries; - grad_output_ptr += batch_id * head_dim_value * num_queries; - delta_ptr += batch_id * num_queries; - - grad_query_ptr += batch_id * head_dim * num_queries; - grad_key_ptr += batch_id * head_dim * num_keys; - grad_value_ptr += batch_id * head_dim_value * num_keys; + int32_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += (batch_id * num_heads + head_id) * num_queries; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + + gO_strideM = warp_uniform(gO_strideM); + gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); + q_strideM = warp_uniform(q_strideM); + k_strideM = warp_uniform(k_strideM); + v_strideM = warp_uniform(v_strideM); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); } __host__ dim3 getBlocksGrid() const { - return dim3(1, 1, num_batches); + return dim3(1, num_heads, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); @@ -179,7 +244,6 @@ struct AttentionBackwardKernel { attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue - with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = @@ -225,7 +289,6 @@ struct AttentionBackwardKernel { struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul - Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ @@ -601,7 +664,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - __device__ __forceinline__ void clear() { + CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } @@ -614,14 +677,14 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); TORCH_CHECK( - p.head_dim % kMinimumAlignment == 0, - "query/key is not correctly aligned"); + p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); TORCH_CHECK( - p.head_dim_value % kMinimumAlignment == 0, - "value is not correctly aligned"); + p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); + TORCH_CHECK( + p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); } - static __device__ void kernel(Params& p_) { + static CUTLASS_DEVICE void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot register const Params p = p_; @@ -658,7 +721,7 @@ struct AttentionBackwardKernel { __syncthreads(); } - OutputFragments output_frags; + OutputFragments register output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -695,7 +758,7 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ void loadDi( + static CUTLASS_DEVICE void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -710,7 +773,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void processBlockIJ( + static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -718,9 +781,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int32_t warp_id = threadIdx.y; - int32_t lane_id = threadIdx.x; + int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int8_t warp_id = warp_uniform(threadIdx.y); + int8_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -734,8 +797,8 @@ struct AttentionBackwardKernel { auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -747,8 +810,8 @@ struct AttentionBackwardKernel { }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); @@ -757,8 +820,8 @@ struct AttentionBackwardKernel { }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); @@ -770,14 +833,14 @@ struct AttentionBackwardKernel { }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -803,16 +866,16 @@ struct AttentionBackwardKernel { // k_j typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -893,14 +956,14 @@ struct AttentionBackwardKernel { num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value + col, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -951,16 +1014,16 @@ struct AttentionBackwardKernel { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -1057,16 +1120,16 @@ struct AttentionBackwardKernel { num_keys_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.head_dim}, - p.grad_query_ptr + query_start * p.head_dim + col, + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); }; // k_j typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1153,8 +1216,8 @@ struct AttentionBackwardKernel { num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim + col, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); @@ -1162,8 +1225,8 @@ struct AttentionBackwardKernel { // q_i typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1236,15 +1299,15 @@ struct AttentionBackwardKernel { kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; auto thread_id = get_thread_id(); typename MatmulQK::Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); @@ -1259,7 +1322,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void writeFragsToGmem( + static CUTLASS_DEVICE void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1268,8 +1331,8 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : std::min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, get_thread_id()); accumulateInGmem( @@ -1279,8 +1342,8 @@ struct AttentionBackwardKernel { true); typename MatmulGradK::OutputTileIterator outputK_it( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, get_thread_id()); @@ -1292,7 +1355,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void accumulateInGmem( + static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1334,7 +1397,9 @@ struct AttentionBackwardKernel { } template - static __device__ void computeDelta(Params const& p, int32_t query_start) { + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1349,13 +1414,15 @@ struct AttentionBackwardKernel { bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; - const __restrict__ AccessType* grad_output_ptr = - reinterpret_cast( - p.grad_output_ptr + (query_start + laneRow) * p.head_dim_value + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); - const __restrict__ AccessType* output_ptr = - reinterpret_cast( - p.output_ptr + (query_start + laneRow) * p.head_dim_value + + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); static constexpr int64_t kMaxIters = @@ -1430,13 +1497,13 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ int8_t get_lane_id() { + static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; } - static __device__ __forceinline__ int8_t get_warp_id() { + static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; } - static __device__ __forceinline__ int16_t get_thread_id() { + static CUTLASS_DEVICE int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } }; @@ -1457,8 +1524,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) #define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ _ATTENTION_KERNEL_BACKWARD_BEGIN( \ AttentionBackwardKernel) \ - auto batch_id = blockIdx.z; \ - p.advance_batches(batch_id); \ + p.advance_to_block(); \ Kernel::kernel(p); \ _ATTENTION_KERNEL_BACKWARD_END(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 564adb2d51ea..e9f3d5029aa8 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -62,6 +62,15 @@ inline bool check_for_attn_weights(sdp_params params, bool debug) { } return true; } + +inline bool check_for_non_zero_dropout(sdp_params params, bool debug) { + if (params.dropout != 0.0) { + TORCH_CHECK(!debug, "Mem_efficient does not support non_zero dropout. Dropout_p: ", params.dropout); + return false; + } + return true; +} + inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { if (!params.query.is_nested()) { return true; @@ -230,7 +239,8 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, - check_for_seq_len_1_nested_tensor}; + check_for_seq_len_1_nested_tensor, + check_for_non_zero_dropout}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/test_transformers.py b/test/test_transformers.py index a9d0d960fb9a..c86b89bed5ef 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -21,8 +21,11 @@ TEST_WITH_ROCM, IS_WINDOWS, slowTest, - set_default_dtype + set_default_dtype, + gradcheck ) + +from torch.testing._internal.common_methods_invocations import wrapper_set_seed from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater if TEST_FAIRSEQ: @@ -860,11 +863,22 @@ def rand_tensor(*shape): actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) - # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. - # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. - if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) + if attn_mask_dim is None: + q = q.double().clone() + k = k.double().clone() + v = v.double().clone() + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + + assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') @torch.no_grad() def test_mask_check_fastpath(self): @@ -1079,6 +1093,28 @@ def rand_tensor(shape): self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): + + batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 + query, key, value = torch.rand((batch_size, seq_len, 3 * num_heads * head_dim), + device="cuda", dtype=torch.float32, requires_grad=True).chunk(3, -1) + query = query.view(batch_size, -1, num_heads, head_dim) + key = key.view(batch_size, -1, num_heads, head_dim) + value = value.view(batch_size, -1, num_heads, head_dim) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Normally we would transpose the inputs but the fused kernels expect + # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel + # in fp32 + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), + (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8349a308be35..a0892b32a835 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2591,7 +2591,7 @@ - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor self: grad.reshape_symint(self.sym_sizes()) -# Nested Tensor +# NestedTensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" @@ -2612,6 +2612,11 @@ nested_size: non_differentiable nested_strides: non_differentiable +# Transformers +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) + # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 001fd455e82e..3b43b8fb4863 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11944,8 +11944,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), OpInfo( 'nn.functional._scaled_dot_product_attention', - op=lambda inp, *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, inp, *args, **kwargs), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), sample_inputs_func=sample_inputs_scaled_dot_product_attention, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), From 1f88b208acab2cf974849c9161d24f08486f592c Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 15 Nov 2022 01:25:17 +0000 Subject: [PATCH 155/453] Fix cuda/cpu check on NoneType (Unit test) (#88970) Summary: Fix cuda/cpu check on NoneType (unit test) Test Plan: sabdcastle/ github CI/CD Differential Revision: D41208798 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88970 Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch --- test/test_transformers.py | 9 +++++++++ torch/nn/modules/activation.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index c86b89bed5ef..93a94a5604c9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1168,6 +1168,15 @@ def make_tensor(*size, device=device, dtype=dtype): self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, torch.ones_like(q), 0.0, False, False)) + # Test failing MHA when bias was NoneType + def test_bias_is_none(self): + x = torch.rand((1, 5, 10)) + model = torch.nn.modules.activation.MultiheadAttention(10, 1, bias=False, batch_first=True) + model.eval() + model(x, x, x) + # completes without error + + # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for # cross device / dtype testing. instantiate_parametrized_tests(TestTransformers) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index e6b3b778e5fb..b00da06126a7 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1113,7 +1113,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O why_not_fast_path = "some Tensor argument has_torch_function" elif not all([(x is None or x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]): why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" - elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]): + elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]): why_not_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") if not why_not_fast_path: From 45d2daaf855d4e79f6e09c4d3f85743b955446e6 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 15 Nov 2022 02:32:55 +0000 Subject: [PATCH 156/453] Fix lookup file update in dashboard (#89024) Lookup file should be updated before graphs are generated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89024 Approved by: https://github.com/mlazos, https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index d27763c41b0b..319ff677db4f 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -866,7 +866,7 @@ def generate_comment(self): title = "## Accuracy Regressions ##\n" body = ( "For each relevant compiler, we compare the most recent 2 reports " - "(that run actually the compiler) to find models where previously " + "(that actually run the compiler) to find models where previously " "successful accuracy tests now fail.\n\n" ) dtype = self.args.dtypes[0] @@ -1031,29 +1031,35 @@ def __init__(self, args): self.output_dir = args.output_dir self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") assert os.path.exists(self.lookup_file) + try: + self.update_lookup_file() + except subprocess.CalledProcessError: + print("failed to update lookup file") - def archive(self): + def update_lookup_file(self): dtype = self.args.dtypes[0] - # Copy the folder to archived location - archive( - self.output_dir, - self.args.dashboard_archive_path, - self.args.archive_name, - dtype, - ) day, _ = archive_data(self.args.archive_name) target_dir = ( default_archive_name(dtype) if self.args.archive_name is None else self.args.archive_name ) - # Update lookup csv the folder to arhived logs subprocess.check_call( f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}', shell=True, ) + def archive(self): + dtype = self.args.dtypes[0] + # Copy the folder to archived location + archive( + self.output_dir, + self.args.dashboard_archive_path, + self.args.archive_name, + dtype, + ) + def upload_graphs(self): title = "## Performance graphs ##\n" str_io = io.StringIO() From cbdb683dc843f2d50617ad962d5e57501e5154d4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 14 Nov 2022 16:51:32 -0500 Subject: [PATCH 157/453] Add test that bias gradient is properly tested in same_two_models (#88995) See https://github.com/pytorch/pytorch/pull/88629#issuecomment-1313850324 for why this got broken. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/88995 Approved by: https://github.com/albanD --- test/dynamo/test_repros.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index fd0fcf9e08bc..503231b4cb12 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1769,6 +1769,21 @@ def forward(self, getitem_1, getitem_2, add): ] self.assertTrue(same_two_models(mod, opt_mod, args)) + def test_optimized_deepcopy(self): + # See https://github.com/pytorch/pytorch/pull/88629 + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True) + + def forward(self, x): + return self.fc(x) + + mod = Foo() + opt_mod = torch._dynamo.optimize("eager")(mod) + args = [torch.randn(1, 2)] + self.assertTrue(same_two_models(mod, opt_mod, args)) + def test_class_member(self): class Foo(torch.nn.Module): a = 4 From 55b88cde0ab0e5457422777971af845842b2689b Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Tue, 15 Nov 2022 03:10:36 +0000 Subject: [PATCH 158/453] [Inductor] Build Shape Padding in Inductor (#88709) Summary: Build shape padding for matmul/bmm/addmm in Inductor Differential Revision: D41071282 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88709 Approved by: https://github.com/bertmaher, https://github.com/Chillee --- torch/_inductor/config.py | 3 + torch/_inductor/decomposition.py | 149 ++++++++++++++++++++++++++++++- torch/_inductor/utils.py | 77 ++++++++++++++++ 3 files changed, 226 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 8f9f2c4f461d..d376fe3e8bf7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -71,6 +71,9 @@ # How to import torchdynamo, either torchdynamo or torch.dynamo dynamo_import = inductor_import.replace("inductor", "dynamo") +# Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs +shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1" +alignment_size = 4 # config specific to codegen/cpp.pp class cpp: diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index e8a20c0dbd26..0b29dd524cb7 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -8,8 +8,9 @@ from torch import Tensor from torch._decomp import get_decompositions from torch._prims_common import is_boolean_dtype, is_integer_dtype +from torch.utils._mode_utils import no_dispatch -from . import config +from . import config, utils log = logging.getLogger(__name__) aten = torch.ops.aten @@ -135,6 +136,26 @@ def floordiv(a, b): return aten.div.Tensor_mode(a, b, rounding_mode="floor") +def get_padded_length(x): + if x % config.alignment_size == 0: + return 0 + return int((x // config.alignment_size + 1) * config.alignment_size) - x + + +def pad_dim(x, padded_length, dim): + pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :]) + return torch.cat([x, pad], dim=dim) + + +def check_device_dtype(a: Tensor, b: Tensor): + return ( + a.is_cuda + and b.is_cuda + and a.dtype == torch.float32 + and b.dtype == torch.float32 + ) + + @register_decomposition([aten.addmm]) def addmm(input, mat1, mat2, *, beta=1, alpha=1): if config.triton.mm != "aten": @@ -144,8 +165,130 @@ def addmm(input, mat1, mat2, *, beta=1, alpha=1): if not isinstance(beta, numbers.Number) or beta != 1: input = input * beta return input + out - else: - return NotImplemented # go directly to lowering + + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input) + ): + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + + if input is not None and k_padded_length == 0: + if m_padded_length != 0 and input.dim() == 2: + input = pad_dim(input, m_padded_length, 0) + elif n_padded_length != 0: + if input.dim() == 2: + input = pad_dim(input, n_padded_length, 1) + elif input.dim() == 1: + input = pad_dim(input, n_padded_length, 0) + + if k_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha) + elif m_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[ + :-m_padded_length, : + ] + elif n_padded_length != 0: + return torch.ops.aten.addmm(input, mat1, mat2, beta=beta, alpha=alpha)[ + :, :-n_padded_length + ] + + return NotImplemented # go directly to lowering + + +def should_pad_bench(mat1, mat2, op, input=None): + with no_dispatch(): + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + elif op is torch.ops.aten.bmm: + m_padded_length = get_padded_length(mat1.shape[1]) + k_padded_length = get_padded_length(mat1.shape[2]) + n_padded_length = get_padded_length(mat2.shape[2]) + else: + return False + + if m_padded_length == k_padded_length == n_padded_length == 0: + return False + + mat1 = torch.randn_like(mat1) + mat2 = torch.randn_like(mat2) + warmup = 5 + rep = 100 + if op is torch.ops.aten.bmm or op is torch.ops.aten.mm: + ori_time = utils.do_bench( + lambda: op(mat1, mat2), warmup=warmup, rep=rep, fast_flush=True + )[0] + else: + if input is not None: + input = torch.randn_like(input) + ori_time = utils.do_bench( + lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True + )[0] + + mat1_pad = mat1.new_empty([get_padded_length(i) + i for i in mat1.shape]) + mat2_pad = mat2.new_empty([get_padded_length(i) + i for i in mat2.shape]) + if op is torch.ops.aten.addmm: + input_pad = None + if input is not None and input.is_cuda and input.dtype == torch.float32: + input_pad = input.new_empty( + [get_padded_length(i) + i for i in input.shape] + ) + pad_time = utils.do_bench( + lambda: op(input_pad, mat1_pad, mat2_pad), + warmup=warmup, + rep=rep, + fast_flush=True, + )[0] + else: + pad_time = utils.do_bench( + lambda: op(mat1_pad, mat2_pad), warmup=warmup, rep=rep, fast_flush=True + )[0] + + # Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.3x for + # aten.mm and aten.addmm and 2x for aten.bmm represent a reasonable tradeoff between + # performance improvement from shape padding and overhead from addtional memory ops + # TODO: Build a learned model which would be better than this heuristic + if op is torch.ops.aten.mm or op is torch.ops.aten.addmm: + return ori_time > pad_time * 1.3 + else: + return ori_time > pad_time * 2 + + +@register_decomposition([aten.bmm]) +def bmm_decomp(mat1, mat2): + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.bmm) + ): + m_padded_length = get_padded_length(mat1.shape[1]) + k_padded_length = get_padded_length(mat1.shape[2]) + n_padded_length = get_padded_length(mat2.shape[2]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 2) + mat2 = pad_dim(mat2, k_padded_length, 1) + return torch.ops.aten.bmm(mat1, mat2) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 1) + return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous() + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 2) + return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous() + + return NotImplemented # go directly to lowering @register_decomposition([aten.rsqrt]) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 5bfda50dd6f7..08e95b9b5cc3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -57,6 +57,83 @@ def conditional_product(*args): return functools.reduce(operator.mul, [x for x in args if x]) +def do_bench( + fn, + warmup=25, + rep=100, + grad_to_none=None, + percentiles=(0.5, 0.2, 0.8), + record_clocks=False, + fast_flush=False, +): + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param percentiles: Performance percentile to return in addition to the median. + :type percentiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + + # Estimate the runtime of the function + fn() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) + if percentiles: + percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist() + return tuple(percentiles) + else: + return torch.mean(times).item() + + def sympy_product(it): return functools.reduce(operator.mul, it, sympy.Integer(1)) From ce8a45c282c68abbf37f7af99d4bd7cb53fa020d Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Nov 2022 03:32:00 +0000 Subject: [PATCH 159/453] [vision hash update] update the pinned vision hash (#89026) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89026 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index b9eda365de0c..c9bfe60001af 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -deba056203d009fec6b58afb9fa211f6ee3328c8 +b1f6c9e271368cd84837522af39e68dd4b5768a7 From dd6beca854be6cc0619d0b0693bc2fc558636217 Mon Sep 17 00:00:00 2001 From: Everton Constantino Date: Tue, 15 Nov 2022 04:10:49 +0000 Subject: [PATCH 160/453] Changing the use from ASSERT_EQ to ASSERT_FLOAT_EQ on nn_utils test. (#83693) Changing the use from ASSERT_EQ to ASSERT_FLOAT_EQ on nn_utils.cpp:ClipGradNorm as this is the proper way to compare equality between floating point values. This avoids `test_api` ClipGradNorm failing for WoA. Pull Request resolved: https://github.com/pytorch/pytorch/pull/83693 Approved by: https://github.com/ngimel, https://github.com/kit1980 --- test/cpp/api/nn_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp index 3d24749a9653..76aab44ac290 100644 --- a/test/cpp/api/nn_utils.cpp +++ b/test/cpp/api/nn_utils.cpp @@ -92,7 +92,7 @@ TEST_F(NNUtilsTest, ClipGradNorm) { ASSERT_LE(norm_after, max_norm); auto scaled = compare_scaling(grads); ASSERT_NEAR(0, scaled.std().item().toFloat(), 1e-7); - ASSERT_EQ(scaled[0].item().toFloat(), 1); + ASSERT_FLOAT_EQ(scaled[0].item().toFloat(), 1); } // should accept a single tensor as input auto p1 = torch::randn({10, 10}); From 73d71ae3d62607f2e480af37c470375ea405eb1c Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 15 Nov 2022 00:21:52 +0000 Subject: [PATCH 161/453] [WIP] Unwrap View in Reinterpret View (#89016) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89016 Approved by: https://github.com/ngimel --- torch/_inductor/ir.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index fdc10c9ca16a..8327fe0d7b52 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -998,11 +998,11 @@ def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=No x.data.decide_layout() return x, x.data.layout if isinstance(x, ReinterpretView): + # making the base of x contiguous or stride_ordered will not necessarily make + # the ReinterpretedView either, so dont pass along those arguments buffer, _ = as_storage_and_layout( x.data, freeze=freeze, - want_contiguous=want_contiguous, - stride_order=stride_order, ) return buffer, x.layout raise NotImplementedError @@ -1402,6 +1402,10 @@ class ReinterpretView(BaseView): layout: "Layout" + def __post_init__(self): + if isinstance(self.data, BaseView): + self.data = self.data.unwrap_view() + def __str__(self): return self.str_helper( [ From 21dd311077d00ff5c3f930295ddc8cf915a262d7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 15 Nov 2022 05:08:26 +0000 Subject: [PATCH 162/453] Add a mode to rerun all disabled tests (without running anything else) (#88646) Rerun all disabled test to gather their latest result so that we can close disabled tickets automatically. When running under this mode (RERUN_DISABLED_TESTS=true), only disabled tests are run while the rest are skipped `` The logic is roughly as follows, the test runs multiple times (n=50) * If the disabled test passes, and it's flaky, do nothing because it's still flaky. In the test report, we'll see the test passes with the following skipped message: ``` ``` * If the disabled test passes every single time, and it is not flaky anymore, mark it so that it can be closed later. We will see the test runs and passes, i.e. ``` ``` * If the disabled test fails after all retries, this is also expected. So only report this but don't fail the job (because we don't care about red signals here), we'll see the test is skipped (without the `flaky` field), i.e. ``` ``` This runs at the same schedule as `mem_leak_check` (daily). The change to update test stats, and (potentially) grouping on HUD will come in separated PRs. ### Testing * pull https://github.com/pytorch/pytorch/actions/runs/3447434434 * trunk https://github.com/pytorch/pytorch/actions/runs/3447434928 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88646 Approved by: https://github.com/clee2000 --- .github/scripts/filter_test_configs.py | 27 +++++++- .github/scripts/test_filter_test_configs.py | 30 ++++++++- .github/workflows/_linux-test.yml | 4 +- .github/workflows/_mac-test.yml | 3 +- .github/workflows/_rocm-test.yml | 4 +- .github/workflows/_win-test.yml | 3 +- test/run_test.py | 5 +- test/test_dataloader.py | 3 + test/test_indexing.py | 7 +- torch/testing/_internal/common_utils.py | 72 +++++++++++++++++++-- 10 files changed, 143 insertions(+), 15 deletions(-) diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 06c8f90441eb..bb5314434e07 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -34,6 +34,13 @@ "xla", }} +# Supported modes when running periodically +SUPPORTED_PERIODICAL_MODES = { + "mem_leak_check", + "rerun_disabled_tests", +} + + def parse_args() -> Any: from argparse import ArgumentParser parser = ArgumentParser("Filter all test configurations and keep only requested ones") @@ -109,6 +116,23 @@ def filter(test_matrix: Dict[str, List[Any]], labels: Set[str]) -> Dict[str, Lis return filtered_test_matrix +def set_periodic_modes(test_matrix: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + """ + Apply all periodic modes when running under a schedule + """ + scheduled_test_matrix: Dict[str, List[Any]] = { + "include": [], + } + + for config in test_matrix.get("include", []): + for mode in SUPPORTED_PERIODICAL_MODES: + cfg = config.copy() + cfg[mode] = mode + scheduled_test_matrix["include"].append(cfg) + + return scheduled_test_matrix + + def set_output(name: str, val: Any) -> None: if os.getenv("GITHUB_OUTPUT"): with open(str(os.getenv("GITHUB_OUTPUT")), "a") as env: @@ -163,8 +187,7 @@ def main() -> None: filtered_test_matrix = test_matrix if args.event_name == "schedule": - for config in filtered_test_matrix.get("include", []): - config["mem_leak_check"] = "mem_leak_check" + filtered_test_matrix = set_periodic_modes(filtered_test_matrix) # Set the filtered test matrix as the output set_output("test-matrix", json.dumps(filtered_test_matrix)) diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index a043a3535543..55410e846c97 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -4,7 +4,14 @@ import yaml import json from unittest import TestCase, main, mock -from filter_test_configs import get_labels, filter, PREFIX, VALID_TEST_CONFIG_LABELS +from filter_test_configs import ( + get_labels, + filter, + set_periodic_modes, + PREFIX, + VALID_TEST_CONFIG_LABELS, + SUPPORTED_PERIODICAL_MODES +) import requests from requests.models import Response from typing import Any, Dict @@ -86,5 +93,26 @@ def test_filter_with_valid_label(self) -> None: self.assertEqual(case["expected"], json.dumps(filtered_test_matrix)) + def test_set_periodic_modes(self) -> None: + testcases = [ + { + "test_matrix": "{include: []}", + "description": "Empty test matrix", + }, + { + "test_matrix": '{include: [{config: "default", runner: "linux"}, {config: "cfg", runner: "macos"}]}', + "descripion": "Replicate each periodic mode in a different config", + }, + ] + + for case in testcases: + test_matrix = yaml.safe_load(case["test_matrix"]) + scheduled_test_matrix = set_periodic_modes(test_matrix) + self.assertEqual( + len(test_matrix["include"]) * len(SUPPORTED_PERIODICAL_MODES), + len(scheduled_test_matrix["include"]) + ) + + if __name__ == '__main__': main() diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index dc1346205e63..6ad30080fd64 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -115,7 +115,8 @@ jobs: DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla - PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0'}} + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} timeout-minutes: 240 run: | set -x @@ -170,6 +171,7 @@ jobs: -e XLA_CUDA \ -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \ + -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --security-opt seccomp=unconfined \ diff --git a/.github/workflows/_mac-test.yml b/.github/workflows/_mac-test.yml index 82dee7b54841..cbc3372e1c42 100644 --- a/.github/workflows/_mac-test.yml +++ b/.github/workflows/_mac-test.yml @@ -129,7 +129,8 @@ jobs: - name: Test id: test env: - PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0'}} + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 0d8ff874ba03..dd1a0830275c 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -97,7 +97,8 @@ jobs: DOCKER_IMAGE: ${{ inputs.docker-image }} XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_JIT_ENABLE_NVFUSER: 1 - PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0'}} + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} timeout-minutes: 270 run: | set -x @@ -148,6 +149,7 @@ jobs: -e SCCACHE_BUCKET \ -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ -e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \ + -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --security-opt seccomp=unconfined \ diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index a0047abbc0f5..0cabb8ec469a 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -124,7 +124,8 @@ jobs: TEST_CONFIG: ${{ matrix.config }} PR_BODY: ${{ github.event.pull_request.body }} TORCH_CUDA_ARCH_LIST: "7.0" - PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0'}} + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} run: | COMMIT_MESSAGES=$(git cherry -v "origin/${GIT_DEFAULT_BRANCH:-master}") diff --git a/test/run_test.py b/test/run_test.py index 59454c6aaa3f..1273ab45c4fb 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -439,8 +439,11 @@ def run_test( if options.pytest: unittest_args = [arg if arg != "-f" else "-x" for arg in unittest_args] elif IS_CI: + ci_args = ["--import-slow-tests", "--import-disabled-tests"] + if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + ci_args.append("--rerun-disabled-tests") # use the downloaded test cases configuration, not supported in pytest - unittest_args.extend(["--import-slow-tests", "--import-disabled-tests"]) + unittest_args.extend(ci_args) # Extra arguments are not supported with pytest executable = get_executable_command( diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 6a7ff90527d3..347f9be73e8b 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -2716,6 +2716,9 @@ def __getitem__(self, index): @unittest.skipIf(IS_WINDOWS, "Needs fork") +@unittest.skipIf( + TEST_WITH_ASAN, + "This test hangs when running with ASAN, see https://github.com/pytorch/pytorch/issues/75492") class TestConvAfterFork(TestCase): # Tests crash reported in https://github.com/pytorch/pytorch/issues/53565 def test_conv_after_fork(self): diff --git a/test/test_indexing.py b/test/test_indexing.py index 1d5f2ea68ac2..5dc23a3d5465 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -11,7 +11,8 @@ import numpy as np from torch.testing import make_tensor -from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing._internal.common_utils import ( + TestCase, run_tests, TEST_WITH_TORCHDYNAMO) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA, onlyNativeDeviceTypes) @@ -737,6 +738,10 @@ def test_byte_mask_accumulate(self, device): self.assertEqual(y, torch.ones(size=(10, 10), device=device)) self.assertEqual(len(w), 2) + @unittest.skipIf( + TEST_WITH_TORCHDYNAMO, + "This test causes SIGKILL when running with dynamo, https://github.com/pytorch/pytorch/issues/88472" + ) def test_index_put_accumulate_large_tensor(self, device): # This test is for tensors with number of elements >= INT_MAX (2^31 - 1). N = (1 << 31) + 5 diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 8f497d515eb5..e0b703046c54 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -107,7 +107,6 @@ RETRY_TEST_CASES = os.getenv('PYTORCH_RETRY_TEST_CASES') == '1' OVERRIDE_FLAKY_SIGNAL = os.getenv('PYTORCH_OVERRIDE_FLAKY_SIGNAL') == '1' DISABLE_RUNNING_SCRIPT_CHK = os.getenv('PYTORCH_DISABLE_RUNNING_SCRIPT_CHK') == '1' -MAX_NUM_RETRIES = 3 DEFAULT_DISABLED_TESTS_FILE = '.pytorch-disabled-tests.json' DEFAULT_SLOW_TESTS_FILE = '.pytorch-slow-tests.json' @@ -506,6 +505,7 @@ def _get_test_report_path(): parser.add_argument('--run-parallel', type=int, default=1) parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SLOW_TESTS_FILE) parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE) +parser.add_argument('--rerun-disabled-tests', action='store_true') # Only run when -h or --help flag is active to display both unittest and parser help messages. def run_unittest_help(argv): @@ -527,6 +527,9 @@ def run_unittest_help(argv): # infer flags based on the default settings GRAPH_EXECUTOR = cppProfilingFlagsToProfilingMode() +RERUN_DISABLED_TESTS = args.rerun_disabled_tests +# Rerun disabled tests many more times to make sure that they are not flaky anymore +MAX_NUM_RETRIES = 3 if not RERUN_DISABLED_TESTS else 50 SLOW_TESTS_FILE = args.import_slow_tests DISABLED_TESTS_FILE = args.import_disabled_tests @@ -1653,6 +1656,9 @@ def check_if_enable(test: unittest.TestCase): raise unittest.SkipTest("test is slow; run with PYTORCH_TEST_WITH_SLOW to enable test") sanitized_test_method_name = remove_device_and_dtype_suffixes(test._testMethodName) if not IS_SANDCASTLE: + should_skip = False + skip_msg = "" + for disabled_test, (issue_url, platforms) in disabled_tests_dict.items(): disable_test_parts = disabled_test.split() if len(disable_test_parts) > 1: @@ -1687,11 +1693,22 @@ def check_if_enable(test: unittest.TestCase): platforms = list(filter(lambda p: p in platform_to_conditional, platforms)) if platforms == [] or any([platform_to_conditional[platform] for platform in platforms]): + should_skip = True skip_msg = f"Test is disabled because an issue exists disabling it: {issue_url}" \ f" for {'all' if platforms == [] else ''}platform(s) {', '.join(platforms)}. " \ "If you're seeing this on your local machine and would like to enable this test, " \ "please make sure CI is not set and you are not using the flag --import-disabled-tests." - raise unittest.SkipTest(skip_msg) + break + + if should_skip and not RERUN_DISABLED_TESTS: + # Skip the disabled test when not running under --rerun-disabled-tests verification mode + raise unittest.SkipTest(skip_msg) + + if not should_skip and RERUN_DISABLED_TESTS: + skip_msg = "Test is enabled but --rerun-disabled-tests verification mode is set, so only" \ + " disabled tests are run" + raise unittest.SkipTest(skip_msg) + if TEST_SKIP_FAST: if not getattr(test, test._testMethodName).__dict__.get('slow_test', False): raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST") @@ -2039,9 +2056,48 @@ def wrap_with_cuda_memory_check(self, method): def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_red=0, num_green=0): using_unittest = isinstance(result, unittest.TestResult) if num_runs_left == 0: + # The logic when RERUN_DISABLED_TESTS is set to true is as follows: + # |-if the disabled test passes: + # |-- if it's flaky: + # |--- Do nothing because it's still flaky + # |-- elif it isn't flaky anymore: + # |--- Close the disabled ticket (later) + # | + # |- elif the disabled test fails after n retries: + # |-- This is expected, report this but don't fail the job + skipped_msg = { + "num_red": num_red, + "num_green": num_green, + "max_num_retries": MAX_NUM_RETRIES, + "rerun_disabled_test": RERUN_DISABLED_TESTS, + } + + traceback_str = "" + if RERUN_DISABLED_TESTS and using_unittest: + # Hide all failures and errors when RERUN_DISABLED_TESTS is enabled. This is + # a verification check, we don't want more red signals coming from it + if result.failures: + _, traceback_str = result.failures.pop(-1) + if result.errors: + _, traceback_str = result.errors.pop(-1) + + if traceback_str: + skipped_msg["traceback_str"] = traceback_str + + if num_green == 0: + # The disabled test fails, report as skipped but don't fail the job + result.addSkip(self, json.dumps(skipped_msg)) + + if num_red == 0: + # The test passes after re-running multiple times. This acts as a signal + # to confirm that it's not flaky anymore + result.addSuccess(self) + if num_green > 0 and num_red > 0 and using_unittest: - result.addSkip(self, f'{{"flaky": {True}, "num_red": {num_red}, "num_green": {num_green},' + - f'"max_num_retries": {MAX_NUM_RETRIES}}}') + skipped_msg["flaky"] = True + # Still flaky, do nothing + result.addSkip(self, json.dumps(skipped_msg)) + return if using_unittest: @@ -2100,9 +2156,13 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re result.addExpectedFailure(self, err) self._run_with_retry(result=result, num_runs_left=num_retries_left, report_only=report_only, num_red=num_red + 1, num_green=num_green) - elif report_only and num_retries_left < MAX_NUM_RETRIES: + elif (RERUN_DISABLED_TESTS or report_only) and num_retries_left < MAX_NUM_RETRIES: + # Always re-run up to MAX_NUM_RETRIES when running under report only or rerun disabled tests modes print(f" {self._testMethodName} succeeded - num_retries_left: {num_retries_left}") - result.addUnexpectedSuccess(self) + if RERUN_DISABLED_TESTS: + result.addSuccess(self) + else: + result.addUnexpectedSuccess(self) self._run_with_retry(result=result, num_runs_left=num_retries_left, report_only=report_only, num_red=num_red, num_green=num_green + 1) elif not report_only and num_retries_left < MAX_NUM_RETRIES: From 68fd8f37063f0011f1c0589e8f38f7606e3f6748 Mon Sep 17 00:00:00 2001 From: Iris Date: Tue, 15 Nov 2022 06:13:15 +0000 Subject: [PATCH 163/453] [BE] [c10d][send] Improve error message on dist.send() with destination rank as itself (#89004) This improves error msg on dist.send() and add corresponding test in test_c10d_common.py(https://github.com/pytorch/pytorch/blob/master/test/distributed/test_c10d_common.py). Context in issue#83912: https://github.com/pytorch/pytorch/issues/83912 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89004 Approved by: https://github.com/H-Huang --- test/distributed/test_c10d_common.py | 3 +++ torch/distributed/distributed_c10d.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 77ee7487a0af..a43b1343923c 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1427,6 +1427,9 @@ def test_send_recv(self): dist.send(input_tensor, (self.rank + 1) % self.world_size) self.assertEqual(input_tensor, torch.zeros(2, 2) + 1) + with self.assertRaises(ValueError): + dist.send(input_tensor, dist.get_rank()) + # test recv input_tensor = torch.zeros(2, 2) dist.recv(input_tensor, (self.rank + 1) % self.world_size) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 4a132d141e00..33569f5169e5 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1179,12 +1179,19 @@ def send(tensor: torch.Tensor, dst: int, group: Optional[ProcessGroup] = None, t Args: tensor (Tensor): Tensor to send. - dst (int): Destination rank. + dst (int): Destination rank. Destination rank should not be the same + as the rank of the current process. group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. tag (int, optional): Tag to match send with remote recv """ + if get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + _check_single_tensor(tensor, "tensor") if _rank_not_in_group(group): _warn_not_in_group("send") From 5ed90c40f874359aca13f7f50e6d115524937d02 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 15 Nov 2022 06:16:13 +0000 Subject: [PATCH 164/453] enable index_put test (#89019) Per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/89019 Approved by: https://github.com/desertfire --- test/inductor/test_torchinductor_opinfo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 3880b87c082c..4e706bab0ea6 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -168,6 +168,8 @@ def process(device_type): "__getitem__": {b8, f16, f32, f64, i32, i64}, "addr": {f16}, "allclose": {f16, f32, f64}, + "amax": {f16}, + "amin": {f16}, "angle": {f16, f32, f64}, "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, @@ -204,7 +206,6 @@ def process(device_type): "fft.rfft2": {f32, f64}, "fft.rfftn": {f32, f64}, "index_add": {f16}, - "index_put": {f16, f32, f64}, "index_reduce": {f16, f32, f64}, "istft": {f32, f64}, "linalg.eig": {f32, f64}, @@ -311,7 +312,6 @@ def process(device_type): "fft.rfft": {f16, f32, f64}, "fft.rfft2": {f16, f32, f64}, "fft.rfftn": {f16, f32, f64}, - "index_put": {f16, f32, f64}, "index_reduce": {f16, f32, f64}, "istft": {f32, f64}, "linalg.eig": {f32, f64}, @@ -441,13 +441,15 @@ def wrapper_set_seed(op, *args, **kwargs): inductor_all_samples = { "softmax.with_dtype", "index_add", - "index_put", "index_copy", "scatter_reduce.sum", "select_scatter", "squeeze", "unsqueeze", "sum", + "amax", + "amin", + "all", } @@ -549,7 +551,6 @@ def fn(*args, **kwargs): "check_gradient": requires_grad, } adjusted_kwargs.update(overridden_kwargs) - self.check_model_cuda( fn, args, From 60e59c075561068c7d1fe9e9fc40a2df3cd2d2d7 Mon Sep 17 00:00:00 2001 From: peterjc123 Date: Tue, 15 Nov 2022 06:36:24 +0000 Subject: [PATCH 165/453] Fix get_default_qat_qconfig for PT 1.13 (#88876) See https://github.com/pytorch/pytorch/pull/84329/files#r1019916766 for more context Pull Request resolved: https://github.com/pytorch/pytorch/pull/88876 Approved by: https://github.com/jgong5, https://github.com/vkuzo --- test/quantization/core/test_top_level_apis.py | 32 +++++++++++++++++++ torch/ao/quantization/qconfig.py | 2 +- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/test/quantization/core/test_top_level_apis.py b/test/quantization/core/test_top_level_apis.py index 7343a16040d2..f76db1cd4139 100644 --- a/test/quantization/core/test_top_level_apis.py +++ b/test/quantization/core/test_top_level_apis.py @@ -59,3 +59,35 @@ def test_fake_quants(self) -> None: for observer in self.fake_quants: obs = self._get_observer_ins(observer) obs.forward(t) + + +class TestQConfig(TestCase): + + REDUCE_RANGE_DICT = { + 'fbgemm': (True, False), + 'qnnpack': (False, False), + 'onednn': (False, False), + 'x86': (True, False), + } + + def test_reduce_range_qat(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(2): + qconfig = torch.ao.quantization.get_default_qat_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.activation_post_process.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.activation_post_process.reduce_range, reduce_ranges[1]) + + def test_reduce_range(self) -> None: + for backend, reduce_ranges in self.REDUCE_RANGE_DICT.items(): + for version in range(1): + qconfig = torch.ao.quantization.get_default_qconfig(backend, version) + + fake_quantize_activ = qconfig.activation() + self.assertEqual(fake_quantize_activ.reduce_range, reduce_ranges[0]) + + fake_quantize_weight = qconfig.weight() + self.assertEqual(fake_quantize_weight.reduce_range, reduce_ranges[1]) diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py index b75e16ef044f..f52bf713c6f9 100644 --- a/torch/ao/quantization/qconfig.py +++ b/torch/ao/quantization/qconfig.py @@ -339,7 +339,7 @@ def get_default_qat_qconfig(backend='fbgemm', version=1): quant_min=0, quant_max=255), weight=default_per_channel_weight_fake_quant) - if backend == 'x86': + elif backend == 'x86': qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, From 5314af5383e56376cd62da22ae07681656667e71 Mon Sep 17 00:00:00 2001 From: Wenzhe Xue Date: Tue, 15 Nov 2022 07:29:52 +0000 Subject: [PATCH 166/453] Set correct size of `attr::output_layouts` when the graph has multiple outputs in JIT oneDNN fuser (#88496) Bug: Previously, `initOutputLayouts()` was called after creating a graph and before merging other nodes. It is a vector with one element. So when a graph contains multiple outputs, e.g. using AOTAutograd compile in my case, layout_propagation pass try to access out of range elements in the vector. Then it comes to the second bug in `useOpaqueLayout()`, the out of range checks the index with the updated output size instead of the size of the vector. Then used `[]` to access the element, which is out of range. Fixes the above two issues: 1. check the offset is within range with the size of `attr::output_layouts` vector instead of another variable. This check catches the error now. 2. change the place to initial `attr::output_layouts` after node merging. The graph may change with node merging. Thus we moved the initialization in layout_propagation with the complete graph. Added test time: `Ran 1 test in 0.383s` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88496 Approved by: https://github.com/jgong5, https://github.com/sanchitintel --- test/test_jit_llga_fuser.py | 30 +++++++++++++++++ .../csrc/jit/codegen/onednn/graph_helper.cpp | 33 ++++++++++--------- torch/csrc/jit/codegen/onednn/graph_helper.h | 5 +-- .../jit/codegen/onednn/layout_propagation.cpp | 9 +++++ 4 files changed, 60 insertions(+), 17 deletions(-) diff --git a/test/test_jit_llga_fuser.py b/test/test_jit_llga_fuser.py index 4804a442c1d6..12bd955043b9 100644 --- a/test/test_jit_llga_fuser.py +++ b/test/test_jit_llga_fuser.py @@ -774,6 +774,36 @@ def t3(x, y): self.assertGraphContainsExactly(t_jit_3.graph_for(x, y), LLGA_FUSION_GROUP, 0) +@unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") +@unittest.skip("Enable when integration with dynamo aot_autograd is more stable") +class TestDynamoAOT(JitTestCase): + def test_dynamo_aot_ts_onednn(self): + class Seq(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ) + + def forward(self, x): + return self.layers(x) + + mod = Seq() + + import torch._dynamo + aot_mod = torch._dynamo.optimize("aot_ts", nopython=True)(mod) + + for _ in range(10): + with torch.jit.fuser("fuser3"): + loss = aot_mod(torch.rand([10, 10])).sum() + loss.backward() + + torch._dynamo.reset() + + @unittest.skipIf(IS_AVX512_UNSUPPORTED, "This test fails for BF16 on machines without AVX512.") @unittest.skipIf(LLGA_NOT_ENABLED, "MKL-DNN build is disabled") class TestModel(JitLlgaTestCase): diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.cpp b/torch/csrc/jit/codegen/onednn/graph_helper.cpp index a8a202acf0da..a14dce108dd1 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.cpp +++ b/torch/csrc/jit/codegen/onednn/graph_helper.cpp @@ -505,7 +505,6 @@ Node* LlgaGraphHelper::createSingletonSubgraph(Node* n, AliasDb& aliasDb) { auto group = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( n, prim::oneDNNFusionGroup, aliasDb); opToOwningPartition_.add(group, partitionId); - LlgaNodeWrapper(group).initOutputLayouts(); return group; } @@ -585,25 +584,29 @@ LlgaNodeWrapper::LlgaNodeWrapper(const Node* node) } void LlgaNodeWrapper::setOpaqueLayout(size_t offset) { - TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); + const auto num_output = n->is(attr::output_layouts).size(); + TORCH_CHECK( + offset < num_output, + "Out of range. (Invalid index ", + offset, + " for attr::output_layouts with size ", + num_output, + ")"); auto& layouts = const_cast&>(n->is(attr::output_layouts)); // NOLINT - layouts.at(offset) = 1; + layouts.at(offset) = OPAQUE_LAYOUT; } bool LlgaNodeWrapper::useOpaqueLayout(size_t offset) const { - TORCH_CHECK(offset < n->outputs().size(), "Invalid output offset ", offset); - return n->is(attr::output_layouts)[offset] == 1; -} - -void LlgaNodeWrapper::initOutputLayouts() { - if (n->hasAttribute(attr::output_layouts)) { - return; - } - - // Init all output layouts as undef - std::vector layouts(n->outputs().size(), 0); - n->is_(attr::output_layouts, layouts); + const auto num_output = n->is(attr::output_layouts).size(); + TORCH_CHECK( + offset < num_output, + "Out of range. (Invalid index ", + offset, + " for attr::output_layouts with size ", + num_output, + ")"); + return n->is(attr::output_layouts)[offset] == OPAQUE_LAYOUT; } } // namespace onednn diff --git a/torch/csrc/jit/codegen/onednn/graph_helper.h b/torch/csrc/jit/codegen/onednn/graph_helper.h index 5422a90d9e97..fbb5eaa84aec 100644 --- a/torch/csrc/jit/codegen/onednn/graph_helper.h +++ b/torch/csrc/jit/codegen/onednn/graph_helper.h @@ -10,6 +10,9 @@ namespace jit { namespace fuser { namespace onednn { +#define STRIDED_LAYOUT 0 +#define OPAQUE_LAYOUT 1 + struct OpPartitionMap { void add(uint64_t opId, uint64_t partitionId) { opmap_[opId] = partitionId; @@ -92,8 +95,6 @@ class LlgaNodeWrapper { friend class LlgaGraphHelper; private: - void initOutputLayouts(); - Node* n; }; diff --git a/torch/csrc/jit/codegen/onednn/layout_propagation.cpp b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp index 448e1cf85884..4201282fb083 100644 --- a/torch/csrc/jit/codegen/onednn/layout_propagation.cpp +++ b/torch/csrc/jit/codegen/onednn/layout_propagation.cpp @@ -1,5 +1,6 @@ #include #include +#include namespace torch { namespace jit { @@ -10,6 +11,14 @@ void LayoutPropagation(Node* n) { if (!LlgaGraphHelper::isLlgaSubgraph(n)) return; + // initial attr::output_layouts if undefined + if (!n->hasAttribute(attr::output_layouts)) { + const auto num_output = n->outputs().size(); + GRAPH_DEBUG("Initial output_layouts of size ", num_output); + std::vector layouts(num_output, STRIDED_LAYOUT); + n->is_(attr::output_layouts, layouts); + } + for (auto input : n->inputs()) { auto prev = input->node(); auto offset = input->offset(); From 50c18217a3849c56a0fe5bdb923bd67fa70da31c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 15 Nov 2022 09:37:09 +0000 Subject: [PATCH 167/453] Revert "Add mem efficient backward (#88856)" This reverts commit 35e668b5ced25e735b6e523d557ed7fd60267914. Reverted https://github.com/pytorch/pytorch/pull/88856 on behalf of https://github.com/DanilBaibak due to breaking internal builds --- aten/src/ATen/native/native_functions.yaml | 5 - .../native/transformers/cuda/attention.cu | 16 +- .../transformers/cuda/attention_backward.cu | 261 ------------------ .../transformers/cuda/flash_attn/fmha_api.cpp | 4 - .../attention_backward_generic.cu | 166 +++++++++++ .../attention_forward_generic.cu | 232 ++++++++++++++++ .../cuda/mem_eff_attention/find_default_mma.h | 7 +- .../cuda/mem_eff_attention/kernel_backward.h | 250 ++++++----------- .../ATen/native/transformers/cuda/sdp_utils.h | 12 +- test/test_transformers.py | 44 +-- tools/autograd/derivatives.yaml | 7 +- .../_internal/common_methods_invocations.py | 4 +- 12 files changed, 507 insertions(+), 501 deletions(-) delete mode 100644 aten/src/ATen/native/transformers/cuda/attention_backward.cu create mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu create mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9572ccc56653..de087c0b8a89 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13287,11 +13287,6 @@ dispatch: CUDA: _efficient_attention_forward -- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_backward - - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 46543d4663fa..f65fedd6d795 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -746,9 +746,7 @@ std::tuple flash_attention_helper_dense_unpacked( std::tuple mem_eff_helper( const Tensor& query, const Tensor& key, - const Tensor& value, - bool compute_log_sumexp, - bool is_causal) { + const Tensor& value){ // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) @@ -756,18 +754,16 @@ std::tuple mem_eff_helper( Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - Tensor attention, log_sumexp; - std::tie(attention, log_sumexp) = at::_efficient_attention_forward( + Tensor attention = std::get<0>(at::_efficient_attention_forward( q_t, k_t, v_t, c10::nullopt, c10::nullopt, c10::nullopt, - compute_log_sumexp, - is_causal); - attention = attention.transpose(1,2); - return std::make_tuple(std::move(attention), Tensor()); + false, + false)).transpose(1,2); + return std::make_tuple(attention, Tensor()); } std::tuple _scaled_dot_product_attention_forward_cuda( @@ -780,7 +776,7 @@ std::tuple _scaled_dot_product_attention_forward_cuda( case sdp::SDPBackend::flash_attention: return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); + return mem_eff_helper(query_, key , value); case sdp::SDPBackend::math: return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); default: diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu deleted file mode 100644 index af005b2669b2..000000000000 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ /dev/null @@ -1,261 +0,0 @@ -#include - -#include - -#include -#include - -#include -#include -#include -#include - -#ifdef USE_FLASH_ATTENTION -#include -#endif - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ - } - -#define DISPATCH_MAXK(func) \ - { \ - const auto maxK = std::max(query.size(3), value.size(3)); \ - if (maxK <= 64) { \ - constexpr int kMaxK = 64; \ - func(); \ - } else if (maxK <= 128) { \ - constexpr int kMaxK = 128; \ - func(); \ - } else { \ - constexpr int kMaxK = std::numeric_limits::max(); \ - func(); \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ - } - -namespace at { - -namespace native { - -std::tuple _efficient_attention_backward( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& logsumexp, - const at::Tensor& out, - bool causal) { - #if defined(USE_FLASH_ATTENTION) - if (!grad_out_.defined()) { - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); - } - // ndim - TORCH_CHECK(query.dim() == grad_out_.dim()); - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == value.dim()); - TORCH_CHECK(query.dim() == 4); - - // batch size - TORCH_CHECK(query.size(0) == grad_out_.size(0)); - TORCH_CHECK(query.size(0) == key.size(0)); - TORCH_CHECK(query.size(0) == value.size(0)); - - // seqlen - TORCH_CHECK(key.size(1) == value.size(1)); - TORCH_CHECK(query.size(1) == grad_out_.size(1)); - - // Num heads - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(2) == value.size(2)); - TORCH_CHECK(query.size(2) == grad_out_.size(2)); - - // Embedding per head - TORCH_CHECK(query.size(3) == key.size(3)); - TORCH_CHECK(value.size(3) == grad_out_.size(3)); - - // handle potentially non-contiguous grad_out through a copy - auto grad_out = grad_out_.contiguous(); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); - - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - - at::cuda::CUDAGuard device_guard(query.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t nH = query.size(2); - int64_t K = query.size(3); - - // It does not make sense to use that in practice, - // but let's still make sure we are correct - // As we iterate through keys first, we skip - // keys with no query associated, so they are not - // initialized - bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q, grad_k, grad_v; - if (!grad_kv_needs_init && query.size(1) == key.size(1) && - query.size(3) == value.size(3) && - query.storage().is_alias_of(key.storage()) && - query.storage().is_alias_of(value.storage())) { - // Create one big contiguous chunk - // This is because q, k and v usually come from a single - // output of a linear layer that is chunked. - // Creating the gradients with the right layout saves us - // a `torch.cat` call in the backward pass - at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); - grad_q = chunk.select(2, 0); - grad_k = chunk.select(2, 1); - grad_v = chunk.select(2, 2); - } else { - grad_q = at::empty_like(query); - grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); - } - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - - // TODO: Fuse this into a kernel? - // This is a bottleneck for smaller sequences (M <= 128) - auto delta = Kernel::kKernelComputesDelta - ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) - : (grad_out.to(at::kFloat) * out.to(at::kFloat)) - .sum(-1) - .transpose(-2, -1) - .contiguous(); - TORCH_INTERNAL_ASSERT(delta.size(0) == B); - TORCH_INTERNAL_ASSERT(delta.size(1) == nH); - TORCH_INTERNAL_ASSERT(delta.size(2) == M); - - typename Kernel::Params p; - p.query_ptr = (scalar_t*)query.data_ptr(); - p.key_ptr = (scalar_t*)key.data_ptr(); - p.value_ptr = (scalar_t*)value.data_ptr(); - p.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); - p.output_ptr = (scalar_t*)out.data_ptr(); - p.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); - p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); - p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); - p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); - p.delta_ptr = (float*)delta.data_ptr(); - p.head_dim = query.size(3); - p.head_dim_value = value.size(3); - p.num_queries = query.size(1); - p.num_keys = key.size(1); - p.num_batches = B; - p.num_heads = nH; - p.causal = causal; - - ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); - - ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); - - ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; - TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); - TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); - TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); - - ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); - - Kernel::check_supported(p); - - constexpr auto kernel_fn = attention_kernel_backward_batched; - - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - } - - // second syntax resulted in the error below on windows - // error C3495: 'kernel_fn': a simple capture must be a variable - // with automatic storage duration declared - // in the reaching scope of the lambda -#ifdef _WIN32 - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - TORCH_INTERNAL_ASSERT( - attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, - "Something went wrong in the build process"); -#else - auto checkBinaryArchMatches = [&]() { - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; - }; - TORCH_INTERNAL_ASSERT( - checkBinaryArchMatches(), "Something went wrong in the build process"); -#endif - - kernel_fn<<>>(p); - }; - - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); - #endif - TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); -} - -} // namespace native -} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index 6c86e1ff63b0..a8d6110e951d 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -29,7 +29,6 @@ #ifdef USE_FLASH_ATTENTION #include #include -#include #include #include @@ -186,9 +185,6 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; - // Otherwise the kernel will be launched from cuda:0 device - at::cuda::CUDAGuard device_guard{q.get_device()}; - auto opts = q.options(); auto o = at::empty({ total_q, num_heads, head_size }, opts); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu new file mode 100644 index 000000000000..07c14ad8195d --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu @@ -0,0 +1,166 @@ +#include + +#define DISPATCH_MAXK(func) \ + { \ + const auto maxK = std::max(query.size(2), value.size(2)); \ + if (maxK <= 64) { \ + constexpr int kMaxK = 64; \ + func(); \ + } else if (maxK <= 128) { \ + constexpr int kMaxK = 128; \ + func(); \ + } else { \ + constexpr int kMaxK = std::numeric_limits::max(); \ + func(); \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = \ + AttentionBackwardKernel; \ + bool isAligned = \ + (QUERY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ + KEY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ + VALUE.stride(1) % AlignedAK::kOptimalAlignement == 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })); \ + } + +namespace { +std::tuple +mem_efficient_attention_backward_cutlass( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + const at::Tensor& out, + bool causal) { + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == 3); + + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + TORCH_CHECK(value.size(2) == grad_out_.size(2)); + + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(0) == key.size(0)); + + TORCH_CHECK(query.size(0) == value.size(0)); + TORCH_CHECK(key.size(1) == value.size(1)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + + CHECK_NOSPARSE_CONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(value); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + at::cuda::CUDAGuard device_guard(query.device()); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t K = query.size(2); + + // It does not make sense to use that in practice, + // but let's still make sure we are correct + // As we iterate through keys first, we skip + // keys with no query associated, so they are not + // initialized + bool grad_kv_needs_init = causal && N > M; + at::Tensor grad_q = at::empty_like(query); + at::Tensor grad_k = + grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); + at::Tensor grad_v = + grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)).sum(-1); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == M); + + typename Kernel::Params params; + params.query_ptr = (scalar_t*)query.data_ptr(); + params.key_ptr = (scalar_t*)key.data_ptr(); + params.value_ptr = (scalar_t*)value.data_ptr(); + params.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); + params.output_ptr = (scalar_t*)out.data_ptr(); + params.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); + params.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + params.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + params.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + params.delta_ptr = (float*)delta.data_ptr(); + params.head_dim = query.size(2); + params.head_dim_value = value.size(2); + params.num_queries = query.size(1); + params.num_keys = key.size(1); + params.num_batches = B; + params.causal = causal; + Kernel::check_supported(params); + + constexpr auto kernel_fn = attention_kernel_backward_batched; + + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + auto checkBinaryArchMatches = [&]() { + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); + + kernel_fn<<>>( + params); + }; + + DISPATCH_KERNEL( + query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); +} // namespace + +} // namespace + +// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { +// m.impl( +// TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_cutlass"), +// TORCH_FN(mem_efficient_attention_backward_cutlass)); +// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu new file mode 100644 index 000000000000..59b3637c8a43 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu @@ -0,0 +1,232 @@ +#include + + +#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ + { \ + if (VALUE_HEAD_DIM <= 64) { \ + constexpr bool kIs64x64 = true; \ + constexpr bool kSingleValueIteration = true; \ + FN(); \ + } else { \ + constexpr bool kIs64x64 = false; \ + if (VALUE_HEAD_DIM <= 128) { \ + constexpr bool kSingleValueIteration = true; \ + FN(); \ + } else { \ + constexpr bool kSingleValueIteration = false; \ + FN(); \ + } \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_BLOCKSIZE( \ + VALUE.size(-1), ([&]() { \ + static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \ + static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = AttentionKernel< \ + scalar_t, \ + ArchTag, \ + true, \ + kQueriesPerBlock, \ + kKeysPerBlock, \ + kSingleValueIteration>; \ + /* Run a more efficient kernel (with `isAligned=True`) \ + if memory is correctly aligned*/ \ + bool isAligned = \ + (QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \ + KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \ + VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \ + /* TODO: Should we warn or log somewhere when we use a \ + less efficient kernel due to wrong alignment? */ \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionKernel< \ + scalar_t, \ + ArchTag, \ + kIsAligned, \ + kQueriesPerBlock, \ + kKeysPerBlock, \ + kSingleValueIteration>; \ + FUNC(); \ + })) \ + })) \ + })); \ + })); \ + } + +namespace { +/* + There are 2 modes for using this function. + (Mode BMHK) With all the heads having the same seqlen + (Mode 1MHK) `batch=1` with all tokens across batches concatenated +*/ +std::tuple efficient_attention_forward_cutlass( + const at::Tensor& query, // [b, seqlen, num_heads, K] + const at::Tensor& key, // [b, seqlen, num_heads, K] + const at::Tensor& value, // [b, seqlen, num_heads, Kv] + // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the + // position of the first query token for batch $b + const c10::optional& cu_seqlens_q, + // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the + // position of the first key token for batch $b + const c10::optional& cu_seqlens_k, + // (Mode 1MHK only) Maximum sequence length across batches + const c10::optional max_seqlen_q_, + bool compute_logsumexp, + bool causal) { + TORCH_CHECK(query.dim() == 4); + TORCH_CHECK(key.dim() == 4); + TORCH_CHECK(value.dim() == 4); + + // Batch sizes + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // Sequence length + TORCH_CHECK(key.size(1) == value.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + + int64_t max_seqlen_q, max_seqlen_k; + TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); + if (cu_seqlens_q.has_value()) { + TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); + CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); + TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); + TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); + TORCH_CHECK(max_seqlen_q_.has_value()); + max_seqlen_q = *max_seqlen_q_; + max_seqlen_k = 0; // Will be set inside the kernel + } else { + max_seqlen_q = query.size(1); + max_seqlen_k = key.size(1); + } + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::cuda::CUDAGuard device_guard(query.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t num_heads = query.size(-2); + int64_t K = query.size(-1); + int64_t Kv = value.size(-1); + + at::Tensor res; + at::Tensor logsumexp; + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + res = at::empty( + {B, M, num_heads, Kv}, + query.options().dtype( + TypeTraits::atScalarType())); + + // NOTE: Should be aligned (by padding) in case M is + // not a good number for loading during backward + constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; + logsumexp = at::empty( + {B, + num_heads, + compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, + query.options().dtype(at::ScalarType::Float)); + + typename Kernel::Params p; + p.query_ptr = (scalar_t*)query.data_ptr(); + p.key_ptr = (scalar_t*)key.data_ptr(); + p.value_ptr = (scalar_t*)value.data_ptr(); + p.logsumexp_ptr = compute_logsumexp + ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() + : nullptr; + at::Tensor output_accum; + if (Kernel::kNeedsOutputAccumulatorBuffer) { + output_accum = at::empty( + {B, M, num_heads, Kv}, + query.options().dtype( + TypeTraits::atScalarType())); + p.output_accum_ptr = + (typename Kernel::output_accum_t*)output_accum.data_ptr(); + } else { + p.output_accum_ptr = nullptr; + } + p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); + + if (cu_seqlens_q.has_value()) { + p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); + p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); + } + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ + } + + p.num_heads = num_heads; + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = max_seqlen_q; + p.num_keys = max_seqlen_k; + p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; + p.causal = causal; + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + + constexpr auto kernel_fn = attention_kernel_batched; + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + AT_CUDA_CHECK(cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); + } + Kernel::check_supported(p); + kernel_fn<<>>(p); + }; + // Dispatch to the right kernel + DISPATCH_KERNEL(query, key, value, ([&]() { + launchKernel(Kernel{}, computeCapability); + })); + + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(res, logsumexp); +} +} // namespace + +// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { +// m.impl( +// TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_cutlass"), +// TORCH_FN(efficient_attention_forward_cutlass)); +// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h index b0e7106f3cfc..399593fd0957 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h @@ -1,16 +1,15 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instantiate to run a GEMM with various parameters (see + datastructures to instanciate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instantiation priority rules, it will only create an MmaMultiStage with + instanciation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and - `default_mma_core.h` files and wrapped this template to allow our use case. + `default_mma_core.h` files and wrapped this template to allow our usecase. This is really only for the FastF32 case - aka using TensorCores with fp32. */ -#pragma once #include #include diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index c9652c40d38e..e25701a7588a 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1,5 +1,7 @@ #pragma once + #include +#include #include #include @@ -73,113 +75,46 @@ struct AttentionBackwardKernel { struct Params { // Input tensors - scalar_t* query_ptr; // [Mq, nH, K] - scalar_t* key_ptr; // [Mk, nH, K] - scalar_t* value_ptr; // [Mk, nH, Kv] - lse_scalar_t* logsumexp_ptr; // [nH, Mq] - scalar_t* output_ptr; // [Mq, nH, Kv] - scalar_t* grad_output_ptr; // [Mq, nH, Kv] - accum_t* delta_ptr; // [Mq, nH] + scalar_t* query_ptr; // [num_queries, head_dim] + scalar_t* key_ptr; // [num_keys, head_dim] + scalar_t* value_ptr; // [num_keys, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_queries] + scalar_t* output_ptr; // [num_queries, head_dim_value] + scalar_t* grad_output_ptr; // [num_queries, head_dim_value] + accum_t* delta_ptr; // [num_queries] // Output tensors - output_t* grad_query_ptr; // [Mq, nH, K] - output_t* grad_key_ptr; // [Mk, nH, K] - output_t* grad_value_ptr; // [Mk, nH, Kv] + scalar_t* grad_query_ptr; // [num_queries, head_dim] + scalar_t* grad_key_ptr; // [num_keys, head_dim] + scalar_t* grad_value_ptr; // [num_keys, head_dim_value] // Dimensions/strides int32_t head_dim; int32_t head_dim_value; int32_t num_queries; int32_t num_keys; - int32_t num_heads; - bool causal; - - int32_t q_strideM; - int32_t k_strideM; - int32_t v_strideM; - int32_t gO_strideM; - int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise - - CUTLASS_HOST_DEVICE int32_t o_strideM() const { - return head_dim_value * num_heads; - } - CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim; - } - CUTLASS_HOST_DEVICE int32_t gK_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim; - } - CUTLASS_HOST_DEVICE int32_t gV_strideM() const { - return gQKV_strideM_multiplier * num_heads * head_dim_value; - } - - // Everything below is only used in `advance_to_block` - // and shouldn't use registers - int64_t o_strideH; - int32_t q_strideH; - int32_t k_strideH; - int32_t v_strideH; - int64_t o_strideB; - int64_t q_strideB; - int64_t k_strideB; - int64_t v_strideB; int32_t num_batches; + bool causal; - int64_t gO_strideB; - int64_t gQ_strideB; - int64_t gK_strideB; - int64_t gV_strideB; - int64_t gO_strideH; - int64_t gQ_strideH; - int64_t gK_strideH; - int64_t gV_strideH; - - CUTLASS_DEVICE void advance_to_block() { + __device__ void advance_batches(int32_t batch_id) { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - int32_t batch_id = blockIdx.z; - int32_t head_id = blockIdx.y; - - query_ptr += batch_id * q_strideB + head_id * q_strideH; - key_ptr += batch_id * k_strideB + head_id * k_strideH; - value_ptr += batch_id * v_strideB + head_id * v_strideH; - logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim; - output_ptr += batch_id * o_strideB + head_id * o_strideH; - grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; - delta_ptr += (batch_id * num_heads + head_id) * num_queries; - - grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; - grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; - grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; - - head_dim = warp_uniform(head_dim); - head_dim_value = warp_uniform(head_dim_value); - num_queries = warp_uniform(num_queries); - num_keys = warp_uniform(num_keys); - num_heads = warp_uniform(num_heads); - - gO_strideM = warp_uniform(gO_strideM); - gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); - q_strideM = warp_uniform(q_strideM); - k_strideM = warp_uniform(k_strideM); - v_strideM = warp_uniform(v_strideM); - - query_ptr = warp_uniform(query_ptr); - key_ptr = warp_uniform(key_ptr); - value_ptr = warp_uniform(value_ptr); - logsumexp_ptr = warp_uniform(logsumexp_ptr); - output_ptr = warp_uniform(output_ptr); - grad_output_ptr = warp_uniform(grad_output_ptr); - delta_ptr = warp_uniform(delta_ptr); - - grad_query_ptr = warp_uniform(grad_query_ptr); - grad_key_ptr = warp_uniform(grad_key_ptr); - grad_value_ptr = warp_uniform(grad_value_ptr); + query_ptr += batch_id * head_dim * num_queries; + key_ptr += batch_id * head_dim * num_keys; + value_ptr += batch_id * head_dim_value * num_keys; + logsumexp_ptr += batch_id * lse_dim; + output_ptr += batch_id * head_dim_value * num_queries; + grad_output_ptr += batch_id * head_dim_value * num_queries; + delta_ptr += batch_id * num_queries; + + grad_query_ptr += batch_id * head_dim * num_queries; + grad_key_ptr += batch_id * head_dim * num_keys; + grad_value_ptr += batch_id * head_dim_value * num_keys; } __host__ dim3 getBlocksGrid() const { - return dim3(1, num_heads, num_batches); + return dim3(1, 1, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); @@ -244,6 +179,7 @@ struct AttentionBackwardKernel { attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = @@ -289,6 +225,7 @@ struct AttentionBackwardKernel { struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ @@ -664,7 +601,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - CUTLASS_DEVICE void clear() { + __device__ __forceinline__ void clear() { gradV.clear(); gradK.clear(); } @@ -677,14 +614,14 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); TORCH_CHECK( - p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); + p.head_dim % kMinimumAlignment == 0, + "query/key is not correctly aligned"); TORCH_CHECK( - p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); - TORCH_CHECK( - p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); + p.head_dim_value % kMinimumAlignment == 0, + "value is not correctly aligned"); } - static CUTLASS_DEVICE void kernel(Params& p_) { + static __device__ void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot register const Params p = p_; @@ -721,7 +658,7 @@ struct AttentionBackwardKernel { __syncthreads(); } - OutputFragments register output_frags; + OutputFragments output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -758,7 +695,7 @@ struct AttentionBackwardKernel { } } - static CUTLASS_DEVICE void loadDi( + static __device__ __forceinline__ void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -773,7 +710,7 @@ struct AttentionBackwardKernel { } template - static CUTLASS_DEVICE void processBlockIJ( + static __device__ __forceinline__ void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -781,9 +718,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int8_t warp_id = warp_uniform(threadIdx.y); - int8_t lane_id = threadIdx.x; + int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int32_t warp_id = threadIdx.y; + int32_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -797,8 +734,8 @@ struct AttentionBackwardKernel { auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( - {int32_t(p.gO_strideM)}, - p.grad_output_ptr + query_start * p.gO_strideM + col, + {int32_t(p.head_dim_value)}, + p.grad_output_ptr + query_start * p.head_dim_value + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -810,8 +747,8 @@ struct AttentionBackwardKernel { }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( - {int32_t(p.k_strideM)}, - p.key_ptr + key_start * p.k_strideM + col, + {int32_t(p.head_dim)}, + p.key_ptr + key_start * p.head_dim + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); @@ -820,8 +757,8 @@ struct AttentionBackwardKernel { }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( - {int32_t(p.q_strideM)}, - p.query_ptr + query_start * p.q_strideM + col, + {int32_t(p.head_dim)}, + p.query_ptr + query_start * p.head_dim + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); @@ -833,14 +770,14 @@ struct AttentionBackwardKernel { }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( - {int32_t(p.gO_strideM)}, - p.grad_output_ptr + query_start * p.gO_strideM, + {int32_t(p.head_dim_value)}, + p.grad_output_ptr + query_start * p.head_dim_value, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B( - {int32_t(p.v_strideM)}, - p.value_ptr + key_start * p.v_strideM, + {int32_t(p.head_dim_value)}, + p.value_ptr + key_start * p.head_dim_value, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -866,16 +803,16 @@ struct AttentionBackwardKernel { // k_j typename Mma::IteratorA iterator_A( - {int32_t(p.k_strideM)}, - p.key_ptr + key_start * p.k_strideM, + {int32_t(p.head_dim)}, + p.key_ptr + key_start * p.head_dim, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.q_strideM)}, - p.query_ptr + query_start * p.q_strideM, + {int32_t(p.head_dim)}, + p.query_ptr + query_start * p.head_dim, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -956,14 +893,14 @@ struct AttentionBackwardKernel { num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( - typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, - p.grad_value_ptr + key_start * p.gV_strideM() + col, + typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, + p.grad_value_ptr + key_start * p.head_dim_value + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B( - {int32_t(p.gO_strideM)}, - p.grad_output_ptr + query_start * p.gO_strideM + col, + {int32_t(p.head_dim_value)}, + p.grad_output_ptr + query_start * p.head_dim_value + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -1014,16 +951,16 @@ struct AttentionBackwardKernel { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A( - {int32_t(p.gO_strideM)}, - p.grad_output_ptr + query_start * p.gO_strideM, + {int32_t(p.head_dim_value)}, + p.grad_output_ptr + query_start * p.head_dim_value, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.v_strideM)}, - p.value_ptr + key_start * p.v_strideM, + {int32_t(p.head_dim_value)}, + p.value_ptr + key_start * p.head_dim_value, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -1120,16 +1057,16 @@ struct AttentionBackwardKernel { num_keys_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, - p.grad_query_ptr + query_start * p.gQ_strideM() + col, + typename MatmulGradQ::OutputTileIterator::Params{p.head_dim}, + p.grad_query_ptr + query_start * p.head_dim + col, {problem_size.m(), problem_size.n()}, thread_id); }; // k_j typename Mma::IteratorB iterator_B( - {int32_t(p.k_strideM)}, - p.key_ptr + key_start * p.k_strideM + col, + {int32_t(p.head_dim)}, + p.key_ptr + key_start * p.head_dim + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1216,8 +1153,8 @@ struct AttentionBackwardKernel { num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( - typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, - p.grad_key_ptr + key_start * p.gK_strideM() + col, + typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, + p.grad_key_ptr + key_start * p.head_dim + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); @@ -1225,8 +1162,8 @@ struct AttentionBackwardKernel { // q_i typename Mma::IteratorB iterator_B( - {int32_t(p.q_strideM)}, - p.query_ptr + query_start * p.q_strideM + col, + {int32_t(p.head_dim)}, + p.query_ptr + query_start * p.head_dim + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1299,15 +1236,15 @@ struct AttentionBackwardKernel { kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; auto thread_id = get_thread_id(); typename MatmulQK::Mma::IteratorA iterator_A( - {int32_t(p.k_strideM)}, - p.key_ptr + key_start * p.k_strideM, + {int32_t(p.head_dim)}, + p.key_ptr + key_start * p.head_dim, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B( - {int32_t(p.q_strideM)}, - p.query_ptr + query_start * p.q_strideM, + {int32_t(p.head_dim)}, + p.query_ptr + query_start * p.head_dim, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); @@ -1322,7 +1259,7 @@ struct AttentionBackwardKernel { } template - static CUTLASS_DEVICE void writeFragsToGmem( + static __device__ __forceinline__ void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1331,8 +1268,8 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : std::min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( - typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, - p.grad_value_ptr + key_start * p.gV_strideM(), + typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, + p.grad_value_ptr + key_start * p.head_dim_value, {num_keys_in_block, p.head_dim_value}, get_thread_id()); accumulateInGmem( @@ -1342,8 +1279,8 @@ struct AttentionBackwardKernel { true); typename MatmulGradK::OutputTileIterator outputK_it( - typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, - p.grad_key_ptr + key_start * p.gK_strideM(), + typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, + p.grad_key_ptr + key_start * p.head_dim, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, get_thread_id()); @@ -1355,7 +1292,7 @@ struct AttentionBackwardKernel { } template - static CUTLASS_DEVICE void accumulateInGmem( + static __device__ __forceinline__ void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1397,9 +1334,7 @@ struct AttentionBackwardKernel { } template - static CUTLASS_DEVICE void computeDelta( - Params const& p, - int32_t query_start) { + static __device__ void computeDelta(Params const& p, int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1414,15 +1349,13 @@ struct AttentionBackwardKernel { bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; - // on windows, previous syntax __restrict__ AccessType* - // resulted in error: "restrict" is not allowed - const AccessType* __restrict__ grad_output_ptr = - reinterpret_cast( - p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + const __restrict__ AccessType* grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.head_dim_value + laneFirstCol); - const AccessType* __restrict__ output_ptr = - reinterpret_cast( - p.output_ptr + (query_start + laneRow) * p.o_strideM() + + const __restrict__ AccessType* output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.head_dim_value + laneFirstCol); static constexpr int64_t kMaxIters = @@ -1497,13 +1430,13 @@ struct AttentionBackwardKernel { } } - static CUTLASS_DEVICE int8_t get_lane_id() { + static __device__ __forceinline__ int8_t get_lane_id() { return threadIdx.x; } - static CUTLASS_DEVICE int8_t get_warp_id() { + static __device__ __forceinline__ int8_t get_warp_id() { return threadIdx.y; } - static CUTLASS_DEVICE int16_t get_thread_id() { + static __device__ __forceinline__ int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } }; @@ -1524,7 +1457,8 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) #define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ _ATTENTION_KERNEL_BACKWARD_BEGIN( \ AttentionBackwardKernel) \ - p.advance_to_block(); \ + auto batch_id = blockIdx.z; \ + p.advance_batches(batch_id); \ Kernel::kernel(p); \ _ATTENTION_KERNEL_BACKWARD_END(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index e9f3d5029aa8..564adb2d51ea 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -62,15 +62,6 @@ inline bool check_for_attn_weights(sdp_params params, bool debug) { } return true; } - -inline bool check_for_non_zero_dropout(sdp_params params, bool debug) { - if (params.dropout != 0.0) { - TORCH_CHECK(!debug, "Mem_efficient does not support non_zero dropout. Dropout_p: ", params.dropout); - return false; - } - return true; -} - inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { if (!params.query.is_nested()) { return true; @@ -239,8 +230,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, - check_for_seq_len_1_nested_tensor, - check_for_non_zero_dropout}; + check_for_seq_len_1_nested_tensor}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/test_transformers.py b/test/test_transformers.py index 93a94a5604c9..939d91e7ee87 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -21,11 +21,8 @@ TEST_WITH_ROCM, IS_WINDOWS, slowTest, - set_default_dtype, - gradcheck + set_default_dtype ) - -from torch.testing._internal.common_methods_invocations import wrapper_set_seed from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater if TEST_FAIRSEQ: @@ -863,22 +860,11 @@ def rand_tensor(*shape): actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) + # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. + # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. + if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) - if attn_mask_dim is None: - q = q.double().clone() - k = k.double().clone() - v = v.double().clone() - q.requires_grad_() - k.requires_grad_() - v.requires_grad_() - - assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), - (q, k, v, attn_mask, dropout_p)) - assert gradcheck(lambda *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), - (q, k, v, attn_mask, dropout_p)) - @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') @torch.no_grad() def test_mask_check_fastpath(self): @@ -1093,28 +1079,6 @@ def rand_tensor(shape): self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) - @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") - @parametrize("contiguous_inputs", [True, False]) - def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): - - batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - query, key, value = torch.rand((batch_size, seq_len, 3 * num_heads * head_dim), - device="cuda", dtype=torch.float32, requires_grad=True).chunk(3, -1) - query = query.view(batch_size, -1, num_heads, head_dim) - key = key.view(batch_size, -1, num_heads, head_dim) - value = value.view(batch_size, -1, num_heads, head_dim) - - if contiguous_inputs: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - # Normally we would transpose the inputs but the fused kernels expect - # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel - # in fp32 - assert gradcheck(lambda *args, **kwargs: - wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), - (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a0892b32a835..8349a308be35 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2591,7 +2591,7 @@ - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor self: grad.reshape_symint(self.sym_sizes()) -# NestedTensor +# Nested Tensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" @@ -2612,11 +2612,6 @@ nested_size: non_differentiable nested_strides: non_differentiable -# Transformers -- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) - output_differentiability: [True, False] - query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) - # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3b43b8fb4863..001fd455e82e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11944,8 +11944,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), OpInfo( 'nn.functional._scaled_dot_product_attention', - op=lambda *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + op=lambda inp, *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, inp, *args, **kwargs), sample_inputs_func=sample_inputs_scaled_dot_product_attention, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), From 8c46a5de3a2e72c5ffbb714fa4e2d44fc2e59951 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 14 Nov 2022 20:16:43 -0800 Subject: [PATCH 168/453] Add debug handle to xnnpack schema (#89033) As title, add three things to the schema 1. debug handle for each node 2. file identifier, so we can sanity check we are getting the xnnpack schema flatbuffers file, instead of other random binary 3. extension, so the dumped binary will end up with its own extension like `myschema.xnnpack` (maybe can have a better name) instead of the default extension `.bin` Differential Revision: [D40906970](https://our.internmc.facebook.com/intern/diff/D40906970/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89033 Approved by: https://github.com/mcr229 --- torch/csrc/jit/backends/xnnpack/serialization/schema.fbs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs index 3b4b53debd02..6f72e604d0c4 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs +++ b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs @@ -54,6 +54,8 @@ union ValueUnion { table Node { node:NodeUnion; + // An int which can be linked back to the node in the origin graph + debug_handle:uint; } table Value { From 2452e3f99a072760fc46d3f9025aaa37ca7ea2ab Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 14 Nov 2022 20:16:45 -0800 Subject: [PATCH 169/453] Update xnnpack graph schema to use xnode and xvalue (#89036) There are different nodes definition like [Node in autograd](https://www.internalfb.com/code/fbsource/fbcode/caffe2/torch/csrc/autograd/function.h?lines=108-609&reveal=108-609) and onnxnodes and etc. Understand namespace can be used where nodes from definition are used together, however it's still better to slightly differentiate the name. Differential Revision: [D41002324](https://our.internmc.facebook.com/intern/diff/D41002324/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89036 Approved by: https://github.com/mcr229 --- .../backends/xnnpack/compiler/xnn_compiler.cpp | 16 ++++++++-------- .../backends/xnnpack/serialization/schema.fbs | 16 ++++++++-------- .../xnnpack/serialization/serializer.cpp | 4 ++-- .../backends/xnnpack/serialization/serializer.h | 4 ++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index 3bbff2309904..4147edf90e85 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -46,10 +46,10 @@ XNNExecutor XNNCompiler::compileModel(std::string ser_model) { // a new mapping from the old ids to the newly created ones std::unordered_map remapped_ids; - for (auto value : *flatbuffer_graph->values()) { - switch (value->value_type()) { - case fb_xnnpack::ValueUnion::XNNTensorValue: { - auto tensor_value = value->value_as_XNNTensorValue(); + for (auto value : *flatbuffer_graph->xvalues()) { + switch (value->xvalue_type()) { + case fb_xnnpack::XValueUnion::XNNTensorValue: { + auto tensor_value = value->xvalue_as_XNNTensorValue(); const void* data_ptr = nullptr; auto buffer_idx = tensor_value->constant_buffer_idx(); @@ -85,10 +85,10 @@ XNNExecutor XNNCompiler::compileModel(std::string ser_model) { } } - for (auto node : *flatbuffer_graph->nodes()) { - switch (node->node_type()) { - case fb_xnnpack::NodeUnion::XNNAdd: { - auto graph_node = node->node_as_XNNAdd(); + for (auto node : *flatbuffer_graph->xnodes()) { + switch (node->xnode_type()) { + case fb_xnnpack::XNodeUnion::XNNAdd: { + auto graph_node = node->xnode_as_XNNAdd(); status = xnn_define_add2( subgraph_ptr, output_min, diff --git a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs index 6f72e604d0c4..87ebe20a825a 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs +++ b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs @@ -44,22 +44,22 @@ table XNNTensorValue { id_out:uint; } -union NodeUnion { +union XNodeUnion { XNNAdd, } -union ValueUnion { +union XValueUnion { XNNTensorValue, } -table Node { - node:NodeUnion; +table XNode { + xnode:XNodeUnion; // An int which can be linked back to the node in the origin graph debug_handle:uint; } -table Value { - value:ValueUnion; +table XValue { + xvalue:XValueUnion; } table XNNAdd { @@ -72,8 +72,8 @@ table XNNAdd { table XNNGraph { // Schema version. version:string; - nodes:[Node]; - values:[Value]; + xnodes:[XNode]; + xvalues:[XValue]; // Ids of external inputs input_ids:[uint]; diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp index 306884a89456..df1ccc791781 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp @@ -20,7 +20,7 @@ void XNNSerializer::serializeAddNode( const auto addNode = CreateXNNAdd(_builder, input1_id, input2_id, output_id, flags); const auto flatbufferNode = - CreateNode(_builder, NodeUnion::XNNAdd, addNode.Union()); + CreateXNode(_builder, XNodeUnion::XNNAdd, addNode.Union()); _nodes.push_back(flatbufferNode); } @@ -61,7 +61,7 @@ void XNNSerializer::serializeTensorValue( id_out); const auto flatbufferValue = - CreateValue(_builder, ValueUnion::XNNTensorValue, tensorValue.Union()); + CreateXValue(_builder, XValueUnion::XNNTensorValue, tensorValue.Union()); _values.push_back(flatbufferValue); } diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h index 3d6927f7678b..6d01571d424d 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -61,10 +61,10 @@ class XNNSerializer { flatbuffers_fbsource::FlatBufferBuilder _builder; // Vector of the serialized xnnpack nodes - std::vector> _nodes; + std::vector> _nodes; // Vector of the serialized xnnpack values - std::vector> _values; + std::vector> _values; std::vector> _constantBuffer; std::vector _bufferSizes; From 63e16216d8830b6340816c873b035e1a31ad4636 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 15 Nov 2022 13:21:39 +0000 Subject: [PATCH 170/453] [c10d] Implement `__instancecheck__` for `c10d::ReduceOp` (#88275) Summary: - Customize the metaclass of `torch.distributed.distributed_c10d.ReduceOp` for the sake of custom `__instancecheck__` - Add `copy.copy`, `copy.deepcopy`, and `pickle` support with tests Rel: - #81272 - #84243 - #87191 - #87303 - #87555 Ref: - https://github.com/pybind/pybind11/issues/2696 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88275 Approved by: https://github.com/wanchaol --- test/distributed/test_c10d_common.py | 32 ++++- test/distributed/test_c10d_nccl.py | 18 ++- torch/_C/_distributed_c10d.pyi | 2 + .../distributed/c10d/ProcessGroupNCCL.cpp | 7 +- torch/csrc/distributed/c10d/Types.hpp | 9 +- torch/csrc/distributed/c10d/init.cpp | 117 ++++++++++++++++-- torch/distributed/distributed_c10d.py | 11 -- 7 files changed, 156 insertions(+), 40 deletions(-) diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index a43b1343923c..c03a68228990 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2,6 +2,7 @@ import copy import os +import pickle import sys import tempfile import threading @@ -1657,15 +1658,44 @@ def comm_fn(tensor, group=None): class ReduceOpTest(TestCase): + # Ref: https://github.com/pytorch/pytorch/issues/87191 def test_op_isinstance_of_reduceop(self): for reduce_op in ( c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, ): self.assertTrue(isinstance(reduce_op, c10d.ReduceOp)) - for scale in ([torch.tensor(1.0)], 2.0): + for scale in (torch.tensor(1.0), 2.0): self.assertTrue(isinstance(dist._make_nccl_premul_sum(scale), c10d.ReduceOp)) + # Ref: https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700 + def test_reduceop_copyable(self): + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + self.assertEqual(copy.copy(reduce_op), reduce_op) + self.assertEqual(copy.deepcopy(reduce_op), reduce_op) + self.assertEqual(copy.copy(c10d.ReduceOp(reduce_op)), reduce_op) + self.assertEqual(copy.deepcopy(c10d.ReduceOp(reduce_op)), reduce_op) + + for scale in (torch.tensor(1.0), 2.0): + reduce_op = dist._make_nccl_premul_sum(scale) + self.assertEqual(copy.copy(reduce_op), reduce_op) + self.assertEqual(copy.deepcopy(reduce_op), reduce_op) + + def test_reduceop_pickle(self): + for reduce_op in ( + c10d.ReduceOp.SUM, c10d.ReduceOp.AVG, c10d.ReduceOp.PRODUCT, c10d.ReduceOp.MIN, c10d.ReduceOp.MAX, + c10d.ReduceOp.BAND, c10d.ReduceOp.BOR, c10d.ReduceOp.BXOR, + ): + pickle.loads(pickle.dumps(reduce_op)) + orig = c10d.ReduceOp(reduce_op) + self.assertEqual(pickle.loads(pickle.dumps(orig)), orig) + for scale in (torch.tensor(1.0), 2.0): + reduce_op = dist._make_nccl_premul_sum(scale) + self.assertEqual(pickle.loads(pickle.dumps(reduce_op)), reduce_op) + if __name__ == "__main__": assert ( diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c514ea4ab31f..cdc167bc4d1a 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -348,16 +348,14 @@ def allreduce(tensors, op): # Premul Sum if torch.cuda.nccl.version() >= (2, 11, 1): for dtype in torch.half, torch.float, torch.double: - for factor in (3.0, - (torch.tensor([5.0], device=local_device_id, dtype=dtype),)): + for factor in (3.0, torch.tensor([5.0], device=local_device_id, dtype=dtype)): tensors = [torch.tensor([self.rank + 1]).cuda(local_device_id).to(dtype=dtype)] allreduce(tensors, c10d._make_nccl_premul_sum(factor)) - f = factor if isinstance(factor, float) else factor[0] # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 self.assertEqualIgnoreType( - f * torch.tensor([float(self.world_size * (self.world_size + 1) / 2)], device=local_device_id), + factor * torch.tensor([float(self.world_size * (self.world_size + 1) / 2)], device=local_device_id), tensors[0], ) @@ -435,9 +433,9 @@ def reduce(xs, rootRank, rootTensor, op=None): # Premul sum if torch.cuda.nccl.version() >= (2, 11, 1): - for factor in (3.0, (torch.tensor([5.0], device=local_device_id),)): - if isinstance(factor, tuple): - factor_ref = factor[0].cpu().item() + for factor in (3.0, torch.tensor([5.0], device=local_device_id)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() else: factor_ref = factor float_tensors = [ @@ -933,9 +931,9 @@ def perm(n, k): self.assertEqualIgnoreType(expected, output_tensor) if torch.cuda.nccl.version() >= (2, 11, 1): - for factor in (3.0, (torch.tensor([5.0], device=self.rank),),): - if isinstance(factor, tuple): - factor_ref = factor[0].cpu().item() + for factor in (3.0, torch.tensor([5.0], device=self.rank)): + if isinstance(factor, torch.Tensor): + factor_ref = factor.cpu().item() else: factor_ref = factor output = [t.float() for t in output] diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 56b86bd504bf..f16a8ec362f5 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -91,6 +91,8 @@ class DebugLevel(Enum): class ReduceOp: + def __init__(self, op: "RedOpType"): ... + SUM = ... PRODUCT = ... MIN = ... diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1d788a2c2e0c..387fe5eb4dcc 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -83,11 +83,10 @@ ncclRedOpRAII unpackPreMulSum( const auto* preMulSupplement = reinterpret_cast(reduceOp.supplement_.get()); ncclRedOp_t preMulSum; - bool has_tensor = !preMulSupplement->tensor_factors.empty(); + bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; - T* ptr_factor = has_tensor - ? preMulSupplement->tensor_factors[dev_in_group].data_ptr() - : nullptr; + T* ptr_factor = + has_tensor ? preMulSupplement->tensor_factor.data_ptr() : nullptr; T scalar_factor = T(preMulSupplement->double_factor); ncclRedOpCreatePreMulSum( &preMulSum, diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 64fbc45c6588..be20fcadba64 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -8,6 +8,7 @@ #include #include +#include #include namespace c10d { @@ -21,9 +22,11 @@ struct TORCH_API _SupplementBase : torch::CustomClassHolder { // The point of use in ProcessGroupNCCL knows how to unpack it. struct NCCLPreMulSumSupplement : _SupplementBase { double double_factor{0.0}; - std::vector tensor_factors; + at::Tensor tensor_factor; NCCLPreMulSumSupplement(double f) : double_factor{f} {} - NCCLPreMulSumSupplement(std::vector f) : tensor_factors{std::move(f)} {} + NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { + TORCH_CHECK_EQ(tensor_factor.numel(), 1); + } }; // Other ReduceOps that need different supplementary data can also @@ -60,7 +63,7 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder { } } - // The heap resource supplement_, if it exists, is managed by a shared_ptr, + // The heap resource supplement_, if it exists, is managed by a c10::intrusive_ptr, // so constructors and operator= can be simple ReduceOp(const ReduceOp& other) : op_(other.op_), supplement_(other.supplement_) {} diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 313aabee7cd9..d39fc322d326 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -235,6 +236,61 @@ void _register_builtin_comm_hook( reducer.register_builtin_comm_hook(comm_hook_type); } +// Customize the metaclass of ::c10d::ReduceOp for the backward compatibility. +// https://github.com/pytorch/pytorch/pull/84243 changed ::c10d::ReduceOp to +// struct from enum, sacrificing some of the Python built-in function supports +// such as `isinstance` (see https://github.com/pytorch/pytorch/issues/87191) +// and `copy` (see +// https://github.com/pytorch/pytorch/pull/87303#discussion_r1002879700). Below, +// we define a custom `isinstance` in CPython/pybind11 +// (`reduceopmeta___instancecheck__`) and modify the default metaclass of +// pybind11 (`GetReduceOpMetaclass`) so that +// `isinstance(torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp)` +// returns :obj:`True` as if `ReduceOp` is enum. +// Ref: +// - https://docs.python.org/3/extending/newtypes_tutorial.html +// - https://docs.python.org/3/c-api/typeobj.html?highlight=tp_methods +// - https://github.com/pybind/pybind11/issues/2696 +static PyObject* reduceopmeta___instancecheck__( + PyObject* self, + PyObject* args) { + if (Py_TYPE(self) == Py_TYPE(args)) { + Py_RETURN_TRUE; + } + if (c10::string_view(args->ob_type->tp_name).find("RedOpType") != + c10::string_view::npos) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} +static PyMethodDef reduceopmeta_methods[] = { + {"__instancecheck__", + (PyCFunction)reduceopmeta___instancecheck__, + METH_O, + "Custom `__instancecheck__` for ReduceOp"}, + {NULL, NULL}}; +PyTypeObject* GetReduceOpMetaclass() { + static auto* metaclass = [] { + PyTypeObject* base_metaclass = + pybind11::detail::get_internals().default_metaclass; + PyType_Slot slots[] = { + {Py_tp_base, base_metaclass}, + {Py_tp_methods, reduceopmeta_methods}, + {0}, + }; + PyType_Spec spec = {}; + spec.name = "torch._C._distributed_c10d._ReduceOpMeta"; + spec.basicsize = base_metaclass->tp_basicsize; + spec.flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + spec.slots = slots; + PyTypeObject* metaclass = (PyTypeObject*)PyType_FromSpec(&spec); + if (!metaclass) + throw py::error_already_set(); + return metaclass; + }(); + return metaclass; +} + PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { C10_LOG_API_USAGE_ONCE("c10d.python.import"); @@ -520,7 +576,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO // making `PREMUL_SUM` callable, i.e., allowing for // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol. // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types - py::class_<::c10d::ReduceOp> reduce_op(module, "ReduceOp", R"( + py::class_<::c10d::ReduceOp> reduce_op( + module, "ReduceOp", py::metaclass((PyObject*)GetReduceOpMetaclass()), R"( An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``, ``MIN``, ``MAX``, ``BAND``, ``BOR``, ``BXOR``, and ``PREMUL_SUM``. @@ -562,14 +619,51 @@ This class does not support ``__members__`` property.)"); [](const ::c10d::ReduceOp& self, const ::c10d::ReduceOp& other) { return self == other.op_; }) - .def("__hash__", [](const ::c10d::ReduceOp& self) { - return static_cast(self.op_); - }); - - // note(crcrpar): Deliberately skip - // [`export_values`](https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types) - // here and manually set values in Python side. See note "ReduceOp static - // class attributes to support `isinstance`" + .def( + "__hash__", + [](const ::c10d::ReduceOp& self) { + return static_cast(self.op_); + }) + .def( + "__copy__", + [](const ::c10d::ReduceOp& self) { return ::c10d::ReduceOp(self); }) + .def( + "__deepcopy__", + [](const ::c10d::ReduceOp& self, const py::dict& memo) { + return ::c10d::ReduceOp(self); + }) + .def(py::pickle( + [](const ::c10d::ReduceOp& r) { + // __getstate__ + if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { + return py::make_tuple(r.op_, py::none()); + } + TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp"); + const auto* preMulSupplement = + reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>( + r.supplement_.get()); + if (!preMulSupplement->tensor_factor.defined()) { + return py::make_tuple(r.op_, preMulSupplement->double_factor); + } else { + return py::make_tuple(r.op_, preMulSupplement->tensor_factor); + } + }, + [](const py::tuple t) { + // __setstate__ + TORCH_CHECK(t.size() == 2, "Invalid state"); + const auto op = + static_cast<::c10d::ReduceOp::RedOpType>(t[0].cast()); + if (op != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { + return ::c10d::ReduceOp(op); + } + const auto preMulSupplement_factor = t[1]; + if (py::isinstance(preMulSupplement_factor)) { + return ::c10d::makeNCCLPreMulSum(t[1].cast()); + } else { + return ::c10d::makeNCCLPreMulSum(t[1].cast()); + } + })); + py::enum_<::c10d::ReduceOp::RedOpType>(reduce_op, "RedOpType") .value("SUM", ::c10d::ReduceOp::RedOpType::SUM) .value("AVG", ::c10d::ReduceOp::RedOpType::AVG) @@ -579,7 +673,8 @@ This class does not support ``__members__`` property.)"); .value("BAND", ::c10d::ReduceOp::RedOpType::BAND) .value("BOR", ::c10d::ReduceOp::RedOpType::BOR) .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR) - .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM); + .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM) + .export_values(); // note(crcrpar): This could be removed because users will not pass // `RedOpType` to reduce collective ops Ref: [Implicit @@ -597,7 +692,7 @@ This class does not support ``__members__`` property.)"); py::call_guard()) .def( "_make_nccl_premul_sum", - &::c10d::makeNCCLPreMulSum>, + &::c10d::makeNCCLPreMulSum, py::arg("factor").noconvert(), py::return_value_policy::copy, // seems safest py::call_guard()); diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 33569f5169e5..f46aaaef94ef 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -237,17 +237,6 @@ def register_backend(cls, name, func, extended_api=False): dist_backend = Backend -# NOTE(crcrpar): [ReduceOp static class attributes to support `isinstance`] -# A ReduceOp instance of `PREMUL_SUM` is supposed to be created via `_make_nccl_premul_sum` -# while the other `op`s (meaning RedOpType members) can be directly passed to c10d reduce collectives. -# I changed `ReduceOp` to struct from enum class and introduced RedOpType enum class for PREMUL_SUM, -# which broke an implicit contract of ReduceOp being enum-like with which users apply isinstance to -# `op`, for example, `isinstance(ReduceOp.SUM, ReduceOp)`: https://github.com/pytorch/pytorch/issues/87191 -DENY_LIST = ("PREMUL_SUM", ) -for _red_op_name, _red_op_value in ReduceOp.RedOpType.__members__.items(): - setattr(ReduceOp, _red_op_name, _red_op_value if _red_op_name in DENY_LIST else ReduceOp(_red_op_value)) - - class _reduce_op(object): r""" Deprecated enum-like class for reduction operations: ``SUM``, ``PRODUCT``, From 5faa2792fa3c46f2124d1d1c5f7b6a3865d47d7b Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Tue, 15 Nov 2022 01:06:23 +0000 Subject: [PATCH 171/453] Symintify decomps for split and upsample_bilinear; Fix decomp for _softmax_backward_data and native_dropout_backward (#88761) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88761 Approved by: https://github.com/ezyang --- test/dynamo/test_dynamic_shapes.py | 12 --- test/functorch/test_aotdispatch.py | 4 - test/functorch/test_ops.py | 4 + test/functorch/test_vmap.py | 3 + test/inductor/test_torchinductor_opinfo.py | 1 + test/test_decomp.py | 3 + test/test_proxy_tensor.py | 22 +++-- torch/_decomp/decompositions.py | 98 +++++++++++++++---- .../_internal/common_methods_invocations.py | 94 ++++++++++++++---- 9 files changed, 177 insertions(+), 64 deletions(-) diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 294ea9e54952..f3964a777aa8 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -106,23 +106,11 @@ def make_dynamic_cls(cls): DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes ) -# DynamicShapesReproTests -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_eval_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - unittest.expectedFailure( DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes # Unable to cast Python instance to C++ type ) -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_train_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e31ac58039ec..eb34a3fb7582 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1005,7 +1005,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cholesky_inverse', ''), # could not find kernel xfail('cholesky_solve', ''), # could not find kernel - xfail('chunk', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition @@ -1137,7 +1136,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta... xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'bicubic'), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.interpolate', 'bilinear'), # Cannot call sizes() on tensor with symbolic sizes/str... xfail('nn.functional.interpolate', 'linear'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'nearest'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.interpolate', 'trilinear'), # Cannot call sizes() on tensor with symbolic sizes/st... @@ -1164,7 +1162,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function... xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1197,7 +1194,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ... - xfail('split', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('std', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('std_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 2e303922dfa1..643ff0ec862a 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1055,6 +1055,7 @@ def test(): xfail('segment_reduce', 'lengths'), xfail('sparse.sampled_addmm', ''), xfail("native_batch_norm"), + xfail("native_dropout_backward"), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): if not op.supports_autograd: @@ -1220,6 +1221,8 @@ def get_vjp(cotangents, *primals): xfail('segment_reduce', 'offsets'), # NYI: forward-AD for segment_reduce xfail('index_reduce', ''), # NYI: forward-AD for index_reduce xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce + xfail('native_dropout_backward'), # NYI + })) @opsToleranceOverride('TestOperators', 'test_jvpvjp', ( tol1('masked.prod', @@ -1377,6 +1380,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): # input while the running_mean or running_var, which will be updated in # place, were not batched. xfail("native_batch_norm"), + xfail('native_dropout_backward',) })) @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index fb8722b8405b..0c38c5101cf8 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3239,6 +3239,7 @@ def test(): xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops xfail('sparse.sampled_addmm'), # sparse xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work + skip('_softmax_backward_data'), skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format # ---------------------------------------------------------------------- @@ -3380,6 +3381,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('bernoulli', ''), xfail('linalg.lu_factor', ''), xfail('nn.functional.feature_alpha_dropout', 'with_train'), + xfail('native_dropout_backward'), xfail('nn.functional.kl_div', ''), xfail('multinomial', ''), xfail('column_stack', ''), @@ -3453,6 +3455,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('equal', ''), xfail('linalg.lu', ''), skip('linalg.ldl_solve', ''), + skip('_softmax_backward_data'), })) def test_op_has_batch_rule(self, device, dtype, op): # needs to be fixed diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 4e706bab0ea6..83d8d40e21ec 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -430,6 +430,7 @@ def wrapper_set_seed(op, *args, **kwargs): "randn": {"assert_equal": False}, ("nn.functional.tanhshrink", "cuda", f16): {"atol": 3e-4, "rtol": 0.001}, ("cummax", "cuda", f16): {"atol": 5e-4, "rtol": 0.002}, + ("_softmax_backward_data", "cuda", f16): {"atol": 0.008, "rtol": 0.002}, "gradient": {"check_gradient": False}, # segfault on check_gradient # Following tests failed, and causing subsequent tests failing with unrecoverable CUDA error "linalg.solve_triangular": {"check_gradient": False}, diff --git a/test/test_decomp.py b/test/test_decomp.py index 67e99d5eb829..a3658792c5e7 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -294,6 +294,9 @@ def normalize_op_input_output(f, sample, requires_grad=True): (None, None, "meshgrid"), # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) (None, None, "diag"), + + # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 + ("cpu", torch.bfloat16, "_softmax_backward_data"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 42ecc3d376ab..894b35693430 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1124,7 +1124,6 @@ def f(a, b, c, d, e): xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel xfail('cdist', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... - xfail('chunk', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('column_stack', ''), # Tensors of type TensorImpl do not have numel xfail('combinations', ''), xfail('count_nonzero', ''), # Could not run 'aten::count_nonzero.dim_IntList' with arguments from the 'Meta' ba... @@ -1247,7 +1246,6 @@ def f(a, b, c, d, e): xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco... xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d... - xfail('nn.functional.interpolate', 'bilinear'), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function... xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d... xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... @@ -1267,7 +1265,6 @@ def f(a, b, c, d, e): xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de... xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1313,7 +1310,6 @@ def f(a, b, c, d, e): xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... - xfail('split', ''), # 'torch._C.SymIntNode' and 'int' xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... xfail('sum_to_size', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('svd', ''), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition @@ -1439,10 +1435,13 @@ def _fn(t, *args, **kwargs): return _fn def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): - def f(args, kwargs, extra_args): + def f(args, kwargs, extra_args, extra_kwargs): if extra_args: for i, t in extra_args: args[i] = t.size() + if extra_kwargs: + for k, t in extra_kwargs.items(): + kwargs[k] = t.size() fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op return fn(*args, **kwargs) @@ -1463,23 +1462,26 @@ def f(args, kwargs, extra_args): # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in # symbolic mode, a no-op otherwise) extra_args = [] + extra_kwargs = {} for i, arg in enumerate(args): if isinstance(arg, torch.Size): - extra_args.append((i, torch.empty((), device="cpu").expand(arg))) - # TODO: support kwargs + extra_args.append((i, torch.empty(arg, device="cpu"))) + for key, value in kwargs.items(): + if isinstance(value, torch.Size): + extra_kwargs[key] = torch.empty(value, device="cpu") try: - new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args) + new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) except DynamicOutputShapeException as e: self.skipTest("Dynamic output shape operation in trace") for arg in args: if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: arg.uniform_(0, 1) try: - old_out = f(args, kwargs, extra_args) + old_out = f(args, kwargs, extra_args, extra_kwargs) except Exception: continue - new_out = wrapper_set_seed(new_f, args, kwargs, extra_args) + new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) self.assertEqual(new_out, old_out) class TestProxyTensorOpInfo(TestCase): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 1a2d332e99fd..7c84cb7e2ca8 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4,7 +4,7 @@ from enum import Enum from functools import partial, reduce from itertools import product -from typing import Callable, cast, Iterable, List, Optional, Tuple +from typing import Callable, cast, Iterable, List, Optional, Tuple, Union import torch import torch._prims_common as utils @@ -13,6 +13,7 @@ from torch._decomp import register_decomposition from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper +from torch.fx.experimental.symbolic_shapes import guard_int, sym_float, sym_int from torch.utils._pytree import tree_flatten, tree_map DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined] @@ -696,7 +697,12 @@ def _softmax_backward_data( grad_input = new_grad_output - output * torch.sum( new_grad_output, dim=dim, keepdim=True ) - return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype) + + # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor + # if grad_output.device == torch.device("cpu"): + # return grad_input.contiguous() + + return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous() @register_decomposition(aten._log_softmax_backward_data) @@ -912,9 +918,17 @@ def check_positive(param, param_name, strict=True): @register_decomposition(aten.native_dropout_backward) -@pw_cast_for_opmath def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float): - return grad_output * (mask.type_as(grad_output) * scale) + # According to the CUDA kernel implementation we should have this test; + # but it seems to fail tests! + # utils.check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}") + + # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format + # This different from TensorIterator's behavior + r = (grad_output * (mask.type_as(grad_output) * scale)).clone( + memory_format=utils.suggest_memory_format(grad_output) + ) + return r @register_decomposition(aten.unfold_backward) @@ -1095,8 +1109,9 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]: assert dim_size == 0 return [self] chunks = (dim_size + split_size - 1) // split_size + chunks = guard_int(chunks) split_sizes = [split_size for i in range(chunks)] - split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size) + split_sizes[-1] = split_size - (split_size * chunks - dim_size) return torch.split(self, split_sizes, dim) @@ -1786,29 +1801,74 @@ def norm( return torch.linalg.vector_norm(self, p, dim, keepdim, dtype=dtype) +# aten/src/ATen/native/UpSample.cpp compute_output_size +def upsample_compute_output_size(input_size, output_size, scale_factors): + spatial_dimensions = len(input_size) - 2 + if output_size is not None: + utils.check( + scale_factors is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(output_size) == spatial_dimensions, lambda: "") + return output_size + if scale_factors is not None: + # NB: this isn't necessary lol + utils.check( + output_size is None, + lambda: "Must specify exactly one of output_size and scale_factors", + ) + utils.check(len(scale_factors) == spatial_dimensions, lambda: "") + return [ + # Returning output_size as float. We cannot convert it to int directly, + # as latter computation of scale_factor is relying output size being float + sym_float(input_size[i + 2] * scale_factors[i]) + for i in range(spatial_dimensions) + ] + utils.check( + False, lambda: "Must specify exactly one of output_size and scale_factors" + ) + + +def get_scale_value(scales, idx): + if scales is None: + return None + return scales[idx] + + @register_decomposition(torch.ops.aten.upsample_bilinear2d.vec) -@register_decomposition(torch.ops.aten.upsample_bilinear2d.vec, type="pre_autograd") +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@torch.ops.aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd) +def upsample_bilinear2d_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + + # NB: osize could be a list of float when scale_factors is float + # so we cannot redispatch to aten.upsample_bilinear2d.default here + return upsample_bilinear2d(input, osize, align_corners, scale_h, scale_w) + + +@register_decomposition(torch.ops.aten.upsample_bilinear2d.default) +@torch.ops.aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd) @pw_cast_for_opmath -def upsample_bilinear2d_vec( +def upsample_bilinear2d( input: Tensor, - output_size: Optional[List[int]], + output_size: List[Union[int, float]], align_corners: bool, - scale_factors: Optional[List[float]], + scales_h: Optional[float] = None, + scales_w: Optional[float] = None, ) -> Tensor: # get dimensions of original image n_batch, n_channels, in_h, in_w = input.shape - if output_size is not None: - out_h = float(output_size[0]) - out_w = float(output_size[1]) - elif scale_factors is not None: - out_h = in_h * scale_factors[0] - out_w = in_w * scale_factors[1] + out_h = sym_float(output_size[0]) + out_w = sym_float(output_size[1]) # Calculate horizontal and vertical scaling factor + # TODO: Figure out if scales_h/scales_w matters here if out_h > 1: if align_corners: - h_scale_factor = (in_h - 1) / (int(out_h) - 1) + h_scale_factor = (in_h - 1) / (sym_int(out_h) - 1) else: h_scale_factor = in_h / out_h else: @@ -1816,14 +1876,14 @@ def upsample_bilinear2d_vec( if out_w > 1: if align_corners: - w_scale_factor = (in_w - 1) / (int(out_w) - 1) + w_scale_factor = (in_w - 1) / (sym_int(out_w) - 1) else: w_scale_factor = in_w / out_w else: w_scale_factor = 0.0 - i = torch.arange(int(out_h), dtype=input.dtype, device=input.device) - j = torch.arange(int(out_w), dtype=input.dtype, device=input.device) + i = torch.arange(sym_int(out_h), dtype=input.dtype, device=input.device) + j = torch.arange(sym_int(out_w), dtype=input.dtype, device=input.device) if align_corners: x = h_scale_factor * i diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 001fd455e82e..0c59af77736a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -406,6 +406,21 @@ def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): # running_mean and running_var are required in evaluation mode (training: False) but not in training mode yield SampleInput(make_arg((1, 2, 3)), args=(None, None, None, None), kwargs={'training': True}) +def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + cases = [ + ((S,), 0), + ((S, S), 0), + ((S, M, S), -1), + ] + input_dtypes = [dtype] + if dtype == torch.float and device == 'cuda': + input_dtypes += [torch.float16] + + for (shape, dim), input_dtype in product(cases, input_dtypes): + yield SampleInput(make_arg(shape), make_arg(shape), dim, input_dtype) def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) @@ -1173,7 +1188,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs): cases = ((), (S, S, S), (S,)) for shape in cases: - yield(SampleInput(make_arg(shape))) + yield SampleInput(make_arg(shape)) # TODO: add reduction kwargs def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs): @@ -3745,8 +3760,8 @@ def sample_inputs_upsample(mode, self, device, dtype, requires_grad, **kwargs): def shape(size, rank, with_batch_channel=True): if with_batch_channel: - return tuple([N, C] + ([size] * rank)) - return tuple([size] * rank) + return torch.Size([N, C] + ([size] * rank)) + return torch.Size([size] * rank) make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=-1, high=1) @@ -5794,9 +5809,9 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals if list_args: cases = ( - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), - ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], 2),), - ((S, S, S), ([int(S / 2), S - int(S / 2) * 2, int(S / 2)], -2),) + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), 2),), + ((S, S, S), (torch.Size([int(S / 2), S - int(S / 2) * 2, int(S / 2)]), -2),) ) else: cases = ( # type: ignore[assignment] @@ -5811,10 +5826,10 @@ def sample_inputs_split(op_info, device, dtype, requires_grad, *, list_args=Fals def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - cases = (((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)],)), - ((S, S, S), ([int(S / 3), S - int(S / 3), 0],)), - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], 2)), - ((S, S, S), ([int(S / 3), S - int(S / 3) * 2, int(S / 3)], -2)), + cases = (((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3), 0]),)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), 2)), + ((S, S, S), (torch.Size([int(S / 3), S - int(S / 3) * 2, int(S / 3)]), -2)), ) for shape, args in cases: @@ -6190,7 +6205,7 @@ def sample_inputs_resize_ops(op_info, device, dtype, requires_grad, **kwargs): else: raise ValueError("sample_inputs_resize_ops is being used with incorrect operator") - yield(SampleInput(make_arg(shape, requires_grad=requires_grad), args=args)) + yield SampleInput(make_arg(shape, requires_grad=requires_grad), args=args) def sample_inputs_view_reshape(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6446,7 +6461,7 @@ def sample_inputs_expand(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield(SampleInput(make_arg(shape), args=(args, ))) + yield SampleInput(make_arg(shape), args=(args,)) def sample_inputs_conversion(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6469,8 +6484,8 @@ def sample_inputs_expand_as(op_info, device, dtype, requires_grad, **kwargs): ) for shape, shape_other in cases: - yield(SampleInput(make_arg(shape, requires_grad=requires_grad), - args=(make_arg(shape_other, requires_grad=False), ))) + yield SampleInput(make_arg(shape, requires_grad=requires_grad), + args=(make_arg(shape_other, requires_grad=False),)) def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs): @@ -6588,8 +6603,8 @@ def sample_inputs_nonzero(op_info, device, dtype, requires_grad, **kwargs): inputs.append(mixed) for input_t, as_tuple in product(inputs, [False, True]): - yield(SampleInput(input_t.clone().requires_grad_(requires_grad), - kwargs=dict(as_tuple=as_tuple))) + yield SampleInput(input_t.clone().requires_grad_(requires_grad), + kwargs=dict(as_tuple=as_tuple)) def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) @@ -6600,7 +6615,7 @@ def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs): for case in cases: shape, args = case - yield(SampleInput(make_arg(shape), args=args)) + yield SampleInput(make_arg(shape), args=args) def reference_inputs_chunk(op, device, dtype, requires_grad, **kwargs): yield from sample_inputs_chunk(op, device, dtype, requires_grad, **kwargs) @@ -6678,6 +6693,15 @@ def sample_inputs_dropout(op_info, device, dtype, requires_grad, *, yield SampleInput(make_arg(case), p=p, training=training) yield SampleInput(make_arg(case)) +def sample_inputs_dropout_backward(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + make_mask = partial(make_tensor, device=device, dtype=torch.bool, requires_grad=False) + + cases = ((S, S, S, S), (S,), ()) + scale_vals = [0.0, 1.0, 2.0] + + for case, scale in product(cases, scale_vals): + yield SampleInput(make_arg(case), make_mask(case), scale) def sample_inputs_embedding_bag(op_info, device, dtype, requires_grad, **kwargs): def make_input(shape): @@ -8095,7 +8119,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): in_shape = input.shape in_rank = len(in_shape) for d in start_dim, end_dim: - if not((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): + if not ((in_rank == 0 and d in (-1, 0)) or -in_rank <= d < in_rank): raise IndexError(f"Dimension out of range (expected to be in range of [{-in_rank}, {in_rank-1}], but got {d}") end_dim = end_dim if end_dim >= 0 else in_rank + end_dim start_dim = start_dim if start_dim >= 0 else in_rank + start_dim @@ -8424,7 +8448,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): variant_test_name='decomposed', dtypes=all_types_and_complex_and(torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, - *[torch.bfloat16] if(CUDA11OrLater or TEST_WITH_ROCM) else []), + *[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []), assert_autodiffed=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -10554,6 +10578,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_out=True), + OpInfo( + '_softmax_backward_data', + op=torch.ops.aten._softmax_backward_data, + aten_name='_softmax_backward_data', + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.bfloat16, torch.float16), + sample_inputs_func=sample_inputs_softmax_backward_data, + assert_autodiffed=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_out=False, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + ), + ), # `softmin` supports different dtypes based on whether `dtype` argument, # is passed or not. Hence two OpInfo entries, one with dtype and other without. # https://github.com/pytorch/pytorch/issues/68752 @@ -15927,6 +15967,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_dropout, inplace_variant=lambda input, *args, **kwargs: wrapper_set_seed(torch.nn.functional.dropout, input, *args, **kwargs, inplace=True)), + OpInfo( + "native_dropout_backward", + op=torch.ops.aten.native_dropout_backward.default, + aten_name="native_dropout_backward", + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + supports_out=False, + sample_inputs_func=sample_inputs_dropout_backward, + skips=( + DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'), + # Lazy tensor failures + DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), + ), + ), OpInfo( "nn.functional.dropout2d", op=lambda input, *args, **kwargs: From b9029fc4497a9453e76892c9cf56144add89faf7 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Fri, 11 Nov 2022 08:55:40 -0800 Subject: [PATCH 172/453] [ao] quant_type.py fixing public v private (#87519) Summary: made _get_quant_type_to_str private Test Plan: python test/test_public_bindings.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709282](https://our.internmc.facebook.com/intern/diff/D40709282) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87519 Approved by: https://github.com/jcaip --- test/allowlist_for_publicAPI.json | 4 ++-- test/quantization/ao_migration/test_quantization.py | 2 +- test/quantization/fx/test_quantize_fx.py | 4 ++-- torch/ao/quantization/__init__.py | 1 - torch/ao/quantization/fx/custom_config.py | 6 +++--- torch/ao/quantization/quant_type.py | 3 +-- torch/quantization/__init__.py | 2 +- torch/quantization/quant_type.py | 2 +- 8 files changed, 11 insertions(+), 13 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index ba4a2e96df21..94ff57700af6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -795,7 +795,7 @@ "prepare_qat", "propagate_qconfig_", "qconfig_equals", - "quant_type_to_str", + "_get_quant_type_to_str", "quantize", "quantize_dynamic", "quantize_dynamic_jit", @@ -874,7 +874,7 @@ ], "torch.quantization.quant_type": [ "QuantType", - "quant_type_to_str" + "_get_quant_type_to_str" ], "torch.quantization.quantization_mappings": [ "get_default_compare_output_module_list", diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 2617e7a1187d..9c246e1b7cd8 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -118,7 +118,7 @@ def test_package_import_quant_type(self): def test_function_import_quant_type(self): function_list = [ 'QuantType', - 'quant_type_to_str', + '_get_quant_type_to_str', ] self._test_function_import('quant_type', function_list) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 6eb9246c85a7..6721e397180e 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -33,8 +33,8 @@ from torch.ao.quantization import ( QuantType, - quant_type_to_str, ) +from torch.ao.quantization.quant_type import _get_quant_type_to_str from torch.ao.quantization import ( QuantStub, @@ -2636,7 +2636,7 @@ def forward(self, x): } for quant_type in [QuantType.STATIC, QuantType.DYNAMIC]: - key = quant_type_to_str(quant_type) + key = _get_quant_type_to_str(quant_type) qconfig, quantized_module_class, num_observers = test_configs[key] qconfig_dict = {"": qconfig} if key == "static": diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 2e8390c1acc7..1ba2a60ed3d1 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -124,7 +124,6 @@ "prepare_qat", "propagate_qconfig_", "qconfig_equals", - "quant_type_to_str", "quantize", "quantize_dynamic", "quantize_dynamic_jit", diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 0f5f5bfe8d15..9d08853a4126 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -4,7 +4,7 @@ from torch.ao.quantization import QConfigMapping from torch.ao.quantization.backend_config import BackendConfig -from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, quant_type_to_str +from torch.ao.quantization.quant_type import QuantType, _quant_type_from_str, _get_quant_type_to_str __all__ = [ @@ -263,7 +263,7 @@ def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): for quant_type, float_to_observed_mapping in self.float_to_observed_mapping.items(): if FLOAT_TO_OBSERVED_DICT_KEY not in d: d[FLOAT_TO_OBSERVED_DICT_KEY] = {} - d[FLOAT_TO_OBSERVED_DICT_KEY][quant_type_to_str(quant_type)] = float_to_observed_mapping + d[FLOAT_TO_OBSERVED_DICT_KEY][_get_quant_type_to_str(quant_type)] = float_to_observed_mapping if len(self.non_traceable_module_names) > 0: d[NON_TRACEABLE_MODULE_NAME_DICT_KEY] = self.non_traceable_module_names if len(self.non_traceable_module_classes) > 0: @@ -350,7 +350,7 @@ def to_dict(self) -> Dict[str, Any]: for quant_type, observed_to_quantized_mapping in self.observed_to_quantized_mapping.items(): if OBSERVED_TO_QUANTIZED_DICT_KEY not in d: d[OBSERVED_TO_QUANTIZED_DICT_KEY] = {} - d[OBSERVED_TO_QUANTIZED_DICT_KEY][quant_type_to_str(quant_type)] = observed_to_quantized_mapping + d[OBSERVED_TO_QUANTIZED_DICT_KEY][_get_quant_type_to_str(quant_type)] = observed_to_quantized_mapping if len(self.preserved_attributes) > 0: d[PRESERVED_ATTRIBUTES_DICT_KEY] = self.preserved_attributes return d diff --git a/torch/ao/quantization/quant_type.py b/torch/ao/quantization/quant_type.py index 9d2a3a2bdc7b..d3b1d034a1fe 100644 --- a/torch/ao/quantization/quant_type.py +++ b/torch/ao/quantization/quant_type.py @@ -2,7 +2,6 @@ __all__ = [ "QuantType", - "quant_type_to_str", ] # Quantization type (dynamic quantization, static quantization). @@ -21,7 +20,7 @@ class QuantType(enum.IntEnum): } # TODO: make this private -def quant_type_to_str(quant_type: QuantType) -> str: +def _get_quant_type_to_str(quant_type: QuantType) -> str: return _quant_type_to_str[quant_type] def _quant_type_from_str(name: str) -> QuantType: diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index df9a75d02264..6e4ede123eb0 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -30,7 +30,7 @@ def default_eval_fn(model, calib_data): # Top level API for graph mode quantization on GraphModule(torch.fx) # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx', - 'QuantType', 'quant_type_to_str', # quantization type + 'QuantType', # quantization type # custom module APIs 'get_default_static_quant_module_mappings', 'get_static_quant_module_class', 'get_default_dynamic_quant_module_mappings', diff --git a/torch/quantization/quant_type.py b/torch/quantization/quant_type.py index cd2e5e020a6a..c7f7cc15dbdd 100644 --- a/torch/quantization/quant_type.py +++ b/torch/quantization/quant_type.py @@ -8,4 +8,4 @@ """ from torch.ao.quantization.quant_type import QuantType -from torch.ao.quantization.quant_type import quant_type_to_str +from torch.ao.quantization.quant_type import _get_quant_type_to_str From b815f1fc502387311a7b4da8c2f52ead56cbfff5 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 15 Nov 2022 13:05:30 +0000 Subject: [PATCH 173/453] Symintify view_as_complex and view_as_real (#89052) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #89052 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89052 Approved by: https://github.com/ezyang --- aten/src/ATen/native/ComplexHelper.h | 31 ++++++++++++++-------------- test/functorch/test_aotdispatch.py | 4 ---- test/test_proxy_tensor.py | 1 - torch/_subclasses/fake_tensor.py | 2 ++ 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/aten/src/ATen/native/ComplexHelper.h b/aten/src/ATen/native/ComplexHelper.h index 8d69f6292772..9533115a7066 100644 --- a/aten/src/ATen/native/ComplexHelper.h +++ b/aten/src/ATen/native/ComplexHelper.h @@ -18,19 +18,18 @@ namespace at { namespace native { // View tensor with new dtype, storage offset, sizes and strides inline Tensor view_tensor( const Tensor &tensor, ScalarType dtype, - int64_t offset, IntArrayRef sizes, IntArrayRef strides) { + c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) { Storage storage = tensor.storage(); auto key_set = tensor.key_set().remove(DispatchKey::Conjugate); auto new_tensor = detail::make_tensor( c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype)); auto * impl = new_tensor.unsafeGetTensorImpl(); - impl->set_storage_offset(offset); - impl->set_sizes_and_strides(sizes, strides); + impl->set_sizes_and_strides(sizes, strides, offset); return new_tensor; } -inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { - DimVector res(oldstride.size() + 1); +inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) { + SymDimVector res(oldstride.size() + 1); for (const auto i : c10::irange(oldstride.size())) { res[i] = oldstride[i] * 2; } @@ -40,13 +39,13 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) { Tensor _view_as_real_physical(const Tensor& self) { TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors"); - auto old_sizes = self.sizes(); - DimVector new_sizes(old_sizes.size() + 1); + auto old_sizes = self.sym_sizes(); + SymDimVector new_sizes(old_sizes.size() + 1); std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin()); // last dimension will always have two elements containing the real and imag vals new_sizes.back() = 2; - auto new_strides = computeStrideForViewAsReal(self.strides()); - auto new_storage_offset = 2 * self.storage_offset(); + auto new_strides = computeStrideForViewAsReal(self.sym_strides()); + auto new_storage_offset = self.sym_storage_offset() * 2; const auto float_type = c10::toRealValueType(self.scalar_type()); auto real_tensor = view_tensor(self, float_type, new_storage_offset, new_sizes, new_strides); return real_tensor; @@ -60,11 +59,11 @@ Tensor view_as_real(const Tensor& self) { return _view_as_real_physical(self); } -inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) { +inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) { const int64_t dim = oldstride.size(); TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1"); - DimVector res(dim - 1); + SymDimVector res(dim - 1); for (const auto i : c10::irange(res.size())) { TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension"); res[i] = oldstride[i] / 2; @@ -79,16 +78,16 @@ Tensor view_as_complex(const Tensor& self) { self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf, "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type()); - auto old_sizes = self.sizes(); + auto old_sizes = self.sym_sizes(); TORCH_CHECK(old_sizes.size() != 0, "Input tensor must have one or more dimensions"); TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2"); - DimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); + SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1); - const auto new_strides = computeStrideForViewAsComplex(self.strides()); + const auto new_strides = computeStrideForViewAsComplex(self.sym_strides()); const auto complex_type = c10::toComplexType(self.scalar_type()); - TORCH_CHECK(self.storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); - const auto new_storage_offset = self.storage_offset() / 2; + TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2"); + const auto new_storage_offset = self.sym_storage_offset() / 2; return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides); } diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index eb34a3fb7582..752b03ac9984 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1000,14 +1000,11 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cartesian_prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('cdouble'), # RuntimeError: aten.view_as_real.default - couldn't find symbolic meta function/decomposition - xfail('cfloat'), # RuntimeError: aten.view_as_real.default - couldn't find symbolic meta function/decomposition xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cholesky_inverse', ''), # could not find kernel xfail('cholesky_solve', ''), # could not find kernel xfail('column_stack', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('combinations', ''), # aten.masked_select.default - xfail('complex', ''), # aten.view_as_real.default - couldn't find symbolic meta function/decomposition xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('cummax', ''), # aten.cummax.default - couldn't find symbolic meta function/decomposition xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition @@ -1211,7 +1208,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('unflatten', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('var', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('var_mean', ''), # Cannot call numel() on tensor with symbolic sizes/strides - xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/deco... xfail('view_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('vsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides } diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 894b35693430..24efcab9e5cb 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1321,7 +1321,6 @@ def f(a, b, c, d, e): xfail('trapz', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('trapezoid', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition - xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 65f571f93ec0..8dec2475df15 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -971,6 +971,8 @@ def cpp_meta_supports_symint(self, func): aten.as_strided_.default, aten.zeros.default, aten.detach.default, + aten.view_as_real.default, + aten.view_as_complex.default, aten.set_.source_Storage_storage_offset, aten._sparse_coo_tensor_with_dims_and_tensors.default, ] From 62ba15e10e875ce088dff26e872605ee70c8c04a Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Mon, 14 Nov 2022 23:26:15 -0800 Subject: [PATCH 174/453] Rewrite assert statement with torch._assert under config (#88246) This diff rewrites assert statement in python with torch._assert under config. The resulting graph looks something like: ``` SOURCE CODE: def f(x): assert x[0] == 3 return x.cos() CAPTURED GRAPH: graph(): %arg0 : [#users=2] = placeholder[target=arg0] %getitem : [#users=1] = call_function[target=operator.getitem](args = (%arg0, 0), kwargs = {}) %eq : [#users=1] = call_function[target=operator.eq](args = (%getitem, 3), kwargs = {}) %_assert : [#users=0] = call_function[target=torch._assert](args = (%eq, "assertion_error"), kwargs = {}) %cos : [#users=1] = call_method[target=cos](args = (%arg0,), kwargs = {}) return cos ``` Note that this introduces side-effect as it could error out while executing graph, but the assertion can eliminated via DCE if we choose to ignore it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88246 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 92 ++++++++++++++++++++++++++++++ torch/_dynamo/config.py | 3 + torch/_dynamo/symbolic_convert.py | 94 +++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 503231b4cb12..e30a1275ed13 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1938,6 +1938,98 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, "First dim need to be 3" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + cnt = torch._dynamo.testing.CompileCounter() + + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + self.assertEqual(cnt.op_count, 6) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_not_rewrite_assert_for_other_errors(self): + def f(x): + b = x.sin() + if not x.sum() <= 3: + raise ValueError("input sum needs to be 3") + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + opt_fn = torch._dynamo.optimize("eager")(f) + with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): + opt_fn(*args) + + # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement. + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_fstring_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, f"First dim need to be {x[0]}" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_without_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_noop(self): + def f(x): + b = x.sin() + assert True + assert x.dtype == torch.float32 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + cnt = torch._dynamo.testing.CompileCounter() + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + # torch._assert shouldn't be in the graph + self.assertEqual(cnt.op_count, 3) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False) + def test_not_rewrite_assert(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 12088383e741..39a1a6433419 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -87,6 +87,9 @@ # if an exception is encountered replay_record_enabled = False +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True + # Show a warning on every graph break print_graph_breaks = False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d707bee930ee..d5c05f76efb0 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -53,6 +53,7 @@ fake_tensors_available, graph_break_dup_warning_checker, istype, + proxy_args_kwargs, ) from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy @@ -121,10 +122,103 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction): return impl +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if sys.version_info < (3, 9): + if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": + return False + else: + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + if current_instruction_pointer >= len(self.instructions): + return False + + inst = self.instructions[current_instruction_pointer] + has_error_msg = False + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + self.LOAD_CONST(inst) + has_error_msg = True + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + if inst.opname != "CALL_FUNCTION": + return False + + # CALL_FUNCTION should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + if not has_error_msg: + # Push dummy value instead of error message + self.push(ConstantVariable("assertion error")) + + return True + + def generic_jump(truth_fn: typing.Callable, push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + self.output.guards.update(error_msg.guards) + # Skip over things like `assert True` + if value.is_python_constant() and bool(value.as_python_constant()): + self.jump(inst) + return + + # Manually insert torch._assert instead of python assert and jump over + # assert related instructions as we don't need them anymore. + self.output.create_proxy( + "call_function", + torch._assert, + *proxy_args_kwargs((value, error_msg), {}), + current_tx=self, + ) + self.jump(inst) + return + if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value) From 2819df9a19480feba72f9c613be25e56d4f05142 Mon Sep 17 00:00:00 2001 From: Pruthvi Madugundu Date: Tue, 15 Nov 2022 17:49:00 +0000 Subject: [PATCH 175/453] [ROCm] Enable python ref executor UTs for ROCm (#88981) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88981 Approved by: https://github.com/mruberry --- test/test_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index c688f6521af1..0ef2e4ee6d60 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -57,7 +57,6 @@ onlyCPU, onlyNativeDeviceTypes, OpDTypes, - skipCUDAIfRocm, skipMeta, ) from torch._subclasses.fake_tensor import ( @@ -393,7 +392,6 @@ def test_python_ref_torch_fallback(self, device, dtype, op): @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @onlyCUDA - @skipCUDAIfRocm @ops(python_ref_db) @parametrize('executor', ['aten', 'nvfuser']) @skipIfTorchInductor("Takes too long for inductor") From ee4412381ea3577fbf32858f35f8b76bdc548b49 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Tue, 15 Nov 2022 17:55:29 +0000 Subject: [PATCH 176/453] Allow ROCm runners to have 2 or more gpus (#89011) [This run](https://github.com/pytorch/pytorch/actions/runs/3432340660/jobs/5721731207) failed claiming that it couldn't detect GPUs on the runner. Inspecting the rocminfo output (higher up in logs) show that it in fact had three GPUs, but the workflow is currently setup to expect either 2 or 4 gpus. The workflow files currently have no way of specifying wither it'll get a 2 gpu or a 4 gpu machine, so really 2 is all any test can expect to get. [This old PR](https://github.com/pytorch/pytorch/pull/72142/files) shows that historically ROCm runners only had 4 gpus, then later the logic was extended to expect 2 GPU runners as well. It's not clear how the ROCm runner ended up with 3 gpus instead of 2 or 4 (something for ROCm folks to look into) but there doesn't seem to be a good reason for ROCm workflows to fail if 3 (or 5) gpus ever show up on a machine. This PR makes the workflows resilient to ROCm having these alternate GPU counts Also filed https://github.com/pytorch/pytorch/issues/89012 against the ROCm team to explore why the runner only had 3 gpus Pull Request resolved: https://github.com/pytorch/pytorch/pull/89011 Approved by: https://github.com/huydhn --- .github/actions/setup-rocm/action.yml | 7 ++- .github/templates/common.yml.j2 | 7 ++- ...inux-binary-libtorch-cxx11-abi-nightly.yml | 28 ++++++++-- ...inux-binary-libtorch-pre-cxx11-nightly.yml | 28 ++++++++-- ...nerated-linux-binary-manywheel-nightly.yml | 56 ++++++++++++++++--- 5 files changed, 108 insertions(+), 18 deletions(-) diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index 97dfd22c76ac..d91762eb9a86 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -36,7 +36,12 @@ runs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index a2941546abe1..edb652ff16ce 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -78,7 +78,12 @@ concurrency: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 6a23b85f433a..f9ab6798787f 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -845,7 +845,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -988,7 +993,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1131,7 +1141,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1274,7 +1289,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index 27358089ba2d..55e4a19b8e8a 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -845,7 +845,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -988,7 +993,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1131,7 +1141,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1274,7 +1289,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index ac9edc252c28..efe3e2c0d17c 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -337,7 +337,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -477,7 +482,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -855,7 +865,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -995,7 +1010,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1373,7 +1393,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1513,7 +1538,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -1891,7 +1921,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure @@ -2031,7 +2066,12 @@ jobs: run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') if [[ "x$ngpu" != "x2" && "x$ngpu" != "x4" ]]; then - echo "Failed to detect GPUs on the runner" + if [[ $ngpu -eq 0 ]]; then + echo "Error: Failed to detect any GPUs on the runner" + else + echo "Error: Detected $ngpu GPUs on the runner, when only 2 or 4 were expected" + fi + echo "Please file an issue on pytorch/pytorch reporting the faulty runner. Include a link to the runner logs so the runner can be identified" exit 1 fi - name: Runner health check disconnect on failure From 1db0f735e8fe14245e98e875c15ecf95ed2142ce Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Fri, 11 Nov 2022 16:30:01 -0800 Subject: [PATCH 177/453] [Profiler] Account for caching when assigning IDs (#88917) The python tracer caches information about module and optimizer state. That means that for subsequent calls, the presence of a Tensor in these fields does not imply that the Tensor is still live; just that it was live during the first call. (I should perhaps rename the fields to something like `stale_parameters` to convey this.) Unless we discard subsequent calls ID assignment get tripped up when it see's a Tensor that was already released. Differential Revision: [D41226827](https://our.internmc.facebook.com/intern/diff/D41226827/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88917 Approved by: https://github.com/chaekit --- torch/csrc/profiler/data_flow.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/csrc/profiler/data_flow.cpp b/torch/csrc/profiler/data_flow.cpp index 543202786788..dcb3eaffd439 100644 --- a/torch/csrc/profiler/data_flow.cpp +++ b/torch/csrc/profiler/data_flow.cpp @@ -69,6 +69,10 @@ void calculateUniqueTensorIDs( // -------------------------------------------------------------------------- { RawTensors raw_tensors; + + // The python tracer caches values, so it's only safe to use the first case. + ska::flat_hash_set seen_modules; + ska::flat_hash_set seen_optimizers; for (auto& result : sorted_results) { result->visit(c10::overloaded( [&](ExtraFields& torch_op) { @@ -78,7 +82,8 @@ void calculateUniqueTensorIDs( }, [&](ExtraFields& py_call) { // torch.nn.Module - if (py_call.module_.has_value()) { + if (py_call.module_.has_value() && + seen_modules.insert(py_call.module_->self_).second) { for (auto& p : py_call.module_->parameters_) { raw_tensors(p.metadata_); raw_tensors(p.grad_metadata_); @@ -86,7 +91,8 @@ void calculateUniqueTensorIDs( } // torch.optim.Optimizer - if (py_call.optimizer_.has_value()) { + if (py_call.optimizer_.has_value() && + seen_optimizers.insert(py_call.optimizer_->self_).second) { for (auto& p : py_call.optimizer_->parameters_) { raw_tensors(p.metadata_); raw_tensors(p.grad_metadata_); From 279dcce702a56f5b3ce5e864fa4db2f882e01084 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Tue, 15 Nov 2022 19:08:31 +0000 Subject: [PATCH 178/453] disable test that fails in fbcode (#88786) Summary: caffe2/test:torch_cuda - test_advanced_indexing_assignment_lazy (test_view_ops.TestViewOpsLAZY) RuntimeError: TorchScript backend not yet supported in FBCODE/OVRSOURCE builds File "/usr/local/fbcode/platform010/lib/python3.8/unittest/suite.py", line 163, in _handleClassSetUp setUpClass() File "/re_cwd/fbcode/buck-out/opt/gen/caffe2/test/torch_cuda#binary,link-tree/torch/testing/_internal/common_device_type.py", line 506, in setUpClass torch._lazy.ts_backend.init() File "/re_cwd/fbcode/buck-out/opt/gen/caffe2/test/torch_cuda#binary,link-tree/torch/_lazy/ts_backend.py", line 6, in init torch._C._lazy_ts_backend._init() Test Plan: Rely on CI. Differential Revision: D41170545 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88786 Approved by: https://github.com/zou3519 --- test/test_view_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_view_ops.py b/test/test_view_ops.py index c4729557c416..3c4376b501f9 100644 --- a/test/test_view_ops.py +++ b/test/test_view_ops.py @@ -9,7 +9,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, + IS_FBCODE, TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, numpy_to_torch_dtype_dict, skipIfTorchDynamo ) from torch.testing._internal.common_device_type import \ @@ -857,6 +857,7 @@ def test_advanced_indexing_nonview(self, device): nv[1, 1] = 0 self.assertNotEqual(t[2, 2], nv[1, 1]) + @unittest.skipIf(IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds") def test_advanced_indexing_assignment(self, device): t = torch.ones(3, 3, device=device) rows = torch.tensor([[0, 0], [2, 2]], device=device) From 2439bc1e9bab3721bb9f1c4853baf03b610c89da Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Fri, 11 Nov 2022 16:30:03 -0800 Subject: [PATCH 179/453] [Profiler] Memory profiler part 2: Config validation (#86853) Memory profiling requires `record_shapes`, `profile_memory`, and `with_stack`. This PR just adds a skeleton endpoint with a good error message if certain flags are missing. Differential Revision: [D39920801](https://our.internmc.facebook.com/intern/diff/D39920801/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86853 Approved by: https://github.com/chaekit --- test/profiler/test_memory_profiler.py | 22 ++++++++++++++++++++++ torch/profiler/_memory_profiler.py | 13 ++++++++++++- torch/profiler/profiler.py | 24 +++++++++++++++++++++--- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index c725f8bec51a..3fd6b04b8a76 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -12,6 +12,28 @@ torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True ) +@skipIfTorchDynamo("TorchDynamo removes profiler altogether.") +class TestMemoryProfiler(TestCase): + def test_config_check(self) -> None: + with torch.profiler.profile() as prof: + pass + + pattern = r"record_shapes=True, profile_memory=True, with_stack=True" + with self.assertRaisesRegex(ValueError, pattern): + prof._memory_profile() + + with torch.profiler.profile(record_shapes=True, with_stack=True) as prof: + pass + + pattern = r"^profile_memory=True required for memory profiling\.$" + with self.assertRaisesRegex(ValueError, pattern): + prof._memory_profile() + + with profile() as prof: + pass + + self.assertIsInstance(prof._memory_profile(), _memory_profiler.MemoryProfile) + class ScaleLayer(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index cab771931489..355d3322a4e0 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -2,7 +2,13 @@ from typing import Any, Iterator, Optional, Tuple import torch -from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata, RecordScope +from torch._C._autograd import _ProfilerResult +from torch._C._profiler import ( + _EventType, + _ProfilerEvent, + _TensorMetadata, + RecordScope, +) @dataclasses.dataclass @@ -112,3 +118,8 @@ def extract_gradients( p_grad_key = TensorKey.from_tensor(p_grad) if p_grad_key is not None: yield TensorKey.from_tensor(p), p_grad_key + + +class MemoryProfile: + def __init__(self, result: _ProfilerResult) -> None: + pass diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index ceca36126dcd..31b85eb26f0f 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -16,10 +16,19 @@ _ExperimentalConfig, _remove_execution_graph_observer, ) -from torch.autograd import ProfilerActivity, kineto_available +from torch.autograd import kineto_available, ProfilerActivity +from torch.profiler import _memory_profiler + + +__all__ = [ + "supported_activities", + "ProfilerAction", + "schedule", + "tensorboard_trace_handler", + "profile", + "ExecutionGraphObserver", +] -__all__ = ['supported_activities', 'ProfilerAction', 'schedule', 'tensorboard_trace_handler', 'profile', - 'ExecutionGraphObserver'] def supported_activities(): """ @@ -208,6 +217,15 @@ def _get_distributed_info(self): "world_size": dist.get_world_size() } + def _memory_profile(self) -> _memory_profiler.MemoryProfile: + required = ("record_shapes", "profile_memory", "with_stack") + missing = [f"{i}=True" for i in required if not getattr(self, i)] + if missing: + raise ValueError(f"{', '.join(missing)} required for memory profiling.") + + assert self.profiler is not None and self.profiler.kineto_results is not None + return _memory_profiler.MemoryProfile(self.profiler.kineto_results) + class ProfilerAction(Enum): """ From 8023c9dc6420bce8e37ad4e4e363cb7bed7f70de Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Fri, 11 Nov 2022 16:30:05 -0800 Subject: [PATCH 180/453] [Profiler] Memory profiler part 3: Schema parsing and mutable arguments (#86854) The appropriate annotation for a block of memory is a function of time: an input can be mutated in-place to become an activation, a clever kernel might steal the memory of a detached input (such as a mask) to use as output memory, etc. We could pessimistically assume that all ops mutate all of their inputs, however inspection of schema allows us to significantly narrow that assumption with minimal effort. Checking schemas also allows us to distinguish between dispatcher ops (which have load bearing semantics) and user annotations with reasonably high precision. Differential Revision: [D40220390](https://our.internmc.facebook.com/intern/diff/D40220390/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86854 Approved by: https://github.com/chaekit --- test/profiler/test_memory_profiler.py | 105 +++++++++++++++++++++++- torch/_C/_profiler.pyi | 1 + torch/csrc/profiler/python/init.cpp | 1 + torch/profiler/_memory_profiler.py | 111 +++++++++++++++++++++++++- 4 files changed, 216 insertions(+), 2 deletions(-) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 3fd6b04b8a76..6924cb355659 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -1,6 +1,6 @@ # Owner(s): ["oncall: profiler"] import functools -from typing import Iterator, Optional +from typing import Iterator, List, Optional, Tuple import torch from torch._C._profiler import _EventType @@ -12,6 +12,7 @@ torch.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True ) + @skipIfTorchDynamo("TorchDynamo removes profiler altogether.") class TestMemoryProfiler(TestCase): def test_config_check(self) -> None: @@ -242,5 +243,107 @@ def test_extract_gradients_from_module_and_optimizer(self) -> None: ) +class TestDataFlow(TestCase): + @staticmethod + def formatSchemas( + prof: torch.profiler.profile, indent: int = 12 + ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]: + tree = prof.profiler.kineto_results.experimental_event_tree() + out: List[Tuple[str, Tuple[bool, ...]]] = [] + for node in _utils.traverse_dfs(tree): + if node.tag == _EventType.TorchOp: + e = node.extra_fields + schemas = _memory_profiler.SchemaMatcher.match_schemas(e) + name = node.name + if len(schemas) == 1: + name = f"{name}.{schemas[0].overload_name}" + elif len(schemas) > 1: + name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}" + + out.append((name, _memory_profiler.SchemaMatcher.inputs_are_mutable(e))) + return tuple(out) + + def test_match_schemas(self) -> None: + with profile() as prof: + x = torch.ones((1,)).mul(2).add_(2) + _ = torch.sin(x, out=torch.empty_like(x)) + + self.assertEqual( + self.formatSchemas(prof), + ( + ("aten::ones.", (False,) * 5), + ("aten::empty.memory_format", (False,) * 6), + # + # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + ("aten::fill_.Scalar", (True, False)), + ("aten::mul.Tensor", (False, False)), + ("aten::to.dtype", (False,) * 5), + ("aten::_to_copy.", (False,) * 7), + ("aten::empty_strided.", (False,) * 6), + # + # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + ("aten::copy_.", (True, False, False)), + # + # add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + ("aten::add_.Tensor", (True, False, False)), + ("aten::to.dtype", (False,) * 5), + ("aten::_to_copy.", (False,) * 7), + ("aten::empty_strided.", (False,) * 6), + # + # copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + ("aten::copy_.", (True, False, False)), + ("aten::empty_like.", (False,) * 6), + ("aten::empty_strided.", (False,) * 6), + # + # sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + ("aten::sin.out", (False, True)), + ), + ) + + def test_match_schemas_backward(self) -> None: + x = torch.ones((1,)) + w = torch.ones((1,), requires_grad=True) + with profile() as prof: + torch.mul(x, w).backward() + + self.assertEqual( + self.formatSchemas(prof), + ( + ("aten::mul.Tensor", (False, False)), + ("aten::ones_like.", (False,) * 6), + ("aten::empty_like.", (False,) * 6), + ("aten::empty_strided.", (False,) * 6), + # + # fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + ("aten::fill_.Scalar", (True, False)), + ("autograd::engine::evaluate_function: MulBackward0", ()), + # + # Cannot find schema, all inputs presumed mutable + ("MulBackward0", (True,)), + ("aten::mul.Tensor", (False, False)), + ( + "autograd::engine::evaluate_function: torch::autograd::AccumulateGrad", + (), + ), + # + # Cannot find schema, all inputs presumed mutable + ("torch::autograd::AccumulateGrad", (True,)), + ("aten::detach.", (False,)), + ("detach", (True,)), + ), + ) + + def test_match_schemas_tensorlist(self) -> None: + x = torch.ones((1,)) + y = torch.ones((1,)) + with profile() as prof: + torch.cat([x, y], axis=0) + + self.assertEqual( + self.formatSchemas(prof), + (("aten::cat.", (False, False)),), + ) + + if __name__ == "__main__": run_tests() diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index da0f191e26b5..4a1fe23cec61 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -135,6 +135,7 @@ Scalar = Union[int, float, bool, complex] Input = Optional[Union[_TensorMetadata, List[_TensorMetadata], Scalar]] class _ExtraFields_TorchOp: + name: str sequence_number: int allow_tf32_cublas: bool diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 2a5839fc6a22..d910afe4234a 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -163,6 +163,7 @@ void initPythonBindings(PyObject* module) { using torch_op_t = ExtraFields; py::class_(m, "_ExtraFields_TorchOp") + .def_readonly("name", &torch_op_t::name_) .def_property_readonly( "inputs", [](const torch_op_t& op) { diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 355d3322a4e0..cd652a6a000f 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -1,10 +1,12 @@ import dataclasses -from typing import Any, Iterator, Optional, Tuple +from typing import Any, Iterator, List, Optional, Tuple, Union import torch +from torch._C import FunctionSchema from torch._C._autograd import _ProfilerResult from torch._C._profiler import ( _EventType, + _ExtraFields_TorchOp, _ProfilerEvent, _TensorMetadata, RecordScope, @@ -120,6 +122,113 @@ def extract_gradients( yield TensorKey.from_tensor(p), p_grad_key +class SchemaMatcher: + """Lookup operator schema based on profiled name. + + When profiling we record the operator's name but not the schema. However + some analysis requires that information. Fortunately we can look up + registered schema from the recorded name. We do not, however, record the + overload and so we must compare the profiled arguments with all overloads + to determine viable matches. + + Note: Once https://github.com/pytorch/pytorch/issues/78871 is completed + this code will be obsolete. + """ + + @classmethod + def inputs_are_mutable(cls, t: _ExtraFields_TorchOp) -> Tuple[bool, ...]: + """Determine which inputs may have mutated based on function schema. + + Note that we don't need to resolve down to a single schema to perform + this analysis. An input is mutable if it is mutable in any overload. In + practice, however, it is overwhelmingly common to match a single + overload. If we cannot find any valid schema then we must be + conservative and assume all inputs are mutable. + """ + mutable: Optional[List[bool]] = None + for schema in cls.match_schemas(t): + mutable = mutable or [False for _ in schema.arguments] + for i, arg in enumerate(schema.arguments): + mutable[i] |= getattr(arg.alias_info, "is_write", False) + + return tuple(mutable or (True for _ in t.inputs)) + + @classmethod + def match_schemas(cls, t: _ExtraFields_TorchOp) -> Tuple[FunctionSchema, ...]: + signature = tuple( + # Tensor + TensorKey.from_tensor(i) if isinstance(i, _TensorMetadata) + # + # TensorList + else [TensorKey.from_tensor(j) for j in i] if isinstance(i, list) + # + # Scalar and uncaptured inputs. + else i + for i in t.inputs + ) + + def matches(schema) -> bool: + return len(schema.arguments) == len(signature) and all( + cls._types_match(observed, schema_arg.type) + for observed, schema_arg in zip(signature, schema.arguments) + ) + + return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s)) + + @classmethod + def _types_match(cls, observed, schema_type) -> bool: + if isinstance(schema_type, torch._C.OptionalType): + schema_type = schema_type.getElementType() + return observed is None or cls._types_match(observed, schema_type) + + if isinstance(schema_type, torch._C.AnyType): + return True + + if schema_type.isSubtypeOf(torch._C.ListType.ofTensors()): + return isinstance(observed, list) and all( + isinstance(i, TensorKey) for i in observed + ) + + type_map: Tuple[Tuple[Any, Union[type, Tuple[type, ...]]], ...] = ( + (torch._C.TensorType, TensorKey), + (torch._C.NoneType, type(None)), + (torch._C.BoolType, bool), + (torch._C.IntType, int), + (torch._C.FloatType, float), + (torch._C.ComplexType, complex), + (torch._C.NumberType, (bool, int, float, complex)), + ) + + for jit_type, py_types in type_map: + if isinstance(schema_type, jit_type): + return isinstance(observed, py_types) + + # Profiler only records a subset of possible argument types. If we + # reach this point then the schema must call for a type that profiler + # does not record. Thus, the schema can only be a match if `observed` + # is also None. + return observed is None + + @staticmethod + def lookup_schemas(name: str) -> Optional[Tuple[FunctionSchema, ...]]: + # TODO(robieta): + # _jit_get_schemas_for_operator is quite expensive. (~100us / call) + # Consider adding `functools.lru_cache` if that becomes an issue. + + try: + # Schema lookup will throw if `name` is malformed. (For example, + # schemas must be namespaced and schema lookup will fail if name + # does not include "::".) We simply catch the exception and return + # `None` to denote that `name` cannot be an operator name. + # + # Note that record_function annotations also go through this path, + # so it is expected that some names will not correspond to PyTorch + # operators. + return tuple(torch._C._jit_get_schemas_for_operator(name)) + except RuntimeError: + return None + + class MemoryProfile: def __init__(self, result: _ProfilerResult) -> None: pass From f5df68509097c65263ccf100e5df6b1057e9a2fa Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 15 Nov 2022 19:25:53 +0000 Subject: [PATCH 181/453] Enable channels_last_3d on SyncBatchNorm (#88401) This PR enabled the use of fast channels_last kernels on SyncBatchNorm with channels_last_3d memory format. With a small benchmark script here https://github.com/pytorch/pytorch/issues/88021#issuecomment-1299059859, on V100, I got master: ``` DDP channels_last=False, run_forward_backward, time: 0.8945400714874268 sec DDP channels_last=True, run_forward_backward, time: 1.4736433029174805 sec ``` This PR: ``` DDP channels_last=False, run_forward_backward, time: 0.8927242755889893 sec DDP channels_last=True, run_forward_backward, time: 0.48697471618652344 sec ``` This PR is a follow-up of https://github.com/pytorch/pytorch/pull/46906 Close https://github.com/pytorch/pytorch/issues/88021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88401 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/Normalization.cu | 7 +++++-- test/test_nn.py | 10 +++++----- torch/nn/modules/_functions.py | 10 ++++++++-- .../testing/_internal/distributed/distributed_test.py | 11 ++++++++--- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 3b27ebfc7d92..df460447464b 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -48,8 +48,11 @@ bool is_mixed_type(const Tensor& input, const Args&... parameters) { } inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) { - return (self.is_contiguous(at::MemoryFormat::ChannelsLast) || - (self.is_contiguous() && self.strides()[1] == 1)); + return ( + self.is_contiguous(at::MemoryFormat::ChannelsLast) || + self.is_contiguous(at::MemoryFormat::ChannelsLast3d) || + (self.is_contiguous() && self.strides()[1] == 1) + ); } enum class Impl { diff --git a/test/test_nn.py b/test/test_nn.py index b07793e79f48..2b96838e3601 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -10283,16 +10283,16 @@ def test_sync_batchnorm_accuracy_cuda(self): # fwd: torch.batch_norm_stats, torch.batch_norm_gather_stats_with_counts, torch.batch_norm_elemt # bwd: torch.batch_norm_backward_reduce, torch.batch_norm_backward_elemt - def _batch_norm_stats(data): + def _batch_norm_stats(data, memory_format, mean_axes): mean1, _ = torch.batch_norm_stats(data, 1e-5) - mean2, _ = torch.batch_norm_stats(data.to(memory_format=torch.channels_last), 1e-5) - mean_ref = torch.mean(data, (0, 2, 3), keepdim=False) + mean2, _ = torch.batch_norm_stats(data.to(memory_format=memory_format), 1e-5) + mean_ref = torch.mean(data, mean_axes, keepdim=False) self.assertEqual(mean_ref, mean1) self.assertEqual(mean_ref, mean2) - data = torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda') - _batch_norm_stats(data) + _batch_norm_stats(torch.randn(1, 96, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last, (0, 2, 3)) + _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='cuda'), torch.channels_last_3d, (0, 2, 3, 4)) def test_flatten(self): tensor_input = torch.randn(2, 1, 2, 3) diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 66200345cbc2..464c56a548a6 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -7,7 +7,10 @@ class SyncBatchNorm(Function): @staticmethod def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): - if not input.is_contiguous(memory_format=torch.channels_last): + if not ( + input.is_contiguous(memory_format=torch.channels_last) or + input.is_contiguous(memory_format=torch.channels_last_3d) + ): input = input.contiguous() if weight is not None: weight = weight.contiguous() @@ -104,7 +107,10 @@ def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, @staticmethod def backward(self, grad_output): - if not grad_output.is_contiguous(memory_format=torch.channels_last): + if not ( + grad_output.is_contiguous(memory_format=torch.channels_last) or + grad_output.is_contiguous(memory_format=torch.channels_last_3d) + ): grad_output = grad_output.contiguous() saved_input, weight, mean, invstd, count_tensor = self.saved_tensors grad_input = grad_weight = grad_bias = None diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 43a49b0489dc..c67dfc7c40a3 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -5324,6 +5324,10 @@ def test_post_localSGD_optimizer_step_reload(self): ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self): + self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last) + self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(torch.channels_last_3d) + + def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(self, memory_format): group, group_id, rank = self._init_global_test() num_processes = dist.get_world_size() local_bs = 2 @@ -5336,14 +5340,15 @@ def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self): model_gpu, device_ids=[rank] ) - memory_format = torch.channels_last + shapes = [global_bs, 2, 4, 4] + ([] if memory_format is torch.channels_last else [4]) + input_gpu = ( - torch.randn(global_bs, 2, 4, 4, dtype=torch.float) + torch.randn(*shapes, dtype=torch.float) .cuda(rank) .to(memory_format=memory_format) ) target_gpu = ( - torch.randn(global_bs, 2, 4, 4, dtype=torch.float) + torch.randn(*shapes, dtype=torch.float) .cuda(rank) .to(memory_format=memory_format) ) From d60abe4b9521e235c0e9beb00cda0d6c5673f4e0 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Tue, 15 Nov 2022 19:34:38 +0000 Subject: [PATCH 182/453] [Inductor] Build FX Linear + Permute Vertical Fusion in Inductor (#88859) Summary: Build fx-based linear/matmul/bmm + permute/transpose vertical fusion in Inductor For an internal Ads model: **1.15x -> 1.36x speedup** Differential Revision: D41071665 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88859 Approved by: https://github.com/jianyuh, https://github.com/jansel --- test/inductor/test_torchinductor.py | 106 +++++++++++++++ torch/_inductor/config.py | 4 + torch/_inductor/overrides.py | 199 ++++++++++++++++++++++++++++ 3 files changed, 309 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 23fb2f7712e0..b64f40377995 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10,6 +10,7 @@ import typing import unittest import weakref +from typing import Any, Callable from unittest.mock import patch import torch @@ -18,6 +19,7 @@ from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, @@ -39,6 +41,14 @@ from torch._inductor import codecache, config, metrics from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing + from torch._inductor.overrides import ( + linear_permute_fusion, + linear_transpose, + permute_linear_fusion, + permute_matmul_fusion, + transpose_linear, + transpose_matmul, + ) from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.utils import has_torchvision_roi_align, timed @@ -113,6 +123,29 @@ def maybe_test(*args, **kwargs): return wrap_test +PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] + + +def chain_passes(*passes: PassFunc) -> PassFunc: + def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module) + return module + + return parent_pass + + +def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: + return sum( + [ + 1 if (n.op == "call_function" and n.target == target_op) else 0 + for n in module.graph.nodes + ] + ) + + class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -1586,6 +1619,79 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) + def test_linear_permute_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + a0 = torch.nn.functional.linear(input, self.weight, self.bias) + b0 = a0.permute(0, 2, 1) + return b0 + + m, k, n = 16, 8, 4 + trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, m, k) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_linear_transpose = count_call_function(traced, linear_transpose) + self.assertEqual(num_linear, 0) + self.assertEqual(num_linear_transpose, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_linear_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.nn.functional.linear(input1, self.weight, self.bias) + return output + + m, k, n = 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, k, m) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_transpose_linear = count_call_function(traced, transpose_linear) + self.assertEqual(num_linear, 0) + self.assertEqual(num_transpose_linear, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_bmm_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, batch: int, k: int, n: int): + super().__init__() + self.other = torch.randn(batch, k, n) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.bmm(input1, self.other) + return output + + batch, m, k, n = 6, 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) + module = TestModule(batch, k, n).eval() + input = torch.randn(batch, k, m) + traced = trace_func(module, [input]) + num_bmm = count_call_function(traced, torch.bmm) + num_transpose_matmul = count_call_function(traced, transpose_matmul) + self.assertEqual(num_bmm, 0) + self.assertEqual(num_transpose_matmul, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + def test_slice1(self): def fn(a): return ( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d376fe3e8bf7..c552101c1cae 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -75,6 +75,10 @@ shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1" alignment_size = 4 +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + + # config specific to codegen/cpp.pp class cpp: # set to torch.get_num_threads() diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 3a95aa7ce880..cf2cd5f60f51 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -19,6 +19,8 @@ from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.overrides import TorchFunctionMode +from . import config + log = logging.getLogger(__name__) @@ -425,6 +427,14 @@ def check_node_is_add_inplace(node): def fuse_fx(gm: torch.fx.GraphModule, example_inputs): + if config.permute_fusion: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm).propagate(*example_inputs) + gm = linear_permute_fusion(gm) + gm = permute_linear_fusion(gm) + gm = permute_matmul_fusion(gm) + # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm @@ -528,6 +538,195 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["weight"] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] + else: + return self.node.kwargs["bias"] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["other"] + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if ( + node.op == "call_method" + and node.target == "permute" + and check_permute(node) + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and node.target == torch.nn.functional.linear: + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.bmm or node.target == torch.matmul + ): + normalized = NormalizedMatmulNode(node) + A = normalized.get_input() + B = normalized.get_other() + Atrans = Btrans = False + if A.op == "call_method" and A.target == "permute" and check_permute(A): + Atrans = True + if len(A.args) > 0: + A = A.args[0] + else: + A = A.kwargs["input"] + + if B.op == "call_method" and B.target == "permute" and check_permute(B): + Btrans = True + if len(B.args) > 0: + B = B.args[0] + else: + B = B.kwargs["input"] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(A, B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) + + def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): From ff6d2a6d1b8245563c8122849144dddaa276483a Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Tue, 15 Nov 2022 20:22:54 +0000 Subject: [PATCH 183/453] Add mem efficient backward (#88856) # Registers the derivative for mem efficient backward - Use gradcheck to test correctness. The kernel is not implemented for fp64 so run checks with bumped tolerances in fp32 - I also made updates based off of Xformer main branch and flash-attention cutlass branch. - This will enable the fused backward to be called for scaled dot product attention Pull Request resolved: https://github.com/pytorch/pytorch/pull/88856 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 5 + .../native/transformers/cuda/attention.cu | 16 +- .../transformers/cuda/attention_backward.cu | 261 ++++++++++++++++++ .../transformers/cuda/flash_attn/fmha_api.cpp | 4 + .../attention_backward_generic.cu | 166 ----------- .../attention_forward_generic.cu | 232 ---------------- .../cuda/mem_eff_attention/find_default_mma.h | 7 +- .../cuda/mem_eff_attention/kernel_backward.h | 250 +++++++++++------ .../ATen/native/transformers/cuda/sdp_utils.h | 12 +- test/test_transformers.py | 44 ++- tools/autograd/derivatives.yaml | 7 +- .../_internal/common_methods_invocations.py | 4 +- 12 files changed, 501 insertions(+), 507 deletions(-) create mode 100644 aten/src/ATen/native/transformers/cuda/attention_backward.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu delete mode 100644 aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index de087c0b8a89..9572ccc56653 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13287,6 +13287,11 @@ dispatch: CUDA: _efficient_attention_forward +- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index f65fedd6d795..46543d4663fa 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -746,7 +746,9 @@ std::tuple flash_attention_helper_dense_unpacked( std::tuple mem_eff_helper( const Tensor& query, const Tensor& key, - const Tensor& value){ + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { // Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head) // Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head) // Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head) @@ -754,16 +756,18 @@ std::tuple mem_eff_helper( Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); - Tensor attention = std::get<0>(at::_efficient_attention_forward( + Tensor attention, log_sumexp; + std::tie(attention, log_sumexp) = at::_efficient_attention_forward( q_t, k_t, v_t, c10::nullopt, c10::nullopt, c10::nullopt, - false, - false)).transpose(1,2); - return std::make_tuple(attention, Tensor()); + compute_log_sumexp, + is_causal); + attention = attention.transpose(1,2); + return std::make_tuple(std::move(attention), Tensor()); } std::tuple _scaled_dot_product_attention_forward_cuda( @@ -776,7 +780,7 @@ std::tuple _scaled_dot_product_attention_forward_cuda( case sdp::SDPBackend::flash_attention: return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value); + return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); case sdp::SDPBackend::math: return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); default: diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu new file mode 100644 index 000000000000..af005b2669b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -0,0 +1,261 @@ +#include + +#include + +#include +#include + +#include +#include +#include +#include + +#ifdef USE_FLASH_ATTENTION +#include +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ + } + +#define DISPATCH_MAXK(func) \ + { \ + const auto maxK = std::max(query.size(3), value.size(3)); \ + if (maxK <= 64) { \ + constexpr int kMaxK = 64; \ + func(); \ + } else if (maxK <= 128) { \ + constexpr int kMaxK = 128; \ + func(); \ + } else { \ + constexpr int kMaxK = std::numeric_limits::max(); \ + func(); \ + } \ + } + +#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ + { \ + cudaDeviceProp* properties = \ + at::cuda::getDeviceProperties(QUERY.device().index()); \ + const int computeCapability = properties->major * 10 + properties->minor; \ + DISPATCH_MAXK(([&] { \ + DISPATCH_TYPES( \ + QUERY, ([&]() { \ + DISPATCH_ARCHTAG( \ + computeCapability, ([&]() { \ + using AlignedAK = \ + AttentionBackwardKernel; \ + bool isAligned = \ + (QUERY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + KEY.stride(2) % AlignedAK::kOptimalAlignement == 0 && \ + VALUE.stride(2) % AlignedAK::kOptimalAlignement == 0); \ + DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ + using Kernel = AttentionBackwardKernel< \ + ArchTag, \ + scalar_t, \ + kIsAligned, \ + kMaxK>; \ + FUNC(); \ + })) \ + })) \ + })) \ + })); \ + } + +namespace at { + +namespace native { + +std::tuple _efficient_attention_backward( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& logsumexp, + const at::Tensor& out, + bool causal) { + #if defined(USE_FLASH_ATTENTION) + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + // ndim + TORCH_CHECK(query.dim() == grad_out_.dim()); + TORCH_CHECK(query.dim() == key.dim()); + TORCH_CHECK(query.dim() == value.dim()); + TORCH_CHECK(query.dim() == 4); + + // batch size + TORCH_CHECK(query.size(0) == grad_out_.size(0)); + TORCH_CHECK(query.size(0) == key.size(0)); + TORCH_CHECK(query.size(0) == value.size(0)); + + // seqlen + TORCH_CHECK(key.size(1) == value.size(1)); + TORCH_CHECK(query.size(1) == grad_out_.size(1)); + + // Num heads + TORCH_CHECK(query.size(2) == key.size(2)); + TORCH_CHECK(query.size(2) == value.size(2)); + TORCH_CHECK(query.size(2) == grad_out_.size(2)); + + // Embedding per head + TORCH_CHECK(query.size(3) == key.size(3)); + TORCH_CHECK(value.size(3) == grad_out_.size(3)); + + // handle potentially non-contiguous grad_out through a copy + auto grad_out = grad_out_.contiguous(); + CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); + + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); + CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); + + at::cuda::CUDAGuard device_guard(query.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + int64_t B = query.size(0); + int64_t M = query.size(1); + int64_t N = key.size(1); + int64_t nH = query.size(2); + int64_t K = query.size(3); + + // It does not make sense to use that in practice, + // but let's still make sure we are correct + // As we iterate through keys first, we skip + // keys with no query associated, so they are not + // initialized + bool grad_kv_needs_init = causal && N > M; + at::Tensor grad_q, grad_k, grad_v; + if (!grad_kv_needs_init && query.size(1) == key.size(1) && + query.size(3) == value.size(3) && + query.storage().is_alias_of(key.storage()) && + query.storage().is_alias_of(value.storage())) { + // Create one big contiguous chunk + // This is because q, k and v usually come from a single + // output of a linear layer that is chunked. + // Creating the gradients with the right layout saves us + // a `torch.cat` call in the backward pass + at::Tensor chunk = at::empty({B, M, 3, nH, K}, query.options()); + grad_q = chunk.select(2, 0); + grad_k = chunk.select(2, 1); + grad_v = chunk.select(2, 2); + } else { + grad_q = at::empty_like(query); + grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); + grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + } + + auto launchKernel = [&](auto _k, int computeCapability) { + using Kernel = decltype(_k); + using scalar_t = typename Kernel::scalar_t; + (void)_k; + + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + + // TODO: Fuse this into a kernel? + // This is a bottleneck for smaller sequences (M <= 128) + auto delta = Kernel::kKernelComputesDelta + ? at::empty({B, nH, M}, query.options().dtype(at::ScalarType::Float)) + : (grad_out.to(at::kFloat) * out.to(at::kFloat)) + .sum(-1) + .transpose(-2, -1) + .contiguous(); + TORCH_INTERNAL_ASSERT(delta.size(0) == B); + TORCH_INTERNAL_ASSERT(delta.size(1) == nH); + TORCH_INTERNAL_ASSERT(delta.size(2) == M); + + typename Kernel::Params p; + p.query_ptr = (scalar_t*)query.data_ptr(); + p.key_ptr = (scalar_t*)key.data_ptr(); + p.value_ptr = (scalar_t*)value.data_ptr(); + p.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); + p.output_ptr = (scalar_t*)out.data_ptr(); + p.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); + p.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); + p.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); + p.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); + p.delta_ptr = (float*)delta.data_ptr(); + p.head_dim = query.size(3); + p.head_dim_value = value.size(3); + p.num_queries = query.size(1); + p.num_keys = key.size(1); + p.num_batches = B; + p.num_heads = nH; + p.causal = causal; + + ASSIGN_CHECK_OVERFLOW(p.gO_strideB, grad_out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideM, grad_out.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.gO_strideH, grad_out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.o_strideB, out.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.o_strideH, out.stride(2)); + + ASSIGN_CHECK_OVERFLOW(p.gQ_strideB, grad_q.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideB, grad_k.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideB, grad_v.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); + p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; + TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); + TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); + TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); + + ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); + ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); + ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); + ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); + + Kernel::check_supported(p); + + constexpr auto kernel_fn = attention_kernel_backward_batched; + + if (smem_bytes > 0xc000) { + TORCH_INTERNAL_ASSERT( + computeCapability >= 70, + "This kernel requires too much shared memory on this machine!"); + cudaFuncSetAttribute( + kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + + // second syntax resulted in the error below on windows + // error C3495: 'kernel_fn': a simple capture must be a variable + // with automatic storage duration declared + // in the reaching scope of the lambda +#ifdef _WIN32 + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + TORCH_INTERNAL_ASSERT( + attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability, + "Something went wrong in the build process"); +#else + auto checkBinaryArchMatches = [&]() { + cudaFuncAttributes attr; + AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); + return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; + }; + TORCH_INTERNAL_ASSERT( + checkBinaryArchMatches(), "Something went wrong in the build process"); +#endif + + kernel_fn<<>>(p); + }; + + DISPATCH_KERNEL( + query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_q, grad_k, grad_v); + #endif + TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index a8d6110e951d..aaf7d833fe83 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -29,6 +29,7 @@ #ifdef USE_FLASH_ATTENTION #include #include +#include #include #include @@ -185,6 +186,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; bool loop = max_seqlen_k > blocksize_c; + // Otherwise the kernel will be launched from cuda:0 device + at::cuda::CUDAGuard device_guard{q.device()}; + auto opts = q.options(); auto o = at::empty({ total_q, num_heads, head_size }, opts); diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu deleted file mode 100644 index 07c14ad8195d..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_backward_generic.cu +++ /dev/null @@ -1,166 +0,0 @@ -#include - -#define DISPATCH_MAXK(func) \ - { \ - const auto maxK = std::max(query.size(2), value.size(2)); \ - if (maxK <= 64) { \ - constexpr int kMaxK = 64; \ - func(); \ - } else if (maxK <= 128) { \ - constexpr int kMaxK = 128; \ - func(); \ - } else { \ - constexpr int kMaxK = std::numeric_limits::max(); \ - func(); \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_MAXK(([&] { \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = \ - AttentionBackwardKernel; \ - bool isAligned = \ - (QUERY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - KEY.stride(1) % AlignedAK::kOptimalAlignement == 0 && \ - VALUE.stride(1) % AlignedAK::kOptimalAlignement == 0); \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionBackwardKernel< \ - ArchTag, \ - scalar_t, \ - kIsAligned, \ - kMaxK>; \ - FUNC(); \ - })) \ - })) \ - })) \ - })); \ - } - -namespace { -std::tuple -mem_efficient_attention_backward_cutlass( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& logsumexp, - const at::Tensor& out, - bool causal) { - TORCH_CHECK(query.dim() == grad_out_.dim()); - TORCH_CHECK(query.dim() == key.dim()); - TORCH_CHECK(query.dim() == 3); - - TORCH_CHECK(query.size(0) == grad_out_.size(0)); - TORCH_CHECK(query.size(1) == grad_out_.size(1)); - TORCH_CHECK(value.size(2) == grad_out_.size(2)); - - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(0) == key.size(0)); - - TORCH_CHECK(query.size(0) == value.size(0)); - TORCH_CHECK(key.size(1) == value.size(1)); - - // handle potentially non-contiguous grad_out through a copy - auto grad_out = grad_out_.contiguous(); - - CHECK_NOSPARSE_CONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(value); - CHECK_NOSPARSE_CONTIGUOUS_CUDA(grad_out); - - at::cuda::CUDAGuard device_guard(query.device()); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t K = query.size(2); - - // It does not make sense to use that in practice, - // but let's still make sure we are correct - // As we iterate through keys first, we skip - // keys with no query associated, so they are not - // initialized - bool grad_kv_needs_init = causal && N > M; - at::Tensor grad_q = at::empty_like(query); - at::Tensor grad_k = - grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - at::Tensor grad_v = - grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - - // TODO: Fuse this into a kernel? - // This is a bottleneck for smaller sequences (M <= 128) - auto delta = Kernel::kKernelComputesDelta - ? at::empty({B, M}, query.options().dtype(at::ScalarType::Float)) - : (grad_out.to(at::kFloat) * out.to(at::kFloat)).sum(-1); - TORCH_INTERNAL_ASSERT(delta.size(0) == B); - TORCH_INTERNAL_ASSERT(delta.size(1) == M); - - typename Kernel::Params params; - params.query_ptr = (scalar_t*)query.data_ptr(); - params.key_ptr = (scalar_t*)key.data_ptr(); - params.value_ptr = (scalar_t*)value.data_ptr(); - params.logsumexp_ptr = (typename Kernel::lse_scalar_t*)logsumexp.data_ptr(); - params.output_ptr = (scalar_t*)out.data_ptr(); - params.grad_output_ptr = (scalar_t*)grad_out.data_ptr(); - params.grad_query_ptr = (scalar_t*)grad_q.data_ptr(); - params.grad_key_ptr = (scalar_t*)grad_k.data_ptr(); - params.grad_value_ptr = (scalar_t*)grad_v.data_ptr(); - params.delta_ptr = (float*)delta.data_ptr(); - params.head_dim = query.size(2); - params.head_dim_value = value.size(2); - params.num_queries = query.size(1); - params.num_keys = key.size(1); - params.num_batches = B; - params.causal = causal; - Kernel::check_supported(params); - - constexpr auto kernel_fn = attention_kernel_backward_batched; - - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - } - - auto checkBinaryArchMatches = [&]() { - cudaFuncAttributes attr; - AT_CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel_fn)); - return attr.binaryVersion >= Kernel::ArchTag::kMinComputeCapability; - }; - TORCH_INTERNAL_ASSERT( - checkBinaryArchMatches(), "Something went wrong in the build process"); - - kernel_fn<<>>( - params); - }; - - DISPATCH_KERNEL( - query, key, value, ([&] { launchKernel(Kernel{}, computeCapability); })); - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(grad_q, grad_k, grad_v); -} // namespace - -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_backward_cutlass"), -// TORCH_FN(mem_efficient_attention_backward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu deleted file mode 100644 index 59b3637c8a43..000000000000 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/attention_forward_generic.cu +++ /dev/null @@ -1,232 +0,0 @@ -#include - - -#define DISPATCH_BLOCKSIZE(VALUE_HEAD_DIM, FN) \ - { \ - if (VALUE_HEAD_DIM <= 64) { \ - constexpr bool kIs64x64 = true; \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kIs64x64 = false; \ - if (VALUE_HEAD_DIM <= 128) { \ - constexpr bool kSingleValueIteration = true; \ - FN(); \ - } else { \ - constexpr bool kSingleValueIteration = false; \ - FN(); \ - } \ - } \ - } - -#define DISPATCH_KERNEL(QUERY, KEY, VALUE, FUNC) \ - { \ - cudaDeviceProp* properties = \ - at::cuda::getDeviceProperties(QUERY.device().index()); \ - const int computeCapability = properties->major * 10 + properties->minor; \ - DISPATCH_BLOCKSIZE( \ - VALUE.size(-1), ([&]() { \ - static constexpr int64_t kQueriesPerBlock = kIs64x64 ? 64 : 32; \ - static constexpr int64_t kKeysPerBlock = kIs64x64 ? 64 : 128; \ - DISPATCH_TYPES( \ - QUERY, ([&]() { \ - DISPATCH_ARCHTAG( \ - computeCapability, ([&]() { \ - using AlignedAK = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - true, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - /* Run a more efficient kernel (with `isAligned=True`) \ - if memory is correctly aligned*/ \ - bool isAligned = \ - (QUERY.stride(2) % AlignedAK::kAlignmentQ == 0 && \ - KEY.stride(2) % AlignedAK::kAlignmentK == 0 && \ - VALUE.stride(2) % AlignedAK::kAlignmentV == 0); \ - /* TODO: Should we warn or log somewhere when we use a \ - less efficient kernel due to wrong alignment? */ \ - DISPATCH_BOOL(isAligned, kIsAligned, ([&]() { \ - using Kernel = AttentionKernel< \ - scalar_t, \ - ArchTag, \ - kIsAligned, \ - kQueriesPerBlock, \ - kKeysPerBlock, \ - kSingleValueIteration>; \ - FUNC(); \ - })) \ - })) \ - })); \ - })); \ - } - -namespace { -/* - There are 2 modes for using this function. - (Mode BMHK) With all the heads having the same seqlen - (Mode 1MHK) `batch=1` with all tokens across batches concatenated -*/ -std::tuple efficient_attention_forward_cutlass( - const at::Tensor& query, // [b, seqlen, num_heads, K] - const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] - // (Mode 1MHK only) [b+1]: cu_seqlens_q[b] contains the - // position of the first query token for batch $b - const c10::optional& cu_seqlens_q, - // (Mode 1MHK only) [b+1]: cu_seqlens_k[b] contains the - // position of the first key token for batch $b - const c10::optional& cu_seqlens_k, - // (Mode 1MHK only) Maximum sequence length across batches - const c10::optional max_seqlen_q_, - bool compute_logsumexp, - bool causal) { - TORCH_CHECK(query.dim() == 4); - TORCH_CHECK(key.dim() == 4); - TORCH_CHECK(value.dim() == 4); - - // Batch sizes - TORCH_CHECK(query.size(0) == key.size(0)); - TORCH_CHECK(query.size(0) == value.size(0)); - - // Sequence length - TORCH_CHECK(key.size(1) == value.size(1)); - - // Num heads - TORCH_CHECK(query.size(2) == key.size(2)); - TORCH_CHECK(query.size(2) == value.size(2)); - - // Embedding per head - TORCH_CHECK(query.size(3) == key.size(3)); - - int64_t max_seqlen_q, max_seqlen_k; - TORCH_CHECK(cu_seqlens_q.has_value() == cu_seqlens_k.has_value()); - if (cu_seqlens_q.has_value()) { - TORCH_CHECK(cu_seqlens_q->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_k->scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cu_seqlens_q->dim() == 1 && cu_seqlens_k->dim() == 1); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_q)); - CHECK_NOSPARSE_CONTIGUOUS_CUDA((*cu_seqlens_k)); - TORCH_CHECK(cu_seqlens_q->size(0) == cu_seqlens_k->size(0)); - TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1"); - TORCH_CHECK(max_seqlen_q_.has_value()); - max_seqlen_q = *max_seqlen_q_; - max_seqlen_k = 0; // Will be set inside the kernel - } else { - max_seqlen_q = query.size(1); - max_seqlen_k = key.size(1); - } - - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key); - CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(value); - - at::cuda::CUDAGuard device_guard(query.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - int64_t B = query.size(0); - int64_t M = query.size(1); - int64_t N = key.size(1); - int64_t num_heads = query.size(-2); - int64_t K = query.size(-1); - int64_t Kv = value.size(-1); - - at::Tensor res; - at::Tensor logsumexp; - - auto launchKernel = [&](auto _k, int computeCapability) { - using Kernel = decltype(_k); - using scalar_t = typename Kernel::scalar_t; - (void)_k; - - res = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - - // NOTE: Should be aligned (by padding) in case M is - // not a good number for loading during backward - constexpr decltype(M) kAlignLSE = Kernel::kAlignLSE; - logsumexp = at::empty( - {B, - num_heads, - compute_logsumexp ? ceil_div(max_seqlen_q, kAlignLSE) * kAlignLSE : 0}, - query.options().dtype(at::ScalarType::Float)); - - typename Kernel::Params p; - p.query_ptr = (scalar_t*)query.data_ptr(); - p.key_ptr = (scalar_t*)key.data_ptr(); - p.value_ptr = (scalar_t*)value.data_ptr(); - p.logsumexp_ptr = compute_logsumexp - ? (typename Kernel::lse_scalar_t*)logsumexp.data_ptr() - : nullptr; - at::Tensor output_accum; - if (Kernel::kNeedsOutputAccumulatorBuffer) { - output_accum = at::empty( - {B, M, num_heads, Kv}, - query.options().dtype( - TypeTraits::atScalarType())); - p.output_accum_ptr = - (typename Kernel::output_accum_t*)output_accum.data_ptr(); - } else { - p.output_accum_ptr = nullptr; - } - p.output_ptr = (typename Kernel::output_t*)res.data_ptr(); - - if (cu_seqlens_q.has_value()) { - p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); - p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); - } - -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ - } - - p.num_heads = num_heads; - p.head_dim = query.size(3); - p.head_dim_value = value.size(3); - p.num_queries = max_seqlen_q; - p.num_keys = max_seqlen_k; - p.num_batches = cu_seqlens_q.has_value() ? cu_seqlens_q->size(0) - 1 : B; - p.causal = causal; - - ASSIGN_CHECK_OVERFLOW(p.q_strideB, query.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.k_strideB, key.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.v_strideB, value.stride(0)); - ASSIGN_CHECK_OVERFLOW(p.q_strideM, query.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.k_strideM, key.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.v_strideM, value.stride(1)); - ASSIGN_CHECK_OVERFLOW(p.q_strideH, query.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.k_strideH, key.stride(2)); - ASSIGN_CHECK_OVERFLOW(p.v_strideH, value.stride(2)); - - constexpr auto kernel_fn = attention_kernel_batched; - size_t smem_bytes = sizeof(typename Kernel::SharedStorage); - if (smem_bytes > 0xc000) { - TORCH_INTERNAL_ASSERT( - computeCapability >= 70, - "This kernel requires too much shared memory on this machine!"); - AT_CUDA_CHECK(cudaFuncSetAttribute( - kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes)); - } - Kernel::check_supported(p); - kernel_fn<<>>(p); - }; - // Dispatch to the right kernel - DISPATCH_KERNEL(query, key, value, ([&]() { - launchKernel(Kernel{}, computeCapability); - })); - - AT_CUDA_CHECK(cudaGetLastError()); - return std::make_tuple(res, logsumexp); -} -} // namespace - -// TORCH_LIBRARY_IMPL(xformers, CUDA, m) { -// m.impl( -// TORCH_SELECTIVE_NAME("xformers::efficient_attention_forward_cutlass"), -// TORCH_FN(efficient_attention_forward_cutlass)); -// } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h index 399593fd0957..b0e7106f3cfc 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/find_default_mma.h @@ -1,15 +1,16 @@ /*! \file \brief Cutlass provides helper template functions to figure out the right - datastructures to instanciate to run a GEMM with various parameters (see + datastructures to instantiate to run a GEMM with various parameters (see `cutlass/gemm/threadblock/default_mma.h`). However, due to template - instanciation priority rules, it will only create an MmaMultiStage with + instantiation priority rules, it will only create an MmaMultiStage with kStages=3 (otherwise creates an MmePipelined - which is not compatible with FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, so we just copy-pasted some code from `default_mma.h` and - `default_mma_core.h` files and wrapped this template to allow our usecase. + `default_mma_core.h` files and wrapped this template to allow our use case. This is really only for the FastF32 case - aka using TensorCores with fp32. */ +#pragma once #include #include diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index e25701a7588a..c9652c40d38e 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1,7 +1,5 @@ #pragma once - #include -#include #include #include @@ -75,46 +73,113 @@ struct AttentionBackwardKernel { struct Params { // Input tensors - scalar_t* query_ptr; // [num_queries, head_dim] - scalar_t* key_ptr; // [num_keys, head_dim] - scalar_t* value_ptr; // [num_keys, head_dim_value] - lse_scalar_t* logsumexp_ptr; // [num_queries] - scalar_t* output_ptr; // [num_queries, head_dim_value] - scalar_t* grad_output_ptr; // [num_queries, head_dim_value] - accum_t* delta_ptr; // [num_queries] + scalar_t* query_ptr; // [Mq, nH, K] + scalar_t* key_ptr; // [Mk, nH, K] + scalar_t* value_ptr; // [Mk, nH, Kv] + lse_scalar_t* logsumexp_ptr; // [nH, Mq] + scalar_t* output_ptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr; // [Mq, nH, Kv] + accum_t* delta_ptr; // [Mq, nH] // Output tensors - scalar_t* grad_query_ptr; // [num_queries, head_dim] - scalar_t* grad_key_ptr; // [num_keys, head_dim] - scalar_t* grad_value_ptr; // [num_keys, head_dim_value] + output_t* grad_query_ptr; // [Mq, nH, K] + output_t* grad_key_ptr; // [Mk, nH, K] + output_t* grad_value_ptr; // [Mk, nH, Kv] // Dimensions/strides int32_t head_dim; int32_t head_dim_value; int32_t num_queries; int32_t num_keys; - int32_t num_batches; + int32_t num_heads; bool causal; - __device__ void advance_batches(int32_t batch_id) { + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + int32_t gO_strideM; + int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH; + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int64_t o_strideB; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int32_t num_batches; + + int64_t gO_strideB; + int64_t gQ_strideB; + int64_t gK_strideB; + int64_t gV_strideB; + int64_t gO_strideH; + int64_t gQ_strideH; + int64_t gK_strideH; + int64_t gV_strideH; + + CUTLASS_DEVICE void advance_to_block() { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; - query_ptr += batch_id * head_dim * num_queries; - key_ptr += batch_id * head_dim * num_keys; - value_ptr += batch_id * head_dim_value * num_keys; - logsumexp_ptr += batch_id * lse_dim; - output_ptr += batch_id * head_dim_value * num_queries; - grad_output_ptr += batch_id * head_dim_value * num_queries; - delta_ptr += batch_id * num_queries; - - grad_query_ptr += batch_id * head_dim * num_queries; - grad_key_ptr += batch_id * head_dim * num_keys; - grad_value_ptr += batch_id * head_dim_value * num_keys; + int32_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + logsumexp_ptr += (batch_id * num_heads + head_id) * lse_dim; + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += (batch_id * num_heads + head_id) * num_queries; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + + gO_strideM = warp_uniform(gO_strideM); + gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier); + q_strideM = warp_uniform(q_strideM); + k_strideM = warp_uniform(k_strideM); + v_strideM = warp_uniform(v_strideM); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); } __host__ dim3 getBlocksGrid() const { - return dim3(1, 1, num_batches); + return dim3(1, num_heads, num_batches); } __host__ dim3 getThreadsGrid() const { return dim3(kWarpSize, kNumWarpsPerBlock, 1); @@ -179,7 +244,6 @@ struct AttentionBackwardKernel { attn_T = k_j @ q_i.transpose(-2, -1) # matmul attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, -1)).exp() # epilogue - with attn_T.shape = (kBlockSizeJ, kBlockSizeI) */ using ThreadblockShape = @@ -225,7 +289,6 @@ struct AttentionBackwardKernel { struct MatmulGradV { /* grad_v[j_start:j_end] += attn_T @ do_i # matmul - Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) (we might need to iterate multiple times on K) */ @@ -601,7 +664,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - __device__ __forceinline__ void clear() { + CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } @@ -614,14 +677,14 @@ struct AttentionBackwardKernel { CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); TORCH_CHECK( - p.head_dim % kMinimumAlignment == 0, - "query/key is not correctly aligned"); + p.q_strideH % kMinimumAlignment == 0, "query is not correctly aligned"); TORCH_CHECK( - p.head_dim_value % kMinimumAlignment == 0, - "value is not correctly aligned"); + p.k_strideH % kMinimumAlignment == 0, "key is not correctly aligned"); + TORCH_CHECK( + p.v_strideH % kMinimumAlignment == 0, "value is not correctly aligned"); } - static __device__ void kernel(Params& p_) { + static CUTLASS_DEVICE void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot register const Params p = p_; @@ -658,7 +721,7 @@ struct AttentionBackwardKernel { __syncthreads(); } - OutputFragments output_frags; + OutputFragments register output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -695,7 +758,7 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ void loadDi( + static CUTLASS_DEVICE void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -710,7 +773,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void processBlockIJ( + static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -718,9 +781,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int32_t warp_id = threadIdx.y; - int32_t lane_id = threadIdx.x; + int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int8_t warp_id = warp_uniform(threadIdx.y); + int8_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -734,8 +797,8 @@ struct AttentionBackwardKernel { auto prologueGradV = [&](int col) { typename MatmulGradV::Mma::IteratorB iterator_dO( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -747,8 +810,8 @@ struct AttentionBackwardKernel { }; auto prologueGradQ = [&](int col) { typename MatmulGradQ::Mma::IteratorB iterator_K( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {num_keys_in_block, p.head_dim - col}, thread_id, no_offset); @@ -757,8 +820,8 @@ struct AttentionBackwardKernel { }; auto prologueGradK = [&](int col) { typename MatmulGradK::Mma::IteratorB iterator_Q( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {num_queries_in_block, p.head_dim - col}, thread_id, no_offset); @@ -770,14 +833,14 @@ struct AttentionBackwardKernel { }; auto prologueDOV = [&]() { typename MatmulDOIVJ::Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); typename MatmulDOIVJ::Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -803,16 +866,16 @@ struct AttentionBackwardKernel { // k_j typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {problem_size.m(), problem_size.k()}, thread_id, no_offset); // q_i.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -893,14 +956,14 @@ struct AttentionBackwardKernel { num_keys_in_block, p.head_dim_value - col, num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradV::OutputTileIterator( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value + col, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, {num_keys_in_block, p.head_dim_value - col}, thread_id); }; typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value + col, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, {num_queries_in_block, p.head_dim_value - col}, thread_id, no_offset); @@ -951,16 +1014,16 @@ struct AttentionBackwardKernel { using Mma = typename MatmulDOIVJ::Mma; // do_i typename Mma::IteratorA iterator_A( - {int32_t(p.head_dim_value)}, - p.grad_output_ptr + query_start * p.head_dim_value, + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, {num_queries_in_block, p.head_dim_value}, thread_id, no_offset); // v_j.transpose(-2, -1) typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim_value)}, - p.value_ptr + key_start * p.head_dim_value, + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, {p.head_dim_value, num_keys_in_block}, thread_id, no_offset); @@ -1057,16 +1120,16 @@ struct AttentionBackwardKernel { num_keys_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradQ::OutputTileIterator( - typename MatmulGradQ::OutputTileIterator::Params{p.head_dim}, - p.grad_query_ptr + query_start * p.head_dim + col, + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, {problem_size.m(), problem_size.n()}, thread_id); }; // k_j typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim + col, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1153,8 +1216,8 @@ struct AttentionBackwardKernel { num_queries_in_block); auto createEpilogueIter = [&]() { return typename MatmulGradK::OutputTileIterator( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim + col, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, thread_id); @@ -1162,8 +1225,8 @@ struct AttentionBackwardKernel { // q_i typename Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim + col, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, {problem_size.k(), problem_size.n()}, thread_id, no_offset); @@ -1236,15 +1299,15 @@ struct AttentionBackwardKernel { kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; auto thread_id = get_thread_id(); typename MatmulQK::Mma::IteratorA iterator_A( - {int32_t(p.head_dim)}, - p.key_ptr + key_start * p.head_dim, + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, {p.num_keys - key_start, p.head_dim}, thread_id, cutlass::MatrixCoord{0, 0}); typename MatmulQK::Mma::IteratorB iterator_B( - {int32_t(p.head_dim)}, - p.query_ptr + query_start * p.head_dim, + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, {p.head_dim, p.num_queries - query_start}, thread_id, cutlass::MatrixCoord{0, 0}); @@ -1259,7 +1322,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void writeFragsToGmem( + static CUTLASS_DEVICE void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1268,8 +1331,8 @@ struct AttentionBackwardKernel { ? MatmulQK::Mma::Shape::kM : std::min((int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); typename MatmulGradV::OutputTileIterator outputV_it( - typename MatmulGradV::OutputTileIterator::Params{p.head_dim_value}, - p.grad_value_ptr + key_start * p.head_dim_value, + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), {num_keys_in_block, p.head_dim_value}, get_thread_id()); accumulateInGmem( @@ -1279,8 +1342,8 @@ struct AttentionBackwardKernel { true); typename MatmulGradK::OutputTileIterator outputK_it( - typename MatmulGradK::OutputTileIterator::Params{p.head_dim}, - p.grad_key_ptr + key_start * p.head_dim, + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), {num_keys_in_block, false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, get_thread_id()); @@ -1292,7 +1355,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void accumulateInGmem( + static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1334,7 +1397,9 @@ struct AttentionBackwardKernel { } template - static __device__ void computeDelta(Params const& p, int32_t query_start) { + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1349,13 +1414,15 @@ struct AttentionBackwardKernel { bool rowPred = (query_start + laneRow) < p.num_queries; bool pred = rowPred; - const __restrict__ AccessType* grad_output_ptr = - reinterpret_cast( - p.grad_output_ptr + (query_start + laneRow) * p.head_dim_value + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + laneFirstCol); - const __restrict__ AccessType* output_ptr = - reinterpret_cast( - p.output_ptr + (query_start + laneRow) * p.head_dim_value + + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + laneFirstCol); static constexpr int64_t kMaxIters = @@ -1430,13 +1497,13 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ int8_t get_lane_id() { + static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; } - static __device__ __forceinline__ int8_t get_warp_id() { + static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; } - static __device__ __forceinline__ int16_t get_thread_id() { + static CUTLASS_DEVICE int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } }; @@ -1457,8 +1524,7 @@ __global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) #define INSTANTIATE_ATTENTION_KERNEL_BACKWARD(ARCH, ...) \ _ATTENTION_KERNEL_BACKWARD_BEGIN( \ AttentionBackwardKernel) \ - auto batch_id = blockIdx.z; \ - p.advance_batches(batch_id); \ + p.advance_to_block(); \ Kernel::kernel(p); \ _ATTENTION_KERNEL_BACKWARD_END(); diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 564adb2d51ea..e9f3d5029aa8 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -62,6 +62,15 @@ inline bool check_for_attn_weights(sdp_params params, bool debug) { } return true; } + +inline bool check_for_non_zero_dropout(sdp_params params, bool debug) { + if (params.dropout != 0.0) { + TORCH_CHECK(!debug, "Mem_efficient does not support non_zero dropout. Dropout_p: ", params.dropout); + return false; + } + return true; +} + inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { if (!params.query.is_nested()) { return true; @@ -230,7 +239,8 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, - check_for_seq_len_1_nested_tensor}; + check_for_seq_len_1_nested_tensor, + check_for_non_zero_dropout}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/test/test_transformers.py b/test/test_transformers.py index 939d91e7ee87..93a94a5604c9 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -21,8 +21,11 @@ TEST_WITH_ROCM, IS_WINDOWS, slowTest, - set_default_dtype + set_default_dtype, + gradcheck ) + +from torch.testing._internal.common_methods_invocations import wrapper_set_seed from torch.testing._internal.common_cuda import TEST_CUDA, SM80OrLater if TEST_FAIRSEQ: @@ -860,11 +863,22 @@ def rand_tensor(*shape): actual = torch.ops.aten._scaled_dot_product_attention( query, key, value, attn_mask, dropout_p, need_attn_weights, is_causal) - # freeze_rng_state() doesn't seem to work outside of CPU, so dropout makes the results incomparable. - # TODO: Do this skipping in a nicer way once the granular test skipping logic lands. - if dropout_p == 0.0 or device == 'cpu': self.assertEqual(actual, expected) + if attn_mask_dim is None: + q = q.double().clone() + k = k.double().clone() + v = v.double().clone() + q.requires_grad_() + k.requires_grad_() + v.requires_grad_() + + assert gradcheck(lambda *args, **kwargs: wrapper_set_seed(sdp_ref, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (q, k, v, attn_mask, dropout_p)) + @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref') @torch.no_grad() def test_mask_check_fastpath(self): @@ -1079,6 +1093,28 @@ def rand_tensor(shape): self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3) + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): + + batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 + query, key, value = torch.rand((batch_size, seq_len, 3 * num_heads * head_dim), + device="cuda", dtype=torch.float32, requires_grad=True).chunk(3, -1) + query = query.view(batch_size, -1, num_heads, head_dim) + key = key.view(batch_size, -1, num_heads, head_dim) + value = value.view(batch_size, -1, num_heads, head_dim) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + # Normally we would transpose the inputs but the fused kernels expect + # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel + # in fp32 + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), + (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 8349a308be35..a0892b32a835 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2591,7 +2591,7 @@ - name: _mkldnn_reshape(Tensor self, int[] shape) -> Tensor self: grad.reshape_symint(self.sym_sizes()) -# Nested Tensor +# NestedTensor - name: _nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor list: "grad.defined()? at::unbind(grad) : std::vector(list.size())" @@ -2612,6 +2612,11 @@ nested_size: non_differentiable nested_strides: non_differentiable +# Transformers +- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) + # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor self: fft_r2c_backward(grad, dim, normalization, onesided, self.sym_size(dim.back())) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0c59af77736a..5e60eff2865e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11984,8 +11984,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ), OpInfo( 'nn.functional._scaled_dot_product_attention', - op=lambda inp, *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, inp, *args, **kwargs), + op=lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), sample_inputs_func=sample_inputs_scaled_dot_product_attention, dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), From 18c1f2f82eee51bf0e0061dc08d5416b6a7fe0cf Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Tue, 15 Nov 2022 20:35:34 +0000 Subject: [PATCH 184/453] [torch] [analytics] add pytorch event logger callsites to transformers and encoder/decoders (#88896) Differential Revision: D41227275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88896 Approved by: https://github.com/mikekgfb --- torch/nn/modules/transformer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 37e8823edf2c..5f1bc7bb2785 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -56,6 +56,7 @@ def __init__(self, d_model: int = 512, nhead: int = 8, num_encoder_layers: int = device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(Transformer, self).__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") if custom_encoder is not None: self.encoder = custom_encoder @@ -186,6 +187,7 @@ class TransformerEncoder(Module): def __init__(self, encoder_layer, num_layers, norm=None, enable_nested_tensor=True, mask_check=True): super(TransformerEncoder, self).__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers self.norm = norm @@ -307,6 +309,7 @@ class TransformerDecoder(Module): def __init__(self, decoder_layer, num_layers, norm=None): super(TransformerDecoder, self).__init__() + torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers self.norm = norm From d8466964b348b6172317f70b8e52de02402bad54 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 15 Nov 2022 20:35:48 +0000 Subject: [PATCH 185/453] Add range check to multi margin loss target (#89008) Fixes https://github.com/pytorch/pytorch/issues/88724 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89008 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/MultiMarginLoss.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/cuda/MultiMarginLoss.cu b/aten/src/ATen/native/cuda/MultiMarginLoss.cu index 15e6d1e9dc0c..26f21cfa59a2 100644 --- a/aten/src/ATen/native/cuda/MultiMarginLoss.cu +++ b/aten/src/ATen/native/cuda/MultiMarginLoss.cu @@ -31,6 +31,7 @@ __global__ void MultiMarginLoss_forward_kernel( scalar_t *input_k = input + k*dim; scalar_t *output_k = output + k; int target_k = static_cast(target[k]); + CUDA_KERNEL_ASSERT(target_k >= 0 && target_k < dim && "target index is out of bounds"); scalar_t input_target_k = input_k[target_k]; int i_start = threadIdx.x; From 3e2ba60ac0598c6d85ea83a25fd15df855b9f2f9 Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Tue, 15 Nov 2022 20:36:13 +0000 Subject: [PATCH 186/453] [torch] [analytics] add pytorch event logger callsites to torch.save and torch.load (#89003) Summary: as title. Differential Revision: D41239419 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89003 Approved by: https://github.com/ezyang, https://github.com/dzhulgakov --- torch/serialization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/serialization.py b/torch/serialization.py index 3078e57587be..5f9eda67648b 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -427,6 +427,7 @@ def save( >>> buffer = io.BytesIO() >>> torch.save(x, buffer) """ + torch._C._log_api_usage_once("torch.save") _check_dill_version(pickle_module) _check_save_filelike(f) @@ -760,6 +761,7 @@ def load( # Load a module with 'ascii' encoding for unpickling >>> torch.load('module.pt', encoding='ascii') """ + torch._C._log_api_usage_once("torch.load") UNSAFE_MESSAGE = ( "Weights only load failed. Re-running `torch.load` with `weights_only` set to `False`" " will likely succeed, but it can result in arbitrary code execution." From edd2dea859613a9792cfd08a77cf6ae56a531644 Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Tue, 15 Nov 2022 20:46:00 +0000 Subject: [PATCH 187/453] [torch] [analytics] add dynamo to analytics (#88915) Summary: as title. Differential Revision: D41237602 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88915 Approved by: https://github.com/jansel --- torch/_dynamo/eval_frame.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8d9e3b7b6aa1..cb3cffaa73d1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -351,6 +351,7 @@ def optimize( def toy_example(a, b): ... """ + torch._C._log_api_usage_once("torch._dynamo.optimize") if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": return _NullDecorator() if sys.platform == "win32": @@ -451,6 +452,7 @@ def guard_export_print(guards): def export( f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs ): + torch._C._log_api_usage_once("torch._dynamo.export") if decomposition_table is not None or tracing_mode != "real": assert ( aten_graph From 9262d18e1bc1f31479677cbd2c121770f3f36522 Mon Sep 17 00:00:00 2001 From: Fabio Rocha Date: Mon, 14 Nov 2022 10:47:32 +0000 Subject: [PATCH 188/453] [inductor] Introduce CSEVariable type and use it to track if Triton variables are scalar (#88347) This fixes https://github.com/pytorch/torchdynamo/issues/1515 To fix it, we need to keep track of whether a Triton variable is a scalar (so we can not use a mask when doing indirect loads through them). This requires a way of annotating variable names generated by CSE with properties. So now CSE will use CSEVariable class to keep track of variables and let backends subclass it so they can annotate them with whatever information they want. TritonCSEVariable is such a subclass that track the `is_scalar` property. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88347 Approved by: https://github.com/jgong5, https://github.com/ngimel --- test/inductor/test_torchinductor.py | 14 ++++++++++ torch/_inductor/codegen/common.py | 41 ++++++++++++++++++++++++----- torch/_inductor/codegen/triton.py | 23 +++++++++++++++- 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b64f40377995..f43a333d1f09 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5011,6 +5011,20 @@ def forward(pred_objectness_logits_3_: torch.Tensor): result = forward(*args) assert same(result, torch.sort(args[0], descending=True, dim=1)[0]) + @requires_cuda() + def test_scalar_triton_index(self): + # The indirect indexing via a scalar like below used to lead to + # bad triton code that made triton segfault when compiling. + # See https://github.com/pytorch/torchdynamo/issues/1515 + def fn(a): + zero = torch.zeros((16,), device=a.device, dtype=torch.int64) + return (a[zero],) + + a = torch.randn((8,), dtype=torch.float32, device="cuda") + + fn_optimized = torch._dynamo.optimize("inductor")(fn) + assert same(fn(a), fn_optimized(a)) + class TritonCodeGenTests(TestCase): from torch._inductor.triton_ops.autotune import CachingAutotuner diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 932e8c91bc7d..2803970295cc 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -34,7 +34,8 @@ class ExprPrinter(Printer): @staticmethod def paren(string): if ( - re.match(r"^[a-z0-9_.]+$", string, re.I) + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.I) or re.match(r"^\([^)]*\)$", string, re.I) or string == "" ): @@ -405,6 +406,21 @@ def _is_removed(name, buffers): ) +class CSEVariable: + """A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis. + The backends can inherit from this class and overload the "create_cse_var" Kernel to do that. + The "update_on_args" method gives you a hook for annotations, see example of TritonCSEVariable in triton.py.""" + + def __init__(self, name): + self.name = name + + def __str__(self): + return self.name + + def update_on_args(self, args, kwargs): + pass + + class CSE: """Common subexpression elimination""" @@ -425,6 +441,7 @@ def __init__( self.reduction_cache = reduction_cache or {} self.iter_buffer_ids = iter_buffers or itertools.count() self.invalidated_stores = set() + self.varname_map = {} def invalidate(self, keep_vars: typing.Set[str]): for name, tmp in list(self.store_cache.items()): @@ -442,9 +459,11 @@ def clone(self): self.store_cache, ) - def generate(self, buffer: IndentedBuffer, expr: str, write=True): - assert isinstance(expr, str), expr - if expr.startswith(self.name_prefix) and re.match(r"^[a-z0-9]+$", expr): + def generate( + self, buffer: IndentedBuffer, expr: typing.Union[str, CSEVariable], write=True + ) -> CSEVariable: + assert isinstance(expr, (str, CSEVariable)), type(expr) + if isinstance(expr, CSEVariable): return expr if expr not in self.cache: var = self.newvar() @@ -454,8 +473,11 @@ def generate(self, buffer: IndentedBuffer, expr: str, write=True): buffer.writeline(f"{self.prefix}{var} = {expr}{self.suffix}") return self.cache[expr] - def newvar(self): - return f"{self.name_prefix}{next(self.iter_buffer_ids)}" + def newvar(self) -> CSEVariable: + var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" + var = V.kernel.create_cse_var(var_name) + self.varname_map[var_name] = var + return var class CodeGen: @@ -539,9 +561,11 @@ class CSEProxy: @staticmethod def __getattr__(name): def inner(*args, **kwargs): - return self.cse.generate( + csevar = self.cse.generate( self.compute, getattr(parent_handler, name)(*args, **kwargs) ) + csevar.update_on_args(args, kwargs) + return csevar return inner @@ -598,3 +622,6 @@ def rename_indexing(self, index) -> sympy.Expr: x: self.args.size(x) for x in sorted_symbols if x.name.startswith("s") } return sympy_subs(index, replacements) + + def create_cse_var(self, *args, **kwargs): + return CSEVariable(*args, **kwargs) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 88a0ad4977be..b79b03232a8a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -24,6 +24,7 @@ ) from ..virtualized import ops, V from .common import ( + CSEVariable, DeferredLine, ExprPrinter, IndentedBuffer, @@ -109,6 +110,17 @@ def triton_constant(value): return repr(value) +class TritonCSEVariable(CSEVariable): + def __init__(self, name): + super().__init__(name) + self.is_scalar = False + + def update_on_args(self, args, kwargs): + self.is_scalar = all( + not (isinstance(arg, TritonCSEVariable)) or arg.is_scalar for arg in args + ) + + class TritonOverrides(OpOverrides): """Map element-wise ops to Triton""" @@ -752,7 +764,13 @@ def indexing( # https://github.com/openai/triton/issues/633 mask = ["None"] - return index_str, " & ".join(mask) + if ( + index_str in self.cse.varname_map + and self.cse.varname_map[index_str].is_scalar + ): + mask = ["None"] + + return index_str, " & ".join(map(str, mask)) def var_ranges(self): return dict( @@ -1106,6 +1124,9 @@ def call_kernel(self, code, name: str): f"{name}.run({call_args}, grid=grid({', '.join(grid)}), stream={stream_name})" ) + def create_cse_var(self, *args, **kwargs): + return TritonCSEVariable(*args, **kwargs) + class TritonScheduling: def __init__(self, scheduler): From d47b94fa8e17ad805f1283943dd2b1bc46b309b8 Mon Sep 17 00:00:00 2001 From: Fabio Rocha Date: Mon, 14 Nov 2022 10:47:34 +0000 Subject: [PATCH 189/453] [inductor] Added bucketize to decomp table (#88348) These are the benchmark results vs eager ``` [--------------------------- bucketize ----------------------------] | eager | decomp 32 threads: -------------------------------------------------------- ((16384, 1024), (16,)), (True, True) | 600 | 464 ((16384, 1024), (16,)), (True, False) | 542 | 464 ((16384, 1024), (16,)), (False, True) | 780 | 731 ((16384, 1024), (16,)), (False, False) | 777 | 731 ((16384, 1024), (64,)), (True, True) | 624 | 515 ((16384, 1024), (64,)), (True, False) | 603 | 515 ((16384, 1024), (64,)), (False, True) | 789 | 718 ((16384, 1024), (64,)), (False, False) | 786 | 718 ((16384, 1024), (256,)), (True, True) | 878 | 820 ((16384, 1024), (256,)), (True, False) | 891 | 830 ((16384, 1024), (256,)), (False, True) | 897 | 900 ((16384, 1024), (256,)), (False, False) | 900 | 900 ((16384, 1024), (1024,)), (True, True) | 2000 | 1890 ((16384, 1024), (1024,)), (True, False) | 1950 | 1892 ((16384, 1024), (1024,)), (False, True) | 1990 | 1962 ((16384, 1024), (1024,)), (False, False) | 1990 | 2060 ((16384, 1024), (4096,)), (True, True) | 3405 | 3155 ((16384, 1024), (4096,)), (True, False) | 3244 | 3154 ((16384, 1024), (4096,)), (False, True) | 3282 | 3219 ((16384, 1024), (4096,)), (False, False) | 3278 | 3220 ((16384, 1024), (16384,)), (True, True) | 4626 | 4672 ((16384, 1024), (16384,)), (True, False) | 4629 | 4671 ((16384, 1024), (16384,)), (False, True) | 4662 | 4829 ((16384, 1024), (16384,)), (False, False) | 4665 | 4824 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88348 Approved by: https://github.com/ngimel --- torch/_inductor/decomposition.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 0b29dd524cb7..44bfd46505a2 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -104,6 +104,7 @@ aten.upsample_nearest2d_backward, aten.softplus, aten.softplus_backward, + aten.bucketize, ] ) From da2afcb1e0006354f78d5e56d2933382d7af9ebf Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 15 Nov 2022 21:05:59 +0000 Subject: [PATCH 190/453] Add test for out-of-bounds Tensor access on GPU (#39211) Since CUDA context can not recover safely from on-device assert, use `torch.multiprocessing.spawn` to execute a method in another context and verify that it raises unrecoverable error. As those types of tests are pretty slow (6 seconds on powerful linux box with one GPU) run it only in the slow shard. Closes https://github.com/pytorch/pytorch/issues/38944 Pull Request resolved: https://github.com/pytorch/pytorch/pull/39211 Approved by: https://github.com/ezyang --- test/test_cuda.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_cuda.py b/test/test_cuda.py index fada440a7293..59f379487c43 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1576,6 +1576,38 @@ def test_multinomial_invalid_probs_cuda(self): self._spawn_test_multinomial_invalid_probs_cuda([1., -inf, 1.]) self._spawn_test_multinomial_invalid_probs_cuda([1., 1., nan]) + @staticmethod + def _mute_init(): + os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno()) + + def _spawn_method(self, method, arg): + ctx = torch.multiprocessing.get_context("spawn") + with ctx.Pool(1, initializer=self._mute_init) as pool: + errors = pool.map(method, [arg]) + for e in errors: + if 'device-side assert triggered' not in str(e): + self.fail(e) + + @staticmethod + def _test_index_bounds_cuda(idx): + x = torch.arange(10, device="cuda") + try: + y = x[torch.tensor([idx])] + return f"x[torch.tensor([{idx})]={y}" + except RuntimeError as err: + return err + + @slowTest + @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ + don't support multiprocessing with spawn start method") + @skipIfRocm + def test_index_out_of_bounds_exception_cuda(self): + test_method = TestCuda._test_index_bounds_cuda + # Test in-bound access works fine + self.assertEqual(test_method(1), "x[torch.tensor([1)]=tensor([1], device='cuda:0')") + # Test that indexing out of bounds causes assert + self._spawn_method(test_method, 11) + @slowTest @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") def test_huge_index(self): From 8dc3353b0b1c12f64ba790c7be85cfbc99448cb4 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Tue, 15 Nov 2022 21:16:15 +0000 Subject: [PATCH 191/453] add `to(dtype)` support for all sparse compressed formats (#89055) Fixes [#88419](https://github.com/pytorch/pytorch/issues/88419) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89055 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/TensorConversions.cpp | 84 +++++++++++----------- test/test_sparse_csr.py | 14 ++++ 2 files changed, 58 insertions(+), 40 deletions(-) diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index ec699bf1bf7f..96275bde8299 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -244,48 +244,52 @@ Tensor _to_copy( // memory_format is handled separately due to MemoryFormat::Preserve logic options = self.options().merge_in(options).memory_format(c10::nullopt); auto memory_format = optional_memory_format.value_or(MemoryFormat::Preserve); + // TODO: Use the dispatcher for this. // Currently there are unenumerated extensibility issues preventing this. - if (self.is_sparse_csr()) { - TORCH_CHECK( - memory_format == MemoryFormat::Preserve, - "sparse_csr only supports memory format Preserve, but got ", - memory_format, - " instead."); - - auto new_values = at::native::to( - self.values(), - dtype, - c10::kStrided, // values are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - auto new_crow_indices = at::native::to( - self.crow_indices(), - self.crow_indices().scalar_type(), // indices are integral - c10::kStrided, // indices are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - auto new_col_indices = at::native::to( - self.col_indices(), - self.col_indices().scalar_type(), // indices are integral - c10::kStrided, // indices are strided - device, - pin_memory, - non_blocking, - true, // force copy since we're in _to_copy - memory_format); - - return at::native::_sparse_csr_tensor_unsafe( - new_crow_indices, - new_col_indices, + if (at::sparse_csr::is_sparse_compressed(self)) { + TORCH_CHECK( + memory_format == MemoryFormat::Preserve, + "to(options): ", at::sparse_csr::layoutToString(self.layout()), + " only supports memory format Preserve, but got ", memory_format, + " instead."); + + Tensor compressed_indices, plain_indices; + std::tie(compressed_indices, plain_indices) = at::sparse_csr::getCompressedPlainIndices(self); + + const auto new_values = at::native::to( + self.values(), + dtype, + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + const auto new_compressed_indices = at::native::to( + compressed_indices, + compressed_indices.scalar_type(), + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + const auto new_plain_indices = at::native::to( + plain_indices, + plain_indices.scalar_type(), + c10::kStrided, + device, + pin_memory, + non_blocking, + true, // force copy since we are in _to_copy + memory_format); + + return at::native::_sparse_compressed_tensor_unsafe( + new_compressed_indices, + new_plain_indices, new_values, self.sizes(), new_values.scalar_type(), diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index d2e3c5fc3851..e83616489fc2 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -934,6 +934,20 @@ def test_dim(self, layout): self.assertEqual(sparse.dense_dim(), dense_dim) + @skipMeta + @all_sparse_compressed_layouts() + @dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16)) + def test_to_dtype(self, layout, device, dtype): + # to_dense does not support hybrid inputs + input_gen = self._generate_small_inputs(layout, device=device, enable_hybrid=False) + for compressed_indices, plain_indices, values, size in input_gen: + sparse = torch.sparse_compressed_tensor(compressed_indices, plain_indices, values, size, + dtype=dtype, layout=layout, device=device) + for to_dtype in all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16): + sparse_to_dtype = sparse.to(to_dtype) + dense_to_dtype = sparse.to_dense().to(to_dtype) + self.assertEqual(sparse_to_dtype.to_dense(), dense_to_dtype) + def _npref_block_addmm_addmv(c, a, b, alpha, beta): return alpha * (a @ b) + beta * c From 175b7e1cde0eaaef0465aa9c760842e5ea07e104 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 15 Nov 2022 21:27:14 +0000 Subject: [PATCH 192/453] print xpass (#89020) Print unexpected success as XPASS. I will submit a PR to test-infra so that the log classifier can find these Ex: https://github.com/pytorch/pytorch/actions/runs/3466368885/jobs/5790424173 ``` test_import_hipify (__main__.TestHipify) ... ok (0.000s) test_check_onnx_broadcast (__main__.TestONNXUtils) ... ok (0.000s) test_prepare_onnx_paddings (__main__.TestONNXUtils) ... ok (0.000s) test_load_standalone (__main__.TestStandaloneCPPJIT) ... ok (16.512s) ====================================================================== XPASS [4.072s]: test_smoke (__main__.TestCollectEnv) ---------------------------------------------------------------------- ---------------------------------------------------------------------- Ran 31 tests in 24.594s FAILED (skipped=7, unexpected successes=1) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89020 Approved by: https://github.com/huydhn, https://github.com/seemethere --- torch/testing/_internal/common_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e0b703046c54..fa3eda3758e4 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -770,6 +770,9 @@ def addSkip(self, test, reason): # it stands for `verbose_str` captured in the closure c.cell_contents = f"skip: {reason}" + def printErrors(self) -> None: + super().printErrors() + self.printErrorList("XPASS", self.unexpectedSuccesses) test_report_path = get_report_path() verbose = '--verbose' in argv or '-v' in argv if verbose: From 67af734adeebf448c54bbc294e115244c5c32f35 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Tue, 15 Nov 2022 21:33:38 +0000 Subject: [PATCH 193/453] skip test that is broken in head (#88759) Test Plan: Rely on CI. Differential Revision: D41156351 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88759 Approved by: https://github.com/zou3519 --- test/mobile/test_lite_script_type.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/mobile/test_lite_script_type.py b/test/mobile/test_lite_script_type.py index 8769a4b2363a..44eb6d4778e8 100644 --- a/test/mobile/test_lite_script_type.py +++ b/test/mobile/test_lite_script_type.py @@ -4,6 +4,7 @@ import torch.utils.bundled_inputs import io from typing import Dict, List, NamedTuple +import unittest from torch.jit.mobile import _load_for_lite_interpreter from torch.testing._internal.common_utils import TestCase, run_tests @@ -34,6 +35,7 @@ def forward(self, a: torch.Tensor): ) + @unittest.skip("T137512434") def test_typing_dict_with_namedtuple(self): class Foo(NamedTuple): id: torch.Tensor From d0130cd21ee419fcb33a9ceefa3583aac1e736e1 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 14 Nov 2022 14:47:15 +0000 Subject: [PATCH 194/453] Enable test_ops for inductor (#88994) Summary: skip several unsupported test cases Pull Request resolved: https://github.com/pytorch/pytorch/pull/88994 Approved by: https://github.com/Krovatkin --- .jenkins/pytorch/test.sh | 2 +- test/test_ops.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 6bbda7f4d707..5fa54f538f35 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -249,7 +249,7 @@ test_inductor_distributed() { } test_inductor() { - python test/run_test.py --include test_modules --verbose + python test/run_test.py --include test_modules test_ops --verbose # TODO: investigate "RuntimeError: CUDA driver API confirmed a leak" # seen intest_ops_gradients.py # pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64" diff --git a/test/test_ops.py b/test/test_ops.py index 0ef2e4ee6d60..11d659e5cd2b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1069,6 +1069,7 @@ def _test_inplace_preserve_storage(samples, variants): # Reference testing for operations in complex32 against complex64. # NOTE: We test against complex64 as NumPy doesn't have a complex32 equivalent dtype. @ops(op_db, allowed_dtypes=(torch.complex32,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_complex_half_reference_testing(self, device, dtype, op): if not op.supports_dtype(torch.complex32, device): unittest.skip("Does not support complex32") @@ -1098,6 +1099,7 @@ def test_complex_half_reference_testing(self, device, dtype, op): @ops(op_db, allowed_dtypes=(torch.bool,)) @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") + @skipIfTorchInductor("Inductor does not support view with dtype yet") def test_non_standard_bool_values(self, device, dtype, op): # Test boolean values other than 0x00 and 0x01 (gh-54789) def convert_boolean_tensors(x): @@ -1497,6 +1499,7 @@ def clone_and_perform_view(input, **kwargs): self.assertEqual(tensor.grad, cloned1_tensor.grad) @ops(ops_and_refs, allowed_dtypes=(torch.cfloat,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_conj_view(self, device, dtype, op): if not op.test_conjugated_samples: self.skipTest("Operation doesn't support conjugated inputs.") @@ -1519,6 +1522,7 @@ def test_conj_view(self, device, dtype, op): ) @ops(ops_and_refs, allowed_dtypes=(torch.double,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_view(self, device, dtype, op): if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") @@ -1538,6 +1542,7 @@ def test_neg_view(self, device, dtype, op): ) @ops(ops_and_refs, allowed_dtypes=(torch.cdouble,)) + @skipIfTorchInductor("Inductor does not support complex dtype yet") def test_neg_conj_view(self, device, dtype, op): if not op.test_neg_view: self.skipTest("Operation not tested with tensors with negative bit.") From 35093fc1ab9749e6b763acead007e56b54c6375b Mon Sep 17 00:00:00 2001 From: Michael Wootton Date: Tue, 15 Nov 2022 21:40:43 +0000 Subject: [PATCH 195/453] Enable correct supported activities for kineto on rocm (#88207) A compile time guard was preventing ActivityType::CUDA from being available on rocm. This caused both the GPU_FALLBACK and CUDA modes to be active at the same time. So operators were being charged gpu time for the hipEventRecord ranges and the actual kernel execution times. This caused incorrect (and often negative) cuda times, in e.g. table(). Pull Request resolved: https://github.com/pytorch/pytorch/pull/88207 Approved by: https://github.com/malfet, https://github.com/jeffdaily --- torch/csrc/autograd/init.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index ee963232d316..6bfd4bd4bfed 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -279,8 +279,9 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ActivityType::CPU}; -#if defined(USE_KINETO) && !defined(LIBKINETO_NOCUPTI) - if (at::getNumGPUs() > 0 && !at::hasHIP()) { +#if defined(USE_KINETO) && \ + (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) + if (at::getNumGPUs() > 0) { activities.insert(ActivityType::CUDA); } #endif From ee05f47bddfb97b4b292808543d928b3526fc0ca Mon Sep 17 00:00:00 2001 From: Charlie Yan Date: Tue, 15 Nov 2022 18:03:53 +0000 Subject: [PATCH 196/453] Rebase and re-land thread PG (#88795) The previous PR (https://github.com/pytorch/pytorch/pull/88627) has been reverted due to a failed check. After rebasing and rerun, all checks passed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88795 Approved by: https://github.com/huydhn, https://github.com/wanchaol --- test/distributed/test_multi_threaded_pg.py | 45 +++ test/test_testing.py | 1 + torch/testing/_internal/common_distributed.py | 149 +++++++-- .../distributed/multi_threaded_pg.py | 288 ++++++++++++++++++ 4 files changed, 457 insertions(+), 26 deletions(-) create mode 100644 test/distributed/test_multi_threaded_pg.py create mode 100644 torch/testing/_internal/distributed/multi_threaded_pg.py diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py new file mode 100644 index 000000000000..6a0fe33cd8ad --- /dev/null +++ b/test/distributed/test_multi_threaded_pg.py @@ -0,0 +1,45 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import torch.distributed as dist + +if not dist.is_available(): + print("Distributed not available, skipping tests", file=sys.stderr) + sys.exit(0) + +from torch.testing._internal.common_distributed import ( + spawn_threads_and_init_comms, + MultiThreadedTestCase + +) +from torch.testing._internal.common_utils import TestCase, run_tests + +DEFAULT_WORLD_SIZE = 4 + +class TestObjectCollectivesWithWrapper(TestCase): + @spawn_threads_and_init_comms(world_size=4) + def test_broadcast_object_list(self): + val = 99 if dist.get_rank() == 0 else None + object_list = [val] * dist.get_world_size() + + dist.broadcast_object_list(object_list=object_list) + self.assertEqual(99, object_list[0]) + +class TestObjectCollectivesWithBaseClass(MultiThreadedTestCase): + @property + def world_size(self): + return 4 + + def test_broadcast_object_list(self): + val = 99 if dist.get_rank() == 0 else None + object_list = [val] * dist.get_world_size() + print(f"{dist.get_rank()} -> {dist.get_world_size()}") + + dist.broadcast_object_list(object_list=object_list) + self.assertEqual(99, object_list[0]) + + def test_something_else(self): + pass + +if __name__ == "__main__": + run_tests() diff --git a/test/test_testing.py b/test/test_testing.py index 8fe66043e5a1..f05883919f17 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1806,6 +1806,7 @@ def test_circular_dependencies(self) -> None: # And these both end up with transitive dependencies on distributed ignored_modules.append("torch.nn.parallel._replicated_tensor_ddp_interop") ignored_modules.append("torch.testing._internal.common_fsdp") + ignored_modules.append("torch.testing._internal.common_distributed") torch_dir = os.path.dirname(torch.__file__) for base, folders, files in os.walk(torch_dir): diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 607211087ddc..883a48a5a5fe 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -2,10 +2,10 @@ import logging import multiprocessing import os +import subprocess import sys import tempfile import threading -import subprocess import time import traceback import types @@ -14,11 +14,7 @@ from dataclasses import dataclass from datetime import timedelta from enum import Enum -from functools import ( - partial, - reduce, - wraps -) +from functools import partial, reduce, wraps from io import StringIO from typing import NamedTuple, Optional, Union @@ -26,16 +22,17 @@ import torch.cuda.nccl import torch.distributed as c10d from torch.testing._internal.common_utils import ( - TestCase, - TEST_WITH_ROCM, - TEST_WITH_TSAN, FILE_SCHEMA, find_free_port, - retry_on_connect_failures, IS_SANDCASTLE, - sandcastle_skip_if, + retry_on_connect_failures, sandcastle_skip, + sandcastle_skip_if, + TEST_WITH_ROCM, + TEST_WITH_TSAN, + TestCase, ) +from torch.testing._internal.distributed.multi_threaded_pg import run_with_threaded_pg logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -67,11 +64,10 @@ class TestSkip(NamedTuple): "generic": TestSkip( 86, "Test skipped at subprocess level, look at subprocess log for skip reason" ), - "importerror": TestSkip( - 88, "Test skipped due to missing import" - ), + "importerror": TestSkip(88, "Test skipped due to missing import"), } + @dataclass class DistTestCases: # Backends that do not support a specific collective @@ -93,6 +89,7 @@ class DistTestCases: def skip_if_no_gpu(func): """Skips if the world size exceeds the number of GPUs, ensuring that if the test is run, each rank has its own GPU via ``torch.cuda.device(rank)``.""" + @wraps(func) def wrapper(*args, **kwargs): if not torch.cuda.is_available(): @@ -116,6 +113,7 @@ def wrapper(*args, **kwargs): return wrapper + def skip_if_odd_worldsize(func): @wraps(func) def wrapper(*args, **kwargs): @@ -126,6 +124,7 @@ def wrapper(*args, **kwargs): return wrapper + def require_n_gpus_for_nccl_backend(n, backend): def decorator(func): @wraps(func) @@ -139,12 +138,17 @@ def wrapper(*args, **kwargs): return decorator + def import_transformers_or_skip(): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): try: - from transformers import BertConfig, AutoModelForMaskedLM # noqa: Unused + from transformers import ( # noqa: Unused + AutoModelForMaskedLM, + BertConfig, + ) + return func(*args, **kwargs) except ImportError: sys.exit(TEST_SKIPS["importerror"].exit_code) @@ -153,6 +157,7 @@ def wrapper(*args, **kwargs): return decorator + def skip_if_lt_x_gpu(x): def decorator(func): @wraps(func) @@ -191,10 +196,13 @@ def verify_ddp_error_logged(model_DDP, err_substr): logging_err = ddp_logging_data["error"] # Remove C++ stacktrace if needed. actual = ( - err_substr if err_substr.find("\nException raised from ") == -1 + err_substr + if err_substr.find("\nException raised from ") == -1 else err_substr.split("\nException raised from ")[0] ) - assert actual in logging_err, f"Did not find expected {actual} in ddp logging data error: {logging_err}" + assert ( + actual in logging_err + ), f"Did not find expected {actual} in ddp logging data error: {logging_err}" def with_nccl_blocking_wait(func): @@ -319,7 +327,7 @@ def wrapper(*args, **kwargs): def skip_if_win32(): return sandcastle_skip_if( - sys.platform == 'win32', + sys.platform == "win32", "This unit test case is not supportted on Windows platform", ) @@ -352,13 +360,14 @@ def create_tcp_store( # TSAN runs much slower. TIMEOUT_DEFAULT = 500 else: - TIMEOUT_DEFAULT = int(os.getenv('DISTRIBUTED_TESTS_DEFAULT_TIMEOUT', '300')) + TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300")) TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400} # https://github.com/pytorch/pytorch/issues/75665 if TEST_WITH_ROCM: TIMEOUT_OVERRIDE["test_join_kwargs"] = 200 + def create_device(interface=None): if sys.platform == "win32" or interface is None: return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1") @@ -449,9 +458,7 @@ def init_multigpu_helper(world_size: int, backend: str): if world_size > nGPUs: nGPUs_per_process = nGPUs // world_size rank_to_GPU = { - i: list( - visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process] - ) + i: list(visible_devices[i * nGPUs_per_process : (i + 1) * nGPUs_per_process]) for i in range(world_size) } return rank_to_GPU @@ -482,6 +489,9 @@ def cleanup_temp_dir() -> None: tmp_dir.cleanup() +# Most tests operate with this worldsize +DEFAULT_WORLD_SIZE = 4 + # [How does MultiProcessTestCase work?] # Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by # default `world_size()` returns 4. Let's take `test_rpc_spawn.py` as an @@ -508,7 +518,7 @@ def _should_stop_test_suite(self) -> bool: @property def world_size(self) -> int: - return 4 + return DEFAULT_WORLD_SIZE def join_or_run(self, fn): @wraps(fn) @@ -607,7 +617,10 @@ def _event_listener(parent_pipe, signal_pipe, rank: int): @classmethod def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None: # Enable DDP + ReplicatedTensor - from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor + from torch.nn.parallel._replicated_tensor_ddp_utils import ( + _set_ddp_with_replicated_tensor, + ) + _set_ddp_with_replicated_tensor(True) self = cls(test_name) @@ -815,16 +828,20 @@ def _check_return_codes(self, elapsed_time) -> None: self.assertEqual( first_process.exitcode, 0, - msg="Expected zero exit code but got {} for pid: {}".format(first_process.exitcode, first_process.pid) + msg="Expected zero exit code but got {} for pid: {}".format( + first_process.exitcode, first_process.pid + ), ) @property def is_master(self) -> bool: return self.rank == 0 + # Cannot use functools.cache as it requires python 3.9 EFA_PROBE_RESULT = None + def has_efa() -> bool: """ If shell command `fi_info -p efa -t FI_EP_RDM` returns exit code 0 then we assume that the machine has @@ -836,7 +853,9 @@ def has_efa() -> bool: return EFA_PROBE_RESULT try: - EFA_PROBE_RESULT = subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0 + EFA_PROBE_RESULT = ( + subprocess.run(["fi_info", "-p", "efa", "-t", "FI_EP_RDM"]).returncode == 0 + ) except FileNotFoundError: EFA_PROBE_RESULT = False return EFA_PROBE_RESULT @@ -850,3 +869,81 @@ def tp_transports(): see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022 """ return ["shm", "uv"] if has_efa() else None + + +def _run_test_with_mt_pg(self, timeout, world_size, callback): + failed_ranks = run_with_threaded_pg(world_size, timeout, callback) + for rank, exc_info in failed_ranks: + print(f"Rank {rank} raised:") + for line in traceback.format_exception(*exc_info): + sys.stdout.write(line) + self.assertEqual([], failed_ranks, "Some ranks failed") + + +def spawn_threads_and_init_comms( + func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE +): + """ + Wrapper to use with a test method + """ + if func is None: + return partial( + spawn_threads_and_init_comms, timeout=timeout, world_size=world_size + ) + + @wraps(func) + def wrapper(self, *args, **kwargs): + _run_test_with_mt_pg( + self, timeout, world_size, lambda: func(self, *args, **kwargs) + ) + + return wrapper + + +class MultiThreadedTestCase(TestCase): + """ + Simple test runner that executes all tests with the in-proc process group. + + A single instance of the TestCase object for all threads. + + Difference from regular test runner: + Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown) + Not sure what these two would be good for though. + No global state possible + How bad of a limitation is this? + """ + + def __init__(self, method_name: str = "runTest") -> None: + super().__init__(method_name) + self._test_method = getattr(self, method_name, None) + setattr(self, method_name, self.threaded_run_test) + if TestCase.setUp != type(self).setUp: + raise RuntimeError( + f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead" + ) + if TestCase.tearDown != type(self).tearDown: + raise RuntimeError( + f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead" + ) + + def threaded_run_test(self): + self.perThreadSetUp() + try: + _run_test_with_mt_pg( + self=self, + timeout=TIMEOUT_DEFAULT, + world_size=self.world_size, + callback=self._test_method, + ) + finally: + self.perThreadTearDown() + + def perThreadSetUp(self): + pass + + def perThreadTearDown(self): + pass + + @property + def world_size(self) -> int: + raise RuntimeError("world size not implemented") diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py new file mode 100644 index 000000000000..7e18f870f2e7 --- /dev/null +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -0,0 +1,288 @@ +import queue +import sys +import threading +import time +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import _create_work_from_future, Store +from torch.futures import Future +from torch.utils._pytree import tree_flatten + +""" +TODO: +Lots of missing collectives. +Collectives validation. +Make timeout robust by making collectives respect the test deadline. +Make tests robuts by making collectives interruptible. +We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures. + +""" + + +def flatten_list(lst): + return tree_flatten(lst)[0] + + +def ret_work(ret): + fut = Future() + fut.set_result(ret) + return _create_work_from_future(fut) + + +class AllGather: + def work(self, data): + for src_rank in range(len(data)): + in_tensor_list = data[src_rank][1] + # Can't handle all_gather with multiple tensors + assert len(in_tensor_list) == 1 + src_tensor = in_tensor_list[0] + + for dest in data: + dest_tensor = dest[0][0][src_rank] + with torch.no_grad(): + dest_tensor.copy_(src_tensor) + + +class Broadcast: + def __init__(self, src): + self.src = src + + def work(self, data): + in_tensor_list = flatten_list(data[self.src]) + for i in range(len(data)): + out_tensor_list = flatten_list(data[i]) + for j in range(len(in_tensor_list)): + with torch.no_grad(): + out_tensor_list[j].copy_(in_tensor_list[j]) + + +class Collective: + def __init__(self, world_size, collective): + self._world_size = world_size + self._collective = collective + + self._start_cond = threading.Condition() + self._done_cond = threading.Condition() + + self._data = [None] * world_size + self._count = 0 + self._done = False + + def join(self, rank, data): + with self._start_cond: + self._data[rank] = data + self._count += 1 + + # notify rank 0 + if self._count == self._world_size: + if rank > 0: + self._start_cond.notify() + + if rank == 0: + while self._count < self._world_size: + self._start_cond.wait() + + with self._done_cond: + # wait for rank 0 to finish + if rank > 0: + while not self._done: + self._done_cond.wait() + else: + # copy data around + self._collective.work(self._data) + self._done = True + self._done_cond.notify_all() + return ret_work(data) + + +class ProcessLocalGroup(dist.ProcessGroup): + _pg_lock = threading.Lock() + _pg_list = [] + _count = 0 + _ready = False + + _coll_lock = threading.Lock() + _cur_coll = None + + @classmethod + def _register(cls, pg): + with cls._pg_lock: + while len(cls._pg_list) <= pg._rank: + cls._pg_list.append(None) + cls._pg_list[pg._rank] = pg + cls._count += 1 + if cls._count == pg._world: + cls._ready = True + + @classmethod + def _start_coll(cls, world_size, collective): + with cls._coll_lock: + if not cls._ready: + raise Exception( + f"world not ready, only {cls._count} PG's registered but world has {world_size} ranks" + ) + if cls._cur_coll is None: + cls._cur_coll = Collective(world_size, collective) + return cls._cur_coll + + @classmethod + def _end_coll(cls, collective): + # This is racily called by all ranks, so only one will work + with cls._coll_lock: + if cls._cur_coll == collective: + cls._cur_coll = None + + def allgather(self, output_tensors, input_tensor, options): + coll = ProcessLocalGroup._start_coll(self._world, AllGather()) + res = coll.join(self._rank, (output_tensors, input_tensor)) + ProcessLocalGroup._end_coll(coll) + return res + + def broadcast(self, tensor_list, opts): + coll = ProcessLocalGroup._start_coll(self._world, Broadcast(opts.rootRank)) + res = coll.join(self._rank, tensor_list) + ProcessLocalGroup._end_coll(coll) + return res + + def __init__(self, rank, world): + super(ProcessLocalGroup, self).__init__(rank, world) + self._rank = rank + self._world = world + ProcessLocalGroup._register(self) + + def size(self): + return self._world + + def getBackendName(self): + return "local" + + def __repr__(self): + return f"PLG w:{self._world} r:{self._rank}" + + +def _create_threaded_pg(prefix_store, rank, world_size, timeout): + return ProcessLocalGroup(rank, world_size) + + +dist.Backend.register_backend("threaded", _create_threaded_pg) + + +@dataclass +class WorldData: + default_pg: dist.ProcessGroup + pg_map: Dict[dist.ProcessGroup, Tuple[str, Optional[Store]]] + pg_names: Dict[dist.ProcessGroup, str] + pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]] + group_count: int + + +class ThreadLocalWorld: + _world = threading.local() + + def _get_world(self) -> WorldData: + if not hasattr(ThreadLocalWorld._world, "world"): + ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, 0) + return ThreadLocalWorld._world.world + + @property + def default_pg(self): + return self._get_world().default_pg + + @default_pg.setter + def default_pg(self, value): + self._get_world().default_pg = value + + @property + def pg_map(self): + return self._get_world().pg_map + + @property + def pg_names(self): + return self._get_world().pg_names + + @property + def pg_group_ranks(self): + return self._get_world().pg_group_ranks + + @property + def group_count(self) -> int: + return self._get_world().group_count + + @group_count.setter + def group_count(self, value): + self._get_world().group_count = value + + +_old_pg_world = None + + +def _install_threaded_pg(): + global _old_pg_world + _old_pg_world = dist.distributed_c10d._world + dist.distributed_c10d._world = ThreadLocalWorld() + return dist.distributed_c10d._world + + +def _uninstall_threaded_pg(): + dist.distributed_c10d._world = _old_pg_world + + +def run_with_threaded_pg(world_size, timeout, callback): + """ + Run ``callback`` with ``world_size`` threads using the in-proc process group + """ + world = _install_threaded_pg() + + def world_is_valid(): + return world == dist.distributed_c10d._world + + global_store = dist.HashStore() + exception_queue = queue.Queue() + + def worker(rank): + if not world_is_valid(): + raise TimeoutError("Invalid world") + dist.init_process_group( + backend="threaded", rank=rank, world_size=world_size, store=global_store + ) + try: + callback() + except BaseException as ex: + exception_queue.put((rank, sys.exc_info())) + finally: + if world_is_valid(): + dist.destroy_process_group() + + try: + threads = [ + threading.Thread(target=worker, args=(rank,)) for rank in range(world_size) + ] + for thread in threads: + thread.start() + + deadline = time.time() + timeout + for idx, thread in enumerate(threads): + thread.join(max(0, deadline - time.time())) + if thread.is_alive(): + exception_queue.put( + ( + idx, + ( + TimeoutError, + TimeoutError( + f"Rank failed to join in under {timeout} seconds" + ), + None, + ), + ) + ) + failed_ranks = [] + while not exception_queue.empty(): + failure = exception_queue.get() + failed_ranks.append(failure) + return failed_ranks + finally: + _uninstall_threaded_pg() From 60ffeb986648420810098cba6ac0ad1cee06bd95 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 16 Nov 2022 00:08:34 +0000 Subject: [PATCH 197/453] Don't iterate over graph when adding graph input (#89084) helps with https://github.com/pytorch/torchdynamo/issues/1803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89084 Approved by: https://github.com/jansel --- torch/_dynamo/output_graph.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ee5079581be7..4578fb98dfcb 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -110,6 +110,10 @@ def __init__( self.tensor_id_to_sym_shape_ref = {} self.intermediary_symbols = {} + # Enables creating unique node names by tracking + # all current placeholder node names + self.name_to_input = collections.OrderedDict() + @property def output(self): return self @@ -147,6 +151,7 @@ def restore_graphstate(self, state): del node.meta["example_value"] self.graph.erase_node(node) self.real_value_cache.pop(node, None) + self.name_to_input.pop(node.name, None) def count_calls(self): return count_calls(self.graph) @@ -162,22 +167,22 @@ def get_submodule(self, keys): return obj def create_graph_input(self, name, type_expr=None): - placeholders = [n for n in self.graph.nodes if n.op == "placeholder"] - # unique - used_names = {n.target for n in placeholders} - if name in used_names: + if name in self.name_to_input: for i in itertools.count(): - if f"{name}_{i}" not in used_names: + if f"{name}_{i}" not in self.name_to_input: name = f"{name}_{i}" break - if placeholders: - ctx = self.graph.inserting_after(placeholders[-1]) + if self.name_to_input: + prev_name = next(reversed(self.name_to_input)) + ctx = self.graph.inserting_after(self.name_to_input[prev_name]) else: ctx = self.graph.inserting_before(None) with ctx: - return self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) + self.name_to_input[name] = proxy.node + return proxy def new_var(self, name="tmp"): existing = set(self.code_options["co_varnames"]) @@ -490,6 +495,7 @@ def remove_unused_graphargs(self): del node.meta["example_value"] self.graph.erase_node(node) self.real_value_cache.pop(node, None) + self.name_to_input.pop(node.name, None) self.graphargs = [arg for arg in self.graphargs if arg.uses > 0] @@ -525,6 +531,7 @@ def cleanup(self): if "example_value" in node.meta: del node.meta["example_value"] self.real_value_cache.clear() + self.name_to_input.clear() def create_proxy( self, From a13433940c4e8d7cc54d4fa5b3a9c0ff28fc0e8b Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Wed, 16 Nov 2022 00:29:08 +0000 Subject: [PATCH 198/453] allow loading model from a path in torchbench (#89028) Sometimes it's really convenient to run simple models thru the torchbench.py script rather than those from pytorch/benchmark. This PR add the ability to run any model from a specified path by overloading the --only argument. This PR is split out from #88904 Here is the usage: Specify the path and class name of the model in format like: --only=path:,class: Due to the fact that dynamo changes current working directory, the path should be an absolute path. The class should have a method get_example_inputs to return the inputs for the model. An example looks like ``` class LinearModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) def forward(self, x): return self.linear(x) def get_example_inputs(self): return (torch.randn(2, 10),) ``` Test command: ``` # python benchmarks/dynamo/torchbench.py --performance --only=path:/pytorch/myscripts/model_collection.py,class:LinearModel --backend=eager WARNING:common:torch.cuda.is_available() == False, using CPU cpu eval LinearModel 0.824x p=0.00 ``` Content of model_collection.py ``` from torch import nn import torch class LinearModel(nn.Module): """ AotAutogradStrategy.compile_fn ignore graph with at most 1 call nodes. Make sure this model calls 2 linear layers to avoid being skipped. """ def __init__(self, nlayer=2): super().__init__() layers = [] for _ in range(nlayer): layers.append(nn.Linear(10, 10)) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) def get_example_inputs(self): return (torch.randn(2, 10),) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89028 Approved by: https://github.com/jansel --- benchmarks/dynamo/common.py | 93 +++++++++++++++++++++++++++++++------ 1 file changed, 80 insertions(+), 13 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 198877e0313d..a6e66c4281b6 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -4,6 +4,7 @@ import copy import csv import functools +import importlib import io import logging import os @@ -164,6 +165,42 @@ ] +def model_specified_by_path(path_and_class_str): + return ":" in path_and_class_str + + +def load_model_from_path(path_and_class_str): + configs = {} + for kvstr in path_and_class_str.split(","): + k, v = kvstr.split(":") + configs[k] = v + + for name in ["path", "class"]: + if name not in configs: + raise RuntimeError( + "Invalid --only arguments. Check help message for the correct format" + ) + + path = configs["path"] + class_name = configs["class"] + + if path[:1] != "/": + raise RuntimeError( + "Use absolute path since dynamo may change the current working directory which makes using relative path tricky" + ) + + spec = importlib.util.spec_from_file_location("module_name", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + model_class = getattr(module, class_name) + assert issubclass(model_class, torch.nn.Module) + model = model_class() + assert hasattr(model, "get_example_inputs") + inputs = model.get_example_inputs() + return model, inputs + + def output_csv(filename, headers, row): assert filename existed = os.path.exists(filename) @@ -1393,7 +1430,31 @@ def parse_args(args=None): parser.add_argument( "--fast", "-f", action="store_true", help="skip slow benchmarks" ) - parser.add_argument("--only", help="Run just one model") + parser.add_argument( + "--only", + help="""Run just one model from torchbench. Or + specify the path and class name of the model in format like: + --only=path:,class: + + Due to the fact that dynamo changes current working directory, + the path should be an absolute path. + + The class should have a method get_example_inputs to return the inputs + for the model. An example looks like + ``` + class LinearModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + def get_example_inputs(self): + return (torch.randn(2, 10),) + ``` + """, + ) parser.add_argument( "--training", action="store_true", @@ -1885,19 +1946,25 @@ def run(runner, args, original_dir=None): batch_size = read_batch_size_from_file( args, args.batch_size_file, model_name ) - try: - device, name, model, example_inputs, batch_size = runner.load_model( - device, - model_name, - batch_size=batch_size, - ) - except NotImplementedError as e: - print(e) - import traceback + if model_specified_by_path(args.only): + model, example_inputs = load_model_from_path(args.only) + name = model.__class__.__name__ + model = model.to(device=device) + example_inputs = tree_map(lambda x: x.to(device=device), example_inputs) + else: + try: + device, name, model, example_inputs, batch_size = runner.load_model( + device, + model_name, + batch_size=batch_size, + ) + except NotImplementedError as e: + print(e) + import traceback - print(traceback.format_exc()) - logging.warn(f"{args.only} failed to load") - continue # bad benchmark implementation + print(traceback.format_exc()) + logging.warn(f"{args.only} failed to load") + continue # bad benchmark implementation current_name = name current_device = device From 0ce22574b1aee4688e6ef56f66d6dfb31ae33b04 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 00:45:41 +0000 Subject: [PATCH 199/453] Revert "Enable correct supported activities for kineto on rocm (#88207)" This reverts commit 35093fc1ab9749e6b763acead007e56b54c6375b. Reverted https://github.com/pytorch/pytorch/pull/88207 on behalf of https://github.com/kit1980 due to Broke test_kineto on trunk / win-vs2019-cuda11.6-py3 / test (default, 4, 5, windows.8xlarge.nvidia.gpu) --- torch/csrc/autograd/init.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6bfd4bd4bfed..ee963232d316 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -279,9 +279,8 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { m.def("_supported_activities", []() { std::set activities{ActivityType::CPU}; -#if defined(USE_KINETO) && \ - (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) - if (at::getNumGPUs() > 0) { +#if defined(USE_KINETO) && !defined(LIBKINETO_NOCUPTI) + if (at::getNumGPUs() > 0 && !at::hasHIP()) { activities.insert(ActivityType::CUDA); } #endif From 2268a3215cdadbbbd561100a6368704ba9ef5f0d Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Mon, 14 Nov 2022 11:00:15 -0800 Subject: [PATCH 200/453] [functorch] add switch to enable autograd.Function (#88784) This is mostly a debug or "if you know what you're doing" switch for now. It is not public API. Test Plan: - new tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/88784 Approved by: https://github.com/samdow, https://github.com/soulitzer --- aten/src/ATen/functorch/DynamicLayer.cpp | 13 ++++++++++- aten/src/ATen/functorch/DynamicLayer.h | 6 +++++ test/functorch/test_eager_transforms.py | 29 ++++++++++++++++++++++++ torch/csrc/functorch/init.cpp | 6 +++++ 4 files changed, 53 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index 8a2668fe748b..bea9e6e3a2f4 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -101,7 +101,7 @@ class FuncTorchTLS : public FuncTorchTLSBase { } int64_t checkSupportsAutogradFunction() const override { - TORCH_CHECK(dynamicLayerStack.size() == 0, + TORCH_CHECK(dynamicLayerStack.size() == 0 || getAutogradFunctionAllowed(), "functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. ", "Please rewrite your function to not use autograd.Function while we work on fixing this"); return 0; @@ -128,6 +128,7 @@ class FuncTorchTLS : public FuncTorchTLSBase { std::vector dynamicLayerStack; bool allow_inplace_requires_grad_ = false; + bool allow_autograd_function_ = false; }; static FuncTorchTLS* getRawFunctorchTLS() { @@ -151,6 +152,16 @@ bool getInplaceRequiresGradAllowed() { return functorch_tls->allow_inplace_requires_grad_; } +void setAutogradFunctionAllowed(bool allowed) { + auto* functorch_tls = getRawFunctorchTLS(); + functorch_tls->allow_autograd_function_ = allowed; +} + +bool getAutogradFunctionAllowed() { + auto* functorch_tls = getRawFunctorchTLS(); + return functorch_tls->allow_autograd_function_; +} + static std::vector& dynamicLayerStackAccessor() { return getRawFunctorchTLS()->dynamicLayerStack; } diff --git a/aten/src/ATen/functorch/DynamicLayer.h b/aten/src/ATen/functorch/DynamicLayer.h index 576a9621651a..737620e54ae6 100644 --- a/aten/src/ATen/functorch/DynamicLayer.h +++ b/aten/src/ATen/functorch/DynamicLayer.h @@ -113,6 +113,12 @@ TORCH_API Tensor unwrapIfDead(const Tensor& tensor); TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer); TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector& dynamicLayerStack); +// While a functorch transform is active, autograd.Function is disabled +// by default. The following two APIs are APIs for enabling +// autograd.Function. These are not user-facing APIs. +TORCH_API void setAutogradFunctionAllowed(bool allowed); +TORCH_API bool getAutogradFunctionAllowed(); + // While a functorch grad transform is active, Tensor.requires_grad_() gets // disabled. These two functions are the mechanism to controlling that. TORCH_API void setInplaceRequiresGradAllowed(bool allowed); diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index e88e8007e77e..2dc52d3af085 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -2388,6 +2388,35 @@ def f(x): with self.assertRaises(RuntimeError): grad(f)(x) + def test_autograd_function_debug_switch(self, device): + class MySin(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.sin() + + @staticmethod + def backward(ctx, gy): + x, = ctx.saved_tensors + return gy * x.cos() + + x = torch.randn([]) + + # by default, autograd.Function is disabled in a functorch transform + with self.assertRaisesRegex(RuntimeError, "autograd.Function"): + grad(MySin.apply)(x) + + # we have a debug switch to allow it + self.assertFalse(torch._C._functorch.get_autograd_function_allowed()) + try: + torch._C._functorch.set_autograd_function_allowed(True) + self.assertTrue(torch._C._functorch.get_autograd_function_allowed()) + y = grad(MySin.apply)(x) + finally: + torch._C._functorch.set_autograd_function_allowed(False) + self.assertFalse(torch._C._functorch.get_autograd_function_allowed()) + self.assertEqual(y, x.cos()) + @parametrize('transform', [ 'vmap', 'grad', 'jacrev', 'jacfwd', 'grad_and_value', 'hessian', 'functionalize' ]) diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index b1f696ee3c7d..5248da36baa5 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -438,6 +438,12 @@ void initFuncTorchBindings(PyObject* module) { m.def( "get_inplace_requires_grad_allowed", &at::functorch::getInplaceRequiresGradAllowed); + m.def( + "set_autograd_function_allowed", + &at::functorch::setAutogradFunctionAllowed); + m.def( + "get_autograd_function_allowed", + &at::functorch::getAutogradFunctionAllowed); m.def("dlevel", &dlevel, "dlevel"); m.def("dump_tensor", &dump_tensor, "dump_tensor"); m.def("reshape_dim_into", &at::functorch::reshape_dim_into); From 3bc327993f7182f3305b0aae854a26c83458c5a6 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 15 Nov 2022 08:12:03 -0800 Subject: [PATCH 201/453] PyDispatcher integration with functorch (#88785) This PR teaches PyDispatcher and PyOperator about functorch transforms. It is important that PyDispatcher/PyOperator dispatch with functorch transforms, because this is our plan for higher-order operators (operators that accept functions as arguments). Examples of these include: - functorch transforms over the existing cond operator (control flow) - autograd.Function support for functorch (which I am working towards), - AOTDispatcher (should be a higher order operator) Concretely, the problem with teaching PyDispatcher/PyOperator about functorch is that the stack-based dispatching logic (DynamicLayerStack) is hidden inside the fallbacks for two dispatch keys (DynamicLayer{Front, Back}). PyDispatcher doesn't know about C++ boxed fallbacks, our plan on record for that is that we need to reimplement all of them in Python (but can call helper functions in C++ to make our lives easier). Instead of exposing all of what DynamicLayer{Front, Back} do to python, this PR takes the approach of re-implementing part of the stack-based dispatching in Python. The motivation is that this is more sane and follows what the "ideal" implementation of functorch would have been: - each transform should be a "mode" - there should be no TLS dispatch key set hackery. functorch needs to do this hackery today to re-use VariableType implementations. This PR: - exposes the DynamicLayerStack to Python - The DynamicLayerStack is a stack of Interpreters. These get exposed to Python as well. - Interpreters can run operations (Interpreter.process) or lower them to the next interpreter in the stack (Interpreter.lower) - To use a PyOperator with functorch transforms, a developer needs to register a rule for each transform (vmap, grad, jvp, ...). - The PyOperator API is NOT user-facing. Things like autograd.Function support for functorch will end up going through the autograd.Function API. Question for reviewers: - Does this design make sense? - I'm trying to split up the "functorch support for autograd.Function" work into logical pieces. Would it be better if I didn't? (the full thing is a bit long - 1000-2000 LOC). Test Plan: - new tests that construct PyOperator and compose them with functorch transforms Pull Request resolved: https://github.com/pytorch/pytorch/pull/88785 Approved by: https://github.com/samdow, https://github.com/soulitzer --- aten/src/ATen/functorch/ADInterpreters.cpp | 10 +- aten/src/ATen/functorch/ADInterpreters.h | 6 +- aten/src/ATen/functorch/DynamicLayer.cpp | 4 +- aten/src/ATen/functorch/DynamicLayer.h | 3 + test/functorch/test_eager_transforms.py | 124 ++++++++++++++++++ torch/_C/_functorch.pyi | 34 +++++ torch/_functorch/__init__.py | 0 torch/_functorch/pyfunctorch.py | 142 +++++++++++++++++++++ torch/_functorch/utils.py | 14 ++ torch/_ops.py | 25 +++- torch/csrc/functorch/init.cpp | 35 +++++ torch/csrc/utils/python_dispatch.cpp | 20 +-- torchgen/model.py | 2 + 13 files changed, 398 insertions(+), 21 deletions(-) create mode 100644 torch/_functorch/__init__.py create mode 100644 torch/_functorch/pyfunctorch.py create mode 100644 torch/_functorch/utils.py diff --git a/aten/src/ATen/functorch/ADInterpreters.cpp b/aten/src/ATen/functorch/ADInterpreters.cpp index 46c134f59d61..174949bbc3b4 100644 --- a/aten/src/ATen/functorch/ADInterpreters.cpp +++ b/aten/src/ATen/functorch/ADInterpreters.cpp @@ -28,7 +28,7 @@ static void checkForInvalidMutationOnCaptures( "as inputs."); } -static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) { +Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_level) { if (!tensor.defined()) { return tensor; } @@ -44,6 +44,14 @@ static Tensor materializeGradWrappers(const Tensor& tensor, int64_t current_leve return makeTensorWrapper(tensor, current_level, /*is_immutable=*/true); } +Tensor GradInterpreterPtr::lift(const Tensor& tensor) const { + return materializeGradWrappers(tensor, level()); +} + +Tensor JvpInterpreterPtr::lift(const Tensor& tensor) const { + return materializeGradWrappers(tensor, level()); +} + static void autogradBasedTransformProcess( const c10::OperatorHandle& op, torch::jit::Stack* stack, diff --git a/aten/src/ATen/functorch/ADInterpreters.h b/aten/src/ATen/functorch/ADInterpreters.h index b8ad638c5aee..6ec1cca065d6 100644 --- a/aten/src/ATen/functorch/ADInterpreters.h +++ b/aten/src/ATen/functorch/ADInterpreters.h @@ -7,7 +7,7 @@ namespace at { namespace functorch { // (grad, vjp and jvp). // See NOTE: [functorch interpreter stack] for more details. -struct GradInterpreterPtr { +struct TORCH_API GradInterpreterPtr { explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); } TransformType key() const { return base_->key(); } int64_t level() const { return base_->level(); } @@ -16,11 +16,12 @@ struct GradInterpreterPtr { bool prevGradMode() const { return c10::get(base_->meta()).prevGradMode_; } + Tensor lift(const Tensor& tensor) const; private: const Interpreter* base_; }; -struct JvpInterpreterPtr { +struct TORCH_API JvpInterpreterPtr { explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); } TransformType key() const { return base_->key(); } int64_t level() const { return base_->level(); } @@ -29,6 +30,7 @@ struct JvpInterpreterPtr { bool prevFwdGradMode() const { return c10::get(base_->meta()).prevFwdGradMode_; } + Tensor lift(const Tensor& tensor) const; private: const Interpreter* base_; }; diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index bea9e6e3a2f4..d152f3c08c2d 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -214,7 +214,7 @@ bool areTransformsActive() { return !data.empty(); } -static DynamicLayer popDynamicLayer() { +DynamicLayer popDynamicLayer() { auto& dynamicLayerStack = dynamicLayerStackAccessor(); TORCH_INTERNAL_ASSERT(dynamicLayerStack.size() > 0); auto result = dynamicLayerStack.back(); @@ -232,7 +232,7 @@ static DynamicLayer popDynamicLayer() { return result; } -static int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) { +int64_t pushDynamicLayer(DynamicLayer&& dynamic_layer) { auto& dynamicLayerStack = dynamicLayerStackAccessor(); int64_t layerId = 1 + dynamicLayerStack.size(); TORCH_INTERNAL_ASSERT(layerId == dynamic_layer.layerId()); diff --git a/aten/src/ATen/functorch/DynamicLayer.h b/aten/src/ATen/functorch/DynamicLayer.h index 737620e54ae6..6c7139f5c01e 100644 --- a/aten/src/ATen/functorch/DynamicLayer.h +++ b/aten/src/ATen/functorch/DynamicLayer.h @@ -124,5 +124,8 @@ TORCH_API bool getAutogradFunctionAllowed(); TORCH_API void setInplaceRequiresGradAllowed(bool allowed); TORCH_API bool getInplaceRequiresGradAllowed(); +TORCH_API DynamicLayer popDynamicLayer(); +TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer); + } } // namespace at diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 2dc52d3af085..e123da0d9d3c 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -37,6 +37,8 @@ ) from functorch._src.eager_transforms import enable_fwd_grad, _slice_argnums from functorch.experimental import functionalize +from torch._ops import PyOperator +from torch._functorch.utils import enable_autograd_function # NB: numpy is a testing dependency! import numpy as np @@ -3543,6 +3545,123 @@ def forward(self, x_1): """) +def construct_sum_pyop(): + mysum = PyOperator("mysum") + + @mysum.py_impl(torch._C._functorch.TransformType.Vmap) + def mysum_batch_rule(interpreter, x, dim): + if not torch._C._functorch.is_batchedtensor(x): + with interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + return mysum(x, dim) + + bdim = torch._C._functorch.maybe_get_bdim(x) + value = torch._C._functorch.get_unwrapped(x) + + with interpreter.lower(): + value = value.movedim(bdim, 0) + result = mysum(value, dim + 1) + + return torch._C._functorch._add_batch_dim(result, 0, interpreter.level()) + + @mysum.py_impl(torch._C._functorch.TransformType.Grad) + def mysum_grad_rule(interpreter, x, dim): + level = interpreter.level() + + class MySum(torch.autograd.Function): + @staticmethod + def forward(ctx, x, dim): + ctx.x_shape = x.shape + ctx.dim = dim + x = torch._C._functorch._unwrap_for_grad(x, level) + with torch.enable_grad(), interpreter.lower(): + x = x.view_as(x) # unnecessary, just here to test the dispatch + y = mysum(x, dim) + + y = torch._C._functorch._wrap_for_grad(y, level) + return y + + @staticmethod + def backward(ctx, gy): + return gy.unsqueeze(ctx.dim).expand(ctx.x_shape), None + + with enable_autograd_function(): + return MySum.apply(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCPU) + def mysum_autograd_cpu(x, dim): + return torch.sum(x, dim) + + @mysum.py_impl(torch._C.DispatchKey.AutogradCUDA) + def mysum_autograd_cuda(x, dim): + return torch.sum(x, dim) + + return mysum + +sum_pyop = construct_sum_pyop() + +class TestPyOperatorInteraction(TestCase): + + def test_basic_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = sum_pyop(x, 1) + self.assertEqual(result, torch.sum(x, 1)) + + def test_vmap_sum(self, device): + x = torch.randn(2, 3, 4, device=device) + result = vmap(sum_pyop, (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 1)) + + result = vmap(vmap(sum_pyop, (0, None)), (0, None))(x, 0) + self.assertEqual(result, torch.sum(x, 2)) + + def test_grad_sum(self, device): + x = torch.randn(3, device=device) + gx = grad(sum_pyop)(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_grad_grad_sum(self, device): + x = torch.randn(3, requires_grad=True, device=device) + + def f(x): + # higher order grad. Requires a non-linearity + return sum_pyop(x.sin(), 0) + + def grad_f_sum(x): + return grad(f)(x).sum() + + ggx = grad(grad_f_sum)(x) + self.assertEqual(ggx, -x.sin()) + + def test_vmap_grad_sum(self, device): + x = torch.randn(2, 3, device=device) + gx = vmap(grad(sum_pyop), (0, None))(x, 0) + self.assertEqual(gx, torch.ones_like(x)) + + def test_no_grad_outside_grad(self, device): + x = torch.randn(3, device=device, requires_grad=True) + with torch.no_grad(): + y = grad(sum_pyop)(x, 0) + self.assertEqual(y, torch.ones_like(x)) + self.assertFalse(y.requires_grad) + + def test_no_grad_inside_grad(self, device): + def f(x): + with torch.no_grad(): + shift = sum_pyop(x ** 2, 0) + return sum_pyop(x ** 2, 0) - shift + + x = torch.randn(3, device=device) + y = grad(f)(x) + self.assertEqual(y, 2 * x) + y = grad(lambda x: grad(f)(x).sum())(x) + self.assertEqual(y, torch.full_like(x, 2)) + + x = torch.randn(3, device=device, requires_grad=True) + y = grad(f)(x) + z, = torch.autograd.grad(y.sum(), x) + self.assertEqual(z, torch.full_like(x, 2)) + only_for = ("cpu", "cuda") instantiate_device_type_tests( @@ -3585,6 +3704,11 @@ def forward(self, x_1): globals(), only_for=only_for, ) +instantiate_device_type_tests( + TestPyOperatorInteraction, + globals(), + only_for=only_for, +) instantiate_device_type_tests( TestFunctionalize, globals(), diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index 6ab5f91b78f1..bb9649daadcb 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -1,4 +1,5 @@ from torch import Tensor +from enum import Enum # Defined in torch/csrc/functorch/init.cpp @@ -10,3 +11,36 @@ def is_functorch_wrapped_tensor(tensor: Tensor) -> bool: ... def is_gradtrackingtensor(tensor: Tensor) -> bool: ... def maybe_get_bdim(tensor: Tensor) -> int: ... def maybe_get_level(tensor: Tensor) -> int: ... + +def set_autograd_function_allowed(allowed: bool) -> None: ... +def get_autograd_function_allowed() -> bool: ... + +# Defined in aten/src/ATen/functorch/Interpreter.h +class TransformType(Enum): + Torch: TransformType = ... + Vmap: TransformType = ... + Grad: TransformType = ... + Jvp: TransformType = ... + Functionalize: TransformType = ... + +class CInterpreter: + def key(self) -> TransformType: ... + def level(self) -> int: ... + +class CGradInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def lift(self, Tensor) -> Tensor: ... + def prevGradMode(self) -> bool: ... + +class CVmapInterpreterPtr: + def __init__(self, interpreter: CInterpreter): ... + def key(self) -> TransformType: ... + def level(self) -> int: ... + def batchSize(self) -> int: ... + +class DynamicLayer: + pass + +def peek_interpreter_stack() -> CInterpreter: ... +def pop_dynamic_layer_stack() -> DynamicLayer: ... +def push_dynamic_layer_stack(dl: DynamicLayer) -> int: ... diff --git a/torch/_functorch/__init__.py b/torch/_functorch/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py new file mode 100644 index 000000000000..1ada5b4e1977 --- /dev/null +++ b/torch/_functorch/pyfunctorch.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +import contextlib +from typing import Any +import torch +import torch.utils._pytree as pytree +from torch._C._functorch import ( + TransformType, + CInterpreter, + CGradInterpreterPtr, + CVmapInterpreterPtr, + pop_dynamic_layer_stack, + push_dynamic_layer_stack, +) + +""" +This file contains the functorch integration with PyDispatcher. + +PyDispatcher does not understand functorch's DynamicLayerStack dispatching +logic because it is entirely implemented in C++ in the fallbacks for two +dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable +to directly reuse C++ boxed fallbacks). + +Instead of trying to hammer PyDispatcher into understanding those fallbacks, +we re-implement the logic of peeking the top of the stack for an interpreter, +selecting the interpreter to dispatch on, etc, in Python. This leads to a +simpler design. + +The main difference between C++ functorch and PyDispatcher's functorch logic +is that: +- C++ functorch needs to manually tweak dispatch keys to ping-pong between + DynamicLayerFrontMode and DynamicLayerBackMode. +- PyDispatcher's functorch logic pops an Interpreter from the top of the stack + and asks it to execute the rule associated with the Interpreter. + +In C++ we do the ping-pong because e.g. vmap rules are associated with the +batched DispatchKey, but in PyDispatcher we are able to avoid this by asking +the user to register a batching rule directly to a transform that an +interpreter then invokes. +""" + + +# FuncTorchInterpreter is the Python version of Interpreter (recall that +# the DynamicLayerStack is a stack of interpreters). +# It is a wrapper around the actual C++ Interpreter object. +# +# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h +class FuncTorchInterpreter(ABC): + def __init__(self, cptr: Any): + self._cptr = cptr + + # Process an operation. eg for vmap, this is invoking a batching rule. + # Conceptually this is analogous to Interpreter::process in C++ + @abstractmethod + def process(self, op, args, kwargs): + pass + + # lower an operation from this Interpreter to the next Interpreter on the stack. + # Concretely, this involves temporarily popping the current Interpreter. + # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++ + def lower(self): + return temporarily_pop_interpreter_stack() + + def level(self): + return self._cptr.level() + + def key(self): + return self._cptr.key() + + +@contextlib.contextmanager +def temporarily_pop_interpreter_stack(): + try: + saved = pop_dynamic_layer_stack() + yield + finally: + push_dynamic_layer_stack(saved) + + +class VmapInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Vmap + # NOTE: [Interpreter cdata vs cptr] + # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr + # so that we can access methods specific to the vmap interpreter + self._cdata = cdata + self._cptr = CVmapInterpreterPtr(cdata) + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Vmap] + return kernel(self, *args, **kwargs) + + def batch_size(self): + return self._cptr.batchSize() + + +class GradInterpreter(FuncTorchInterpreter): + def __init__(self, cdata: CInterpreter): + assert cdata.key() == TransformType.Grad + # See NOTE: [Interpreter cdata vs cptr] + self._cdata = cdata + self._cptr = CGradInterpreterPtr(cdata) + + def lift(self, args, kwargs): + args, kwargs = pytree.tree_map_only(torch.Tensor, self._cptr.lift, [args, kwargs]) + return args, kwargs + + def process(self, op, args, kwargs): + kernel = op.functorch_table[TransformType.Grad] + args, kwargs = self.lift(args, kwargs) + return kernel(self, *args, **kwargs) + + # GradInterpreter has custom lower because of the no_grad interaction + # See NOTE [grad and vjp interaction with no_grad] + # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter + def lower(self): + prev_grad_mode = self.prev_grad_mode() + if not self.prev_grad_mode: + return contextlib.nested(torch.no_grad(), super().lower()) + return super().lower() + + def prev_grad_mode(self): + return self._cptr.prevGradMode() + + +def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter: + key = cinterpreter.key() + if key == TransformType.Grad: + return GradInterpreter(cinterpreter) + if key == TransformType.Vmap: + return VmapInterpreter(cinterpreter) + raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}") + + +def retrieve_current_functorch_interpreter(): + interpreter = torch._C._functorch.peek_interpreter_stack() + assert interpreter is not None + return coerce_cinterpreter(interpreter) + + +def dispatch_functorch(op, args, kwargs): + interpreter = retrieve_current_functorch_interpreter() + return interpreter.process(op, args, kwargs) diff --git a/torch/_functorch/utils.py b/torch/_functorch/utils.py new file mode 100644 index 000000000000..c1474ba90fe3 --- /dev/null +++ b/torch/_functorch/utils.py @@ -0,0 +1,14 @@ +import contextlib +from torch._C._functorch import ( + set_autograd_function_allowed, + get_autograd_function_allowed, +) + +@contextlib.contextmanager +def enable_autograd_function(): + try: + prev_state = get_autograd_function_allowed() + set_autograd_function_allowed(True) + yield + finally: + set_autograd_function_allowed(prev_state) diff --git a/torch/_ops.py b/torch/_ops.py index 4c194e9d938b..9163932144d0 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -10,6 +10,7 @@ import torch.jit from torch import _utils_internal +from torch._functorch.pyfunctorch import dispatch_functorch # Query `hasattr` only once. @@ -114,6 +115,7 @@ def __init__(self, name): self._name = name self.table = {} self.python_key_mode_table = {} + self.functorch_table = {} # Make _OPNamespace not scream, this whole name based association needs a good hard look self.__name__ = name @@ -122,18 +124,26 @@ def __init__(self, name): def fallthrough(self, dispatch_key): self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key) - def py_impl(self, dispatch_key_or_mode): + def py_impl(self, dispatch_key_or_mode_or_transform): def inner(fn): - if inspect.isclass(dispatch_key_or_mode) and issubclass( - dispatch_key_or_mode, torch.utils._python_dispatch.TorchDispatchMode + if inspect.isclass(dispatch_key_or_mode_or_transform) and issubclass( + dispatch_key_or_mode_or_transform, + torch.utils._python_dispatch.TorchDispatchMode, ): - mode = dispatch_key_or_mode + mode = dispatch_key_or_mode_or_transform assert mode not in self.python_key_mode_table # TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys? self.python_key_mode_table[mode] = fn return fn - dispatch_key = dispatch_key_or_mode + if isinstance( + dispatch_key_or_mode_or_transform, torch._C._functorch.TransformType + ): + transform = dispatch_key_or_mode_or_transform + self.functorch_table[transform] = fn + return fn + + dispatch_key = dispatch_key_or_mode_or_transform assert ( dispatch_key != torch._C.DispatchKey.Python ), "Please register a mode for the torch._C.DispatchKey.Python key instead." @@ -147,6 +157,9 @@ def inner(fn): def dispatch(self, dispatch_key, *args, **kwargs): from torch.utils._python_dispatch import _get_current_dispatch_mode + if dispatch_key == torch._C.DispatchKey.FuncTorchDynamicLayerFrontMode: + return dispatch_functorch(self, args, kwargs) + if dispatch_key == torch._C.DispatchKey.Python: # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. curr_mode = type(_get_current_dispatch_mode()) @@ -159,7 +172,7 @@ def dispatch(self, dispatch_key, *args, **kwargs): # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key. return self.python_key_mode_table[curr_mode](*args, **kwargs) - assert dispatch_key in self.table + assert dispatch_key in self.table, dispatch_key return self.table[dispatch_key](*args, **kwargs) def __call__(self, *args, **kwargs): diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 5248da36baa5..65a3b3415b7e 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -467,6 +468,40 @@ void initFuncTorchBindings(PyObject* module) { m.def("is_functorch_wrapped_tensor", [](const Tensor& tensor) { return maybe_get_level(tensor) != -1; }); + m.def("peek_interpreter_stack", []() -> c10::optional { + const auto& stack = getDynamicLayerStack(); + if (stack.size() == 0) { + return c10::nullopt; + } + auto result = stack.back().interpreter(); + return result; + }); + m.def("pop_dynamic_layer_stack", &popDynamicLayer); + m.def("push_dynamic_layer_stack", [](DynamicLayer layer) -> int64_t { + return pushDynamicLayer(std::move(layer)); + }); + py::class_(m, "DynamicLayer"); + + py::enum_(m, "TransformType") + .value("Torch", TransformType::Torch) + .value("Grad", TransformType::Grad) + .value("Jvp", TransformType::Jvp) + .value("Functionalize", TransformType::Functionalize) + .value("Vmap", TransformType::Vmap); + py::class_(m, "CInterpreter") + .def("key", &Interpreter::key) + .def("level", &Interpreter::level); + py::class_(m, "CGradInterpreterPtr") + .def(py::init()) + .def("key", &GradInterpreterPtr::key) + .def("level", &GradInterpreterPtr::level) + .def("lift", &GradInterpreterPtr::lift) + .def("prevGradMode", &GradInterpreterPtr::prevGradMode); + py::class_(m, "CVmapInterpreterPtr") + .def(py::init()) + .def("key", &VmapInterpreterPtr::key) + .def("level", &VmapInterpreterPtr::level) + .def("batchSize", &VmapInterpreterPtr::batchSize); } } // namespace impl diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 0ff1f575a61a..662ab9981a1d 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -479,14 +479,14 @@ void initDispatchBindings(PyObject* module) { #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) - py::enum_(m, "DispatchKey") DEF_ONE(Undefined) - DEF_ONE(CompositeExplicitAutogradNonFunctional) - DEF_ONE(CompositeExplicitAutograd) - DEF_ONE(CompositeImplicitAutogradNestedTensor) - DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther) - DEF_ONE(Autograd) DEF_ONE(BackendSelect) - DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot) - DEF_ONE(Python) + py::enum_(m, "DispatchKey") DEF_ONE(Undefined) DEF_ONE( + CompositeExplicitAutogradNonFunctional) DEF_ONE(CompositeExplicitAutograd) + DEF_ONE(CompositeImplicitAutogradNestedTensor) + DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther) + DEF_ONE(Autograd) DEF_ONE(BackendSelect) DEF_ONE(ADInplaceOrView) + DEF_ONE(PythonTLSSnapshot) DEF_ONE(Python) + DEF_ONE(FuncTorchDynamicLayerFrontMode) + DEF_ONE(FuncTorchDynamicLayerBackMode) #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) #define DEF_MULTIPLE(fullname, prefix) \ @@ -495,11 +495,11 @@ void initDispatchBindings(PyObject* module) { C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ DEF_SINGLE(, EndOf##fullname##Backends) - C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) + C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) #undef DEF_MULTIPLE #undef DEF_SINGLE - ; + ; py::class_(m, "DispatchKeySet") .def(py::init()) diff --git a/torchgen/model.py b/torchgen/model.py index a2a658d0a59c..d57d3372a159 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -81,6 +81,7 @@ class DispatchKey(Enum): SparseCsrCUDA = auto() Python = auto() + FuncTorchDynamicLayerBackMode = auto() ZeroTensor = auto() BackendSelect = auto() Named = auto() @@ -91,6 +92,7 @@ class DispatchKey(Enum): Autocast = auto() Batched = auto() VmapMode = auto() + FuncTorchDynamicLayerFrontMode = auto() TESTING_ONLY_GenericWrapper = auto() TESTING_ONLY_GenericMode = auto() From 19cacecf34cf46f1c7ca3920979dcd6fd7709a61 Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Wed, 16 Nov 2022 00:56:12 +0000 Subject: [PATCH 202/453] Fix and Re-enable test_quantize_fx_lite_script_module.py (#88897) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: After D35984526 (https://github.com/pytorch/pytorch/commit/416899d1a9fcb9dbc8bb66ed796b86360f573903), ```torch.ao.quantization.quantize_fx.prepare_fx``` requires passing in ```example_args```. This diff fixes the calls to ```prepare_fx``` in this test by adding in ```example_args``` as necessary. Test Plan: ``` buck test caffe2/test:fx_quantization_lite ``` ``` ✓ ListingSuccess: caffe2/test:fx_quantization_lite : 3 tests discovered (39.689) ✓ Pass: caffe2/test:fx_quantization_lite - test_conv2d (mobile.test_quantize_fx_lite_script_module.TestLiteFuseFx) (44.451) ✓ Pass: caffe2/test:fx_quantization_lite - test_embedding (mobile.test_quantize_fx_lite_script_module.TestLiteFuseFx) (45.462) ✓ Pass: caffe2/test:fx_quantization_lite - test_submodule (mobile.test_quantize_fx_lite_script_module.TestLiteFuseFx) (45.933) Summary Pass: 3 ListingSuccess: 1 Finished test run: https://www.internalfb.com/intern/testinfra/testrun/3096224827259146 ``` Differential Revision: D41227335 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88897 Approved by: https://github.com/dagitses --- test/mobile/test_quantize_fx_lite_script_module.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/mobile/test_quantize_fx_lite_script_module.py b/test/mobile/test_quantize_fx_lite_script_module.py index 44beeef818c3..ebc96d17697b 100644 --- a/test/mobile/test_quantize_fx_lite_script_module.py +++ b/test/mobile/test_quantize_fx_lite_script_module.py @@ -47,7 +47,11 @@ def forward(self, indices): for qconfig, node in configs: qconfig_dict = {"": qconfig} - m = prepare_fx(model, qconfig_dict) + m = prepare_fx( + model, + qconfig_dict, + example_inputs=torch.randint(low=0, high=10, size=(20,)), + ) m = convert_fx(m) self._compare_script_and_mobile(m, input=indices) @@ -65,7 +69,7 @@ def forward(self, x): m = M().eval() qconfig_dict = {"": default_qconfig, "module_name": [("conv1", None)]} - m = prepare_fx(m, qconfig_dict) + m = prepare_fx(m, qconfig_dict, example_inputs=torch.randn(1, 1, 1, 1)) data = torch.randn(1, 1, 1, 1) m = convert_fx(m) # first conv is quantized, second conv is not quantized @@ -84,7 +88,11 @@ def test_submodule(self): "": torch.ao.quantization.get_default_qconfig("qnnpack"), **config, } - model = prepare_fx(model, qconfig_dict) + model = prepare_fx( + model, + qconfig_dict, + example_inputs=torch.randn(5, 5), + ) quant = convert_fx(model) x = torch.randn(5, 5) From 49f0be0762e8cac48ccf3b19d1c662be6b271581 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Nov 2022 06:32:36 -0800 Subject: [PATCH 203/453] Hide ConvParams struct from ConvUtils.h (#89059) It isn't actually used outside of Convolution.cpp, so no reason to publish it. I intend to turn this into a template, so moving it with the method definitions is very convenient. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89059 Approved by: https://github.com/SherlockNoMad --- aten/src/ATen/native/ConvUtils.h | 55 ---------------------------- aten/src/ATen/native/Convolution.cpp | 55 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 55 deletions(-) diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 675f701c8582..b8e2b0842a00 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -80,40 +80,6 @@ static inline bool cudnnv8_use_heur_mode_b() { return cudnnv8_heuristic_mode_b; } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -struct ConvParams { - std::vector stride; - std::vector padding; - std::vector dilation; - bool transposed; - std::vector output_padding; - int groups; - bool benchmark; - bool deterministic; - bool cudnn_enabled; - bool allow_tf32; - - bool is_strided() const; - bool is_dilated() const; - bool is_padded() const; - bool is_output_padding_neg() const; - bool is_output_padding_big() const; - bool is_padding_neg() const; - bool is_stride_nonpos() const; - void view1d_as_2d(); - bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const; - bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const; - bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const; - bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const; - bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const; - bool use_mps(const at::Tensor& input, const at::Tensor& weight) const; - bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; -}; - // Keep in sync with py::enum_ in Module.cpp enum class ConvBackend { CudaDepthwise2d, @@ -140,27 +106,6 @@ enum class ConvBackend { MpsTranspose, }; -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -TORCH_API ConvBackend _select_conv_backend( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias_opt, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -// For BC reasons, have a copy that does not require bias_opt -TORCH_API ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - // Overload for selecting the convolution backend from the full set of convolution inputs. // This overload is exposed to python for testing, etc. TORCH_API ConvBackend select_conv_backend( diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 109f0ac05922..e87d98357ca9 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -82,6 +82,61 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +struct ConvParams { + std::vector stride; + std::vector padding; + std::vector dilation; + bool transposed; + std::vector output_padding; + int groups; + bool benchmark; + bool deterministic; + bool cudnn_enabled; + bool allow_tf32; + + bool is_strided() const; + bool is_dilated() const; + bool is_padded() const; + bool is_output_padding_neg() const; + bool is_output_padding_big() const; + bool is_padding_neg() const; + bool is_stride_nonpos() const; + void view1d_as_2d(); + bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const; + bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const; + bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const; + bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const; + bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const; + bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const; + bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const; + bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt) const; + bool use_mps(const at::Tensor& input, const at::Tensor& weight) const; + bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; +}; + +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +ConvBackend _select_conv_backend( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias_opt, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + DEFINE_DISPATCH(conv_depthwise2d_backward_stub); DEFINE_DISPATCH(conv_depthwise3d_backward_stub); DEFINE_DISPATCH(cudnn_convolution_backward_stub); From 431642111f74a22ebb5edc98e32b1449b4b3e46b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Nov 2022 06:41:53 -0800 Subject: [PATCH 204/453] Move ConvParams methods directly on struct (#89062) This reduces boilerplate. Also, I plan to add a template parameter to ConvParams; without moving the methods onto the struct, I would have to manually template every method. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89062 Approved by: https://github.com/SherlockNoMad --- aten/src/ATen/native/Convolution.cpp | 734 +++++++++++++-------------- 1 file changed, 352 insertions(+), 382 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index e87d98357ca9..29b2ce804c80 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -82,359 +82,6 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -struct ConvParams { - std::vector stride; - std::vector padding; - std::vector dilation; - bool transposed; - std::vector output_padding; - int groups; - bool benchmark; - bool deterministic; - bool cudnn_enabled; - bool allow_tf32; - - bool is_strided() const; - bool is_dilated() const; - bool is_padded() const; - bool is_output_padding_neg() const; - bool is_output_padding_big() const; - bool is_padding_neg() const; - bool is_stride_nonpos() const; - void view1d_as_2d(); - bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const; - bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const; - bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const; - bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const; - bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const; - bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const; - bool use_mps(const at::Tensor& input, const at::Tensor& weight) const; - bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; -}; - -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -ConvBackend _select_conv_backend( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias_opt, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -// For BC reasons, have a copy that does not require bias_opt -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -DEFINE_DISPATCH(conv_depthwise2d_backward_stub); -DEFINE_DISPATCH(conv_depthwise3d_backward_stub); -DEFINE_DISPATCH(cudnn_convolution_backward_stub); -DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub); -DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub); -DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); -DEFINE_DISPATCH(miopen_convolution_backward_stub); -DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); -DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); -DEFINE_DISPATCH(mkldnn_convolution_backward_stub); -DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub); -DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub); -DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub); -REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub); -REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub); -REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub); -REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); -REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); - -std::ostream& operator<<(std::ostream & out, const ConvParams& params) { - out << "ConvParams {" - << " stride = " << IntArrayRef{params.stride} - << " padding = " << IntArrayRef{params.padding} - << " dilation = " << IntArrayRef{params.dilation} - << " transposed = " << params.transposed - << " output_padding = " << IntArrayRef{params.output_padding} - << " groups = " << params.groups - << " benchmark = " << params.benchmark - << " deterministic = " << params.deterministic - << " cudnn_enabled = " << params.cudnn_enabled - << " allow_tf32 = " << params.allow_tf32 - << "}"; - return out; -} - -auto ConvParams::is_strided() const -> bool { - bool is_strided = false; - for (auto s : stride) { - is_strided |= (s != 1); - } - return is_strided; -} - -auto ConvParams::is_dilated() const -> bool { - bool is_dilated = false; - for (auto d : dilation) { - is_dilated |= (d != 1); - } - return is_dilated; -} - -auto ConvParams::is_padded() const -> bool { - bool is_padded = false; - for (auto p : padding) { - is_padded |= (p != 0); - } - return is_padded; -} - -auto ConvParams::is_output_padding_neg() const -> bool { - bool is_non_neg = false; - for (auto p : output_padding) { - is_non_neg |= (p < 0); - } - return is_non_neg; -} - -auto ConvParams::is_output_padding_big() const -> bool { - bool is_big = false; - for (auto i: c10::irange(output_padding.size())) { - is_big |= (output_padding[i] >= stride[i]); - } - return is_big; -} - -auto ConvParams::is_padding_neg() const -> bool { - bool is_non_neg = false; - for (auto p : padding) { - is_non_neg |= (p < 0); - } - return is_non_neg; -} - -auto ConvParams::is_stride_nonpos() const -> bool { - bool is_nonpos = false; - for (auto s : stride) { - is_nonpos |= (s <= 0); - } - return is_nonpos; -} - -auto ConvParams::view1d_as_2d() -> void { - if (stride.size() == 1) { - stride.insert(stride.begin(), 1); - padding.insert(padding.begin(), 0); - dilation.insert(dilation.begin(), 1); - output_padding.insert(output_padding.begin(), 0); - } -} - -auto ConvParams::use_cpu_depthwise3x3_winograd( - const at::Tensor& input, - const at::Tensor& weight, - const c10::optional& bias) const -> bool { -#if defined(__ARM_NEON__) - // Currently only 3x3 depthwise convolutions on tensors of float are supported. - return (input.ndimension() == 4) && - (input.size(1) == groups) && - (weight.ndimension() == 4 ) && - (weight.size(0) % input.size(1) == 0) && - (weight.size(1) == 1) && - (weight.size(2) == 3) && - (weight.size(3) == 3) && - (input.device().is_cpu()) && - (input.scalar_type() == at::kFloat) && - input.is_contiguous() && - (weight.device().is_cpu()) && - (weight.scalar_type() == at::kFloat) && - weight.is_contiguous() && - (!bias.has_value() || bias->is_contiguous()) && - !is_strided() && - !is_dilated() && - !transposed; -#else - return false; -#endif -} - -auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const -> bool { - constexpr int64_t int_max = std::numeric_limits::max(); - int64_t numel_input = input.numel(); - // empty input - if (numel_input == 0) { - return false; - } - // input size can not be reduced to the range of int by splitting the batch dim - int64_t n = input.size(0); - if (numel_input / n > int_max) { - return true; - } - // output size can not be reduced to the range of int by splitting the batch dim - int64_t outsize = 1; - if (transposed) { - std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); - outsize = c10::multiply_integers(o.begin() + 1, o.end()); - } else { - std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); - outsize = c10::multiply_integers(o.begin() + 1, o.end()); - } - return outsize > int_max; -} - -auto ConvParams::use_cudnn(const at::Tensor& input, const at::Tensor& weight) const -> bool { - -// Note [Mobile check segfaults] -// cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest -// that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) -#if !defined(C10_MOBILE) - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - if (!detail::getCUDAHooks().compiledWithCuDNN()) { - return false; - } - if (!input.is_cuda() || !cudnn_enabled) { - return false; - } - if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { - if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { - return false; - } - } - if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { - // bypass dilation checks for channels_last convolution - if (deterministic && is_dilated()) { - // cudnn doesn't support deterministic dilated convolution fully yet - return false; - } - if (is_dilated()) { - return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); - } - } - return !is_output_padding_big(); -#else - return false; -#endif -} - -auto ConvParams::use_mps( const at::Tensor& input, const at::Tensor& weight) const -> bool { - // These checks need to be expanded. Currently we have very limited set of - // checks for MPS. -#ifdef USE_MPS - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - if (!input.is_mps()) { - return false; - } - return true; -#else - return false; -#endif -} - -auto ConvParams::use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const -> bool { - if (needs_64bit_indexing_no_split(input, weight)) { - return false; - } - return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16)) - && detail::getCUDAHooks().compiledWithMIOpen() - && input.is_cuda() - && input.dim() <= MIOPEN_DIM_MAX - && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 - && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16 - && cudnn_enabled - ; -} - -auto ConvParams::use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const -> bool { -#if AT_MKLDNN_ENABLED() - if (!at::globalContext().userEnabledMkldnn()) { - return false; - } - if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { - return true; - } - return (input.is_mkldnn()) || // input is mkldnn Tensor - (input.device().is_cpu() && - input.scalar_type() == kFloat && // only on CPU Float Tensors - !transposed && // or transposed tensors - // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, - // but THNN is faster when single-threaded. - (is_strided() || is_dilated() || input.size(0) >= 16 || - weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && - (groups > 1 - || (weight.size(-1) > 3 && weight.size(-2) > 3) - || input.size(0) > 1 - || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster - ); - -#endif - return false; -} - -auto ConvParams::use_nnpack(const at::Tensor& input, const at::Tensor& weight) const -> bool { -#if AT_NNPACK_ENABLED() - return at::_nnpack_available() && - input.device().is_cpu() && - input.scalar_type() == kFloat && // only on CPU Float Tensors - !is_dilated() && // or dilation - !transposed && // or transposed tensors - input.ndimension() == 4 && // must be in NCHW format - weight.ndimension() == 4 && - (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 -#if !defined(C10_MOBILE) - && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable -#endif - ; -#endif - return false; -} - -auto ConvParams::use_xnnpack( - const at::Tensor& input, - const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const -> bool { -#if defined(C10_MOBILE) - if (!transposed) { - return (input.size(1) == groups) && - xnnpack::use_convolution2d( - input, - weight, - bias_sizes_opt, - padding, - stride, - dilation, - groups, - transposed); - } -#endif - return false; -} - -// We currently only have depthwise support for the case where groups == -// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of -// a depthwise multiplier) -auto ConvParams::is_depthwise( - const at::Tensor& input, const at::Tensor& weight) const -> bool { - return input.is_cuda() && - !transposed && - (input.ndimension() == 4 || input.ndimension() == 5) && - input.size(1) == groups && - groups > 1 && // no point if there is only a single group - weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels -} - // Check workload to activate fast depthwise FP16 cudnn conv kernels bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { int w = input.size(3); // same as h @@ -592,49 +239,372 @@ bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int str return false; } -// Use cudnn for FP16 depthwise convolutions -auto ConvParams::use_cudnn_depthwise( - const at::Tensor& input, const at::Tensor& weight) const -> bool { - if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { - // always use cudnn_depthwise for channels_last format - return true; + +// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +struct ConvParams { + std::vector stride; + std::vector padding; + std::vector dilation; + bool transposed; + std::vector output_padding; + int groups; + bool benchmark; + bool deterministic; + bool cudnn_enabled; + bool allow_tf32; + + bool is_strided() const { + bool is_strided = false; + for (auto s : stride) { + is_strided |= (s != 1); + } + return is_strided; + } + + bool is_dilated() const { + bool is_dilated = false; + for (auto d : dilation) { + is_dilated |= (d != 1); + } + return is_dilated; + } + + bool is_padded() const { + bool is_padded = false; + for (auto p : padding) { + is_padded |= (p != 0); + } + return is_padded; + } + + bool is_output_padding_neg() const { + bool is_non_neg = false; + for (auto p : output_padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_output_padding_big() const { + bool is_big = false; + for (auto i: c10::irange(output_padding.size())) { + is_big |= (output_padding[i] >= stride[i]); + } + return is_big; + } + + bool is_padding_neg() const { + bool is_non_neg = false; + for (auto p : padding) { + is_non_neg |= (p < 0); + } + return is_non_neg; + } + + bool is_stride_nonpos() const { + bool is_nonpos = false; + for (auto s : stride) { + is_nonpos |= (s <= 0); + } + return is_nonpos; + } + + void view1d_as_2d() { + if (stride.size() == 1) { + stride.insert(stride.begin(), 1); + padding.insert(padding.begin(), 0); + dilation.insert(dilation.begin(), 1); + output_padding.insert(output_padding.begin(), 0); + } + } + + bool use_cpu_depthwise3x3_winograd(const at::Tensor& input, const at::Tensor& weight, const c10::optional& bias) const { +#if defined(__ARM_NEON__) + // Currently only 3x3 depthwise convolutions on tensors of float are supported. + return (input.ndimension() == 4) && + (input.size(1) == groups) && + (weight.ndimension() == 4 ) && + (weight.size(0) % input.size(1) == 0) && + (weight.size(1) == 1) && + (weight.size(2) == 3) && + (weight.size(3) == 3) && + (input.device().is_cpu()) && + (input.scalar_type() == at::kFloat) && + input.is_contiguous() && + (weight.device().is_cpu()) && + (weight.scalar_type() == at::kFloat) && + weight.is_contiguous() && + (!bias.has_value() || bias->is_contiguous()) && + !is_strided() && + !is_dilated() && + !transposed; +#else + return false; +#endif + } + + bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { + constexpr int64_t int_max = std::numeric_limits::max(); + int64_t numel_input = input.numel(); + // empty input + if (numel_input == 0) { + return false; + } + // input size can not be reduced to the range of int by splitting the batch dim + int64_t n = input.size(0); + if (numel_input / n > int_max) { + return true; + } + // output size can not be reduced to the range of int by splitting the batch dim + int64_t outsize = 1; + if (transposed) { + std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); + outsize = c10::multiply_integers(o.begin() + 1, o.end()); + } else { + std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); + outsize = c10::multiply_integers(o.begin() + 1, o.end()); + } + return outsize > int_max; } - if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { - long cudnn_version = detail::getCUDAHooks().versionCuDNN(); - if (cudnn_version >= 8200) { - bool kernel_cond = (use_cudnn(input, weight) && + + bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const { + // Note [Mobile check segfaults] + // cudnn and miopen are guaranteed not to be on mobile, and T102591915 / T110194934 suggest + // that maybe the compiledWithCuDNN() check sometimes segfaults (though I can't imagine how) +#if !defined(C10_MOBILE) + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + if (!detail::getCUDAHooks().compiledWithCuDNN()) { + return false; + } + if (!input.is_cuda() || !cudnn_enabled) { + return false; + } + if (input.scalar_type() == at::kBFloat16 || weight.scalar_type() == at::kBFloat16) { + if (!(detail::getCUDAHooks().supportsBFloat16ConvolutionWithCuDNNv8() && at::native::cudnnv8_enabled_check_debug())) { + return false; + } + } + if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous) { + // bypass dilation checks for channels_last convolution + if (deterministic && is_dilated()) { + // cudnn doesn't support deterministic dilated convolution fully yet + return false; + } + if (is_dilated()) { + return detail::getCUDAHooks().supportsDilatedConvolutionWithCuDNN() && !is_output_padding_big(); + } + } + return !is_output_padding_big(); +#else + return false; +#endif + } + + // Use cudnn for FP16 depthwise convolutions + bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const { + if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous && use_cudnn(input, weight)) { + // always use cudnn_depthwise for channels_last format + return true; + } + if (detail::getCUDAHooks().supportsDepthwiseConvolutionWithCuDNN()) { + long cudnn_version = detail::getCUDAHooks().versionCuDNN(); + if (cudnn_version >= 8200) { + bool kernel_cond = (use_cudnn(input, weight) && + input.scalar_type() == kHalf && // only for FP16 + weight.scalar_type() == kHalf && + is_depthwise(input, weight) && + input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks + !is_dilated() && // no dilation supported + (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d + input.size(1) >= 32); // min 32 channels supported) + if (kernel_cond) { + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + } + } + // keep (7600 <= cudnn < 8200) code unchanged + bool kernel_cond = (cudnn_version >= 7600 && + use_cudnn(input, weight) && input.scalar_type() == kHalf && // only for FP16 weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks + weight.size(2) == weight.size(3) && // only square kernels + input.size(2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported - (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d + stride[0] == stride[1] && // equal strides + ((weight.size(3) == 3) || (weight.size(3) == 1)) && input.size(1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + return check_cudnn_depthwise_workload(input, stride[0]); + } else { + return false; } - } - // keep (7600 <= cudnn < 8200) code unchanged - bool kernel_cond = (cudnn_version >= 7600 && - use_cudnn(input, weight) && - input.scalar_type() == kHalf && // only for FP16 - weight.scalar_type() == kHalf && - is_depthwise(input, weight) && - input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - weight.size(2) == weight.size(3) && // only square kernels - input.size(2) >= 7 && // min width/height 7 - !is_dilated() && // no dilation supported - stride[0] == stride[1] && // equal strides - ((weight.size(3) == 3) || (weight.size(3) == 1)) && - input.size(1) >= 32); // min 32 channels supported) - if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); } else { return false; } - } else { + } + + bool use_miopen(const at::Tensor& input, const at::Tensor& weight, bool bias_defined) const { + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16)) + && detail::getCUDAHooks().compiledWithMIOpen() + && input.is_cuda() + && input.dim() <= MIOPEN_DIM_MAX + && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 + && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16 + && cudnn_enabled + ; + } + bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const { +#if AT_MKLDNN_ENABLED() + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (input.device().is_cpu() && input.scalar_type() == kBFloat16 && mkldnn_bf16_device_check()) { + return true; + } + return (input.is_mkldnn()) || // input is mkldnn Tensor + (input.device().is_cpu() && + input.scalar_type() == kFloat && // only on CPU Float Tensors + !transposed && // or transposed tensors + // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, + // but THNN is faster when single-threaded. + (is_strided() || is_dilated() || input.size(0) >= 16 || + weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && + (groups > 1 + || (weight.size(-1) > 3 && weight.size(-2) > 3) + || input.size(0) > 1 + || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster + ); + +#endif + return false; + } + bool use_nnpack(const at::Tensor& input, const at::Tensor& weight) const { +#if AT_NNPACK_ENABLED() + return at::_nnpack_available() && + input.device().is_cpu() && + input.scalar_type() == kFloat && // only on CPU Float Tensors + !is_dilated() && // or dilation + !transposed && // or transposed tensors + input.ndimension() == 4 && // must be in NCHW format + weight.ndimension() == 4 && + (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 +#if !defined(C10_MOBILE) + && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable +#endif + ; +#endif + return false; + } + bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt) const { +#if defined(C10_MOBILE) + if (!transposed) { + return (input.size(1) == groups) && + xnnpack::use_convolution2d( + input, + weight, + bias_sizes_opt, + padding, + stride, + dilation, + groups, + transposed); + } +#endif return false; } + + bool use_mps(const at::Tensor& input, const at::Tensor& weight) const { + // These checks need to be expanded. Currently we have very limited set of + // checks for MPS. +#ifdef USE_MPS + if (needs_64bit_indexing_no_split(input, weight)) { + return false; + } + if (!input.is_mps()) { + return false; + } + return true; +#else + return false; +#endif + } + + // We currently only have depthwise support for the case where groups == + // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of + // a depthwise multiplier) + bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { + return input.is_cuda() && + !transposed && + (input.ndimension() == 4 || input.ndimension() == 5) && + input.size(1) == groups && + groups > 1 && // no point if there is only a single group + weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels + } +}; + +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +ConvBackend _select_conv_backend( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias_opt, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + +DEFINE_DISPATCH(conv_depthwise2d_backward_stub); +DEFINE_DISPATCH(conv_depthwise3d_backward_stub); +DEFINE_DISPATCH(cudnn_convolution_backward_stub); +DEFINE_DISPATCH(cudnn_convolution_transpose_backward_stub); +DEFINE_DISPATCH(slow_conv_transpose3d_backward_stub); +DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); +DEFINE_DISPATCH(miopen_convolution_backward_stub); +DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); +DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); +DEFINE_DISPATCH(mkldnn_convolution_backward_stub); +DEFINE_DISPATCH(slow_conv_dilated2d_backward_stub); +DEFINE_DISPATCH(slow_conv_dilated3d_backward_stub); +DEFINE_DISPATCH(slow_conv_transpose2d_backward_stub); +REGISTER_NO_CPU_DISPATCH(conv_depthwise2d_backward_stub); +REGISTER_NO_CPU_DISPATCH(conv_depthwise3d_backward_stub); +REGISTER_NO_CPU_DISPATCH(cudnn_convolution_backward_stub); +REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); +REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); + +std::ostream& operator<<(std::ostream & out, const ConvParams& params) { + out << "ConvParams {" + << " stride = " << IntArrayRef{params.stride} + << " padding = " << IntArrayRef{params.padding} + << " dilation = " << IntArrayRef{params.dilation} + << " transposed = " << params.transposed + << " output_padding = " << IntArrayRef{params.output_padding} + << " groups = " << params.groups + << " benchmark = " << params.benchmark + << " deterministic = " << params.deterministic + << " cudnn_enabled = " << params.cudnn_enabled + << " allow_tf32 = " << params.allow_tf32 + << "}"; + return out; } static void check_shape_forward(const at::Tensor& input, From d96dd8ff09a9e35f8cce6745c3e015eb0082eb1b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Nov 2022 08:05:31 -0800 Subject: [PATCH 205/453] Add int64_t, SymInt overloads for all binary operators in C++ (#89063) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89063 Approved by: https://github.com/SherlockNoMad --- c10/core/SymInt.h | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 0c7c69fe9553..9ab72a077680 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -235,6 +235,40 @@ inline c10::SymInt multiply_integers(const C& container) { [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } +inline SymInt operator+(int64_t a, const SymInt& b) { + return c10::SymInt(a) + b; +} +inline SymInt operator-(int64_t a, const SymInt& b) { + return c10::SymInt(a) - b; +} +inline SymInt operator*(int64_t a, const SymInt& b) { + return c10::SymInt(a) * b; +} +inline SymInt operator/(int64_t a, const SymInt& b) { + return c10::SymInt(a) / b; +} +inline SymInt operator%(int64_t a, const SymInt& b) { + return c10::SymInt(a) % b; +} +inline bool operator==(int64_t a, const SymInt& b) { + return c10::SymInt(a) == b; +} +inline bool operator!=(int64_t a, const SymInt& b) { + return c10::SymInt(a) != b; +} +inline bool operator<(int64_t a, const SymInt& b) { + return c10::SymInt(a) < b; +} +inline bool operator<=(int64_t a, const SymInt& b) { + return c10::SymInt(a) <= b; +} +inline bool operator>(int64_t a, const SymInt& b) { + return c10::SymInt(a) > b; +} +inline bool operator>=(int64_t a, const SymInt& b) { + return c10::SymInt(a) >= b; +} + C10_API std::ostream& operator<<(std::ostream& os, const SymInt& s); C10_API SymInt operator-(const SymInt& s); } // namespace c10 From 9f0b2c73f36b0f5276f84cdaaef4d54a60df61f5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 01:13:00 +0000 Subject: [PATCH 206/453] Revert "[Inductor] Build FX Linear + Permute Vertical Fusion in Inductor (#88859)" This reverts commit d60abe4b9521e235c0e9beb00cda0d6c5673f4e0. Reverted https://github.com/pytorch/pytorch/pull/88859 on behalf of https://github.com/kit1980 due to Broke Mac OS testing, which were clearly shown in CI --- test/inductor/test_torchinductor.py | 106 --------------- torch/_inductor/config.py | 4 - torch/_inductor/overrides.py | 199 ---------------------------- 3 files changed, 309 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f43a333d1f09..dcb01b9ec78c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10,7 +10,6 @@ import typing import unittest import weakref -from typing import Any, Callable from unittest.mock import patch import torch @@ -19,7 +18,6 @@ from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, @@ -41,14 +39,6 @@ from torch._inductor import codecache, config, metrics from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing - from torch._inductor.overrides import ( - linear_permute_fusion, - linear_transpose, - permute_linear_fusion, - permute_matmul_fusion, - transpose_linear, - transpose_matmul, - ) from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.utils import has_torchvision_roi_align, timed @@ -123,29 +113,6 @@ def maybe_test(*args, **kwargs): return wrap_test -PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] - - -def chain_passes(*passes: PassFunc) -> PassFunc: - def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: - for pass_ in passes: - if isinstance(module, torch.fx.GraphModule): - ShapeProp(module).propagate(*input) - module = pass_(module) - return module - - return parent_pass - - -def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: - return sum( - [ - 1 if (n.op == "call_function" and n.target == target_op) else 0 - for n in module.graph.nodes - ] - ) - - class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -1619,79 +1586,6 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) - def test_linear_permute_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - a0 = torch.nn.functional.linear(input, self.weight, self.bias) - b0 = a0.permute(0, 2, 1) - return b0 - - m, k, n = 16, 8, 4 - trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, m, k) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_linear_transpose = count_call_function(traced, linear_transpose) - self.assertEqual(num_linear, 0) - self.assertEqual(num_linear_transpose, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - def test_permute_linear_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.nn.functional.linear(input1, self.weight, self.bias) - return output - - m, k, n = 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, k, m) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_transpose_linear = count_call_function(traced, transpose_linear) - self.assertEqual(num_linear, 0) - self.assertEqual(num_transpose_linear, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - def test_permute_bmm_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, batch: int, k: int, n: int): - super().__init__() - self.other = torch.randn(batch, k, n) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.bmm(input1, self.other) - return output - - batch, m, k, n = 6, 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) - module = TestModule(batch, k, n).eval() - input = torch.randn(batch, k, m) - traced = trace_func(module, [input]) - num_bmm = count_call_function(traced, torch.bmm) - num_transpose_matmul = count_call_function(traced, transpose_matmul) - self.assertEqual(num_bmm, 0) - self.assertEqual(num_transpose_matmul, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - def test_slice1(self): def fn(a): return ( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c552101c1cae..d376fe3e8bf7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -75,10 +75,6 @@ shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1" alignment_size = 4 -# Fx-based linear/matmul/bmm + permute/transpose vertical fusion -permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" - - # config specific to codegen/cpp.pp class cpp: # set to torch.get_num_threads() diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index cf2cd5f60f51..3a95aa7ce880 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -19,8 +19,6 @@ from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.overrides import TorchFunctionMode -from . import config - log = logging.getLogger(__name__) @@ -427,14 +425,6 @@ def check_node_is_add_inplace(node): def fuse_fx(gm: torch.fx.GraphModule, example_inputs): - if config.permute_fusion: - # For linear permute fusion, we need to check input info to identify - # and perform proper permutation/transpose - ShapeProp(gm).propagate(*example_inputs) - gm = linear_permute_fusion(gm) - gm = permute_linear_fusion(gm) - gm = permute_matmul_fusion(gm) - # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm @@ -538,195 +528,6 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) -class NormalizedLinearNode: - def __init__(self, node: torch.fx.Node) -> None: - assert node.op == "call_function" - assert node.target in [torch.nn.functional.linear] - self.node: torch.fx.Node = node - - def get_input(self) -> torch.fx.Node: - if len(self.node.args) > 0: - return self.node.args[0] - else: - return self.node.kwargs["input"] - - def get_weight(self) -> torch.fx.Node: - if len(self.node.args) > 1: - return self.node.args[1] - else: - return self.node.kwargs["weight"] - - def get_bias(self) -> torch.fx.Node: - if len(self.node.args) > 2: - return self.node.args[2] - else: - return self.node.kwargs["bias"] - - -class NormalizedMatmulNode: - def __init__(self, node: torch.fx.Node) -> None: - assert node.op == "call_function" - assert node.target in [torch.bmm, torch.matmul] - self.node: torch.fx.Node = node - - def get_input(self) -> torch.fx.Node: - if len(self.node.args) > 0: - return self.node.args[0] - else: - return self.node.kwargs["input"] - - def get_other(self) -> torch.fx.Node: - if len(self.node.args) > 1: - return self.node.args[1] - else: - return self.node.kwargs["other"] - - -def check_permute(node: torch.fx.Node): - ranks = len(node.meta["tensor_meta"].shape) - if len(node.args) > 3: - permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] - elif ( - "permutation" in node.kwargs - and node.kwargs["permutation"] is not None - and len(node.kwargs["permutation"]) > 2 - ): - permutation = [i % ranks for i in node.kwargs["permutation"]] - else: - return False - allowed_permutation = list(range(ranks)) - allowed_permutation[-1] = ranks - 2 - allowed_permutation[-2] = ranks - 1 - return permutation == allowed_permutation - - -def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if ( - node.op == "call_method" - and node.target == "permute" - and check_permute(node) - ): - if len(node.args) > 0: - input_node = node.args[0] - else: - input_node = node.kwargs["input"] - if ( - input_node.op == "call_function" - and input_node.target == torch.nn.functional.linear - ): - normalized = NormalizedLinearNode(input_node) - input = normalized.get_input() - weight = normalized.get_weight() - bias = normalized.get_bias() - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - linear_transpose, args=(input, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -# Y1 = X * W^T + bias -# Y2 = Y1.permute(0, 2, 1) -# ----> -# Y2 = (W * X^T + bias.unsqueeze(-1))^T -def linear_transpose( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) - - -def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if node.op == "call_function" and node.target == torch.nn.functional.linear: - if len(node.args) > 0: - input_node = node.args[0] - else: - input_node = node.kwargs["input"] - if ( - input_node.op == "call_method" - and input_node.target == "permute" - and check_permute(input_node) - ): - normalized = NormalizedLinearNode(node) - if len(input_node.args) > 0: - input = input_node.args[0] - else: - input = input_node.kwargs["input"] - weight = normalized.get_weight() - bias = normalized.get_bias() - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - transpose_linear, args=(input, weight, bias) - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: - for node in module.graph.nodes: - if node.op == "call_function" and ( - node.target == torch.bmm or node.target == torch.matmul - ): - normalized = NormalizedMatmulNode(node) - A = normalized.get_input() - B = normalized.get_other() - Atrans = Btrans = False - if A.op == "call_method" and A.target == "permute" and check_permute(A): - Atrans = True - if len(A.args) > 0: - A = A.args[0] - else: - A = A.kwargs["input"] - - if B.op == "call_method" and B.target == "permute" and check_permute(B): - Btrans = True - if len(B.args) > 0: - B = B.args[0] - else: - B = B.kwargs["input"] - - if Atrans or Btrans: - with module.graph.inserting_before(node): - fused_node = module.graph.call_function( - transpose_matmul, - args=(A, B, Atrans, Btrans), - ) - node.replace_all_uses_with(fused_node) - - module.graph.lint() - module.graph.eliminate_dead_code() - module.recompile() - return module - - -# X1 = X.permute(0, 2, 1) -# Y1 = X1 * W1^T + bias1 -# ----> -# Y2 = X1.transpose(-1, -2) * W1^T + bias1 -def transpose_linear( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor -) -> torch.Tensor: - return torch.matmul(input.transpose(-1, -2), weight.t()) + bias - - -def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): - if Atrans: - A = A.transpose(-1, -2) - if Btrans: - B = B.transpose(-1, -2) - return torch.matmul(A, B) - - def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): From 46ba0150cbfb8d86c378f0f3ce2d816e530a933b Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 16 Nov 2022 02:39:22 +0000 Subject: [PATCH 207/453] Increase slow grad check timeout (#89079) Now that periodic jobs are run under `mem_leak_check` mode with parallelization turning off. It's very easy for `linux-bionic-cuda11.6-py3-gcc7-slow-gradcheck / test` to timeout because one of the shards is very close to the 4h mark. * https://hud.pytorch.org/pytorch/pytorch/commit/2452e3f99a072760fc46d3f9025aaa37ca7ea2ab * https://hud.pytorch.org/pytorch/pytorch/commit/35e668b5ced25e735b6e523d557ed7fd60267914 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89079 Approved by: https://github.com/clee2000 --- .github/workflows/_linux-test.yml | 8 +++++++- .github/workflows/periodic.yml | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 6ad30080fd64..16f25fed9121 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -22,6 +22,12 @@ on: description: | If this is set, our linter will use this to make sure that every other job with the same `sync-tag` is identical. + timeout-minutes: + required: false + type: number + default: 240 + description: | + Set the maximum (in minutes) how long the workflow should take to finish env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} @@ -56,6 +62,7 @@ jobs: matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} fail-fast: false runs-on: ${{ matrix.runner }} + timeout-minutes: ${{ inputs.timeout-minutes }} steps: - name: Setup SSH (Click me for login details) uses: pytorch/test-infra/.github/actions/setup-ssh@main @@ -117,7 +124,6 @@ jobs: XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} - timeout-minutes: 240 run: | set -x diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 6e1722b4b6c0..61302e1a0d61 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -33,6 +33,7 @@ jobs: build-environment: linux-bionic-cuda11.6-py3-gcc7-slow-gradcheck docker-image: ${{ needs.linux-bionic-cuda11_6-py3-gcc7-slow-gradcheck-build.outputs.docker-image }} test-matrix: ${{ needs.linux-bionic-cuda11_6-py3-gcc7-slow-gradcheck-build.outputs.test-matrix }} + timeout-minutes: 300 linux-focal-rocm5_2-py3_8-slow-build: name: linux-focal-rocm5.2-py3.8-slow From 397f10067200d9b77acb92952b4ea3741738c28b Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Tue, 15 Nov 2022 19:19:47 +0000 Subject: [PATCH 208/453] [FSDP] Test `named_parameters()` in forward (`use_orig_params=True`) (#89066) This adds a unit test following the FSDP change in https://github.com/pytorch/pytorch/pull/88781. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89066 Approved by: https://github.com/fegin --- .../fsdp/test_fsdp_use_orig_params.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 0f5ffa564c2d..e61f2e4d96de 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -1006,6 +1006,47 @@ def forward(self, x): fsdp_buffer_names = [n for n, _ in fsdp_model.named_buffers()] self.assertEqual(buffer_names, fsdp_buffer_names) + @skip_if_lt_x_gpu(2) + def test_named_parameters_in_forward(self): + """ + Tests that calling ``named_parameters()`` during forward returns FQNs + and ``Tensor`` s corresponding to the original parameters. + """ + param_shapes = [None, None] + assert_equal_fn = self.assertEqual + + class Model(nn.Module): + def __init__(self) -> None: + super().__init__() + self.lin = nn.Linear(5, 5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + nonlocal param_shapes + param_names = [tup[0] for tup in self.named_parameters()] + params = [tup[1] for tup in self.named_parameters()] + assert ( + param_shapes[0] is not None and param_shapes[1] is not None + ), "`param_sizes` should be set" + assert_equal_fn( + param_names, + [ + "lin.weight", + "lin.bias", + ], + ) + assert_equal_fn(params[0].shape, param_shapes[0]) + assert_equal_fn(params[1].shape, param_shapes[1]) + return self.lin(x) + + model = Model().cuda() + # Save the *unsharded* original parameter shapes and check the shapes + # match in the forward pass + param_shapes[0] = model.lin.weight.shape + param_shapes[1] = model.lin.bias.shape + fsdp_model = FSDP(model, use_orig_params=True) + inp = torch.randn((2, 5), device=torch.device("cuda")) + fsdp_model(inp) + instantiate_parametrized_tests(TestFSDPUseOrigParamsMultipleParamGroups) instantiate_parametrized_tests(TestFSDPUseOrigParamsUnshardReshard) From b291c1213ae18e89a5c616913f14b4bb8eda12a8 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Wed, 16 Nov 2022 03:07:54 +0000 Subject: [PATCH 209/453] Create native function for determining which implementation of SDP to call (#89029) # Summary Creates a callable native function that can determine which implementation of scaled dot product will get called. This allows to bump re-order the runtime dispatch of SDP to enable autograd. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89029 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 5 + .../ATen/native/transformers/attention.cpp | 6 + .../native/transformers/cuda/attention.cu | 13 ++ .../ATen/native/transformers/cuda/sdp_utils.h | 3 +- .../ATen/native/transformers/sdp_utils_cpp.h | 9 ++ docs/source/backends.rst | 2 + test/test_transformers.py | 114 +++++++++++++----- torch/backends/cuda/__init__.py | 17 ++- torchgen/native_function_generation.py | 1 + 9 files changed, 137 insertions(+), 33 deletions(-) create mode 100644 aten/src/ATen/native/transformers/sdp_utils_cpp.h diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9572ccc56653..726a54b5e225 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13239,6 +13239,11 @@ variants: function autogen: _scaled_dot_product_attention.out +- func: _fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> int + dispatch: + CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp + CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda + # Register the math kernel for cpu - func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 55c71f9fd064..89a0e4691018 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS @@ -685,6 +686,11 @@ std::tuple _scaled_dot_product_attention( return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); } +int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + return static_cast(sdp::SDPBackend::math); +} + std::tuple _scaled_dot_product_attention_forward_math( const Tensor& query_, const Tensor& key, diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 46543d4663fa..602cf319f74a 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -789,6 +789,19 @@ std::tuple _scaled_dot_product_attention_forward_cuda( } } +int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; + auto backend = select_sdp_backend(kernel_params); + if (backend == sdp::SDPBackend::error) { + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found. ", + "This is likely due to turning off both the math kernel and the fused kernels."); + } + return static_cast(backend); +} + Tensor flash_scaled_dot_product_attention( const Tensor& query, const Tensor& key, diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index e9f3d5029aa8..5d62a6cbd0dc 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -27,8 +28,6 @@ struct sdp_params { bool is_causal; }; -enum class SDPBackend { flash_attention, efficient_attention, math, error }; - template inline bool check_tensor_dtype( sdp_params params, diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h new file mode 100644 index 000000000000..9641a36b33b2 --- /dev/null +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -0,0 +1,9 @@ +#pragma once +namespace sdp { +enum class SDPBackend { + error = -1, + math = 0, + flash_attention = 1, + efficient_attention = 2 +}; +} // namespace sdp \ No newline at end of file diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 80e18f7017a0..2a02b325341f 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -52,6 +52,8 @@ torch.backends.cuda .. autofunction:: torch.backends.cuda.preferred_linalg_library +.. autoclass:: torch.backends.cuda.SDPBackend + .. autofunction:: torch.backends.cuda.flash_sdp_enabled .. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp diff --git a/test/test_transformers.py b/test/test_transformers.py index 93a94a5604c9..abb4c71ec19a 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1,14 +1,16 @@ # Owner(s): ["module: nn"] import contextlib +from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import unittest from unittest.mock import patch import math -from torch.backends.cuda import sdp_kernel +from torch.backends.cuda import sdp_kernel, SDPBackend import torch.optim as optim +from torch.testing._internal.common_dtype import floating_types_and_half from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( @@ -936,18 +938,24 @@ def _test_fastpath(model, key_padding_mask, mock_return_value, attn_mask=None, n _test_fastpath(model, aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) _test_fastpath(model, not_aligned_key_padding_mask, nested_tensor_return_value, nested_tensors=True) + def rand_nt(self, shape, device, dtype, requires_grad=False, packed=False): + batch, seq_len, num_heads, head_dim = shape + size = (seq_len, num_heads, head_dim) if not packed else (seq_len, 3 * num_heads * head_dim) + return torch.nested.nested_tensor([ + torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) + for _ in range(batch)]) + + def rand_tensor(self, shape, device, dtype, requires_grad=False, packed=False): + batch, seq_len, num_heads, head_dim = shape + size = (batch, seq_len, num_heads, head_dim) if not packed else (batch, seq_len, 3 * num_heads * head_dim) + return torch.randn(size, device=device, dtype=dtype, requires_grad=requires_grad) + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels(self, type: str, is_contiguous: bool): - def rand_nt(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.nested.nested_tensor([torch.randn(seq_len, num_heads, head_dim, - device="cuda", dtype=torch.float16) for _ in range(batch)]) - - def rand_tensor(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.randn(batch, seq_len, num_heads, head_dim, device="cuda", dtype=torch.float16) + rand_nt = partial(self.rand_nt, device="cuda", dtype=torch.float16) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16) batch, seq_len, num_heads, head_dim = 32, 64, 16, 64 shape = (batch, seq_len, num_heads, head_dim) @@ -985,14 +993,8 @@ def rand_tensor(shape): @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_fused_kernels_packed(self, type: str, is_contiguous: bool): - def rand_nt(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.nested.nested_tensor([torch.randn(seq_len, 3 * num_heads * head_dim, - device="cuda", dtype=torch.float16) for _ in range(batch)]) - - def rand_tensor(shape): - batch, seq_len, num_heads, head_dim = shape - return torch.randn(batch, seq_len, 3 * num_heads * head_dim, device="cuda", dtype=torch.float16) + rand_nt = partial(self.rand_nt, device="cuda", dtype=torch.float16, packed=True) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 shape = (batch_size, seq_len, num_heads, head_dim) @@ -1098,8 +1100,10 @@ def rand_tensor(shape): def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - query, key, value = torch.rand((batch_size, seq_len, 3 * num_heads * head_dim), - device="cuda", dtype=torch.float32, requires_grad=True).chunk(3, -1) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True) + + qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) + query, key, value = qkv.chunk(3, dim=-1) query = query.view(batch_size, -1, num_heads, head_dim) key = key.view(batch_size, -1, num_heads, head_dim) value = value.view(batch_size, -1, num_heads, head_dim) @@ -1116,6 +1120,49 @@ def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) + @parametrize("type", ["dense", "nested"]) + def test_fused_sdp_choice(self, type: str): + device = "cpu" + # Test that cpu and nestedtensor cpu return MATH backend + for dtype in floating_types_and_half(): + make_tensor = partial(self.rand_tensor, device=device, dtype=dtype) + size = (2, 2, 3, 4) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + assert torch._fused_sdp_choice(q, k, v) == SDPBackend.MATH + + if TEST_CUDA and not TEST_WITH_ROCM and not IS_WINDOWS: + batch_size, seq_len, num_heads, head_dim = 32, 64, 16, 64 + shape = (batch_size, seq_len, num_heads, head_dim) + device = "cuda" + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float16, packed=True) + make_nt = partial(self.rand_nt, device=device, dtype=torch.float16, packed=True) + + qkv = make_tensor(shape) if type == "dense" else make_nt(shape) + query, key, value = qkv.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if SM80OrLater: + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION + else: + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION + + # Change dtype to float32 so that efficient attention should get chosen + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float32, packed=True) + make_nt = partial(self.rand_nt, device=device, dtype=torch.float32, packed=True) + + qkv = make_tensor(shape) if type == "dense" else make_nt(shape) + query, key, value = qkv.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION + + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_sdp_runtime_dispatch(self): # We will test all the constraints that we know will cause a failure @@ -1123,12 +1170,15 @@ def test_sdp_runtime_dispatch(self): # will fail on CI/CD becuase it is not compiled with the right flags device = 'cuda' dtype = torch.float16 - - def make_tensor(*size, device=device, dtype=dtype): - return torch.randn(size, device=device, dtype=dtype) + make_tensor = partial(self.rand_tensor, device=device, dtype=dtype) with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False): - q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4) + size = (2, 3, 4) + q = torch.randn(size, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) + self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", + lambda: torch._fused_sdp_choice(q, k, v)) self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.", lambda: torch.nn.functional._scaled_dot_product_attention(q, k, v)) @@ -1136,29 +1186,33 @@ def make_tensor(*size, device=device, dtype=dtype): # Failures for invalid input # Dim is not 4 - q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4) + q = torch.randn(size, device=device, dtype=dtype) + k = torch.randn(size, device=device, dtype=dtype) + v = torch.randn(size, device=device, dtype=dtype) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) # Xformers can now cover this case but will add back in next PR # Invalid last_dim size - q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4) + size = (2, 2, 3, 4) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) # Invalid dtype - q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float64), make_tensor( - 2, 2, 3, 16, dtype=torch.float64), make_tensor(2, 2, 3, 16, dtype=torch.float64) + size = (2, 2, 3, 16) + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float64) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) - q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float32), make_tensor( - 2, 2, 3, 16, dtype=torch.float32), make_tensor(2, 2, 3, 16, dtype=torch.float32) + make_tensor = partial(self.rand_tensor, device=device, dtype=torch.float32) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( q, k, v, None, 0.0, False, False)) # Failures for unsupported SDP args - q, k, v = make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) # Needs attention weights self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index dd05535d3935..50735e125ec3 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -1,11 +1,12 @@ import sys import torch import contextlib +from enum import IntEnum from typing import Union __all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager", - "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "enable_flash_sdp", + "cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "SDPBackend", "enable_flash_sdp", "flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled", "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"] @@ -164,6 +165,20 @@ def preferred_linalg_library(backend: Union[None, str, torch._C._LinalgBackend] return torch._C._get_linalg_preferred_backend() +class SDPBackend(IntEnum): + r"""Enum class for the scaled dot product attention backends. + + .. warning:: This flag is experimental and subject to change.' + + This class needs to stay inline with the enum defined in: + pytorch/aten/src/ATen/native/transformers/sdp_utils_cpp.h + """ + ERROR = -1 + MATH = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + + def flash_sdp_enabled(): r""" .. warning:: This flag is experimental and subject to change. diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 89314c1dd18d..657a133c31c7 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -73,6 +73,7 @@ "record_stream", # no return "sparse_dim", # returns an int "_nested_tensor_offsets", # returns a vector of ints + "_fused_sdp_choice", # returns an int ] INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ From ce2f8700bafcf44850402a39188ec121ba8b5486 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Tue, 15 Nov 2022 21:02:44 +0000 Subject: [PATCH 210/453] Symintify numel(), infer_size, prims.elementwise_meta (#88956) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956 Approved by: https://github.com/ezyang --- aten/src/ATen/ExpandUtils.cpp | 10 ++++---- aten/src/ATen/ExpandUtils.h | 2 ++ test/test_proxy_tensor.py | 25 ++++++++++++++++--- torch/_prims/__init__.py | 16 +++++++++--- torch/_refs/__init__.py | 4 +-- torch/_subclasses/fake_tensor.py | 6 +---- torch/csrc/autograd/input_metadata.h | 4 +-- .../python_torch_functions_manual.cpp | 2 +- torch/fx/experimental/symbolic_shapes.py | 3 --- torch/fx/traceback.py | 2 +- 10 files changed, 48 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp index a44005a2ef81..ee846c9b82e3 100644 --- a/aten/src/ATen/ExpandUtils.cpp +++ b/aten/src/ATen/ExpandUtils.cpp @@ -13,8 +13,8 @@ TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) { namespace { // NOTE: are_expandable did a similar check, please keep them sync if change is needed -template -Container infer_size_impl(IntArrayRef a, IntArrayRef b) { +template +Container infer_size_impl(ArrayType a, ArrayType b) { size_t dimsA = a.size(); size_t dimsB = b.size(); size_t ndim = dimsA > dimsB ? dimsA : dimsB; @@ -25,8 +25,8 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ptrdiff_t offset = ndim - 1 - i; ptrdiff_t dimA = dimsA - 1 - offset; ptrdiff_t dimB = dimsB - 1 - offset; - int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; - int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; + auto sizeA = (dimA >= 0) ? a[dimA] : 1; + auto sizeB = (dimB >= 0) ? b[dimB] : 1; TORCH_CHECK( sizeA == sizeB || sizeA == 1 || sizeB == 1, @@ -35,7 +35,7 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ") at non-singleton dimension ", i); // 1s map to the other size (even 0). - expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA); } return expandedSizes; diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 786cbf132cd7..9e48421e540f 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -21,6 +21,8 @@ namespace at { TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); +TORCH_API SymDimVector +infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); // Named type instead of a pair/tuple so that we can be sure to // construct the vectors in place and get NRVO. diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 24efcab9e5cb..59b08eea8dce 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -13,6 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch._decomp import decomposition_table +from torch.fx.experimental.symbolic_shapes import sym_float from torch.testing._internal.common_device_type import ops from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy @@ -719,7 +720,6 @@ def deco(cls): @skipIfNoSympy @xfail_inherited_tests([ - "test_mode_tracing_factory_function", "test_make_fx_overloads", "test_trace_subclasses", ]) @@ -961,8 +961,27 @@ def f(x): # happened afterwards self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3) - - + def test_elementwise_meta_with_sym_numbers(self): + def f(x, offset, as_sym_float=False): + x0 = x.size()[0] + if as_sym_float: + x0 = sym_float(x0) + return torch.add(x0, offset) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.int64) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) def test_return_symint(self): def f(x): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index c40960a22445..da8d9af723ac 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -31,6 +31,7 @@ ) from torch._prims_common.wrappers import backwards_not_supported from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.symbolic_shapes import sym_float from torch.overrides import handle_torch_function, has_torch_function from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -390,11 +391,18 @@ def _elementwise_meta( return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype) # Number case - # NOTE: this case is not currently exercised # TODO: fix number type promotion (bool, complex->float) - assert not isinstance(number, torch.SymInt), "NYI" - assert not isinstance(number, torch.SymFloat), "NYI" - return TensorMeta(number) + + # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) + seen_float = False + if isinstance(number, (torch.SymInt, torch.SymFloat)): + for a in args: + assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" + seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) + if seen_float: + number = sym_float(number) + + return TensorMeta(number) # type: ignore[arg-type] def _complex_only_elementwise_meta(*args, **kwargs): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index f2817f0331ac..a0916c3f8268 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -723,10 +723,10 @@ def nan_to_num( nan = 0.0 if posinf is None: - posinf = prims.maximum_value(a.dtype) + posinf = torch.finfo(a.dtype).max if neginf is None: - neginf = prims.minimum_value(a.dtype) + neginf = torch.finfo(a.dtype).min result = where(isnan(a), nan, a) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 8dec2475df15..5d3d3a0e32fe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -861,11 +861,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # and ensure that Meta kernels are dispatched to (see) # Fake Tensor Dispatch Keys # TODO - we should be use the prim aten impl - if ( - "prims::" in func._schema.name - and len(flat_arg_fake_tensors) != 0 - and hasattr(func, "prim_meta_impl") - ): + if "prims::" in func._schema.name and hasattr(func, "prim_meta_impl"): with self: return func.prim_meta_impl(*args, **kwargs) diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 7cb9e8aedb19..8060c11ac457 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -125,13 +125,13 @@ struct InputMetadata { if (grad.is_nested()) { ss << at::native::get_nested_size_tensor(grad); } else { - ss << grad.sizes(); + ss << grad.sym_sizes(); } ss << " but expected shape compatible with "; if (is_nested_tensor()) { ss << shape_as_tensor(); } else { - ss << c10::asIntArrayRefSlow(shape_as_dim_vector()); + ss << shape_as_dim_vector(); } return ss; } diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 562f5a427d38..2c4999c971ea 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -692,7 +692,7 @@ static PyObject* THPVariable_numel( } if (r.idx == 0) { - return wrap(r.tensor(0).numel()); + return wrap(r.tensor(0).sym_numel()); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index ae4427e2320e..bd52760502c6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -261,9 +261,6 @@ def eval(cls, base, divisor): 'floordiv': lambda a, b: FloorDiv(a, b), } -def _nyi(): - raise NotImplementedError() - magic_methods = { **reflectable_magic_methods, 'eq': lambda a, b: sympy.Eq(a, b), diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a07b36b997bd..cee7626e5c83 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -54,7 +54,7 @@ def format_stack() -> List[str]: return current_stack.copy() else: # fallback to traceback.format_stack() - return traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) @compatibility(is_backward_compatible=False) From 8ebbd5a89a66bf84d7358f4d353ec2708d6c5429 Mon Sep 17 00:00:00 2001 From: Johannes Pitz Date: Wed, 16 Nov 2022 04:38:30 +0000 Subject: [PATCH 211/453] Easier to understand event_dim computation (#81396) Fixes #81254 Only easier to understand, not a real fix. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81396 Approved by: https://github.com/fritzo, https://github.com/kit1980 --- .../distributions/transformed_distribution.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 9d7bd6fbd690..a3bab3e836a3 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -57,26 +57,29 @@ def __init__(self, base_distribution, transforms, validate_args=None): base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) - domain_event_dim = transform.domain.event_dim - if len(base_shape) < domain_event_dim: + if len(base_shape) < transform.domain.event_dim: raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." - .format(domain_event_dim, base_shape)) - shape = transform.forward_shape(base_shape) - expanded_base_shape = transform.inverse_shape(shape) + .format(transform.domain.event_dim, base_shape)) + forward_shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(forward_shape) if base_shape != expanded_base_shape: base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] base_distribution = base_distribution.expand(base_batch_shape) - reinterpreted_batch_ndims = domain_event_dim - base_event_dim + reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim if reinterpreted_batch_ndims > 0: base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) self.base_dist = base_distribution # Compute shapes. - event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0) - assert len(shape) >= event_dim - cut = len(shape) - event_dim - batch_shape = shape[:cut] - event_shape = shape[cut:] + transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim + event_dim = max( + transform.codomain.event_dim, # the transform is coupled + base_event_dim + transform_change_in_event_dim # the base dist is coupled + ) + assert len(forward_shape) >= event_dim + cut = len(forward_shape) - event_dim + batch_shape = forward_shape[:cut] + event_shape = forward_shape[cut:] super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): From e2f0648750f2d0d0ac648728ce4c514db178cfa1 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Wed, 16 Nov 2022 05:07:51 +0000 Subject: [PATCH 212/453] Add an option to include actual license terms to the output (#85624) When building products using PyTorch, it is often required to display license terms for all dependencies. The feature itself has been implemented in #81500 but it seems there are no options to enable it. This PR implements the option. cc/ @mattip @rgommers Pull Request resolved: https://github.com/pytorch/pytorch/pull/85624 Approved by: https://github.com/rgommers, https://github.com/seemethere --- third_party/build_bundled.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/third_party/build_bundled.py b/third_party/build_bundled.py index 4da1b84a6f32..d60a2c1354fd 100644 --- a/third_party/build_bundled.py +++ b/third_party/build_bundled.py @@ -181,9 +181,14 @@ def squeeze(t): ), help="location to output new bundled licenses file", ) - + parser.add_argument( + "--include-files", + action="store_true", + default=False, + help="include actual license terms to the output", + ) args = parser.parse_args() fname = args.out_file print(f"+ Writing bundled licenses to {args.out_file}") with open(fname, 'w') as fid: - create_bundled(third_party, fid) + create_bundled(third_party, fid, args.include_files) From 7e66d1d6cdb4e8d854a8da160daeb910783f069d Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Wed, 16 Nov 2022 06:27:13 +0000 Subject: [PATCH 213/453] [Inductor] Support Shape Padding for aten.mm in Inductor (#89086) Summary: Support shape padding for aten.mm in Inductor (originally from [#88709](https://github.com/pytorch/pytorch/pull/88709)) Differential Revision: D41315078 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89086 Approved by: https://github.com/jianyuh --- torch/_inductor/decomposition.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 44bfd46505a2..3254f174b495 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -267,6 +267,31 @@ def should_pad_bench(mat1, mat2, op, input=None): return ori_time > pad_time * 2 +@register_decomposition([aten.mm]) +def mm_decomp(mat1, mat2): + if ( + config.shape_padding + and check_device_dtype(mat1, mat2) + and should_pad_bench(mat1, mat2, torch.ops.aten.mm) + ): + m_padded_length = get_padded_length(mat1.shape[0]) + k_padded_length = get_padded_length(mat1.shape[1]) + n_padded_length = get_padded_length(mat2.shape[1]) + + if k_padded_length != 0: + mat1 = pad_dim(mat1, k_padded_length, 1) + mat2 = pad_dim(mat2, k_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2) + elif m_padded_length != 0: + mat1 = pad_dim(mat1, m_padded_length, 0) + return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :] + elif n_padded_length != 0: + mat2 = pad_dim(mat2, n_padded_length, 1) + return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length] + + return NotImplemented # go directly to lowering + + @register_decomposition([aten.bmm]) def bmm_decomp(mat1, mat2): if ( From 59ba15f37407294eed3ecdb9986b02c5c2d52a70 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 16 Nov 2022 07:44:41 +0000 Subject: [PATCH 214/453] Upload CSV test reports from inductor (#89112) Inductor test report artifacts are now on HUD but its files are in CSV format instead of the default XML files from pytest or unittest that we expect. So this PR uploads both suffixes Pull Request resolved: https://github.com/pytorch/pytorch/pull/89112 Approved by: https://github.com/desertfire --- .github/actions/upload-test-artifacts/action.yml | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 624c4895155a..9fd2342601f1 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -34,7 +34,7 @@ runs: run: | # Remove any previous test reports if they exist rm -f test-reports-*.zip - zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml' + zip -r "test-reports-${FILE_SUFFIX}.zip" test -i '*.xml' -i '*.csv' - name: Zip usage log for upload if: runner.os != 'Windows' && !inputs.use-gha @@ -67,7 +67,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | # -ir => recursive include all files in pattern - 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' + 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' -ir'!test\*.csv' - name: Zip usage log for upload if: runner.os == 'Windows' && !inputs.use-gha @@ -127,8 +127,11 @@ runs: # Add the run attempt, see [Artifact run attempt] name: test-reports-runattempt${{ github.run_attempt }}-${{ inputs.file-suffix }}.zip retention-days: 14 - if-no-files-found: error - path: test/**/*.xml + # Don't want to fail the workflow here because not all workflows have csv files + if-no-files-found: ignore + path: | + test/**/*.xml + test/**/*.csv - name: Store Usage Logs on Github uses: actions/upload-artifact@v3 From 370fc5cb421f54fc9513237390e09cca0e06e01b Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 08:04:37 +0000 Subject: [PATCH 215/453] [dtensor] PART 1: move DeviceMesh and placement to core distributed (#88549) This PR creates `torch.distributed._tensor` package and moves DeviceMesh, PlacementTypes to it part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88549 Approved by: https://github.com/fduwjj --- torch/distributed/_tensor/__init__.py | 0 torch/distributed/_tensor/device_mesh.py | 506 +++++++++++++++++++ torch/distributed/_tensor/placement_types.py | 432 ++++++++++++++++ 3 files changed, 938 insertions(+) create mode 100644 torch/distributed/_tensor/__init__.py create mode 100644 torch/distributed/_tensor/device_mesh.py create mode 100644 torch/distributed/_tensor/placement_types.py diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py new file mode 100644 index 000000000000..5ca3f8c6159b --- /dev/null +++ b/torch/distributed/_tensor/device_mesh.py @@ -0,0 +1,506 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import warnings +from typing import List, Optional, Sequence, TypeVar, Union +import torch +from torch.distributed.distributed_c10d import ( + all_gather, + all_reduce, + broadcast, + get_rank, + get_world_size, + get_global_rank, + ReduceOp, + GroupMember, + scatter, + _get_default_group, + reduce_scatter, + new_group, + ProcessGroup, + all_to_all, + Work, +) + +_global_device_mesh: Optional["DeviceMesh"] = None + + +def get_global_device_mesh() -> "DeviceMesh": + global _global_device_mesh + assert ( + _global_device_mesh is not None + ), "Could not get a default device mesh!" + return _global_device_mesh + + +def set_global_device_mesh(mesh: Optional["DeviceMesh"]) -> None: + global _global_device_mesh + _global_device_mesh = mesh + + +# We want a type for "can be passed to torch.as_tensor()"; +# this is a recursive sequence type, which isn't fully supported +# yet in python. This construct simulates that up to depth 7. +T = TypeVar("T") +_L = Union[T, Sequence[T]] +NDIntList = _L[_L[_L[_L[_L[_L[_L[int]]]]]]] + +MeshExprT = Union[ + torch.Tensor, + NDIntList, +] + + +class DeviceMesh(object): + """ + DeviceMesh represents a mesh of devices, where layout of devices could be + represented as a n-d dimension array, and each value of the n-d dimensional + array is the global id of the default process group ranks. + + DeviceMesh could be used to describe the layout of devices across the cluster, + and serves as a proxy for communication among the device lists within the cluster. + + We use the default ProcessGroup in this DeviceMesh class to implement proper + communications. Note that we also add collective wrappers in this class. This is + used to decouple detailed communication backend with the underlying + DTensor implementation. + + DeviceMesh can be used as a context manager. + Args: + device_type (str): device type of the mesh. Currently supports: cpu, cuda. + mesh (ndarray): could be a multi-dimension array or an integer tensor that + describes the layout of devices, the ids are global ids of the + default process group. + dim_groups (List[ProcessGroup], optional): The ProcessGroup used per mesh + dimension. + + Returns: + A :class:`DeviceMesh` object + + Example (2 host with 4 GPUs each): + ``` + # The following program runs on each process/rank in SPMD manner. + # initialized default world + torch.distributed.init_process_group(backend="nccl", world_size=8) + # initialize device mesh as (2, 4) to represent the topology + # of cross-host(dim 0), and within-host (dim 1) + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ]) + ``` + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7) + + """ + + device_type: str + mesh: torch.Tensor + _backend: str + + def __init__( + self, + device_type: str, + mesh: MeshExprT, + dim_groups: Optional[List[ProcessGroup]] = None, + ) -> None: + self.device_type = device_type + self.mesh = ( + mesh.detach() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, dtype=torch.int) + ) + default_pg = _get_default_group() + self._backend = default_pg._get_backend_name() + # TODO: if user want to pass pg_options, offer a way to do it + # check default pg backend, should support device_type + if device_type == "cpu": + assert ( + self._backend == "gloo" + ), f"ProcessGroup backend: {self._backend} not supporting CPU!" + elif device_type == "cuda": + if self._backend == "gloo": + warnings.warn( + "We recommend using nccl backend for cuda device type, gloo backend might only have partial support!" + ) + assert self._backend == "gloo" or self._backend == "nccl" + else: + raise RuntimeError( + f"DeviceMesh only support cpu or cuda device type, but got {device_type}" + ) + + world_size = get_world_size() + if self.mesh.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" + ) + + unique_mesh_values = self.mesh.unique(sorted=True) + if unique_mesh_values.numel() != self.mesh.numel(): + raise RuntimeError( + f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}" + ) + + # coordinates of this rank on the mesh + rank_coords = (self.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + self._coordinate_on_dim: Optional[List[int]] = ( + rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + ) + + # groups created by dimension, each dimension should have exact + # one valid process group per rank + self._dim_groups: List[ProcessGroup] = [] + if dim_groups is not None: + # if user hand creating dimension based groups + # we just take it and use it for communication + if not isinstance(dim_groups, list): + raise RuntimeError( + "dim_groups expected to be Optional[List[ProcessGroup]]" + ) + + for group in dim_groups: + if not isinstance(group, ProcessGroup): + raise RuntimeError( + f"found object in dim_groups that is not a ProcessGroup: {group}" + ) + + if self.get_rank() in self.mesh: + if len(dim_groups) != self.mesh.ndim: + raise RuntimeError( + f"length of dim_groups ({len(dim_groups)}) expected to be equal to mesh.ndim ({self.mesh.ndim})" + ) + else: + if len(dim_groups) != 0: + raise RuntimeError( + f"length of dim_groups ({len(dim_groups)}) expected to be equal to 0 on rank {self.get_rank()} " + f"for mesh {self.mesh}" + ) + + self._dim_groups = dim_groups + return + + if self.mesh.ndim == 1 and unique_mesh_values[-1] == world_size - 1: + # if the mesh is the same as world_pg, we just append the default + # pg to the first dim goups, as new_group cannot have the exact + # same ranks as world + self._dim_groups.append(default_pg) + else: + # create sub pgs base on the mesh argument specified + # handle multi-dim mesh, create subgroups by + # looping over the pg_ranks_by_dim for each dim + for dim in range(self.mesh.ndim): + # swap the current dim to the last dim + # then reshape to flatten out other dims + pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( + -1, self.mesh.size(dim) + ) + + # multi-dim mesh, create subgroups by + # looping over the pg_ranks for each dim + # and append the groups + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + # call new_group regardless of the current rank in the + # pg or not, it's required that all ranks participate + # in subgroup construction + new_subgroup = new_group( + ranks=subgroup_ranks, backend=self._backend + ) + # only add to dim_groups if the current rank in the subgroup + if self.get_rank() in subgroup_ranks: + if len(self._dim_groups) > dim: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {self.get_rank} " + f"in {subgroup_ranks}!" + ) + self._dim_groups.append(new_subgroup) + + def __enter__(self) -> "DeviceMesh": + # set global device_mesh to this instance + set_global_device_mesh(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # unset global device mesh + set_global_device_mesh(None) + + def __repr__(self) -> str: + return f"DeviceMesh:({self.mesh.tolist()})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceMesh): + return False + if id(self) == id(other): + return True + return self.mesh.equal(other.mesh) + + def get_dim_groups(self) -> List[ProcessGroup]: + return self._dim_groups + + # pyre-fixme[3]: Return type must be annotated. + def size(self, dim: int = 0): + return self.mesh.size(dim) + + @property + def ndim(self) -> int: + return self.mesh.ndim + + def backend(self) -> str: + return self._backend + + def get_rank(self) -> int: + return get_rank() + + def get_coordinate_on_dim(self, dim: int) -> Optional[int]: + """ + Return the relative index of this rank relative to a given + dimension of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim[dim] if self._coordinate_on_dim else None + + def scatter( + self, + output: torch.Tensor, + scatter_list: List[torch.Tensor], + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor lista on rank + 2 to rank 2/3. + + Args: + tensor (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + dim_group = self._dim_groups[mesh_dim] + # src need to be global rank + src_for_dim = 0 + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + + if src_for_dim == get_rank(): + fut = scatter( + output, + scatter_list=scatter_list, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + else: + fut = scatter( + output, + scatter_list=None, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + + return fut + + def broadcast( + self, + tensor: torch.Tensor, + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + dim_group = self._dim_groups[mesh_dim] + # src need to be global rank + src_for_dim = 0 + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + + return broadcast( + tensor, src=src_for_dim, group=dim_group, async_op=async_op + ) + + def all_gather( + self, + tensor_list: List[torch.Tensor], + tensor: torch.Tensor, + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + """ + all_gather the tensor on each rank to the tensor_list on a + device mesh dimension. + + Args: + tensor_list (List[torch.Tensor]): The gathered tensor list. + tensor (torch.Tensor): tensor to be gathered on each rank. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + dim_group = self._dim_groups[mesh_dim] + return all_gather( + tensor_list, tensor, group=dim_group, async_op=async_op + ) + + def all_reduce( + self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment] + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + """ + all_reduce the tensor on each rank on a device mesh dimension, and + return an output tensor on each rank after all_reduce. + + Args: + tensor (torch.Tensor): tensor to be all_reduced on each rank. + op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional): + the reduction op of all_reduce (i.e. ReduceOp.SUM) + mesh_dim (int, optional): indicate which mesh dimension we want + to reduce on. + + Returns: + A :class:`Work` object + """ + dim_group = self._dim_groups[mesh_dim] + return all_reduce(tensor, op=op, group=dim_group, async_op=async_op) + + def reduce_scatter( + self, + output: torch.Tensor, + input_list: List[torch.Tensor], + op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment] + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + """ + reduce the input_list on each rank on a device mesh dimension, and scatter + the results to the output tensor on each rank. + + Args: + output (torch.Tensor): tensor to receive the scattered result. + input_list (List[torch.Tensor]): tensor list to be reduced and scattered + and scattered on each rank. + op (:class:`torch.distributed.distributed_c10d.ReduceOp, optional): + the reduction op of reduce_scatter (i.e. ReduceOp.SUM) + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on. + + Returns: + A :class:`Work` object + """ + if self._backend == "nccl": + dim_group = self._dim_groups[mesh_dim] + fut = reduce_scatter( + output, input_list, op=op, group=dim_group, async_op=async_op + ) + + elif self._backend == "gloo": + # it's gloo, which does not have reduce_scatter + # we have to do all_reduce + scatter + warnings.warn( + "ProcessGroupGloo does not support reduce_scatter, falling back with all reduce!" + ) + my_coordinate = self.get_coordinate_on_dim(mesh_dim) + # TODO: what should happen if rank is not in the mesh? + # see issue https://github.com/pytorch/tau/pull/492 + assert ( + my_coordinate is not None + ), "Rank if not part of mesh" # TODO: figure out behavior here + fut = None + flattened_list = [] + offset_list = [] + + offset = 0 + for input in input_list: + offset_list.append(offset) + offset += input.numel() + flattened_list.append(input.flatten()) + + # all reduce since gloo does not support reduce_scatter + flat_tensor = torch.cat(flattened_list).clone( + memory_format=torch.contiguous_format + ) + fut = self.all_reduce( + flat_tensor, op=op, mesh_dim=mesh_dim, async_op=async_op + ) + # scatter the tensor + output_offset = offset_list[my_coordinate] + output.copy_( + flat_tensor[ + output_offset : output_offset + output.numel() + ].view(output.shape) + ) + else: + raise RuntimeError( + f"backend {self._backend} does not support reduce_scatter!" + ) + return fut + + # TODO: test uneven split on GLOO and NCCL + def all_to_all( + self, + output_tensor_list: List[torch.Tensor], + input_tensor_list: List[torch.Tensor], + mesh_dim: int = 0, + async_op: bool = False, + ) -> Optional[Work]: + dim_group = self._dim_groups[mesh_dim] + + work = None + # no direct dist.all_to_all support on 'gloo' so we manually do scatters + if self.backend() == "gloo": + # TODO: pull the handle of uneven case in #492 + dim_group_size = get_world_size(dim_group) + for i in range(dim_group_size): + # src need to be global rank + src_for_dim = i + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, i) + + work = scatter( + output_tensor_list[i], + input_tensor_list if self.get_rank() == src_for_dim else [], + group=dim_group, + src=src_for_dim, + async_op=async_op, + ) + + elif self.backend() == "nccl": + work = all_to_all( + output_tensor_list, + input_tensor_list, + dim_group, + async_op=async_op, + ) + else: + raise RuntimeError( + f"DeviceMesh does not support all-to-all collective operations on {self.backend()} backend." + ) + return work diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py new file mode 100644 index 000000000000..f2df183b046d --- /dev/null +++ b/torch/distributed/_tensor/placement_types.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass +from typing import Optional, List, Sequence, Tuple, cast + +import torch +import torch.distributed.distributed_c10d as c10d +from torch.distributed._spmd.comm_tensor import CommTensor + +from torch.distributed._tensor.device_mesh import DeviceMesh + + +class Placement(object): + # base class Placement type + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, Shard): + return self.dim == dim + else: + return isinstance(self, Shard) + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, _Partial) + + +@dataclass +class Shard(Placement): + # shard placement, shard on a dim + dim: int + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], int]: + # NOTE: For with_padding option, we pad the tensor on each rank before calling + # the collectives (i.e. scatter/all_gather, etc.). This is because for gloo + # backend, it does not support uneven collectives, nccl supports some, but + # it might be slow compared to even size collective, we need to pad tensor + # before really calling the collective, and unpad/narrow it afterwards + # TODO: consider if we should remove this logic once ProcessGroupGloo + # support uneven list, and collective perfomance on par + assert ( + self.dim <= tensor.ndim + ), f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + assert ( + tensor.size(self.dim) >= num_chunks + ), f"Tensors to be sharded on dim {self.dim} must be at least as large as " + f"the number of devices in that dimension {num_chunks}" + # split tensor over dimension `dim` into n slices with padding if necessary + tensor_list = list(tensor.tensor_split(num_chunks, self.dim)) + idx_start_to_pad = tensor.size(self.dim) % num_chunks + if with_padding or contiguous: + shard_list = [] + for i, shard in enumerate(tensor_list): + if ( + with_padding + and idx_start_to_pad != 0 + and i >= idx_start_to_pad + ): + shard = self._pad_tensor(shard) + # input tensors are expected to be congtiguous by the collective backend + shard = shard.contiguous() if contiguous else shard + shard_list.append(shard) + return shard_list, idx_start_to_pad + else: + return tensor_list, idx_start_to_pad + + def _pad_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + # pad tensor by 1 on the shard dim + pad = [0, 0] * (tensor.ndim - self.dim) + pad[-1] = 1 + return torch.nn.functional.pad(tensor, pad) + + def _unpad_tensor(self, tensor: torch.Tensor) -> torch.Tensor: + # unpad tensor by 1 on the shard dim + return tensor.narrow( + self.dim, start=0, length=tensor.size(self.dim) - 1 + ) + + def _local_shard_size_on_dim( + self, + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + assert ( + size_on_dim >= num_chunks + ), f"Size to be sharded on dim {self.dim} must be at least as large as the number of devices in that dimension {num_chunks}" + split_size, pad_idx = divmod(size_on_dim, num_chunks) + local_shard_size = ( + split_size + 1 if pad_idx != 0 and rank < pad_idx else split_size + ) + local_offset_on_dim = -1 + if return_offset: + local_offset_on_dim = ( + rank * split_size + pad_idx if rank >= pad_idx else rank + ) + return (local_shard_size, local_offset_on_dim) + + def _shard_tensor( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + shard and scatter a tensor on a mesh dimension (use coordinate + 0 on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate_on_dim(mesh_dim) + num_chunks = mesh.size(dim=mesh_dim) + # TODO: what should happen if rank is not in the mesh? + # see issue https://github.com/pytorch/tau/pull/492 + assert ( + my_coordinate is not None + ), "Rank if not part of mesh" # TODO: figure out behavior here + scatter_list, pad_idx = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + output = torch.empty_like(scatter_list[my_coordinate]) + mesh.scatter(output, scatter_list, mesh_dim=mesh_dim) + + if pad_idx != 0 and my_coordinate >= pad_idx: + output = self._unpad_tensor(output) + return output + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: c10d.ReduceOp, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate_on_dim(mesh_dim) + num_chunks = mesh.size(dim=mesh_dim) + # TODO: what should happen if rank is not in the mesh? + # see issue https://github.com/pytorch/tau/pull/492 + assert ( + my_coordinate is not None + ), "Rank if not part of mesh" # TODO: figure out behavior here + scattered_list, pad_idx = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + # wrap with comm tensor + scattered_list = [CommTensor(t) for t in scattered_list] + output = torch.empty_like(scattered_list[my_coordinate]) + mesh.reduce_scatter( + CommTensor(output), + scattered_list, # pyre-ignore[6] + op=reduce_op, + mesh_dim=mesh_dim, + ) + if pad_idx != 0 and my_coordinate >= pad_idx: + output = self._unpad_tensor(output) + return output + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + size: torch.Size, + mesh: DeviceMesh, + mesh_dim: int, + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + my_coordinate = mesh.get_coordinate_on_dim(mesh_dim) + num_chunks = mesh.size(dim=mesh_dim) + # TODO: what should happen if rank is not in the mesh? + # see issue https://github.com/pytorch/tau/pull/492 + assert ( + my_coordinate is not None + ), "Rank if not part of mesh" # TODO: figure out behavior here + # check if it needs to pad input tensor before all_gather + pad_idx = size[self.dim] % num_chunks + if pad_idx != 0 and my_coordinate >= pad_idx: + local_tensor = self._pad_tensor(local_tensor).contiguous() + + gathered_list = [] + # N.B. CommTensor does not change eager mode behavior. During tracing, it + # makes sure communication result is properly waited before subsequent + # read operations. + for _ in range(num_chunks): + gathered_list.append( + CommTensor( + torch.empty_like( + local_tensor, + memory_format=torch.contiguous_format, + ) + ) + ) + + mesh.all_gather(gathered_list, CommTensor(local_tensor.contiguous()), mesh_dim=mesh_dim) # type: ignore[arg-type] + # unpad the tensor if the input tensor was padded + if pad_idx != 0: + gathered_list = [ + self._unpad_tensor(gathered_tensor) # type: ignore[misc] + if i >= pad_idx + else gathered_tensor + for i, gathered_tensor in enumerate(gathered_list) + ] + return torch.cat(gathered_list, dim=self.dim) # type: ignore[arg-type] + + +@dataclass +class Replicate(Placement): + # replicate placement + pass + + +@dataclass +class _Partial(Placement): + # This is a default partial placement with element-wise reduce op + # when doing reduction it follows the contract of `_to_replicate` + # and `_to_shard` to do the reduction and convert the local tensor + # to the corresponding state (replicate or shard) + # + # We can implement custom reductions as needed by subclassing this + # class and override those contracts. + reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.RedOpType.SUM # type: ignore[attr-defined] + + def _to_replicate( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # out-of-place all_reduce to replicate, since the current partial DTensor + # might get used by other ops as well, so we can't inplace modify it + cloned_local = CommTensor( + tensor.clone(memory_format=torch.contiguous_format) + ) + mesh.all_reduce( + cloned_local, c10d.ReduceOp(self.reduce_op), mesh_dim=mesh_dim # type: ignore[call-arg] + ) + return cloned_local + + def _to_shard( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by default call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor( + tensor, mesh, c10d.ReduceOp(self.reduce_op), mesh_dim # type: ignore[call-arg] + ) + + +# used internally to propagate the placements +@dataclass +class DTensorSpec(object): + mesh: DeviceMesh + placements: Sequence[Placement] + # shape of the current dist tensor, this will be set upon + # construction of the DTensor, prop rule could read it, and + # would need to set in output spec when calculate the output + # sharding + shape: torch.Size + # ndim of the current dist tensor, if passed in, this would be + # validated with shape, if not passed in, will be generated from + # the shape + ndim: int = -1 + + def __post_init__(self) -> None: + if self.ndim == -1: + self.ndim = len(self.shape) + + @property + def dim_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 1, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def sums(self) -> List[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @property + def local_shape(self) -> Tuple[int, ...]: + """ + Compute the shape of a local shard of the given DTensor on its current + coordinate of the mesh. + """ + assert ( + self.shape is not None + ), "DTensorSpec does not contain global shape." + local_shape = list(self.shape) # start with global shape + for idx, placement in enumerate(self.placements): + mesh_dim_size = self.mesh.size(idx) + my_coordinate = self.mesh.get_coordinate_on_dim(idx) + assert my_coordinate is not None, "Rank not part of mesh!" + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < self.ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {self.ndim}" + local_shard_size, _ = placement._local_shard_size_on_dim( + local_shape[shard_dim], mesh_dim_size, my_coordinate + ) + assert isinstance(local_shard_size, int) + local_shape[shard_dim] = local_shard_size + return tuple(local_shape) + + @property + def local_offsets(self) -> Tuple[int, ...]: + """ + Compute the offsets of a local shard of the given DTensor on its current + global rank. This is mostly used by distributed checkpointing to know the + exact offsets of the local shard. + """ + assert ( + self.shape is not None + ), "DTensorSpec does not contain global shape." + local_offsets = [0] * self.ndim + local_shape = list(self.shape) + + for idx, placement in enumerate(self.placements): + mesh_dim_size = self.mesh.size(idx) + my_coordinate = self.mesh.get_coordinate_on_dim(idx) + assert my_coordinate is not None, "Rank not part of mesh!" + if isinstance(placement, Shard): + shard_dim = placement.dim + assert ( + shard_dim < self.ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {self.ndim}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate, + return_offset=True, + ) + local_shape[shard_dim] = shard_size + local_offsets[shard_dim] = shard_offset + return tuple(local_offsets) + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: List[int], + sums: List[int], + shape: torch.Size, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + shape (torch.Size): shape of the DTensor associated with this spec. + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = _Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, placements, shape=shape, ndim=len(dim_map)) From 4b945967de2ae9a3c6df579a1541b822de46110c Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 08:04:38 +0000 Subject: [PATCH 216/453] [dtensor] PART 2: move DTensor abstraction and APIs to core distributed (#88176) This PR moves the core DTensor abstraction and high level APIs to torch.distributed._tensor folder, which includes the following: 1. DTensor class 2. high level APIs (distribute_tensor/module) 3. dispatching logic 4. redistribute logic part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88176 Approved by: https://github.com/fduwjj --- torch/distributed/_tensor/README.md | 3 + torch/distributed/_tensor/__init__.py | 189 +++++++++++ torch/distributed/_tensor/api.py | 393 ++++++++++++++++++++++ torch/distributed/_tensor/dispatch.py | 301 +++++++++++++++++ torch/distributed/_tensor/redistribute.py | 236 +++++++++++++ torch/distributed/_tensor/utils.py | 53 +++ 6 files changed, 1175 insertions(+) create mode 100644 torch/distributed/_tensor/README.md create mode 100644 torch/distributed/_tensor/api.py create mode 100644 torch/distributed/_tensor/dispatch.py create mode 100644 torch/distributed/_tensor/redistribute.py create mode 100644 torch/distributed/_tensor/utils.py diff --git a/torch/distributed/_tensor/README.md b/torch/distributed/_tensor/README.md new file mode 100644 index 000000000000..9bbd71b764e5 --- /dev/null +++ b/torch/distributed/_tensor/README.md @@ -0,0 +1,3 @@ +# Distributed Tensor + +This is a prototype distributed tensor implementation that implements most of the basic parts in the RFC https://docs.google.com/document/d/15R3fmoPbzedlKSjtpQ97HFPidp9QTXLEap6gyIvRrMY/edit# diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index e69de29bb2d1..ba09f2fbb690 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Optional, Sequence, Callable, cast + +import torch +import torch.nn as nn +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh +from torch.distributed._tensor.placement_types import Placement, Shard, Replicate + + +# Import all builtin dist tensor ops +# import torch.distributed._tensor.ops + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Distribute a torch.Tensor to the `device_mesh` according to the `placements` + specified. The rank of `device_mesh` and `placements` must be the same. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use `torch.tensor_split` + semantic to shard the tensor and scatter the shards. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as `device_mesh.ndim`. If not specified, we will + by default replicate the tensor across the `device_mesh` from the + first rank of each dimension of the `device_mesh`. + + Returns: + A :class:`DTensor` object + """ + # get default device mesh if there's nothing specified + device_mesh = ( + get_global_device_mesh() if device_mesh is None else device_mesh + ) + # convert tensor to the correponding device type if it's not in that device type + tensor = tensor.to(device_mesh.device_type) + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we just need to check if the + # device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != placements: + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor + + # distribute the tensor according to the placements. + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + output = placement._shard_tensor(local_tensor, device_mesh, idx) + # scatter call could not return a tensor with correct requires_grad + # field, as ProcessGroupNCCL refuse to take a tensor with requires_grad + # to do inplace update! So we manually set it here + output.requires_grad_(tensor.requires_grad) + local_tensor = output + elif placement.is_replicate(): + local_tensor = local_tensor.contiguous() + device_mesh.broadcast(local_tensor, mesh_dim=idx) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + + assert local_tensor is not None, "distributing a tensor should not be None" + return DTensor( + local_tensor, + device_mesh, + placements, + size=tensor.size(), + requires_grad=tensor.requires_grad, + ) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[..., None]] = None, + output_fn: Optional[Callable[..., None]] = None, +) -> nn.Module: + """ + This function converts all module parameters to :class:`DTensor` parameters + according to the `partition_fn` specified. It could also control the input or + output of the module by specifying the `input_fn` and `output_fn`. (i.e. convert + the input to :class:`DTensor`, convert the output back to torch.Tensor) + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the `device_mesh`). If `partition_fn` is not specified, + by default we replicate all module parameters of `module` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. `input_fn` will be installed as a module + `forward_pre_hook` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. output_fn will be + installed as a module `forward_hook` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all `DTensor`s. + """ + + if device_mesh is None: + device_mesh = get_global_device_mesh() + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter( + distribute_tensor(param.data, mesh, full_replicate) + ), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor( + buffer, mesh, full_replicate + ) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[misc] + # register input_fn as module forward hook + if output_fn is not None: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[misc] + ) + + return module + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "DeviceMesh", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", +] diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py new file mode 100644 index 000000000000..bf5514cc7d4e --- /dev/null +++ b/torch/distributed/_tensor/api.py @@ -0,0 +1,393 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import warnings +import torch +from torch.utils._pytree import tree_flatten +from typing import Dict, Callable, Optional, Sequence, cast +from torch.distributed._tensor.device_mesh import get_global_device_mesh, DeviceMesh +from torch.distributed._tensor.placement_types import ( + Placement, + Shard, + Replicate, + _Partial, + DTensorSpec, +) +from torch.distributed._tensor.redistribute import Redistribute + +from torch.distributed._tensor.dispatch import operator_dispatch, OpSchema, OutputSharding + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure our DTensor +# works together with torch.Tensor within autograd engine. This +# allows DistributedTensor to exist on part of the module hierarchy +# and still able to calculate gradients across the torch.Tensor and +# DistributedTensor boundary. +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DistributedTensor params, we would need to make the folloing +# flow to work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input -> Sharded Module B -> DTensor output +# -> output (torch.Tensor) -> Module C -> output (torch.Tensor) +# +# We need the conversion from Module A to DTensor input, which is +# `from_local`, and conversion from DTensor output to output, which +# is `to_local`, thus these two functions must be Autograd functions. +# +class ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward(ctx, input: "DTensor"): # type: ignore[override] + ctx.dtensor_device_mesh = input.device_mesh + ctx.dtensor_placements = input.placements + ctx.dtensor_shape = input.shape + ctx.dtensor_requires_grad = input.requires_grad + return input._local_tensor.detach() + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + device_mesh = ctx.dtensor_device_mesh + placements = ctx.dtensor_placements + return DTensor( + grad_output, + device_mesh, + placements, + size=ctx.dtensor_shape, + requires_grad=grad_output.requires_grad, + ) + + +class FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + run_check: bool, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if run_check: + # TODO: by default check tensor metas across rank + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + device_mesh.broadcast(input, mesh_dim=idx) + + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + tensor_shape = list(input.size()) + for idx, placement in enumerate(placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + local_dim_size = tensor_shape[shard_dim] + tensor_shape[shard_dim] = local_dim_size * device_mesh.size(idx) + + dist_tensor = DTensor( + input, + device_mesh, + placements, + size=torch.Size(tensor_shape), + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + grad_output = Redistribute.apply( + grad_output, previous_device_mesh, previous_placement + ) + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None + + +class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # class attribute that handles operator placements propagation + # rules, keyed by aten op name, value is propagation func + _op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]] = {} + + # class attribute that handles custom registered ops, all handled + # custom ops should appear in this table, and overriding the default + # operators that's been covered by _op_to_rules or fallbacks. + # (custom operator is the highest priority when dispatching). + # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + _custom_dispatch_ops: Dict[str, Callable] = {} + + @staticmethod + def __new__( + cls, + local_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + *, + size: torch.Size, + requires_grad: bool = False, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + Note: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using `DTensor.from_local`, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using `distribute_tensor`. + """ + # recover tensor strides from local tensor strides and global size info + # in the case of sharding + # TODO: we should try to use meta tensor for shape and stride calculation + tensor_stride = list(local_tensor.stride()) + local_size = list(local_tensor.size()) + for placement in placements: + if isinstance(placement, Shard): + shard_dim = placement.dim + # recover tensor stride by modifying the stride that larger than + # the current stride on the shard_dim + for i in range(len(tensor_stride)): + if ( + i != shard_dim + and tensor_stride[i] >= tensor_stride[shard_dim] + ): + # rescale the stride by the shard size + tensor_stride[i] = ( + tensor_stride[i] // local_size[shard_dim] + ) * size[shard_dim] + elif not isinstance(placement, (Replicate, _Partial)): + raise RuntimeError( + f"placement type {type(placement)} not supported!" + ) + + if requires_grad != local_tensor.requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + size, + strides=tensor_stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + # deepcopy and set spec + r._spec = DTensorSpec( + device_mesh, copy.deepcopy(placements), shape=r.size() + ) + # detach local tensor from autograd graph as we initialize the + # distributed tensor and autograd will be working on top of + # the wrapper tensor directly instead of local torch.Tensor + r._local_tensor = local_tensor.detach() + return r + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + @classmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + # if we find nn.functional name in dispatch op, dispatch to it instead, + # this allow us to override some python level behaviors that wouldn't be + # possible in __torch_dispatch__ level. + if func.__name__ in DTensor._custom_dispatch_ops: + # dispatch to the same table as the name should be different between + # torch_function and torch_dispatch + return DTensor._custom_dispatch_ops[func.__name__](*args, **kwargs) + else: + # if not, just do nothing here + return super().__torch_function__(func, types, args, kwargs) + + @classmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + # check that we are not getting mixed vanilla and Distributed tensors + arg_list, _ = tree_flatten(args) + for arg in arg_list: + if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor): + raise RuntimeError( + f"{func}: got mixed distributed and non-distributed tensors." + ) + + if kwargs is None: + kwargs = {} + + return operator_dispatch( + func, + args, + kwargs, + DTensor._op_to_rules, + DTensor._custom_dispatch_ops, + ) + + @classmethod + def from_local( + cls, + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + run_check: bool = True, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the `device_mesh` and `placements` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as `device_mesh.ndim`. If not + specified, we will by default replicate the tensor across the + `device_mesh` from the first rank of each dimension of the `device_mesh`. + run_check (bool, optional): indicate whether to run check across ranks + to check meta information and data. if have :class:`Replicate` in + `placements`, the data on first rank of the device mesh dimension + will be broadcasted to other ranks. + + Returns: + A :class:`DTensor` object + + .. note:: `from_local` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = ( + get_global_device_mesh() if device_mesh is None else device_mesh + ) + # convert the local tensor to desired device base on device mesh's device_type + local_tensor = local_tensor.to(device_mesh.device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, device_mesh, placements, run_check + ) + + def to_local(self) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Returns: + A :class:`torch.Tensor` object that represents the local tensor of its current rank. + + .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + return ToTorchTensor.apply(self) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + ) -> "DTensor": + """ + `redistribute` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from is current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as `device_mesh.ndim`. + + Returns: + A :class:`DTensor` object + + .. note:: `redistribute` is differentiable. + """ + # This API perform necessary transformations and get + # a new DTensor with the new spec. i.e. for + # sharding it's a reshard behavior. + # Note that redistribute currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + device_mesh = ( + get_global_device_mesh() if device_mesh is None else device_mesh + ) + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + for placement in placements: + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to _Partial, _Partial is for internal use only!" + ) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply(self, device_mesh, placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: device_mesh is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> Sequence[Placement]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: placements is a read-only property, it can not be set. + """ + return self._spec.placements diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py new file mode 100644 index 000000000000..8c9e5a22efb8 --- /dev/null +++ b/torch/distributed/_tensor/dispatch.py @@ -0,0 +1,301 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from dataclasses import dataclass +from typing import List, Callable, Dict, Tuple, Optional, cast + +import torch +from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten +from torchgen.model import FunctionSchema, SchemaKind + +import torch.distributed._tensor.api as dtensor +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.redistribute import redistribute_dtensor +from torch.distributed._tensor.utils import ( + ArgKwargsType, + OutputSpecType, + unwrap_local_tensor, + unwrap_schema, + wrap, +) + + +""" +If _ENABLE_FALLBACK set to False, dispatch will fail when an op doesn't +have a sharding rule registered. +""" +_ENABLE_FALLBACK = False + + +""" +Print information on ops input shape and sharding for debugging purposes. +""" +_DEBUG_VERBOSE = False + + +@dataclass +class OpSchema(object): + """ + OpSchema is a data class that describes an operator input schemas, it + includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order + preserved). It is mainly used by the dispatching logic below to run things like + sharding propagation. + + Sharding propagation rules registered could utilize this data class and + do inplace update some fields (when necessary, i.e shape related ops) to make + sure the args/kwargs are legit before passing to the local tensor operator. + This is the main reason that we don't freeze this dataclass. + + NOTE: greater access to the operator inputs comes with greater responsibility. + Here are some basic rules about what can be used and what can be changed. + + Args: + func_schema: the function schema of the operator + args_schema: contains args except that the DTensor args have been replaced + with its DTensorSpec + kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced + with its DTensorSpec + + What can be used: + - every attribute within this class could be read to conduct + sharding propagation. + What can be changed: + - only the args_schema and kwargs_schema could be changed. + - every non-tensor args could be changed to accomodate for local tensor + operations (i.e. for ops like view/reshape/...) + - every "DTensorSpec" attribute inside `args_schema`, `kwargs_schema` and + `args_spec` SHOULD NOT be updated! DTensorSpec are read only and sharding + propagation shouldn't inplace update them, otherwise the input DTensor + placements will get implicitly changed and it's error-prone. + """ + + func_schema: FunctionSchema + args_schema: Tuple[object, ...] + kwargs_schema: Dict[str, object] + is_inplace: bool = False + is_out_variant: bool = False + + def __post_init__(self) -> None: + schema_kind = self.func_schema.kind() + self.is_inplace = ( + schema_kind + == SchemaKind.inplace # pyre-ignore [16] pyre bad at enum + ) + self.is_out_variant = ( + schema_kind == SchemaKind.out # pyre-ignore [16] pyre bad at enum + ) + + @property + def args_spec(self) -> Tuple[DTensorSpec, ...]: + """ + args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list + with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) + mainly used by sharding propagation to propagate the output spec + """ + # filter out non-relavant values from args schema to get a clean spec list + # this would mainly be used by sharding propagation rules + return tuple( + item for item in self.args_schema if isinstance(item, DTensorSpec) + ) + + def __repr__(self) -> str: + return ( + f"OpSchema(func_schema={self.func_schema}," + f" args_schema={self.args_schema}," + f" kwargs_schema={self.kwargs_schema})" + ) + + +@dataclass +class OutputSharding: + """ + OutputSharding is a data class that is used by the sharding propagation + rules, it could set the output_spec upon successful propagation, and if + it failed, output_spec would become None and sharding propagation rules + could give a list of suggestions for inputs to reshard. + + NOTE: the schema_suggestion generated by sharding propagation should be + exactly the same as the operator OpSchema, except the DTensor DTensorSpecs + """ + + output_spec: OutputSpecType + schema_suggestions: Optional[List[OpSchema]] = None + failed_reason: Optional[str] = None + + +def pack_args_kwargs_with_local_tensor( + args: ArgKwargsType, + args_schema: ArgKwargsType, + redistribute_with_schema: bool = False, +) -> ArgKwargsType: + flatten_args, args_tree_spec = tree_flatten(args) + flatten_args_schema, _ = tree_flatten(args_schema) + + for i, arg in enumerate(flatten_args): + if isinstance(arg, dtensor.DTensor): + if redistribute_with_schema: + target_spec = flatten_args_schema[i] + arg = redistribute_dtensor( + arg, target_spec.mesh, target_spec.placements + ) + + # reuse the schema list and update it with local tensor + flatten_args_schema[i] = arg._local_tensor + + return tree_unflatten(flatten_args_schema, args_tree_spec) + + +def _reshape_alias( + x: torch.Tensor, shape: Tuple[int, ...], strides: Tuple[int, ...] +) -> torch.Tensor: + return torch.ops.aten.view(x, shape) + + +_CURRENT_DECOMPOSITION_TABLE: Dict[ + Callable[..., object], Callable[..., object] +] = {torch.ops.aten._reshape_alias.default: _reshape_alias} + + +def propagate_input_sharding( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]], +) -> Tuple[OpSchema, bool, Optional[OutputSharding]]: + # parse the operator schema + func_schema = FunctionSchema.parse(str(op_call._schema)) + # unwrap the args/kwargs schema + args_schema = tree_map(unwrap_schema, args) + kwargs_schema = tree_map(unwrap_schema, kwargs) + + op_schema = OpSchema(func_schema, args_schema, kwargs_schema) + + if _DEBUG_VERBOSE and torch.distributed.get_rank() == 0: + print(f"{op_call}({op_schema})") + local_shapes = tree_map( + lambda t: t.to_local().shape + if isinstance(t, dtensor.DTensor) + else None, + args, + ) + print(f" local shapes: {local_shapes}") + + op_key = str(op_call) + sharding_prop_func = op_to_rules.get(op_key, None) + + if sharding_prop_func is None: + # step 1. If there's not even one sharding rule + # implemented for the operator, we fall back to + # local tensor compute, this is wront currently + # we will change the behavior to reshard to full + # replicate and do the computatation + if not _ENABLE_FALLBACK: + raise NotImplementedError( + f"Operator {op_key} does not have a DistributedTensor rule registered." + ) + else: + return op_schema, False, None + + # step 2. there's sharding propagation rule, run + # sharding propagation to get output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_key}.\n" + f"Input schema: {op_schema}.\n" + f"Error: {e}" + ) from e + + # step 3. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we do auto redistribute on inputs + # to get an eligble input, which we will pick a + # target schema base on the redistribute cost + # TODO: implement full auto distribute with a + # simple cost estimation model + if output_sharding.output_spec is None: + # do auto distributed/boxing here + if output_sharding.schema_suggestions is not None: + # pick the first suggestion for now, + target_schema = output_sharding.schema_suggestions[0] + # run sharding propagation again with target schema + output_sharding = sharding_prop_func(target_schema) + + return target_schema, True, output_sharding + + else: + raise RuntimeError( + f"Sharding propagation failed on op {op_key}!" + f"Input schema: {op_schema}." + f"Failed reason: {output_sharding.failed_reason}" + ) + else: + return op_schema, False, output_sharding + + +def operator_dispatch( + op_call: torch._ops.OpOverload, + args: Tuple[object, ...], + kwargs: Dict[str, object], + op_to_rules: Dict[str, Callable[[OpSchema], OutputSharding]], + custom_dispatch_ops: Dict[str, Callable[..., object]], +) -> object: + # first we need to lift some private aten aliases to public calls + if op_call in _CURRENT_DECOMPOSITION_TABLE: + return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs) + + # STEP 0. See if threre're user defined custom aten operator + # implementations. Custom operators take the highest priority + if str(op_call) in custom_dispatch_ops: + # dispatch to user defined custom distributed tensor ops + return custom_dispatch_ops[str(op_call)](*args, **kwargs) + + target_schema, redistribute, output_sharding = propagate_input_sharding( + op_call, args, kwargs, op_to_rules + ) + + if output_sharding is None: + # default to local tensor ops, this is wrong + # but we use it now to enable more tensor point-wise ops + # TODO: delete this and use replicate (all_gather) as + # the default fallback. + tensor_args = tree_map(unwrap_local_tensor, args) + tensor_kwargs = tree_map(unwrap_local_tensor, kwargs) + local_results = op_call(*tensor_args, **tensor_kwargs) + return wrap(local_results, target_schema.args_spec[0]) + + local_tensor_args = pack_args_kwargs_with_local_tensor( + args, + target_schema.args_schema, + redistribute_with_schema=redistribute, + ) + local_tensor_kwargs = pack_args_kwargs_with_local_tensor( + kwargs, + target_schema.kwargs_schema, + redistribute_with_schema=redistribute, + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(Tuple[object, ...], local_tensor_args) + local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + if target_schema.is_inplace: + # inplace op should return self instead of re-wrapping + self = cast(dtensor.DTensor, args[0]) + self._spec = cast(DTensorSpec, output_sharding.output_spec) + return self + elif target_schema.is_out_variant: + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + for i, out in enumerate(target_schema.func_schema.arguments.out): + out_dt = cast(dtensor.DTensor, kwargs[out.name]) + out_dt._spec = cast(DTensorSpec, output_specs[i]) + out_dts.append(out_dt) + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + return wrap(local_results, output_sharding.output_spec) diff --git a/torch/distributed/_tensor/redistribute.py b/torch/distributed/_tensor/redistribute.py new file mode 100644 index 000000000000..ab36cd408903 --- /dev/null +++ b/torch/distributed/_tensor/redistribute.py @@ -0,0 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Dict, List, Sequence, Tuple, cast + +import torch +import torch.distributed._tensor.api as dtensor +from torch.distributed._tensor.placement_types import Placement, _Partial, Shard, Replicate +from torch.distributed._tensor.device_mesh import DeviceMesh + + +_PlacementItem = Tuple[int, Tuple[Placement, Placement]] + + +def _replicate_then_shard(val: _PlacementItem) -> int: + """ + Replicate from inner to outer dimension. + Shard from outer to inner dimension. + """ + i, (current, target) = val + if (target.is_replicate() or target.is_partial()) and current.is_shard(): + return -i + elif (current.is_replicate() or current.is_partial()) and target.is_shard(): + return i + else: + return 0 + + +def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]: + """ + Decompose Si -> Sj into Si -> R -> Sj + There's 2 ways a shardings can differ within a mesh dimension: + 1) sharding on different tensor dimensions, e.g. Shard(0) -> Shard(1) + 2) different sub-shards of a repeated shard ("mis-aligned sharding") + (Shard(0), Shard(0)) -> (Replicate(), Shard(0)) + Here the Shard(0) -> Shard(0) for mesh dimension 2 is actually + a reshard, because in the first case it's a sub-sharding of an already tensor dimension 0, + and in the second case, it's the first sharding on tensor dimesnion 0. + """ + # detect mis-aligned repeated shardings + from collections import defaultdict + + repeat_dim_current: Dict[int, int] = defaultdict(int) + repeat_dim_target: Dict[int, int] = defaultdict(int) + + output: List[_PlacementItem] = [] + + for i, (current, target) in val: + # detect mis-aligned sharding + if current.is_shard(): + repeat_dim_current[cast(Shard, current).dim] += 1 + if target.is_shard(): + repeat_dim_target[cast(Shard, target).dim] += 1 + if ( + isinstance(current, Shard) + and isinstance(target, Shard) + and ( + current.dim != target.dim + or repeat_dim_current[current.dim] + != repeat_dim_target[target.dim] + ) + ): + # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j) + output.append((i, (current, Replicate()))) + output.append((i, (Replicate(), target))) + else: + output.append((i, (current, target))) + + return output + + +# Intentionally expose this API to trace ops on local tensors +def _redistribute_with_local_tensor( + local_tensor: torch.Tensor, + size: torch.Size, + device_mesh: DeviceMesh, + current_placements: Sequence[Placement], + target_placements: Sequence[Placement], +) -> torch.Tensor: + new_local_tensor = None + + sorted_placements = list( + enumerate(zip(current_placements, target_placements)) + ) + sorted_placements = _decompose_reshard(sorted_placements) + sorted_placements.sort(key=_replicate_then_shard) + + for i, (current, target) in sorted_placements: + my_coordinate = device_mesh.get_coordinate_on_dim(i) + num_chunks = device_mesh.size(dim=i) + # TODO: what should happen if rank is not in the mesh? + # see issue https://github.com/pytorch/tau/pull/492 + assert ( + my_coordinate is not None + ), "Rank if not part of mesh" # TODO: figure out behavior here + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(_Partial, current) + new_local_tensor = partial_spec._to_replicate( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, size, device_mesh, i + ) + else: + raise RuntimeError( + f"redistribute from {current_placements} to {target_placements} not supported yet" + ) + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + if current.is_partial(): + partial_spec = cast(_Partial, current) + new_local_tensor = partial_spec._to_shard( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + shards, _ = target_placement._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + new_local_tensor = shards[my_coordinate].clone() + else: + # NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should + # decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1) + assert ( + current.is_shard() + ), f"Current placement should be shard but found {current}" + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + # TODO: enable this with all_to_all + raise NotImplementedError( + "Changing sharding dim is not supported yet!" + ) + + elif target.is_partial(): + if current.is_replicate(): + # For replicate -> partial, we zero out all other ranks of the current mesh dim + # and leave only 1 rank have the data, to perform a "zero cost" reshard. + if my_coordinate is not None and my_coordinate != 0: + new_local_tensor = local_tensor.zero_() + else: + new_local_tensor = local_tensor + else: + raise RuntimeError( + f"redistribute from {current_placements} to {target_placements} not supported yet" + ) + + assert new_local_tensor is not None + local_tensor = new_local_tensor + + assert new_local_tensor is not None, "redistribute failed!" + + return new_local_tensor + + +def redistribute_dtensor( + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> "dtensor.DTensor": + if input.device_mesh != device_mesh: + # TODO: alltoall reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + local_tensor = input._local_tensor + new_local_tensor = _redistribute_with_local_tensor( + local_tensor, + input.size(), + device_mesh, + input.placements, + placements, + ) + + return dtensor.DTensor( + new_local_tensor, + device_mesh, + placements, + size=input.size(), + requires_grad=local_tensor.requires_grad, + ) + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: List[Placement], + ): + ctx.previous_placement = input.placements + ctx.previous_device_mesh = input.device_mesh + return redistribute_dtensor(input, device_mesh, placements) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + # When we run backward pass of redistribute (i.e. manual redistribute from + # user code instead of torch_dispatch), we scan first and see if we need + # to change the target placement for one special case: + # replicate -> partial. + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is acutally useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + # TODO: see if this make sense for all cases. + target_placements: List[Placement] = [] + for current, target in zip(grad_output.placements, previous_placement): + if current.is_replicate() and target.is_partial(): + # keep target placement to replicate instead of partial in this case + target_placements.append(current) + else: + target_placements.append(target) + + return ( + redistribute_dtensor( + grad_output, previous_device_mesh, target_placements + ), + None, + None, + ) diff --git a/torch/distributed/_tensor/utils.py b/torch/distributed/_tensor/utils.py new file mode 100644 index 000000000000..bb56f488d81f --- /dev/null +++ b/torch/distributed/_tensor/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +from typing import Union, Dict, Tuple, Optional, Sequence + +import torch.distributed._tensor.api as dtensor +from torch.distributed._tensor.placement_types import DTensorSpec + +ArgKwargsType = Union[Tuple[object, ...], Dict[str, object]] +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +# be the same set of possiblities. +OutputSpecType = Optional[Union[DTensorSpec, Sequence[DTensorSpec]]] + + +def unwrap_local_tensor(e: "dtensor.DTensor") -> torch.Tensor: + return e._local_tensor if isinstance(e, dtensor.DTensor) else e + + +def unwrap_schema(e: object) -> object: + return e._spec if isinstance(e, dtensor.DTensor) else e + + +def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + assert spec is not None and isinstance( + spec, DTensorSpec + ), f"output spec does not match with output! Expected DTensorSpec, got {spec}." + return dtensor.DTensor( + res, + spec.mesh, + spec.placements, + size=spec.shape, + requires_grad=res.requires_grad, + ) + elif isinstance(res, list): + assert spec is not None and isinstance( + spec, list + ), f"output spec does not match with output! Expected list, got {spec}." + return list( + dtensor.DTensor(e, s.mesh, s.placements, size=s.shape) + for e, s in zip(res, spec) + ) + elif isinstance(res, tuple): + assert spec is not None and isinstance( + spec, tuple + ), f"output spec does not match with output! Expected tuple, got {spec}" + return tuple( + dtensor.DTensor(e, s.mesh, s.placements, size=s.shape) + for e, s in zip(res, spec) + ) + else: + # if the res contains only non tensor values, we simply return it without rewrapping + return res From 2dcf0978a249ae136c39e396200e5ed51407471d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 08:04:38 +0000 Subject: [PATCH 217/453] [dtensor] PART 3: move most DTensor ops to core distributed (#88177) This PR moves most DTensor ops to torch.distributed._tensor. We will add all tests in the following PRs. part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88177 Approved by: https://github.com/fduwjj --- torch/distributed/_tensor/ops/__init__.py | 7 + torch/distributed/_tensor/ops/common_rules.py | 376 ++++++++++++++ torch/distributed/_tensor/ops/math_ops.py | 141 +++++ torch/distributed/_tensor/ops/matrix_ops.py | 129 +++++ .../distributed/_tensor/ops/pointwise_ops.py | 396 ++++++++++++++ torch/distributed/_tensor/ops/tensor_ops.py | 481 ++++++++++++++++++ .../_tensor/ops/tp_sharding_ops.py | 55 ++ torch/distributed/_tensor/ops/utils.py | 81 +++ 8 files changed, 1666 insertions(+) create mode 100644 torch/distributed/_tensor/ops/__init__.py create mode 100644 torch/distributed/_tensor/ops/common_rules.py create mode 100644 torch/distributed/_tensor/ops/math_ops.py create mode 100644 torch/distributed/_tensor/ops/matrix_ops.py create mode 100644 torch/distributed/_tensor/ops/pointwise_ops.py create mode 100644 torch/distributed/_tensor/ops/tensor_ops.py create mode 100644 torch/distributed/_tensor/ops/tp_sharding_ops.py create mode 100644 torch/distributed/_tensor/ops/utils.py diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/_tensor/ops/__init__.py new file mode 100644 index 000000000000..5012768ee051 --- /dev/null +++ b/torch/distributed/_tensor/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from .matrix_ops import * # noqa: F403 +from .math_ops import * # noqa: F403 +from .tensor_ops import * # noqa: F403 +from .tp_sharding_ops import * # noqa: F403 +from .pointwise_ops import * # noqa: F403 +# from .view_ops import * # noqa: F403 diff --git a/torch/distributed/_tensor/ops/common_rules.py b/torch/distributed/_tensor/ops/common_rules.py new file mode 100644 index 000000000000..29925c8a52c7 --- /dev/null +++ b/torch/distributed/_tensor/ops/common_rules.py @@ -0,0 +1,376 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from typing import List, Sequence, Dict, Tuple, Optional, cast +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.ops.utils import prod + + +def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: + return string[:idx] + new_char + string[idx + 1 :] + + +def _inplace_rewrap_schema_suggestion( + suggestion: OpSchema, input_schema: OpSchema +) -> None: + suggestion_args_spec = suggestion.args_spec + new_arg_schema: List[object] = [] + idx_of_args_spec = 0 + for arg in input_schema.args_schema: + if isinstance(arg, DTensorSpec): + new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) + idx_of_args_spec += 1 + else: + new_arg_schema.append(arg) + suggestion.args_schema = tuple(new_arg_schema) + suggestion.kwargs_schema = input_schema.kwargs_schema + + +def _gen_reshard_suggestions( + op_schema: OpSchema, + input_dims: List[str], + input_specs: Tuple[DTensorSpec, ...], + dim_to_sharding: Dict[str, int], + pending_sum: List[int], +) -> OutputSharding: + suggested_arg_specs: List[DTensorSpec] = [] + for input_dim, input_spec in zip(input_dims, input_specs): + dim_map = [dim_to_sharding[dim] for dim in input_dim] + suggested_arg_specs.append( + DTensorSpec.from_dim_map( + mesh=input_spec.mesh, + dim_map=dim_map, + sums=pending_sum, + shape=input_spec.shape, + ) + ) + suggested_schema = OpSchema( + op_schema.func_schema, tuple(suggested_arg_specs), {} + ) + _inplace_rewrap_schema_suggestion(suggested_schema, op_schema) + return OutputSharding( + None, + schema_suggestions=[suggested_schema], + failed_reason="Input placements op sharding propagation failed, need to reshard!", + ) + + +def einop_rule( + equation: str, + op_schema: OpSchema, + *, + linearity: bool = False, + enforce_sharding: Optional[Dict[str, int]] = None, +) -> OutputSharding: + """ + Propagate the sharding of inputs to output for ops whose data + moves according to einsum notation. This is mostly borrowed + from @zdevito's sharding simulator. Examples: + mk,kn->mn - einsum + ij,ij->ij - addition + ij,j->ij - broadcasted addition + ij->i - reduction + Other ops could use this propagation algorithm when applied, note + that einsum propagation only deal with list of specs (DTensor specs) + as it only works on list of tensors! + + linearity in einop_rule means that the calling op `f` follows this rule: + f(a + b) = f(a) + f(b) + + In this case we can propagate the partial sum, note that linearity in einop + only applies to partial sum, not other operations like min/max (which are + associative but not linear). + """ + # parse einop equation and extract arg specs + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + input_specs = op_schema.args_spec + # NOTE: only support single output unless needed in future + output_dim = output_dims[0] + + dim_to_sharding: Dict[str, int] = {} + dim_to_size: Dict[str, int] = {} + # record pending sum, key is mesh dimension, value is pending sum + # counter across input specs + pending_sums_counter: Dict[int, int] = {} + seen_shardings: Dict[int, str] = {} + needs_reshard = False + + def merge_sharding(dim: str, a: int, b: int) -> int: + # merge the sharding of inputs if it's able to merge, i.e. we can merge + # replicate and shard to shard, but this will trigger an reshard operation + if a != b: + if a == -1 or b == -1: + # reshard the replicate to match the sharded one + nonlocal needs_reshard + needs_reshard = True + return a if a != -1 else b + else: + # TODO: further merge the sharding properly (i.e. reshard one input to replicate) + raise RuntimeError( + f"{equation}: dim {dim} sharded two different ways: {a} and {b}" + ) + else: + return a + + for input_dim, input_spec in zip(input_dims, input_specs): + # deal with partial sums + input_sums = input_spec.sums + for sum_dim in input_sums: + if sum_dim not in pending_sums_counter: + seen_shardings[sum_dim] = "+" + # update pending sum counter for pending sum mesh + # dimension with the occurance from each input + pending_sums_counter[sum_dim] = ( + pending_sums_counter.get(sum_dim, 0) + 1 + ) + + for idx, (dim, mesh_dim) in enumerate( + zip(input_dim, input_spec.dim_map) + ): + if enforce_sharding and dim in enforce_sharding: + if enforce_sharding[dim] != mesh_dim: + needs_reshard = True + dim_to_sharding[dim] = enforce_sharding[dim] + dim_to_size[dim] = input_spec.shape[idx] + elif dim not in dim_to_sharding: + dim_to_sharding[dim] = mesh_dim + dim_to_size[dim] = input_spec.shape[idx] + else: + dim_to_sharding[dim] = merge_sharding( + dim, dim_to_sharding[dim], mesh_dim + ) + assert dim_to_size[dim] == input_spec.shape[idx] + + # after merging sharding, we check if there're multiple + # sharding on the same mesh dim. + merged_sharding_for_dim = dim_to_sharding[dim] + if merged_sharding_for_dim != -1: + if ( + merged_sharding_for_dim in seen_shardings + and dim != seen_shardings[merged_sharding_for_dim] + ): + needs_reshard = True + seen_shardings[merged_sharding_for_dim] += dim + else: + seen_shardings[merged_sharding_for_dim] = dim + + if pending_sums_counter and not linearity: + # return reshard suggestion with no pending sum, because we already properly + # merge the sharding, this reshard suggestion is legit to use + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, [] + ) + else: + # It's a op that support linearity, but not all input arguments are partial + # we fail the sharding propagation with suggestion to make all inputs be + # partial on the corresponding mesh dim (all inputs should be partial for + # the mesh dims in order to execute locally and delay the sum reduction) + for value in pending_sums_counter.values(): + if value != len(input_specs): + needs_reshard = True + + for mesh_dim, dims in seen_shardings.items(): + if len(dims) > 1: + # we found different input dims are being sharded on the same mesh dim + # in order to perform local op computation, we need to reshard inputs + # base on some simple heuristics, now we simply pick the one with least comm + # volume. (i.e. the input with least size) + # TODO: consider a more advanced heuristic to pick the best sharding + costs = [] + for d in dims: + cost = 0 + for input_dim, input_spec in zip(input_dims, input_specs): + if ( + d in input_dim + and input_spec.dim_map[input_dim.index(d)] == mesh_dim + ): + cost += prod( + input_spec.local_shape + ) * input_spec.mesh.size(mesh_dim) + costs.append(cost) + d_to_keep_sharding = dims[costs.index(max(costs))] + for d in dims: + # update dim_to_sharding to keep the sharding of the dim with + # highest comm and make the rest of the dims to replicate + if d != d_to_keep_sharding: + dim_to_sharding[d] = -1 + + pending_sums = list(pending_sums_counter.keys()) + if needs_reshard: + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, pending_sums + ) + + # generate output pending sum if a dim is sharded, and it appears in input + # but not output + for dim, shard_on_mesh in dim_to_sharding.items(): + if dim not in output_dims[0] and shard_on_mesh != -1: + pending_sums.append(shard_on_mesh) + + # if no need to reshard, we directly generate the output sharding + output_dim_map = [] + output_shape = [] + for dim in output_dim: + if dim == "1": + # find output dim that is a singleton dimension, mark sharding and shape + output_dim_map.append(-1) + output_shape.append(1) + else: + output_dim_map.append(dim_to_sharding[dim]) + output_shape.append(dim_to_size[dim]) + + return OutputSharding( + DTensorSpec.from_dim_map( + input_specs[0].mesh, + output_dim_map, + pending_sums, + shape=torch.Size(output_shape), + ) + ) + + +def pointwise_rule( + op_schema: OpSchema, linearity: bool = False +) -> OutputSharding: + """ + Propagate the sharding for pointwise operations. Examples: + ij,ij->ij - addition/mul + ij,j->ij - broadcasted addition + """ + alphabet = "abcdefghijklmnopqrstuvwxyz" + # find the max_dim first in case we need to broadcasting + input_specs = op_schema.args_spec + max_dim = max(input.ndim for input in input_specs) + dimchars = [] + singleton_counter: List[int] = [0] * max_dim + for input in input_specs: + start_dim = max_dim - input.ndim + p = alphabet[start_dim:max_dim] + # handle the "broadcasting to a common shape case" + # see https://pytorch.org/docs/stable/notes/broadcasting.html + # If any of the dimensions is singleton dimension (i.e. 1). + # we mark the dim char as a special "1" to distinguish with + # the non-singleton dimension, so that sharding propagation + # should just ignore the singleton dimension. + if len(input_specs) > 1: + for i in range(max_dim): + if i < start_dim: + # treat the leading miss dim chars as singleton + singleton_counter[i] += 1 + elif input.shape[i - start_dim] == 1: + # mark singleton dim char as a special "1" in einop rule + singleton_counter[i] += 1 + p = _replace_char_in_str(p, "1", (i - start_dim)) + + dimchars.append(p) + out_dimchars = alphabet[:max_dim] + # check if we replace the all inputs dim char with singleton dimension, + # if we replace all inputs, we also need to replace the output dimension. + for output_dim_idx in range(len(out_dimchars)): + out_dimchar = out_dimchars[output_dim_idx] + if singleton_counter[output_dim_idx] == len(input_specs): + out_dimchars = _replace_char_in_str( + out_dimchars, "1", output_dim_idx + ) + + fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" + + enforce_sharding: Dict[str, int] = {} + if op_schema.is_inplace: + # inplace op should keep the input sharding it writes to + for out_dimchar, mesh_dim in zip(out_dimchars, input_specs[0].dim_map): + enforce_sharding[out_dimchar] = mesh_dim + elif op_schema.is_out_variant: + out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map): + enforce_sharding[out_dimchar] = mesh_dim + + return einop_rule( + fmt, + op_schema, + linearity=linearity, + enforce_sharding=enforce_sharding, + ) + + +def linear_pointwise_rule(op_schema: OpSchema) -> OutputSharding: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_rule(op_schema, linearity=True) + + +def reduction_rule( + op_schema: OpSchema, + *, + dims: Optional[Sequence[int]] = None, + keep_dim: bool = False, + reduction_linear: bool = False, +) -> OutputSharding: + """ + Propagate the sharding for reduction operations. Examples: + ij->i - sum on dim + + reduction_linear means that the reduction `f` follows this rule: + f([f(a), f(b)]) = f([a, b]) + + reduction linear should be super set of linearity. + """ + alphabet = "abcdefghijklmnopqrstuvwxyz" + # reduction op usually begin with a single tensor + input_spec = cast(DTensorSpec, op_schema.args_schema[0]) + reduce_dims = range(input_spec.ndim) if dims is None else dims + + if not reduction_linear: + # if the reduction is not linear, we need to clear the pending sum + # on the input spec, also replicate the reducing dimension if the + # reducing dimension is sharded, then suggest a resharding + reshard_dim_map = input_spec.dim_map + needs_reshard = False + for dim in reduce_dims: + if input_spec.dim_map[dim] != -1: + needs_reshard = True + reshard_dim_map[dim] = -1 + needs_reshard = needs_reshard or len(input_spec.sums) > 0 + + if needs_reshard: + no_partial_spec = DTensorSpec.from_dim_map( + input_spec.mesh, reshard_dim_map, [], input_spec.shape + ) + schema_suggestion = OpSchema( + op_schema.func_schema, (no_partial_spec,), {} + ) + _inplace_rewrap_schema_suggestion(schema_suggestion, op_schema) + return OutputSharding( + output_spec=None, schema_suggestions=[schema_suggestion] + ) + + input_chars = alphabet[: input_spec.ndim] + + if dims is None and not keep_dim: + # reducing to a single scalar tensor, we just mark output as empty + out_dimchars = "" + else: + # if keep the reduction dim, we need to keep the dim char by marking + # it as a singleton "1" in the out_dimchars + reduce_dim_char = ord("1") if keep_dim else None + out_dimchars = input_chars.translate( + {ord(alphabet[dim]): reduce_dim_char for dim in reduce_dims} + ) + fmt = f"{input_chars}->{out_dimchars}" + + enforce_sharding: Dict[str, int] = {} + if op_schema.is_out_variant: + out_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + for out_dimchar, mesh_dim in zip(out_dimchars, out_spec.dim_map): + enforce_sharding[out_dimchar] = mesh_dim + + return einop_rule( + fmt, + op_schema, + linearity=reduction_linear, + enforce_sharding=enforce_sharding, + ) diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py new file mode 100644 index 000000000000..eb4cd86ed5c6 --- /dev/null +++ b/torch/distributed/_tensor/ops/math_ops.py @@ -0,0 +1,141 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast, Optional, Sequence + +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.ops.common_rules import reduction_rule, pointwise_rule +from torch.distributed._tensor.ops.utils import register_prop_rule, as_list, normalize_dims + + +def _infer_reduction_dims( + dims_arg: object, ndim: int +) -> Optional[Sequence[int]]: + if dims_arg is None: + return None + dims = cast(Sequence[int], as_list(dims_arg)) + dims = normalize_dims(dims, ndim) + empty_dims = [[0], [-1], []] + if ndim == 0 and dims_arg in empty_dims: + return None + return dims + + +@register_prop_rule("aten.all.default") +def default_reduction_rule(op_schema: OpSchema) -> OutputSharding: + return reduction_rule(op_schema, reduction_linear=True) + + +def sum_rule(op_schema: OpSchema) -> OutputSharding: + args_schema = op_schema.args_schema + input_spec = cast(DTensorSpec, args_schema[0]) + dims = None + if len(args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_spec.ndim) + + keep_dim = len(args_schema) > 2 and bool(args_schema[2]) + return reduction_rule( + op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=True + ) + + +sum_ops = [ + "aten.sum.default", + "aten.sum.dim_IntList", +] +for sum_op in sum_ops: + DTensor._op_to_rules[sum_op] = sum_rule + + +@register_prop_rule("aten._softmax.default") +def softmax_rule(op_schema: OpSchema) -> OutputSharding: + input_spec, softmax_dim, _ = op_schema.args_schema + input_spec = cast(DTensorSpec, input_spec) + softmax_dim = cast(int, softmax_dim) + dim_map = input_spec.dim_map + if softmax_dim < len(dim_map) and dim_map[softmax_dim] >= 0: + raise RuntimeError("Cannot run softmax on sharding dimension!") + return OutputSharding(input_spec) + + +@register_prop_rule("aten._softmax_backward_data.default") +def softmax_bwd_rule(op_schema: OpSchema) -> OutputSharding: + grad_out_spec, out_spec, softmax_dim, _ = op_schema.args_schema + grad_out_spec = cast(DTensorSpec, grad_out_spec) + out_spec = cast(DTensorSpec, out_spec) + softmax_dim = cast(int, softmax_dim) + grad_out_dim_map = grad_out_spec.dim_map + out_dim_map = out_spec.dim_map + if softmax_dim < len(grad_out_dim_map) and ( + grad_out_dim_map[softmax_dim] >= 0 or out_dim_map[softmax_dim] >= 0 + ): + raise RuntimeError( + "Cannot run _softmax_backward_data on sharding dimension!" + ) + return pointwise_rule(op_schema) + + +def mean_rule(op_schema: OpSchema) -> OutputSharding: + args_schema = op_schema.args_schema + input_spec = cast(DTensorSpec, args_schema[0]) + dims = None + # if length of args > 1, we check args to find dims + if len(args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_spec.ndim) + + keep_dim = len(args_schema) > 2 and bool(args_schema[2]) + return reduction_rule( + op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False + ) + + +mean_ops = [ + "aten.mean.default", + "aten.mean.dim", + "aten.mean.out", +] + +for mean_op in mean_ops: + DTensor._op_to_rules[mean_op] = mean_rule + + +def var_rule(op_schema: OpSchema) -> OutputSharding: + args_schema = op_schema.args_schema + input_spec = cast(DTensorSpec, args_schema[0]) + dims = None + # if length of args > 1, we check args to find dims, note that + # var.default have unbias arg as the first argument, so we want + # to check if it's not bool + if len(args_schema) > 1 and not isinstance(args_schema[1], bool): + dims = _infer_reduction_dims(args_schema[1], input_spec.ndim) + + keep_dim = len(args_schema) > 3 and bool(args_schema[3]) + return reduction_rule( + op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False + ) + + +var_ops = [ + "aten.var.default", + "aten.var.dim", + "aten.var.out", +] + +for var_op in var_ops: + DTensor._op_to_rules[var_op] = var_rule + + +@register_prop_rule("aten.var.correction") +@register_prop_rule("aten.var.correction_out") +def var_correction_rule(op_schema: OpSchema) -> OutputSharding: + args_schema = op_schema.args_schema + input_spec = cast(DTensorSpec, args_schema[0]) + dims = None + if len(args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_spec.ndim) + + # keep_dim is a kwarg instead of arg for var.correction + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return reduction_rule( + op_schema, dims=dims, keep_dim=keep_dim, reduction_linear=False + ) diff --git a/torch/distributed/_tensor/ops/matrix_ops.py b/torch/distributed/_tensor/ops/matrix_ops.py new file mode 100644 index 000000000000..47988799282e --- /dev/null +++ b/torch/distributed/_tensor/ops/matrix_ops.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule +from torch.distributed._tensor.ops.utils import register_prop_rule + + +def _update_schema_suggestion_for_addmm( + output_sharding: OutputSharding, + op_schema: OpSchema, + pointwise_add_update: bool = True, +) -> OutputSharding: + # schema suggestion coming from output sharding could be: + # 1. pointwise add sharding input suggestion + # 2. mm sharding input suggestion + # inplace update schema suggestion to return addmm suggestion + assert output_sharding.schema_suggestions is not None + suggestion = output_sharding.schema_suggestions[0] + if pointwise_add_update: + # update with pointwise suggestion + args_schema = ( + suggestion.args_schema[0], + op_schema.args_schema[1], + op_schema.args_schema[2], + ) + else: + # update with mm suggestion + args_schema = ( + op_schema.args_schema[0], + suggestion.args_schema[0], + suggestion.args_schema[1], + ) + + output_sharding.schema_suggestions = [ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=args_schema, + kwargs_schema=op_schema.kwargs_schema, + ) + ] + return output_sharding + + +@register_prop_rule("aten.mm.default") +def mm_rules(op_schema: OpSchema) -> OutputSharding: + return einop_rule("mk,kn->mn", op_schema, linearity=False) + + +@register_prop_rule("aten.addmm.default") +def addmm_rules(op_schema: OpSchema) -> OutputSharding: + input_spec, mat1_spec, mat2_spec = op_schema.args_spec + mm_out_sharding = mm_rules( + OpSchema(op_schema.func_schema, (mat1_spec, mat2_spec), {}) + ) + if mm_out_sharding.output_spec is None: + # non-eligible input, suggest addmm input specs + if mm_out_sharding.schema_suggestions is not None: + # TODO: add more suggestions for resharding + return _update_schema_suggestion_for_addmm( + mm_out_sharding, + op_schema, + pointwise_add_update=False, + ) + else: + return OutputSharding(None) + + # run point wise rule on input + (mm_out) with linearity + output_sharding = pointwise_rule( + OpSchema( + op_schema.func_schema, (input_spec, mm_out_sharding.output_spec), {} + ), + linearity=True, + ) + # if propagation failed, edit the schema suggestion from pointwise rules + # to return addmm suggestion instead as it's a chained suggestion. + if ( + output_sharding.output_spec is None + and output_sharding.schema_suggestions is not None + ): + return _update_schema_suggestion_for_addmm(output_sharding, op_schema) + + return output_sharding + + +@register_prop_rule("aten.t.default") +def transpose_rule(op_schema: OpSchema) -> OutputSharding: + return einop_rule("ij->ji", op_schema, linearity=True) + + +@register_prop_rule("aten.bmm.default") +def bmm_rules(op_schema: OpSchema) -> OutputSharding: + return einop_rule("bmk,bkn->bmn", op_schema, linearity=False) + + +@register_prop_rule("aten.baddbmm.default") +def baddbmm_rules(op_schema: OpSchema) -> OutputSharding: + input_spec, mat1_spec, mat2_spec = op_schema.args_spec + bmm_output_sharding = bmm_rules( + OpSchema(op_schema.func_schema, (mat1_spec, mat2_spec), {}) + ) + if bmm_output_sharding.output_spec is None: + # TODO: add more suggestions + if bmm_output_sharding.schema_suggestions is not None: + return _update_schema_suggestion_for_addmm( + bmm_output_sharding, + op_schema, + pointwise_add_update=False, + ) + else: + return OutputSharding(None) + + # run point wise rule on input + (bmm_out) with linearity + output_sharding = pointwise_rule( + OpSchema( + op_schema.func_schema, + (input_spec, bmm_output_sharding.output_spec), + {}, + ), + linearity=True, + ) + # if propagation failed, edit the schema suggestion from pointwise rules + # to return baddbmm suggestion instead as it's a chained suggestion. + if ( + output_sharding.output_spec is None + and output_sharding.schema_suggestions is not None + ): + return _update_schema_suggestion_for_addmm(output_sharding, op_schema) + + return output_sharding diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py new file mode 100644 index 000000000000..6c92eacd1b8b --- /dev/null +++ b/torch/distributed/_tensor/ops/pointwise_ops.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import cast + +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.ops.common_rules import linear_pointwise_rule, pointwise_rule +from torch.distributed._tensor.placement_types import DTensorSpec, Replicate, _Partial +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.ops.utils import register_prop_rule + +# leave the remaining pointwise_ops list here for convenience, +# Below ops are some pointwise ops that are yet to be supported, +# they might not be a complete list. +# pointwise_ops = [ +# "fake_quantize_per_channel_affine", +# "fake_quantize_per_tensor_affine", +# "floor_divide", # floor_divide is deprecated +# "frexp", # multiple output pointwise op, need to add support +# "gradient", # need investigation on this op +# "imag", # complex data type only +# "quantized_batch_norm", +# "quantized_max_pool1d", +# "quantized_max_pool2d", +# "real", # complex data type only +# ] + + +linear_pointwise_ops = [ + "aten.div.Scalar", # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + "aten.to.dtype", +] + + +pointwise_ops = [ + # please keep the entries below alphabetically sorted + "aten.abs.default", + "aten.acos.default", + "aten.acos.out", + "aten.acos_.default", + "aten.acosh.default", + "aten.acosh.out", + "aten.acosh_.default", + "aten.add.Scalar", + "aten.add.Tensor", + "aten.add.out", + "aten.add_.Scalar", + "aten.add_.Tensor", + "aten.addcdiv.default", + "aten.addcdiv.out", + "aten.addcdiv_.default", + "aten.addcmul.default", + "aten.addcmul.out", + "aten.addcmul_.default", + "aten.angle.default", + "aten.angle.out", + "aten.asin.default", + "aten.asin.out", + "aten.asin_.default", + "aten.asinh.default", + "aten.asinh.out", + "aten.asinh_.default", + "aten.atan.default", + "aten.atan.out", + "aten.atan2.default", + "aten.atan2.out", + "aten.atan2_.default", + "aten.atan_.default", + "aten.atanh.default", + "aten.atanh.out", + "aten.atanh_.default", + "aten.bitwise_and.Scalar", + "aten.bitwise_and.Scalar_Tensor", + "aten.bitwise_and.Scalar_out", + "aten.bitwise_and.Tensor", + "aten.bitwise_and.Tensor_out", + "aten.bitwise_and_.Scalar", + "aten.bitwise_and_.Tensor", + "aten.bitwise_left_shift.Scalar_Tensor", + "aten.bitwise_left_shift.Tensor", + "aten.bitwise_left_shift.Tensor_Scalar", + "aten.bitwise_left_shift.Tensor_Scalar_out", + "aten.bitwise_left_shift.Tensor_out", + "aten.bitwise_left_shift_.Tensor", + "aten.bitwise_left_shift_.Tensor_Scalar", + "aten.bitwise_not.default", + "aten.bitwise_not.out", + "aten.bitwise_not_.default", + "aten.bitwise_or.Scalar", + "aten.bitwise_or.Scalar_Tensor", + "aten.bitwise_or.Scalar_out", + "aten.bitwise_or.Tensor", + "aten.bitwise_or.Tensor_out", + "aten.bitwise_or_.Scalar", + "aten.bitwise_or_.Tensor", + "aten.bitwise_right_shift.Scalar_Tensor", + "aten.bitwise_right_shift.Tensor", + "aten.bitwise_right_shift.Tensor_Scalar", + "aten.bitwise_right_shift.Tensor_Scalar_out", + "aten.bitwise_right_shift.Tensor_out", + "aten.bitwise_right_shift_.Tensor", + "aten.bitwise_right_shift_.Tensor_Scalar", + "aten.bitwise_xor.Scalar", + "aten.bitwise_xor.Scalar_Tensor", + "aten.bitwise_xor.Scalar_out", + "aten.bitwise_xor.Tensor", + "aten.bitwise_xor.Tensor_out", + "aten.bitwise_xor_.Scalar", + "aten.bitwise_xor_.Tensor", + "aten.ceil.default", + "aten.ceil.out", + "aten.ceil_.default", + "aten.clamp.default", + "aten.clamp.out", + "aten.clamp_.default", + "aten.clip.default", + "aten.clip.out", + "aten.clip_.default", + "aten.conj_physical.default", + "aten.conj_physical.out", + "aten.conj_physical_.default", + "aten.copy_sign.Scalar", + "aten.copy_sign.Scalar_out", + "aten.copy_sign.Tensor", + "aten.copy_sign.out", + "aten.copy_sign_.Scalar", + "aten.copy_sign_.Tensor", + "aten.cos.default", + "aten.cos.out", + "aten.cos_.default", + "aten.cosh.default", + "aten.cosh.out", + "aten.cosh_.default", + "aten.deg2rad.default", + "aten.deg2rad.out", + "aten.deg2rad_.default", + "aten.digamma.default", + "aten.digamma.out", + "aten.digamma_.default", + "aten.div.Tensor", + "aten.div.Tensor_mode", + "aten.div.out", + "aten.div.out_mode", + "aten.div_.Tensor", + "aten.div_.Tensor_mode", + "aten.eq.Tensor", + "aten.eq.Tensor_out", + "aten.eq.Scalar", + "aten.eq.Scalar_out", + "aten.equal.default", + "aten.erf.default", + "aten.erf.out", + "aten.erf_.default", + "aten.erfc.default", + "aten.erfc.out", + "aten.erfc_.default", + "aten.erfinv.default", + "aten.erfinv.out", + "aten.erfinv_.default", + "aten.exp.default", + "aten.exp.out", + "aten.exp2.default", + "aten.exp2.out", + "aten.exp2_.default", + "aten.exp_.default", + "aten.expm1.default", + "aten.expm1.out", + "aten.expm1_.default", + "aten.float_power.Scalar", + "aten.float_power.Scalar_out", + "aten.float_power.Tensor_Scalar", + "aten.float_power.Tensor_Scalar_out", + "aten.float_power.Tensor_Tensor", + "aten.float_power.Tensor_Tensor_out", + "aten.float_power_.Scalar", + "aten.float_power_.Tensor", + "aten.floor.default", + "aten.floor.out", + "aten.floor_.default", + "aten.fmod.Scalar", + "aten.fmod.Scalar_out", + "aten.fmod.Tensor", + "aten.fmod.Tensor_out", + "aten.fmod_.Scalar", + "aten.fmod_.Tensor", + "aten.frac.default", + "aten.frac.out", + "aten.frac_.default", + "aten.ge.Scalar", + "aten.ge.Tensor", + "aten.gelu.default", + "aten.gt.Scalar", + "aten.gt.Tensor", + "aten.hypot.default", + "aten.hypot.out", + "aten.hypot_.default", + "aten.i0.default", + "aten.i0.out", + "aten.i0_.default", + "aten.igamma.default", + "aten.igamma.out", + "aten.igamma_.default", + "aten.igammac.default", + "aten.igammac.out", + "aten.igammac_.default", + "aten.isnan.default", + "aten.ldexp.default", + "aten.ldexp.out", + "aten.ldexp_.default", + "aten.le.Scalar", + "aten.le.Tensor", + "aten.lerp.Scalar", + "aten.lerp.Scalar_out", + "aten.lerp.Tensor", + "aten.lerp.Tensor_out", + "aten.lerp_.Scalar", + "aten.lerp_.Tensor", + "aten.lgamma.default", + "aten.lgamma.out", + "aten.lgamma_.default", + "aten.log.default", + "aten.log.out", + "aten.log10.default", + "aten.log10.out", + "aten.log10_.default", + "aten.log1p.default", + "aten.log1p.out", + "aten.log1p_.default", + "aten.log2.default", + "aten.log2.out", + "aten.log2_.default", + "aten.log_.default", + "aten.logaddexp.default", + "aten.logaddexp.out", + "aten.logaddexp2.default", + "aten.logaddexp2.out", + "aten.logical_and.default", + "aten.logical_and.out", + "aten.logical_and_.default", + "aten.logical_not.default", + "aten.logical_not.out", + "aten.logical_not_.default", + "aten.logical_or.default", + "aten.logical_or.out", + "aten.logical_or_.default", + "aten.logical_xor.default", + "aten.logical_xor.out", + "aten.logical_xor_.default", + "aten.logit.default", + "aten.logit.out", + "aten.logit_.default", + "aten.masked_fill.Scalar", + "aten.mul.Scalar", + "aten.mul.Tensor", + "aten.mul.out", + "aten.mul_.Scalar", + "aten.mul_.Tensor", + "aten.mvlgamma.default", + "aten.mvlgamma.out", + "aten.mvlgamma_.default", + "aten.native_dropout_backward.default", + "aten.native_dropout_backward.out", + "aten.nan_to_num.default", + "aten.nan_to_num.out", + "aten.nan_to_num_.default", + "aten.ne.Scalar", + "aten.neg.default", + "aten.neg.out", + "aten.neg_.default", + "aten.nextafter.default", + "aten.nextafter.out", + "aten.nextafter_.default", + "aten.polygamma.default", + "aten.polygamma.out", + "aten.polygamma_.default", + "aten.positive.default", + "aten.pow.Scalar", + "aten.pow.Scalar_out", + "aten.pow.Tensor_Scalar", + "aten.pow.Tensor_Scalar_out", + "aten.pow.Tensor_Tensor", + "aten.pow.Tensor_Tensor_out", + "aten.pow_.Scalar", + "aten.pow_.Tensor", + "aten.reciprocal.default", + "aten.reciprocal.out", + "aten.reciprocal_.default", + "aten.red2deg.default", + "aten.red2deg.out", + "aten.red2deg_.default", + "aten.relu.default", + "aten.relu_.default", + "aten.remainder.Scalar", + "aten.remainder.Scalar_Tensor", + "aten.remainder.Scalar_out", + "aten.remainder.Tensor", + "aten.remainder.Tensor_out", + "aten.remainder_.Scalar", + "aten.remainder_.Tensor", + "aten.round.decimals", + "aten.round.decimals_out", + "aten.round.default", + "aten.round.out", + "aten.round_.decimals", + "aten.round_.default", + "aten.rsqrt.default", + "aten.rsqrt.out", + "aten.rsqrt_.default", + "aten.rsub.Scalar", + "aten.sgn.default", + "aten.sgn.out", + "aten.sgn_.default", + "aten.sigmoid.default", + "aten.sigmoid.out", + "aten.sigmoid_.default", + "aten.sign.default", + "aten.sign.out", + "aten.sign_.default", + "aten.signbit.default", + "aten.signbit.out", + "aten.sin.default", + "aten.sin.out", + "aten.sin_.default", + "aten.sinc.default", + "aten.sinc.out", + "aten.sinc_.default", + "aten.sinh.default", + "aten.sinh.out", + "aten.sinh_.default", + "aten.sqrt.default", + "aten.sqrt.out", + "aten.sqrt_.default", + "aten.square.default", + "aten.square.out", + "aten.square_.default", + "aten.sub.Scalar", + "aten.sub.Tensor", + "aten.sub.out", + "aten.sub_.Scalar", + "aten.sub_.Tensor", + "aten.tan.default", + "aten.tan.out", + "aten.tan_.default", + "aten.tanh.default", + "aten.tanh.out", + "aten.tanh_.default", + "aten.true_divide.Tensor", + "aten.trunc.default", + "aten.trunc.out", + "aten.trunc_.default", + "aten.where.self", + "aten.xlogy.OutScalar_Self", + "aten.xlogy.OutTensor", + "aten.xlogy.Scalar_other", + "aten.xlogy.Scalar_self", + "aten.xlogy.Tensor", + "aten.xlogy_.OutScalar_Other", + "aten.xlogy_.Scalar_other", + "aten.xlogy_.Tensor", + "prims.convert_element_type.default", + # backward point-wise ops + # please keep the entries below alphabetically sorted + "aten.gelu_backward.default", + "aten.sigmoid_backward.default", + "aten.tanh_backward.default", + "aten.threshold_backward.default", +] + + +for op in linear_pointwise_ops: + DTensor._op_to_rules[op] = linear_pointwise_rule + + +for op in pointwise_ops: + DTensor._op_to_rules[op] = pointwise_rule + + +@register_prop_rule("aten.native_dropout.default") +def dropout_rule(op_schema: OpSchema) -> OutputSharding: + self_spec = cast(DTensorSpec, op_schema.args_schema[0]) + + # TODO: We are specializing dropout_rule now because it's + # a non-deterministic algorithm, and replication does not + # not support non-deterministic op yet. We should remove + # this rule and make dropout to use pointwise rule instead + # once we support non-deterministic op. + replicate_or_partial = False + for placement in self_spec.placements: + if isinstance(placement, (Replicate, _Partial)): + replicate_or_partial = True + break + + if replicate_or_partial: + return OutputSharding( + None, failed_reason="Dropout with replication is not supported yet!" + ) + else: + return OutputSharding(self_spec) diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py new file mode 100644 index 000000000000..f386e1fdb9fd --- /dev/null +++ b/torch/distributed/_tensor/ops/tensor_ops.py @@ -0,0 +1,481 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from torch.distributed._tensor.api import ( + DTensor, + DTensorSpec, + Placement, + Replicate, + Shard, + _Partial, +) +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.ops.common_rules import pointwise_rule +from torch.distributed._tensor.ops.utils import register_prop_rule +from typing import List, Optional, Sequence, Tuple, cast + + +# NOTE: the default propagation rule should apply for +# any operator that does not return a DTensor, i.e. +# for operators that only returns int/float/bool, we by +# default still propagate the spec, this is to ensure +# that we only return None for the case where the sharding +# propagation failed, and we should do auto-redistribute +def default_prop_rule(op_schema: OpSchema) -> OutputSharding: + # by default prop the first arg spec + return OutputSharding(op_schema.args_spec[0]) + + +def prop_create_like(op_schema: OpSchema) -> OutputSharding: + # For operators that create tensors with same shape as input but + # with specific content that does not depend on the input, we + # can propagate Sharding, but we have to make sure we move from + # partial to replicated. + input_spec = op_schema.args_spec[0] + output_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=tuple( + Replicate() if isinstance(p, _Partial) else p + for p in input_spec.placements + ), + ndim=input_spec.ndim, + shape=input_spec.shape, + ) + return OutputSharding(output_spec=output_spec) + + +# some tensor ops should not support shard, i.e. local_scalar_dense +# shouldn't work for shard as it requires numel == 1 +def no_shard_prop_rule(op_schema: OpSchema) -> OutputSharding: + # by default prop the first arg spec + tensor_spec = op_schema.args_spec[0] + for placement in tensor_spec.placements: + if placement.is_shard(): + return OutputSharding( + None, + failed_reason=f"Op does not support input placements " + f"with `Shard`, but found placements: " + f"{tensor_spec.placements}", + ) + # otherwise default prop the first arg spec + return OutputSharding(tensor_spec) + + +def new_factory_rule(op_schema: OpSchema) -> OutputSharding: + # this op would benefit from backward sharding propagation! + # Since we cannot do that yet, just return replicated + input = op_schema.args_schema[0] + size = torch.Size(cast(Sequence[int], op_schema.args_schema[1])) + assert isinstance(input, DTensorSpec) + + return OutputSharding( + output_spec=DTensorSpec( + mesh=input.mesh, + placements=[Replicate()] * input.mesh.ndim, + shape=size, + ndim=len(size), + ) + ) + + +default_prop_ops = [ + "aten._to_copy.default", + "aten.clone.default", + "aten.contiguous.default", + "aten.copy_.default", + "aten.detach.default", + "aten.is_same_size.default", + "aten.new_empty_strided.default", +] + +create_like_ops = [ + "aten.empty_like.default", + "aten.fill_.Scalar", + "aten.full_like.default", + "aten.ones_like.default", + "aten.zero_.default", + "aten.zeros_like.default", +] + +new_factory_ops = [ + "aten.new_full.default", + "aten.new_ones.default", + "aten.new_zeros.default", +] + +no_shard_prop_ops = ["aten._local_scalar_dense.default"] + +for op in default_prop_ops: + DTensor._op_to_rules[op] = default_prop_rule + +for op in create_like_ops: + DTensor._op_to_rules[op] = prop_create_like + +for op in no_shard_prop_ops: + DTensor._op_to_rules[op] = no_shard_prop_rule + +for op in new_factory_ops: + DTensor._op_to_rules[op] = new_factory_rule + + +@register_prop_rule("aten.bucketize.Tensor") +def prop_bucketize(op_schema: OpSchema) -> OutputSharding: + """ + Point-wise on the first input (just propagate input sharding). + Expect replicated for second input. + """ + input_schema, boundaries = op_schema.args_schema + assert isinstance(input_schema, DTensorSpec) + assert isinstance(boundaries, DTensorSpec) + + if all(isinstance(p, Replicate) for p in boundaries.placements): + return OutputSharding(output_spec=input_schema) + else: + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=( + input_schema, + DTensorSpec( + mesh=boundaries.mesh, + placements=[Replicate()] + * len(boundaries.placements), + ndim=boundaries.ndim, + shape=boundaries.shape, + ), + ), + kwargs_schema=op_schema.kwargs_schema, + ) + ], + ) + + +def unshard_tensor_dim( + placements: Sequence[Placement], dim: int +) -> Sequence[Placement]: + """Disallow the given tensor dimension to be sharded""" + return tuple( + p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() + for p in placements + ) + + +def _prop_all_but_dim( + op_schema: OpSchema, dim: int, out_shape: torch.Size +) -> OutputSharding: + """ + Considering an op that takes its input as first argument, forwards all shardings + except for the given dimension. + """ + input_spec = op_schema.args_schema[0] + assert isinstance(input_spec, DTensorSpec) + + output_placements = unshard_tensor_dim(input_spec.placements, dim=dim) + output_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=output_placements, + shape=out_shape, + ndim=input_spec.ndim, + ) + + if input_spec.placements == output_placements: + out = OutputSharding(output_spec=output_spec) + else: + suggested_input_spec = DTensorSpec( + mesh=input_spec.mesh, + placements=output_placements, + ndim=input_spec.ndim, + shape=input_spec.shape, + ) + out = OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=(suggested_input_spec,) + + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ), + ], + ) + return out + + +@register_prop_rule("aten.slice.Tensor") +def prop_slice(op_schema: OpSchema) -> OutputSharding: + """NOTE: can be further optimized (right now it replicates before slicing on a sharded dimension)""" + defaults = (None, 0, None, None, 1) + input_spec, dim, start, end, step = ( + op_schema.args_schema + defaults[len(op_schema.args_schema) :] + ) + assert isinstance(input_spec, DTensorSpec) + assert isinstance(dim, int) + assert start is None or isinstance(start, int) + assert end is None or isinstance(end, int) + assert isinstance(step, int) + + # normalize arguments + if dim < 0: + dim += input_spec.ndim + if start is None: + start = 0 + if step is None: + step = 1 + if end is None or end > input_spec.shape[dim]: + end = input_spec.shape[dim] + if start < 0: + start += input_spec.shape[dim] + if end < 0: + end += input_spec.shape[dim] + + if start == 0 and end == input_spec.shape[dim] and step == 1: + return OutputSharding(output_spec=input_spec) + + # shape propagation + slice_len = (end - start + step - 1) // step + out_shape = torch.Size( + tuple(input_spec.shape[0:dim]) + + (slice_len,) + + tuple(input_spec.shape[dim + 1 :]) + ) + + return _prop_all_but_dim(op_schema, dim=dim, out_shape=out_shape) + + +@register_prop_rule("aten.slice_scatter.default") +def prop_slice_scatter(op_schema: OpSchema) -> OutputSharding: + # 1. number of dimensions in input and src need to match. + # 2. number of elements on all non-dim need to match between input and src. + # 3. numer of elements in src in dim need to match the slice size. + # Given the above: + # - We suggest for src to follow the sharding of input, except on the scatter dimension, + # where our best bet for now is to make them replicated as a fall-back. + # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. + + defaults = (None, None, 0, None, None, 1) + input, src, dim, start, end, step = ( + op_schema.args_schema + defaults[len(op_schema.args_schema) :] + ) + assert isinstance(input, DTensorSpec) + assert isinstance(src, DTensorSpec) + assert isinstance(dim, int) + + if dim < 0: + dim += input.ndim + + # first, we keep the input sharding, except for the input dimension + # also, we cannot allow partial sum anymore. + input_suggestion = tuple( + Replicate() + if isinstance(p, _Partial) or (isinstance(p, Shard) and p.dim == dim) + else p + for p in input.placements + ) + + if input_suggestion == tuple(input.placements) and src.placements == tuple( + input.placements + ): + # if our sharding is correct, the output sharding will be the same as the input. + return OutputSharding( + output_spec=DTensorSpec( + mesh=input.mesh, + placements=input.placements, + shape=input.shape, + ndim=input.ndim, + ) + ) + else: + # otherwise, return the suggestion. + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=( + DTensorSpec( + mesh=input.mesh, + placements=input_suggestion, + shape=input.shape, + ndim=input.ndim, + ), + DTensorSpec( + mesh=src.mesh, + placements=input_suggestion, + shape=src.shape, + ndim=src.ndim, + ), + ) + + op_schema.args_schema[2:], + kwargs_schema=op_schema.kwargs_schema, + ) + ], + ) + + +@register_prop_rule("aten.index_select.default") +def prop_index_select(op_schema: OpSchema) -> OutputSharding: + values_spec, dim, indices_spec = op_schema.args_schema + + assert isinstance(values_spec, DTensorSpec) + assert isinstance(dim, int) + assert isinstance(indices_spec, DTensorSpec) + + all_indices_spec: List[Optional[DTensorSpec]] = [ + indices_spec if dim == i else None for i in range(values_spec.ndim) + ] + + result = prop_index( + OpSchema( + func_schema=op_schema.func_schema, + args_schema=(values_spec, all_indices_spec), + kwargs_schema=op_schema.kwargs_schema, + ) + ) + if result.schema_suggestions: + result.schema_suggestions = [ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=(s.args_schema[0], dim, s.args_schema[1][dim]), + kwargs_schema=op_schema.kwargs_schema, + ) + for s in result.schema_suggestions + ] + return result + + +@register_prop_rule("aten.index.Tensor") +def prop_index(op_schema: OpSchema) -> OutputSharding: + """ + Expect replicated on the first input; _mostly_ pointwise on the second input. + TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. + """ + # Current sharding constraints: + # For values: + # 1. We currently require that the dimension of values_spec be replicated or partial + # if they are being indexed on. + # 2. Other dimensions of values_spec can remain sharded if they are so. + # For indices: + # Indices can be either sharded or replicated. All index tensors need to be sharded + # in a compatible way, following the pointwise rule (including resolving _Partial + # into either sharded or replicated) + + values_spec, multi_indices_spec = op_schema.args_schema + assert isinstance(values_spec, DTensorSpec) + assert isinstance(multi_indices_spec, list) + multi_indices_spec = cast(List[Optional[DTensorSpec]], multi_indices_spec) + valid_indices_spec: List[Tuple[int, DTensorSpec]] = [ + (i, a) for i, a in enumerate(multi_indices_spec) if a is not None + ] + + # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. + # Here, we piggyback on the pointwise sharding rule for indices. + indices_out = pointwise_rule( + OpSchema( + func_schema=op_schema.func_schema, + args_schema=tuple(v[1] for v in valid_indices_spec), + kwargs_schema={}, + ) + ) + need_reshard_on_indices = indices_out.output_spec is None + + if not need_reshard_on_indices: + # this means that our inputs are already sharded properly and we will use that as our indices_spec + assert isinstance(indices_out.output_spec, DTensorSpec) + indices_spec: DTensorSpec = indices_out.output_spec + else: + assert indices_out.schema_suggestions is not None + valid_indices_suggestion = indices_out.schema_suggestions[0] + for i, v in enumerate(valid_indices_suggestion.args_spec): + multi_indices_spec[valid_indices_spec[i][0]] = v + # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then + # use that to compute our ideal values_spec + indices_output_spec = pointwise_rule( + valid_indices_suggestion + ).output_spec + assert isinstance(indices_output_spec, DTensorSpec) + indices_spec = indices_output_spec + + lookup_dims = set(v[0] for v in valid_indices_spec) + + need_reshard_on_values = tuple( + ( + isinstance(vp, Shard) + and (vp.dim in lookup_dims or isinstance(ip, Shard)) + ) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + + if not need_reshard_on_indices and not any(need_reshard_on_values): + + value_placements = values_spec.placements + value_shape = values_spec.shape + + all_dims_consecutive = all( + b[0] - a[0] == 1 + for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) + ) + if all_dims_consecutive: + # if all index vectors are consecutives, insert at the dimension of the first index + insert_dim: int = valid_indices_spec[0][0] + else: + # else, insert on the first dimension + insert_dim = 0 + + def place(vp: Placement, ip: Placement) -> Placement: + if isinstance(vp, Shard): + return Shard( + vp.dim + if vp.dim < insert_dim + # accounts for the offset in output dimensions + else vp.dim + + indices_spec.ndim + - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) + ) + if isinstance(ip, Shard): + return Shard(ip.dim + insert_dim) + # _Partial or Replicated + return vp + + value_placements = tuple( + place(vp, ip) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + value_shape = torch.Size( + tuple(value_shape[:insert_dim]) + + tuple(indices_spec.shape) + + tuple(value_shape[insert_dim + len(valid_indices_spec) :]) + ) + + result = OutputSharding( + output_spec=DTensorSpec( + mesh=values_spec.mesh, + placements=value_placements, + shape=value_shape, + ndim=len(value_shape), + ) + ) + return result + else: + result = OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=( + DTensorSpec( + mesh=values_spec.mesh, + placements=[ + Replicate() if need_reshard_on_values[i] else v + for i, v in enumerate(values_spec.placements) + ], + ndim=values_spec.ndim, + shape=values_spec.shape, + ), + multi_indices_spec, + ), + kwargs_schema=op_schema.kwargs_schema, + ) + ], + ) + return result diff --git a/torch/distributed/_tensor/ops/tp_sharding_ops.py b/torch/distributed/_tensor/ops/tp_sharding_ops.py new file mode 100644 index 000000000000..01db8920e674 --- /dev/null +++ b/torch/distributed/_tensor/ops/tp_sharding_ops.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +import torch +import torch.utils._pytree as pytree +from typing import List +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.utils import unwrap_local_tensor +from torch.distributed._tensor.ops.utils import unwrap_single_placement, register_impl + +""" +The ops below were quickly hacked and needed to be polished down the road. +Although they come with unit tests already, the logic is directly borrowed +from ShardedTensor. We need to also make it work for all placement types +of DTensor and all corner cases for sharded distributed tensor. +""" + + +@register_impl("aten.cat.default") +def dist_cat(tensor_list: List[DTensor], dim: int = 0) -> DTensor: + local_inputs = pytree.tree_map(unwrap_local_tensor, tensor_list) + local_tensor = torch.ops.aten.concat(local_inputs, dim=dim) + return DTensor.from_local( + local_tensor, + tensor_list[0].device_mesh, + tensor_list[0].placements, + run_check=False, + ) + + +@register_impl("aten.split.Tensor") +# pyre-fixme[2]: Parameter must be annotated. +def dist_split(self: DTensor, split_size_or_sections, dim=0) -> List[DTensor]: + local_mat = pytree.tree_map(unwrap_local_tensor, self) + mat_placement = pytree.tree_map(unwrap_single_placement, self) + sharding_dim = mat_placement.dim + world_size = self.device_mesh.size(dim=0) + if dim < 0: + dim = self.dim() + dim + if sharding_dim < 0: + sharding_dim = self.dim() + sharding_dim + if dim == sharding_dim: + if type(split_size_or_sections) is list: + split_size_or_sections[sharding_dim] //= world_size + else: + split_size_or_sections //= world_size + tensor_list = local_mat.split(split_size_or_sections, dim=dim) + return [ + DTensor.from_local( + tensor, + self.device_mesh, + [mat_placement], + run_check=False, + ) + for tensor in tensor_list + ] diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py new file mode 100644 index 000000000000..42db7142638a --- /dev/null +++ b/torch/distributed/_tensor/ops/utils.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +import operator + +import torch +from typing import List, Union, Sequence, Iterable +from torch.distributed._tensor.api import DTensor + + +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def unwrap_single_placement(e): + if not isinstance(e, DTensor): + return None + assert len(e.placements) == 1, "more than one placement!" + return e.placements[0] + + +# convenient wrapper to register custom operator impls +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def register_impl(func): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def wrapper(impl): + DTensor._custom_dispatch_ops[func] = impl + return impl + + return wrapper + + +# convenient wrapper to register sharding propagation rules +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def register_prop_rule(func): + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def wrapper(impl): + DTensor._op_to_rules[func] = impl + return impl + + return wrapper + + +def as_list( + x: Union[List[object], object] + # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. +) -> Union[List[object], torch.fx.immutable_collections.immutable_list]: + # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, + # which is an object but treated as a list by the tracer. Therefore, keep + # `immutable_list` intact here as well. + if type(x) is list or isinstance( + x, torch.fx.immutable_collections.immutable_list + ): + return x + else: + return [x] + + +def normalize_dim(dim: int, ndim: int) -> int: + return dim if dim >= 0 else dim + ndim + + +def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]: + """ + normalize a dim or a sequence of dims, so that they + are all positive. + """ + if isinstance(dims, int): + dims = (normalize_dim(dims, ndim),) + elif isinstance(dims, list): + dims = [normalize_dim(dim, ndim) for dim in dims] + elif isinstance(dims, tuple): + dims = tuple([normalize_dim(dim, ndim) for dim in dims]) + return dims + + +def prod(xs: Iterable[int]) -> int: + return functools.reduce(operator.mul, xs, 1) From 1b88476320a99680a6e01f8f4afed5c5196cf39d Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 08:04:38 +0000 Subject: [PATCH 218/453] [dtensor] PART 4: move remaining DTensor ops to core distributed (#88550) This PR moves the view related DTensor ops to core distributed, tests will be add in follow up PRs part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88550 Approved by: https://github.com/fduwjj --- torch/distributed/_tensor/__init__.py | 2 +- torch/distributed/_tensor/ops/__init__.py | 2 +- torch/distributed/_tensor/ops/view_ops.py | 707 ++++++++++++++++++++++ 3 files changed, 709 insertions(+), 2 deletions(-) create mode 100644 torch/distributed/_tensor/ops/view_ops.py diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index ba09f2fbb690..32a57146bc93 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -9,7 +9,7 @@ # Import all builtin dist tensor ops -# import torch.distributed._tensor.ops +import torch.distributed._tensor.ops def distribute_tensor( diff --git a/torch/distributed/_tensor/ops/__init__.py b/torch/distributed/_tensor/ops/__init__.py index 5012768ee051..5550b2ffae08 100644 --- a/torch/distributed/_tensor/ops/__init__.py +++ b/torch/distributed/_tensor/ops/__init__.py @@ -4,4 +4,4 @@ from .tensor_ops import * # noqa: F403 from .tp_sharding_ops import * # noqa: F403 from .pointwise_ops import * # noqa: F403 -# from .view_ops import * # noqa: F403 +from .view_ops import * # noqa: F403 diff --git a/torch/distributed/_tensor/ops/view_ops.py b/torch/distributed/_tensor/ops/view_ops.py new file mode 100644 index 000000000000..a8849b2ed14b --- /dev/null +++ b/torch/distributed/_tensor/ops/view_ops.py @@ -0,0 +1,707 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from dataclasses import dataclass +from typing import ( + Callable, + Dict, + Iterable, + Optional, + Tuple, + Set, + Union, + Sequence, + cast, +) + +import torch +from torch import Tensor + +from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate +from torch.distributed._tensor.api import Shard +from torch.distributed._tensor.dispatch import OpSchema, OutputSharding +from torch.distributed._tensor.ops.utils import ( + normalize_dim, + normalize_dims, + prod, + register_prop_rule, +) + + +Shape = Tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = Tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton""" + + pass + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """ + Output dimension is a set of input dimensions flattened, keeping + right-most adjacent elements adjacent in the output. + """ + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new( + cls, dim: DimSpec, group_shape: Tuple[int, ...], idx: int + ) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implements broadcast on multiple dimensions""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance( + p, InputDim + ), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, Tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return cast(Shape, sizes[0]) # type: ignore[redundant-cast] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + return (Flatten.new(tuple(InputDim(i) for i in range(ndim))),) + + +def dim_movedim( + ndim: int, + input: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len( + destination + ), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert ( + len(sizes) >= ndim + ), f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert ( + total_size % size == 0 + ), f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple( + InputDim(fi) for fi in from_group_dim if from_size[fi] > 1 + ) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: Tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = list(InputDim(i) for i in range(ndim)) + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not acutally a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_reduction( + ndim: int, dim_or_dims: Optional[Union[int, Sequence[int]]], keepdim: bool +) -> DimMap: + """ + General fallback for reduction ops where _Partial() does not apply. + This will cause incoming tensor to be replicated on the reducing dimensions. + """ + if dim_or_dims is None: + dim_or_dims = tuple(range(ndim)) + if isinstance(dim_or_dims, int): + dim_or_dims = (dim_or_dims,) + dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) + return tuple( + InputDim(i) if i not in dim_or_dims else Singleton() + for i in range(ndim) + if i not in dim_or_dims or keepdim + ) + + +@dataclass +class Op: + dim_map: Callable[..., DimMap] + shape_argnum: Optional[int] = None + + +ops: Dict[Callable[..., torch.Tensor], Op] = { + torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)), + torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)), + torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)), + torch.broadcast_to: Op( + dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1 + ), + Tensor.expand: Op( + dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + shape_argnum=1, + ), + torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), + torch.movedim: Op( + dim_map=lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ) + ), + torch.permute: Op( + dim_map=lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) + ) + ), + torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)), + Tensor.repeat: Op( + dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes) + ), + torch.reshape: Op( + dim_map=lambda input, shape: view_groups(input.shape, shape), + shape_argnum=1, + ), + torch.squeeze: Op( + dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim) + ), + torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)), + torch.transpose: Op( + dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1) + ), + torch.unsqueeze: Op( + dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim) + ), + Tensor.view: Op( + dim_map=lambda input, *shape: view_groups(input.shape, shape), + shape_argnum=1, + ), +} + + +def propagate_shape_and_sharding( + in_shard: Sequence[Placement], + local_in_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, +) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]: + """ + Takes as input the global shape of the tensor, and the input sharding, + and produce corresponding output sharding and shape of the output tensor. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(in_shard) == len(mesh_sizes) + sharded_in_dims: Set[int] = set( + s.dim for s in in_shard if isinstance(s, Shard) + ) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + shardable_dims: torch.Tensor = torch.ones( + (len(local_in_shape), len(mesh_sizes)), dtype=torch.bool + ) + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + + seen_input_dims: Set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(local_in_shape)): + shardable_dims[dim, :] = dim in seen_input_dims + + def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + return ( + local_in_shape[cmd.input_dim], + cmd if cmd.input_dim in sharded_in_dims else None, + ) + elif isinstance(cmd, Flatten): + for dim in cmd.input_dims[1:]: + if isinstance(dim, InputDim): + shardable_dims[dim.input_dim, :] = False + dim0 = cmd.input_dims[0] + return ( + prod(get_dim_size(a)[0] for a in cmd.input_dims), + dim0 + if isinstance(dim0, InputDim) + and dim0.input_dim in sharded_in_dims + else None, + ) + elif isinstance(cmd, Split): + _, in_dim = get_dim_size(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisble + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + for mesh_dim, mesh_dim_size in enumerate(mesh_sizes): + shardable_dims[in_dim.input_dim, mesh_dim] = ( + out_size % mesh_dim_size == 0 + ) + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, in_shard): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert ( + out_size % submesh_size == 0 + ), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + + # we will only shard our first component of the split + return out_size, in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Singleton): + return 1, None + elif isinstance(cmd, Broadcast): + return cmd.dim_size, None + elif isinstance(cmd, NewDim): + return cmd.size, None + elif isinstance(cmd, Repeat): + size, in_dim = get_dim_size(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim, :] = False + return size * cmd.times, None + else: + raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}") + + dim_map = {} + out_shape = [] + for dim, cmd in enumerate(rule): + out_size, in_dim = get_dim_size(cmd) + out_shape.append(out_size) + if in_dim is not None: + dim_map[in_dim.input_dim] = dim + + needs_reshard = any( + isinstance(placement, Shard) + and not shardable_dims[placement.dim][mesh_dim] + for mesh_dim, placement in enumerate(in_shard) + ) + + output_placements = ( + None + if needs_reshard + else [ + Shard(dim_map[s.dim]) if isinstance(s, Shard) else s + for s in in_shard + ] + ) + + return (tuple(out_shape), output_placements, shardable_dims) + + +def register_prop_rule_map( + aten_op_name: str, local_op_name: Callable[..., torch.Tensor] +) -> None: + spec: Op = ops[local_op_name] + + @register_prop_rule(aten_op_name) + def reshape_prop(op_schema: OpSchema) -> OutputSharding: + rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_dtensor_spec = op_schema.args_schema[0] + + assert isinstance( + input_dtensor_spec, DTensorSpec + ), "Expected first input to be a DTensorSpec" + global_in_shape = input_dtensor_spec.shape + assert global_in_shape is not None, "Shape required." + + ( + global_out_shape, + shard_out, + shardable_dims, + ) = propagate_shape_and_sharding( + input_dtensor_spec.placements, + tuple(global_in_shape), + rules, + tuple(input_dtensor_spec.mesh.mesh.shape), + ) + + if shard_out is not None: + # no reshard needed + output_dtensor_spec = DTensorSpec( + mesh=input_dtensor_spec.mesh, + placements=shard_out, + shape=torch.Size(global_out_shape), + ndim=len(global_out_shape), + ) + local_out_shape = output_dtensor_spec.local_shape + + # We only need the local shape to lower he call into the local op + args = op_schema.args_schema + shape_argnum = spec.shape_argnum + if shape_argnum is not None: + op_schema.args_schema = ( + args[:shape_argnum] + + (tuple(local_out_shape),) + + args[shape_argnum + 1 :] + ) + + return OutputSharding(output_spec=output_dtensor_spec) + + else: + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + suggested_placements = [ + p + if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim] + else Replicate() + for mesh_dim, p in enumerate(input_dtensor_spec.placements) + ] + return OutputSharding( + output_spec=None, + schema_suggestions=[ + OpSchema( + func_schema=op_schema.func_schema, + args_schema=( + DTensorSpec( + placements=suggested_placements, + mesh=input_dtensor_spec.mesh, + ndim=input_dtensor_spec.ndim, + shape=input_dtensor_spec.shape, + ), + ) + + op_schema.args_schema[1:], + kwargs_schema=op_schema.kwargs_schema, + ) + ], + ) + + +register_prop_rule_map("aten.squeeze.default", torch.squeeze) +register_prop_rule_map("aten.squeeze.dim", torch.squeeze) +register_prop_rule_map("aten.view.default", Tensor.view) +register_prop_rule_map("aten.view.SymInt", Tensor.view) +register_prop_rule_map("aten._unsafe_view.default", Tensor.view) +register_prop_rule_map("aten.unsqueeze.default", torch.unsqueeze) +register_prop_rule_map("aten.expand.default", Tensor.expand) +register_prop_rule_map("aten.permute.default", torch.permute) +register_prop_rule_map("aten.repeat.default", Tensor.repeat) +register_prop_rule_map("aten.transpose.int", torch.transpose) From 527c5bdb4574f12f5071b0466ce981ce1c129d75 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 22:51:31 +0000 Subject: [PATCH 219/453] [dtensor] PART 5: move DTensor basic tests to core distributed (#88178) This PR moves DTensor basic tests to torch.distributed, including dtensor, device_mesh tests part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88178 Approved by: https://github.com/fduwjj --- test/distributed/_tensor/README.md | 11 + test/distributed/_tensor/__init__.py | 1 + test/distributed/_tensor/test_api.py | 234 ++++++++ test/distributed/_tensor/test_device_mesh.py | 518 ++++++++++++++++++ test/distributed/_tensor/test_dtensor.py | 359 ++++++++++++ test/distributed/_tensor/test_redistribute.py | 317 +++++++++++ .../_internal/distributed/_tensor/__init__.py | 0 .../distributed/_tensor/common_dtensor.py | 334 +++++++++++ 8 files changed, 1774 insertions(+) create mode 100644 test/distributed/_tensor/README.md create mode 100644 test/distributed/_tensor/__init__.py create mode 100644 test/distributed/_tensor/test_api.py create mode 100644 test/distributed/_tensor/test_device_mesh.py create mode 100644 test/distributed/_tensor/test_dtensor.py create mode 100644 test/distributed/_tensor/test_redistribute.py create mode 100644 torch/testing/_internal/distributed/_tensor/__init__.py create mode 100644 torch/testing/_internal/distributed/_tensor/common_dtensor.py diff --git a/test/distributed/_tensor/README.md b/test/distributed/_tensor/README.md new file mode 100644 index 000000000000..6235f9657d5f --- /dev/null +++ b/test/distributed/_tensor/README.md @@ -0,0 +1,11 @@ +## Run distributed tensor tests: + +from root, run (either CPU or GPU) + +`pytest test/spmd/tensor/test_tensor.py` + +`pytest test/spmd/tensor/test_ddp.py` + +run specific test case and print stdout/stderr: + +`pytest test/spmd/tensor/test_tensor.py -s -k test_tensor_from_local` diff --git a/test/distributed/_tensor/__init__.py b/test/distributed/_tensor/__init__.py new file mode 100644 index 000000000000..087882b22d1f --- /dev/null +++ b/test/distributed/_tensor/__init__.py @@ -0,0 +1 @@ +# shut up pylint diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py new file mode 100644 index 000000000000..a966f30d1cb9 --- /dev/null +++ b/test/distributed/_tensor/test_api.py @@ -0,0 +1,234 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms +from torch.distributed._tensor import ( + distribute_tensor, + distribute_module, + DeviceMesh, + DTensor, + Shard, + Replicate, +) + + +class MyModel(nn.Module): + def __init__(self, n_features, n_layers, device): + super().__init__() + self.seq = nn.Sequential( + *[ + nn.Linear(n_features, n_features, device=device) + for _ in range(n_layers) + ] + ) + + def forward(self, x): + return self.seq(x) + + def reset_parameters(self): + for m in self.seq: + m.reset_parameters() + + +class DTensorAPITest(DTensorTestBase): + @property + def world_size(self) -> int: + # hard code world size to 4 as we need to test + # at least with 2d mesh + return 4 + + @with_comms + def test_distribute_tensor(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + for requires_grad in [True, False]: + + tensor_to_shard = torch.randn( + 3 * self.world_size, 3, requires_grad=requires_grad + ) + dist_tensor = distribute_tensor( + tensor_to_shard, device_mesh, shard_spec + ) + self.assertEqual( + dist_tensor.size(), torch.Size([3 * self.world_size, 3]) + ) + local_tensor = dist_tensor.to_local() + self.assertEqual(local_tensor.size(), torch.Size([3, 3])) + if requires_grad: + self.assertTrue(dist_tensor.requires_grad) + self.assertTrue(dist_tensor.is_leaf) + + @with_comms + def test_distribute_tensor_errors(self): + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 2) + ) + tensor_shape = [3 * self.world_size, 3 * self.world_size] + tensor_to_distribute = torch.randn(*tensor_shape) + + with self.assertRaisesRegex(ValueError, "must have the same length"): + shard_spec = [Shard(0)] + distribute_tensor(tensor_to_distribute, device_mesh, shard_spec) + + spec = [Shard(0), Shard(1)] + dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec) + + with self.assertRaisesRegex(ValueError, "to a different device mesh"): + new_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size) + ) + distribute_tensor(dtensor, new_mesh, [Shard(0)]) + + with self.assertRaisesRegex(ValueError, "to a different placements"): + new_spec = [Shard(0), Replicate()] + distribute_tensor(dtensor, device_mesh, new_spec) + + @with_comms + def test_distribute_tensor_uneven_sharding(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_sizes_and_shard_dims = [ + ((self.world_size * 3 + 1, 3, 3), 0), + ((self.world_size * 3 + 2, 3, 3), 0), + ((3, self.world_size * 3 + 1, 3), 1), + ((3, self.world_size * 3 + 2, 3), 1), + ((3, 3, self.world_size * 3 + 1), 2), + ((3, 3, self.world_size * 3 + 2), 2), + ] + for input_size, shard_dim in input_sizes_and_shard_dims: + shard_spec = [Shard(shard_dim)] + tensor_to_shard = torch.randn(input_size) + splitted_tensor_list = tensor_to_shard.tensor_split( + self.world_size, dim=shard_dim + ) + dist_tensor = distribute_tensor( + tensor_to_shard, device_mesh, shard_spec + ) + self.assertEqual(dist_tensor.size(), torch.Size(input_size)) + local_tensor = dist_tensor.to_local() + self.assertEqual(local_tensor, splitted_tensor_list[self.rank]) + + @with_comms + def test_distribute_module(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # fully shard all linear modules on dim 0 + module_to_shard = MyModel( + 5 * self.world_size, 20, device=self.device_type + ) + shard_spec = [Shard(0)] + + def shard_fn(name, module, device_mesh): + if isinstance(module, nn.Linear): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, shard_spec) + ) + module.register_parameter(name, dist_param) + + sharded_module = distribute_module( + module_to_shard, device_mesh, shard_fn + ) + for param in sharded_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, shard_spec) + + replica_spec = [Replicate()] + # fully replicate all modules without passing in partition_fn + module_to_replicate = MyModel(5, 20, device=self.device_type) + replica_module = distribute_module(module_to_replicate, device_mesh) + for param in replica_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, replica_spec) + + # fully replicate all modules by passing in partition_fn + def replicate_fn(name, module, device_mesh): + if isinstance(module, nn.Linear): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, replica_spec) + ) + module.register_parameter(name, dist_param) + + module_to_replicate = MyModel(5, 20, device=self.device_type) + replica_module = distribute_module( + module_to_replicate, device_mesh, replicate_fn + ) + for param in replica_module.parameters(): + self.assertIsInstance(param, DTensor) + self.assertEqual(param.placements, replica_spec) + + # only shard part of module, and rest of module should be replicate + def shard_fn(name, module, device_mesh): + if isinstance(module, nn.Linear) and ( + name == "seq.0" or name == "seq.8" + ): + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, shard_spec) + ) + module.register_parameter(name, dist_param) + + module_to_distribute = MyModel( + 5 * self.world_size, 20, device=self.device_type + ) + dist_module = distribute_module( + module_to_distribute, device_mesh, shard_fn + ) + for name, param in dist_module.named_parameters(): + self.assertIsInstance(param, DTensor) + if name.startswith("seq.0") or name.startswith("seq.8"): + self.assertEqual(param.placements, shard_spec) + else: + self.assertEqual(param.placements, replica_spec) + + @with_comms + def test_distribute_module_input_fn_output_fn(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + # fully replicate all linear modules + module_to_replicate = MyModel(20, 1, device=self.device_type) + + # mark input sharding on dim 0 + def input_fn(inputs, device_mesh): + return DTensor.from_local(inputs[0], device_mesh, [Shard(0)]) + + def output_fn(outputs, device_mesh): + assert isinstance(outputs, DTensor) + return outputs.to_local() + + replica_module = distribute_module( + module_to_replicate, + device_mesh, + input_fn=input_fn, + output_fn=output_fn, + ) + + input_tensor = torch.randn(5, 20, device=self.device_type) + local_out = replica_module(input_tensor) + self.assertIsInstance(local_out, torch.Tensor) + self.assertNotIsInstance(local_out, DTensor) + + # full replicate (even on inputs) + model = MyModel(10, 10, device=self.device_type) + + def replicate_input_fn(inputs, device_mesh): + return DTensor.from_local(inputs[0], device_mesh, [Replicate()]) + + replica_model = distribute_module( + model, + device_mesh, + input_fn=replicate_input_fn, + ) + input = torch.randn(10, 10, requires_grad=True) + output = replica_model(input) + output.sum().backward() + param_grad = list(replica_model.parameters())[0].grad + self.assertTrue(isinstance(param_grad, DTensor)) + self.assertTrue(isinstance(param_grad.placements[0], Replicate)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py new file mode 100644 index 000000000000..7088f33f42db --- /dev/null +++ b/test/distributed/_tensor/test_device_mesh.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch + +from torch.distributed.distributed_c10d import ( + ProcessGroup, + new_group, + get_global_rank, + get_world_size, +) +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor.device_mesh import DeviceMesh +from torch.distributed._tensor.placement_types import Shard + + +class DeviceMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_device_mesh_2d(self): + mesh_tensor = torch.arange(4).reshape(2, 2) + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < 2) + dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + current_rank_expected_group_ranks = ( + dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] + ) + self.assertEqual(global_ranks, current_rank_expected_group_ranks) + + @with_comms + def test_device_mesh_2d_from_dim_groups(self): + # construct a two dimension subgroups + dim_groups = [] + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim_group_ranks in expected_ranks_by_dim: + for subgroup_ranks in dim_group_ranks: + subgroup = new_group(ranks=subgroup_ranks) + if self.rank in subgroup_ranks: + dim_groups.append(subgroup) + + # construct a device mesh from the subgroups + mesh = DeviceMesh( + self.device_type, [[0, 1], [2, 3]], dim_groups=dim_groups + ) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < 2) + dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + current_rank_expected_group_ranks = ( + dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] + ) + self.assertEqual(global_ranks, current_rank_expected_group_ranks) + + @with_comms + def test_device_mesh_dim_groups_error(self): + # construct a two dimension subgroups + dim_groups = [] + expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] + for dim_group_ranks in expected_ranks_by_dim: + for subgroup_ranks in dim_group_ranks: + subgroup = new_group(ranks=subgroup_ranks) + if self.rank in subgroup_ranks: + dim_groups.append(subgroup) + + if len(dim_groups) > 0: + # dim_groups is not a list + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=dim_groups[0], + ) + + # dim_groups is a list, but not a list of ProcessGroup + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=[dim_groups[0], "dummy"], + ) + + # dim_groups has incorrect length + self.assertRaises( + RuntimeError, + DeviceMesh, + self.device_type, + [[0, 1], [2, 3]], + dim_groups=[dim_groups[0]], + ) + + @with_comms + def test_device_mesh_nd(self): + # construct a cuda device mesh + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + + for dim, dim_group in enumerate(dim_to_subgroups): + self.assertTrue(dim < mesh_tensor.ndim) + dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) + # print(dim_ranks) + # dim_ranks = expected_ranks_by_dim[dim] + + dim_group_size = get_world_size(dim_group) + self.assertIsInstance(dim_group, ProcessGroup) + self.assertEqual(dim_group_size, 2) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + for ranks in dim_ranks: + if self.rank in ranks: + self.assertEqual(global_ranks, ranks.tolist()) + + +class DeviceMeshCollectiveTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_all_reduce_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + mesh.all_reduce(local_tensor, mesh_dim=0) + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + self.assertEqual(local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_broadcast_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + mesh.broadcast(local_tensor, mesh_dim=0) + self.assertEqual(local_tensor, torch.zeros(3, 3)) + + @with_comms + def test_scatter_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + scatter_tensor_shape = [3, 3, 3] + for scatter_dim in range(len(scatter_tensor_shape)): + shard_placement = Shard(scatter_dim) + scatter_tensor_shape[scatter_dim] *= self.world_size + # make the random seed same across rank + torch.manual_seed(0) + global_tensor = torch.randn( + scatter_tensor_shape, device=self.device_type + ) + splitted_list, _ = shard_placement._split_tensor( + global_tensor, mesh.size(), with_padding=True, contiguous=True + ) + recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.scatter(recv_tensor, splitted_list, mesh_dim=0) + self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) + + @with_comms + def test_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.randn( + device_mesh.size() + 3, device_mesh.size() + 1 + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_to_scatter = tensor_to_split.clone() + tensor_splitted_list = tensor_to_split.tensor_split( + device_mesh.size(), dim=shard_dim + ) + padded_tensor_list, pad_idx = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) + device_mesh.scatter( + scattered_tensor, padded_tensor_list, mesh_dim=0 + ) + # unpad scattered_tensor + if pad_idx != 0 and my_rank >= pad_idx: + scattered_tensor = shard_placement._unpad_tensor( + scattered_tensor + ) + + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) + + @with_comms + def test_all_gather_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dims_to_gather = [0, 1] + for dim in dims_to_gather: + output_size = [3, 3] + output_size[dim] *= self.world_size + # each rank have its own tensor, all_gather gives a list + local_tensor = torch.ones(3, 3, device=self.device_type) + gathered_list = [] + for _ in range(self.world_size): + gathered_list.append(torch.zeros_like(local_tensor)) + mesh.all_gather(gathered_list, local_tensor, mesh_dim=0) + gathered_tensor = torch.cat(gathered_list, dim=dim) + self.assertEqual(gathered_tensor, torch.ones(output_size)) + + @with_comms + def test_all_gather_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_padded_list, pad_idx = shard_placement._split_tensor( + tensor_to_split, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + local_tensor = tensor_padded_list[my_rank] + gathered_list = [] + for _ in range(device_mesh.size()): + gathered_list.append(torch.empty_like(local_tensor)) + + device_mesh.all_gather( + gathered_list, + local_tensor, + mesh_dim=0, + ) + if pad_idx != 0: + gathered_list = [ + shard_placement._unpad_tensor(gathered_tensor) + if i >= pad_idx + else gathered_tensor + for i, gathered_tensor in enumerate(gathered_list) + ] + all_gathered_tensor = torch.cat(gathered_list, dim=shard_dim) + self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) + self.assertEqual(all_gathered_tensor, tensor_to_split) + + @with_comms + def test_reduce_scatter_1d(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dims_to_scatter = [0, 1] + for dim in dims_to_scatter: + input_size = [3, 3] + scattered_tensor = torch.empty(input_size, device=self.device_type) + input_size[dim] *= self.world_size + + input_rs_list = ( + torch.ones(input_size, device=self.device_type) * self.rank + ).tensor_split(self.world_size, dim=dim) + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + mesh.reduce_scatter(scattered_tensor, input_rs_list, mesh_dim=0) + self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_reduce_scatter_uneven(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + my_rank = device_mesh.get_rank() + tensor_to_split = ( + torch.ones( + device_mesh.size() + 3, + device_mesh.size() + 1, + device=self.device_type, + ) + * self.rank + ) + + for shard_dim in range(tensor_to_split.ndim): + shard_placement = Shard(shard_dim) + tensor_to_scatter = tensor_to_split.clone() + tensor_splitted_list = tensor_to_split.tensor_split( + device_mesh.size(), dim=shard_dim + ) + padded_tensor_list, pad_idx = shard_placement._split_tensor( + tensor_to_scatter, + device_mesh.size(), + with_padding=True, + contiguous=True, + ) + + res_num = ((0 + self.world_size - 1) * self.world_size) / 2 + scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) + device_mesh.reduce_scatter( + scattered_tensor, padded_tensor_list, mesh_dim=0 + ) + # unpad scattered_tensor + if pad_idx != 0 and my_rank >= pad_idx: + scattered_tensor = shard_placement._unpad_tensor( + scattered_tensor + ) + + self.assertEqual( + scattered_tensor.size(), tensor_splitted_list[my_rank].size() + ) + self.assertEqual( + scattered_tensor, + torch.ones_like(tensor_splitted_list[my_rank]) * res_num, + ) + + @with_comms + def test_all_gather_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + gathered_tensor_list = list( + torch.empty( + (dim_group_size * 3, 3), device=self.device_type + ).tensor_split(dim_group_size, dim=0) + ) + mesh.all_gather(gathered_tensor_list, local_tensor, mesh_dim=dim) + gathered_tensor = torch.cat(gathered_tensor_list) + exp_tensor = torch.ones(3 * dim_group_size, 3) + for i in range(len(global_ranks)): + exp_tensor[i * 3 : (i + 1) * 3] = ( + torch.ones(3, 3) * global_ranks[i] + ) + self.assertEqual(gathered_tensor, exp_tensor) + + @with_comms + def test_reduce_scatter_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + local_rs_list = ( + torch.ones(dim_group_size * 3, 3, device=self.device_type) + * self.rank + ).tensor_split(dim_group_size, dim=0) + scattered_tensor = torch.empty_like( + local_rs_list[mesh.get_coordinate_on_dim(dim)], + device=self.device_type, + ) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + mesh.reduce_scatter(scattered_tensor, local_rs_list, mesh_dim=dim) + res_num = torch.sum(torch.tensor(global_ranks)) + self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_all_reduce_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + cloned_local_tensor = local_tensor.clone() + mesh.all_reduce(cloned_local_tensor, mesh_dim=dim) + res_num = sum(global_ranks) + self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_broadcast_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + cloned_local_tensor = local_tensor.clone() + mesh.broadcast(cloned_local_tensor, mesh_dim=dim) + res_num = global_ranks[0] + self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) + + @with_comms + def test_scatter_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + scattered_tensors = [ + torch.ones(3, 3, device=self.device_type) * global_rank + for global_rank in global_ranks + ] + received_tensor = torch.empty_like( + scattered_tensors[mesh.get_coordinate_on_dim(dim)] + ) + mesh.scatter(received_tensor, scattered_tensors, mesh_dim=dim) + self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) + + @with_comms + def test_all_to_all_1d(self): + # transpose on a 2D tensor distributed over N nodes: + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + tensor_shape = [3, 3] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (rank + self.rank * self.world_size) + for rank in range(self.world_size) + ] + expected_tensor_list = [ + torch.ones(tensor_shape, device=self.device_type) + * (self.rank + rank * self.world_size) # i.e. transpose + for rank in range(self.world_size) + ] + for scatter_dim in range(len(tensor_shape)): + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.all_to_all(output_tensor_list, input_tensor_list, mesh_dim=0) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) + + self.assertEqual(output_tensor, expected_tensor) + + @with_comms + def test_all_to_all_nd(self): + mesh_tensor = torch.arange(8).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + tensor_shape = [3, 3, 3] + # check all dim groups + dim_to_subgroups = mesh.get_dim_groups() + for dim, dim_group in enumerate(dim_to_subgroups): + my_coordinate = mesh.get_coordinate_on_dim(dim) + dim_group_size = get_world_size(dim_group) + global_ranks = [ + get_global_rank(dim_group, i) for i in range(dim_group_size) + ] + input_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * (i + self.rank * dim_group_size) + for i in range(dim_group_size) + ] + expected_tensor_list = [ + torch.ones(*tensor_shape, device=self.device_type) + * ( + my_coordinate + global_rank * dim_group_size + ) # i.e. transpose + for global_rank in global_ranks + ] + for scatter_dim in range(len(tensor_shape)): + # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim) + output_tensor_list = [ + torch.empty_like(input_tensor_list[idx]) + for idx in range(len(input_tensor_list)) + ] + # scatter on dim > 0 would generate non-contiguous tensor, verify that works + mesh.all_to_all( + output_tensor_list, input_tensor_list, mesh_dim=dim + ) + output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) + expected_tensor = torch.cat( + expected_tensor_list, dim=scatter_dim + ) + self.assertEqual(output_tensor, expected_tensor) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py new file mode 100644 index 000000000000..51ce1bd4ec58 --- /dev/null +++ b/test/distributed/_tensor/test_dtensor.py @@ -0,0 +1,359 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch + +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import DeviceMesh, DTensor, distribute_tensor +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard + + +class DTensorTest(DTensorTestBase): + # @with_comms + # def test_tensor_constructor(self): + # import torch.distributed._tensor as dist_tensor + # shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)]) + # empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec) + # zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec) + # one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec) + + # zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec) + + # dist_tensor.empty_like(empty_tensor) + # dist_tensor.zero_like(empty_tensor) + # dist_tensor.one_like(empty_tensor) + + @with_comms + def test_dtensor_constructor(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3, requires_grad=True) + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + dist_tensor = DTensor( + local_tensor, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + self.assertEqual( + dist_tensor.size(), torch.Size((self.world_size * 3, 3)) + ) + + with self.assertWarnsRegex(UserWarning, "To construct"): + DTensor( + local_tensor, device_mesh, shard_spec, size=dist_tensor_shape + ) + + local_tensor = torch.randn(3, 3, requires_grad=False) + with self.assertWarnsRegex(UserWarning, "To construct"): + dist_tensor = DTensor( + local_tensor, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + + @with_comms + def test_dtensor_stride(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = [Shard(0)] + local_tensor = torch.randn(4, 8) + global_shape = torch.Size([self.world_size * 4, 8]) + dist_tensor = DTensor( + local_tensor, device_mesh, shard0_spec, size=global_shape + ) + # won't affect stride + self.assertEqual(dist_tensor.stride(), (8, 1)) + + shard1_spec = [Shard(1)] + local_tensor = torch.randn(8, 4) + global_shape = torch.Size([8, self.world_size * 4]) + dist_tensor = DTensor( + local_tensor, device_mesh, shard1_spec, size=global_shape + ) + # will affect stride after DT initialized + self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1)) + + # if initialized from a transposed mat + local_tensor = torch.randn(8, 4, 8) + local_tensor_t = local_tensor.permute(1, 2, 0) + global_shape = torch.Size([4, self.world_size * 8, 8]) + self.assertEqual(local_tensor_t.stride(), (8, 1, 32)) + dist_tensor = DTensor( + local_tensor_t, device_mesh, shard1_spec, size=global_shape + ) + global_stride = (8 * self.world_size, 1, 32 * self.world_size) + self.assertEqual(dist_tensor.stride(), global_stride) + + @with_comms + def test_from_local(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, device_mesh, shard_spec + ) + self.assertEqual( + sharded_tensor.size(), torch.Size([self.world_size * 3, 3]) + ) + + replica_spec = [Replicate()] + ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec) + self.assertEqual(ddp_tensor.size(), local_tensor.size()) + + partial_spec = [_Partial()] + partial_tensor = DTensor.from_local( + local_tensor, device_mesh, partial_spec + ) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + + # test dist tensor works with torch.Tensor during backwards + local_tensor_with_grad = torch.randn(3, 3, requires_grad=True) + # do some operations on local tensor + local_tensor_temp = local_tensor_with_grad * 3 + # create the dist tensor with non leaf local tensor, dist tensor created + # should also be non leaf node + dist_tensor = DTensor.from_local( + local_tensor_temp, device_mesh, shard_spec + ) + self.assertFalse(dist_tensor.is_leaf) + # do some random operations on dist tensor + output = dist_tensor * 3 + self.assertIsInstance(output, DTensor) + # trigger .backward() on dist tensor directly + local_grad = torch.ones(3, 3) + grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec) + # run backward directly on dist tensor + output.backward(grad_output) + # check it gradients flow back to original torch.Tensor + self.assertIsNotNone(local_tensor_with_grad.grad) + expected_grad = torch.ones(3, 3) * 9 + self.assertEqual(local_tensor_with_grad.grad, expected_grad) + + @with_comms + def test_to_local(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + dist_tensor_shape = torch.Size([self.world_size * 3, 3]) + local_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + + sharded_tensor = DTensor( + local_tensor_with_grad, + device_mesh, + shard_spec, + size=dist_tensor_shape, + requires_grad=True, + ) + self.assertEqual(sharded_tensor.size(), dist_tensor_shape) + self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad) + + # test dist tensor works with torch.Tensor during backwards + # dist tensor created is a leaf node, do some operation on dist tensor + temp_st = sharded_tensor * 3 + + # do some operation on local tensor of the dist tensor + new_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + res = temp_st.to_local() + new_tensor_with_grad + # call backward directly on torch.Tensor, and see if it works by + # propagating through dist tensor + res.sum().backward() + self.assertIsNotNone(sharded_tensor.grad) + + self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3) + + @with_comms + def test_from_local_then_to_local(self): + # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + # step 1. construct from construct local tensor + local_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + # do some operations on local tensor + local_tensor_temp = local_tensor_with_grad + 8 + # step 2. create the dist tensor with non leaf local tensor, dist tensor + # created should also be non leaf node + dist_tensor = DTensor.from_local( + local_tensor_temp, device_mesh, shard_spec + ) + self.assertFalse(dist_tensor.is_leaf) + # do some random operations on dist tensor + output = dist_tensor * 6 + self.assertIsInstance(output, DTensor) + + # step 3. do some operation on local tensor of the dist tensor + new_tensor_with_grad = torch.randn( + 3, 3, device=self.device_type, requires_grad=True + ) + res = output.to_local() + new_tensor_with_grad + # call backward directly on torch.Tensor, and see if it works by + # propagating all the way back to the original torch.Tensor + res.sum().backward() + self.assertIsNotNone(local_tensor_with_grad.grad) + + expected_grad = torch.ones(3, 3) * 6 + self.assertEqual(local_tensor_with_grad.grad, expected_grad) + + @with_comms + def test_dtensor_spec_read_only_after_set(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, device_mesh, shard_spec + ) + + # modify shard_spec, and dist_tensor's spec should not be changed + shard_spec[0] = Replicate() + self.assertTrue(sharded_tensor.placements is not shard_spec) + self.assertNotEqual(sharded_tensor.placements, shard_spec) + + @with_comms + def test_dtensor_properties(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, device_mesh, shard_spec + ) + self.assertEqual(sharded_tensor.device.type, self.device_type) + + +class DTensorMeshTest(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_dtensor_device_mesh_device_conversion(self): + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + # construct from a cpu local tensor with cuda device mesh + # should automatically convert the dist tensor to cuda + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + @with_comms + def test_dtensor_api_device_mesh_context_manager(self): + with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh: + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, device_mesh=mesh, placements=shard_spec + ) + + with DeviceMesh(self.device_type, list(range(self.world_size))): + shard_spec = [Shard(0)] + local_tensor = torch.randn(3, 3) + sharded_tensor = DTensor.from_local( + local_tensor, placements=shard_spec + ) + replica_spec = [Replicate()] + replica_tensor = sharded_tensor.redistribute( + placements=replica_spec + ) + self.assertEqual( + replica_tensor.size(), torch.Size([3 * self.world_size, 3]) + ) + + @with_comms + def test_dtensor_2d_mesh(self): + mesh_tensor = torch.arange(self.world_size).reshape(2, 4) + # construct a cuda device mesh + mesh = DeviceMesh(self.device_type, mesh_tensor) + + # construct a dist tensor on 2d device mesh and test if works + shard_spec = [Shard(0), Shard(1)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual( + dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)]) + ) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + # if shard on the same tensor dimension + # we should correctly construct the global tensor size + shard_same_dim_spec = [Shard(0), Shard(0)] + local_tensor = torch.randn(3, 3) + dist_tensor = DTensor.from_local( + local_tensor, mesh, shard_same_dim_spec + ) + self.assertEqual( + dist_tensor.size(), torch.Size([3 * self.world_size, 3]) + ) + + @with_comms + def test_device_mesh_nd(self): + # construct a cuda device mesh + mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2) + mesh = DeviceMesh(self.device_type, mesh_tensor) + # construct a dist tensor on 3d device mesh and test if works + shard_spec = [Shard(0), Shard(1), Shard(2)] + local_tensor = torch.randn(3, 3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6])) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + # construct a dist tensor on 3d device mesh with some shards on same dim + shard_spec = [Shard(0), Shard(0), Shard(2)] + local_tensor = torch.randn(3, 3, 3) + dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec) + self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6])) + self.assertEqual(dist_tensor.device.type, self.device_type) + self.assertEqual(dist_tensor.to_local().device.type, self.device_type) + + @with_comms + def test_dtensor_spec_local_shard_offset(self): + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size).reshape(2, 4) + ) + tensor_shape = (3 * self.world_size, 3 * self.world_size) + # sharding specs and its corresponding local shard offsets + shard_spec_and_offsets = [ + ( + [Shard(0), Replicate()], + (3 * (self.world_size // 2) * (self.rank // 4), 0), + ), + ( + [Shard(1), Replicate()], + (0, 3 * (self.world_size // 2) * (self.rank // 4)), + ), + ( + [Replicate(), Shard(0)], + (3 * (self.world_size // 4) * (self.rank % 4), 0), + ), + ( + [Replicate(), Shard(1)], + (0, 3 * (self.world_size // 4) * (self.rank % 4)), + ), + ] + + # loop through all sharding specs and check local shard offsets + logical_tensor = torch.randn(tensor_shape) + for shard_spec, expected_shard_offsets in shard_spec_and_offsets: + dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec) + self.assertEqual( + expected_shard_offsets, dtensor._spec.local_offsets + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py new file mode 100644 index 000000000000..78fc991d615f --- /dev/null +++ b/test/distributed/_tensor/test_redistribute.py @@ -0,0 +1,317 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import itertools +import torch + +from torch.testing._internal.common_utils import run_tests + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor +from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard + + +class RedistributeTest(DTensorTestBase): + @with_comms + def test_shard_to_replicate_forward_backward(self): + # 1) test shard -> replicate forward + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + expected_tensor = torch.randn( + input_size, device=self.device_type, requires_grad=True + ) + dtensor = distribute_tensor( + expected_tensor.clone(), device_mesh, shard_spec + ) + reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec) + self.assertEqual(reshard_dtensor.size(), torch.Size(input_size)) + self.assertEqual(expected_tensor, reshard_dtensor.to_local()) + + # 2) test shard -> replicate backward: + # should give gradient as shard + grad_output = torch.ones_like(reshard_dtensor) + reshard_dtensor.backward(grad_output) + grad_input = dtensor.grad + self.assertEqual(grad_input.placements, shard_spec) + self.assertEqual( + grad_input.to_local(), torch.ones(dtensor.to_local().size()) + ) + + @with_comms + def test_replicate_to_replicate_forward_backward(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + local_tensor = torch.randn( + 12, 3, device=self.device_type, requires_grad=True + ) + # 1) test replicate -> replicate forward + replica_tensor = distribute_tensor( + local_tensor, device_mesh, replica_spec + ) + reshard_replica_tensor = replica_tensor.redistribute( + device_mesh, replica_spec + ) + self.assertEqual(replica_tensor.size(), local_tensor.size()) + self.assertEqual(replica_tensor, reshard_replica_tensor) + + # 2) test replicate -> replicate backward: + # should give gradient as replicate + grad_output = torch.ones_like(reshard_replica_tensor) + reshard_replica_tensor.backward(grad_output) + grad_input = replica_tensor.grad + self.assertEqual(grad_input.placements, replica_spec) + self.assertEqual(grad_input.to_local(), torch.ones(12, 3)) + + @with_comms + def test_replicate_to_shard_forward_backward(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + replica_spec = [Replicate()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + # 1) test replicate -> shard forward + local_replica = torch.randn( + input_size, device=self.device_type, requires_grad=True + ) + splitted_list = local_replica.tensor_split( + self.world_size, shard_dim + ) + # make local tensor as the element of the corresponding chunked list + local_tensor = splitted_list[self.rank] + replica_tensor = distribute_tensor( + local_replica, device_mesh, replica_spec + ) + reshard_tensor = replica_tensor.redistribute( + device_mesh, shard_spec + ) + self.assertEqual(reshard_tensor.size(), replica_tensor.size()) + self.assertEqual(reshard_tensor.placements, shard_spec) + self.assertEqual(reshard_tensor.to_local(), local_tensor) + + # 2) test replicate -> shard backward: + # should give gradient as replicate + grad_output = torch.ones_like(reshard_tensor) + reshard_tensor.backward(grad_output) + grad_input = replica_tensor.grad + self.assertEqual(grad_input.placements, replica_spec) + self.assertEqual(grad_input.to_local(), torch.ones(input_size)) + + @with_comms + def test_partial_to_replicate_forward_backward(self): + # Although we don't allow user to reshard to produce a partial + # placement (i.e. user can't reshard to partial), we do allow + # replicate to partial internally, and also partial to replicate + # backward should work as expected + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + partial_local = torch.randn( + 12, 3, device=self.device_type, requires_grad=True + ) + partial_spec = [_Partial()] + replica_spec = [Replicate()] + # test partial -> replicate, which trigger all_reduce + partial_tensor = DTensor.from_local( + partial_local, device_mesh, partial_spec + ) + global_partial_tensor = partial_tensor.redistribute( + device_mesh, replica_spec + ) + + self.assertEqual(partial_tensor.size(), partial_local.size()) + self.assertEqual( + partial_local * self.world_size, global_partial_tensor.to_local() + ) + + # test backward to have replicate grad on partial + global_partial_tensor.backward(torch.ones_like(global_partial_tensor)) + self.assertIsNotNone(partial_local.grad) + if device_mesh.get_rank() == 0: + self.assertEqual(partial_local.grad, torch.ones_like(partial_local)) + + @with_comms + def test_replicate_to_partial(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + local_tensor = torch.randn( + 12, 3, device=self.device_type, requires_grad=True + ) + partial_spec = _Partial() + replica_spec = Replicate() + # 1) test replicate -> partial forward + replica_tensor = distribute_tensor( + local_tensor, device_mesh, [replica_spec] + ) + with self.assertRaisesRegex( + RuntimeError, "Can not redistribute to _Partial" + ): + partial_tensor = replica_tensor.redistribute( + device_mesh, [partial_spec] + ) + + from torch.distributed._tensor.redistribute import Redistribute + + partial_tensor = Redistribute.apply( + replica_tensor, device_mesh, [partial_spec] + ) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + # test it successfully zero out the contents on other ranks + if self.rank == 0: + self.assertEqual( + replica_tensor.to_local(), partial_tensor.to_local() + ) + else: + self.assertEqual( + partial_tensor.to_local(), torch.zeros_like(local_tensor) + ) + + # replicate to partial on sub groups + local_tensor = torch.randn(12, 3, device=self.device_type) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + # 1) test replicate -> partial on 2d-mesh subgroups + replica_tensor = distribute_tensor( + local_tensor, device_mesh, [replica_spec, replica_spec] + ) + partial_tensor = Redistribute.apply( + replica_tensor, device_mesh, [partial_spec, partial_spec] + ) + self.assertEqual(partial_tensor.size(), local_tensor.size()) + + if self.rank != 3: + # replicate to partial should only zero out rank 3, and leave + # rank 0/2 (rank0 on mesh dim 1) and 0, 1 (rank0 on mesh dim 1) un-touched + self.assertEqual( + replica_tensor.to_local(), partial_tensor.to_local() + ) + else: + self.assertEqual( + replica_tensor.to_local(), torch.zeros_like(local_tensor) + ) + + @with_comms + def test_partial_to_shard(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + partial_spec = [_Partial()] + + input_sizes_and_shard_dim = [ + ((self.world_size * 3, 3), 0), + ((self.world_size * 3 + 1, 3), 0), + ((self.world_size * 3 + 2, 3), 0), + ((3, self.world_size * 3), 1), + ((3, self.world_size * 3 + 1), 1), + ((3, self.world_size * 3 + 2), 1), + ] + + for input_size, shard_dim in input_sizes_and_shard_dim: + shard_spec = [Shard(shard_dim)] + + partial_local = torch.ones(input_size, device=self.device_type) + partial_tensor = DTensor.from_local( + partial_local, device_mesh, partial_spec, run_check=False + ) + + quot, rem = divmod(input_size[shard_dim], self.world_size) + local_shape = list(input_size) + local_shape[shard_dim] = quot + (1 if self.rank < rem else 0) + # test partial to shard, trigger reduce_scatter + scatter_shard_tensor = partial_tensor.redistribute( + device_mesh, shard_spec + ) + self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size()) + self.assertEqual(scatter_shard_tensor.placements, shard_spec) + self.assertEqual( + scatter_shard_tensor.to_local(), + torch.ones(local_shape) * self.world_size, + ) + + +class MultiDimRedistributeTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 8 + + @with_comms + def test_multi_dim_mesh(self): + devices = torch.arange(self.world_size) + for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]: + mesh_shape = torch.arange(self.world_size).view(-1, 2) + device_mesh = DeviceMesh(self.device_type, mesh_shape) + tensor_shape = (16, 24) + + if torch.distributed.get_rank() == 0: + full_tensor = torch.randn(*tensor_shape) + else: + # these should be entirely ignored + # because distribute_tensor is expected to override shards in ranks != 0 + full_tensor = torch.ones(*tensor_shape) + + possibilities = [Replicate()] + [ + Shard(i) for i in range(full_tensor.ndim) + ] + all_outputs = list( + itertools.product(*(mesh_shape.ndim * [possibilities])) + ) + all_inputs = list( + itertools.product( + *(mesh_shape.ndim * [possibilities + [_Partial()]]) + ) + ) + + for inputs in all_inputs: + # if partial, temporarily make it Replicated, then replace replicated with partial afterwards + repl_inputs = [ + Replicate() if s.is_partial() else s for s in inputs + ] + dt = distribute_tensor(full_tensor, device_mesh, repl_inputs) + + if repl_inputs != inputs: + # create a new DTensor reinterpreting some of the replicated entires as "Partial" + dt = DTensor.from_local( + dt.to_local(), device_mesh, inputs, run_check=False + ) + + for outputs in all_outputs: + # redistribute on target outputs + dt2 = dt.redistribute(device_mesh, outputs) + + # replicate and then get first shard + local_full = dt2.redistribute( + device_mesh, device_mesh.ndim * [Replicate()] + ).to_local() + + if torch.distributed.get_rank() == 0: + self.assertEqual(local_full.shape, full_tensor.shape) + + num_sums = 1 + for idx, input in enumerate(inputs): + if input.is_partial(): + num_sums *= mesh_shape.size(idx) + expected = num_sums * full_tensor + self.assertEqual(local_full, expected) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/testing/_internal/distributed/_tensor/__init__.py b/torch/testing/_internal/distributed/_tensor/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py new file mode 100644 index 000000000000..cf2abe0ee8d2 --- /dev/null +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from contextlib import contextmanager +from dataclasses import dataclass +import itertools +import sys +from functools import wraps +from typing import ( + Any, + Callable, + Generator, + Iterator, + Tuple, + Dict, + Optional, + List, + Sequence, + TypeVar, + cast, +) + +import torch +import torch.distributed as dist + +from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + TEST_SKIPS, + skip_if_lt_x_gpu, +) + +from torch.distributed._tensor import ( + DeviceMesh, + Shard, + Replicate, + distribute_tensor, + redistribute, +) +from torch.distributed._tensor.api import DTensor +from torch.distributed._tensor.placement_types import Placement + +DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu" +NUM_DEVICES = 4 + +# We use this as a proxy for "multiple GPUs exist" +if torch.cuda.is_available() and torch.cuda.device_count() > 1: + # when we actually have multiple GPUs, relax the requirement to smaller counts. + NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count()) + +T = TypeVar("T") + + +def skip_unless_torch_gpu(method: T) -> T: + """ + Test decorator which skips the test unless there's a GPU available to torch. + + >>> @skip_unless_torch_gpu + >>> def test_some_method(self) -> None: + >>> ... + """ + # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set. + return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method)) + + +@dataclass +class RedistributeProfile: + num_calls: int + + +@contextmanager +def redistribute_profiler() -> Generator[RedistributeProfile, None, None]: + + orig_redistribute_dtensor = redistribute.redistribute_dtensor + profile: RedistributeProfile = RedistributeProfile(num_calls=0) + + # pyre-ignore[53] + def patched_redistribute_dtensor( + input: DTensor, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + ) -> DTensor: + result = orig_redistribute_dtensor(input, device_mesh, placements) + profile.num_calls += 1 + return result + + try: + # pyre-ignore[9] + redistribute.redistribute_dtensor = patched_redistribute_dtensor + yield profile + finally: + redistribute.redistribute_dtensor = orig_redistribute_dtensor + + +class DTensorTestBase(MultiProcessTestCase): + @property + def world_size(self) -> int: + return NUM_DEVICES + + def build_device_mesh(self) -> DeviceMesh: + return DeviceMesh(DEVICE_TYPE, list(range(NUM_DEVICES))) + + def init_pg(self, backend: str = "nccl") -> None: + if backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + if backend not in ["nccl", "gloo", "mpi"]: + raise RuntimeError(f"Backend {backend} not supported!") + + dist.init_process_group( + backend=backend, + world_size=self.world_size, + rank=self.rank, # pyre-ignore[16] + init_method=f"file://{self.file_name}", # pyre-ignore[16] + ) + + # set device for nccl pg for collectives + if backend == "nccl": + torch.cuda.set_device(self.rank) + + def destroy_pg(self) -> None: + # Wait for all ranks to reach here before starting shutdown. + dist.barrier() + dist.destroy_process_group() + + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + # pyre-ignore[2]: + def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None: + with redistribute_profiler() as profile: + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + # pyre can't find assertTrue anymore? + self.assertEqual(dtc.successful(), True) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual( + d_out.redistribute( + mesh, [Replicate()] * mesh.ndim + ).to_local(), + out, + ) + + +# wrapper to initialize comms (processgroup) +def with_comms( + func: Optional[ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + Callable + ] = None, + backend: Optional[str] = None, +) -> Optional[ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. + Callable +]: + assert func is not None + + @wraps(func) # pyre-ignore[6] + def wrapper( + self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] + ) -> None: + # if backend not specified, and cuda available, then use nccl, else gloo + pg_backend = ( + "nccl" if backend is None and torch.cuda.is_available() else "gloo" + ) + if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + self.device_type = "cuda" if pg_backend == "nccl" else "cpu" + self.init_pg(backend=pg_backend) + func(self) # type: ignore[misc] + self.destroy_pg() + + return wrapper + + +# This is a class for converting args/kwargs of an op into distributed args/kwargs +class DTensorConverter(object): + def __init__( + self, + mesh: DeviceMesh, + args: Tuple[object, ...], + kwargs: Dict[str, object], + ) -> None: + self.hit = 0 + self.miss = 0 + self.mesh = mesh + self.args = args + self.kwargs = kwargs + flatten_args, flatten_args_spec = tree_flatten(args) + flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs) + + self.flatten_args: List[object] = flatten_args + self.flatten_args_spec: TreeSpec = flatten_args_spec + self.flatten_kwargs: List[object] = flatten_kwargs + self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec + + choices_for_args = [] + for arg in self.flatten_args: + if isinstance(arg, torch.Tensor): + choices_for_args.append(self.gen_sharding_choices_for_arg(arg)) + + for arg in self.flatten_kwargs: + if isinstance(arg, torch.Tensor): + choices_for_args.append(self.gen_sharding_choices_for_arg(arg)) + + self.sharding_combs: Iterator[Sequence[Placement]] = iter( + itertools.product(*choices_for_args) + ) + + def successful(self) -> bool: + return self.hit > 0 and self.miss == 0 + + def is_supported_tensor(self, t: torch.Tensor) -> bool: + # TODO: dist tensor need to support quantized and sparse + # tensors, quantized tensor might be relatively easy, but + # sparse tensor have special layouts that we need to possibly + # deal with, until we are clear about them, we don't officially + # support them. + return not any( + [ + t.is_sparse_csr, + t.is_sparse, + t.is_mkldnn, + t.is_quantized, + t.is_nested, + torch._is_functional_tensor(t), + t.is_neg(), + t.is_conj(), + t.device.type in ("lazy", "meta"), + # We need a way to test if a tensor is batched but there + # is no official APi to do it + # torch._C._is_batched(t), + ] + ) + + def gen_sharding_choices_for_arg( + self, arg: torch.Tensor + ) -> Sequence[Placement]: + mesh_size = self.mesh.size() + sharding_choices: List[Placement] = [Replicate()] + # c10d collective does not support bool tensor + # for bool tensor we treat it as replicated + if arg.dtype != torch.bool: + # only generating choices with: replicate, or sharding + # evenly on a dimension that could be sharded + sharding_choices = sharding_choices + [ + Shard(i) + for i, s in enumerate(arg.shape) + if s > 1 and s % mesh_size == 0 + ] + # TODO: add multi mesh choices + # all_choices = itertools.product( + # *(self.mesh.ndim * [sharding_choices]) + # ) + return sharding_choices + + def __iter__(self) -> "DTensorConverter": + return self + + def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]: + try: + next_sharding_choices = next(self.sharding_combs) + idx = 0 + + new_args: List[object] = [] + for arg in self.flatten_args: + if isinstance(arg, torch.Tensor): + new_args.append( + self.to_dist_tensor( + arg, self.mesh, [next_sharding_choices[idx]] + ) + ) + idx += 1 + else: + new_args.append(arg) + + new_kwargs: List[object] = [] + for arg in self.flatten_kwargs: + if isinstance(arg, torch.Tensor): + new_kwargs.append( + self.to_dist_tensor( + arg, self.mesh, [next_sharding_choices[idx]] + ) + ) + idx += 1 + else: + new_kwargs.append(arg) + + return ( + tree_unflatten(new_args, self.flatten_args_spec), + tree_unflatten(new_kwargs, self.flatten_kwargs_spec), + ) + except StopIteration: + raise StopIteration + + def to_dist_tensor( + self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement] + ) -> torch.Tensor: + if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: + if self.is_supported_tensor(t): + self.hit += 1 + # We cannot use distribute_tensor for bool tensors as c10d + # collectives does not support the dtype, we assume op with + # bool tensor args the same tensor so we don't need to broadcast + # TODO: add bool tensor dtype support in c10d collective + if t.dtype == torch.bool: + r = DTensor( + t, + mesh, + placements, + size=t.size(), + requires_grad=t.requires_grad, + ) + else: + r = distribute_tensor(t, mesh, placements) + if type(t) is torch.nn.Parameter: + r = torch.nn.Parameter( # type: ignore[assignment] + r, requires_grad=r.requires_grad + ) + return r + else: + self.miss += 1 + return t + elif torch.overrides.is_tensor_like(t): + # Blindly converting tensor subclasses to dist tensor can cause + # unpredictable problems, we explicitly disable this conversion + # for now (i.e. we don't support DTensor holding tensor subclass + # until there's a strong reason later). + self.miss += 1 + return t + else: + raise RuntimeError( + f"Trying to convert to DTensor, but got {type(t)}" + ) From 550a019fb85647f0bc7fe8ee231dc158b4f30d7c Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 22:51:32 +0000 Subject: [PATCH 220/453] [dtensor] PART 6: move DTensor op tests to core distributed (#88551) This PR moves DTensor op tests to core distributed, including prop_rule, pointwise op, matrix op tests, etc. part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88551 Approved by: https://github.com/aazzolini --- test/distributed/_tensor/test_common_rules.py | 476 +++++++++++++++++ test/distributed/_tensor/test_math_ops.py | 126 +++++ test/distributed/_tensor/test_matrix_ops.py | 302 +++++++++++ .../distributed/_tensor/test_pointwise_ops.py | 285 +++++++++++ .../_tensor/test_tp_sharding_ops.py | 101 ++++ test/distributed/_tensor/test_view_ops.py | 480 ++++++++++++++++++ 6 files changed, 1770 insertions(+) create mode 100644 test/distributed/_tensor/test_common_rules.py create mode 100644 test/distributed/_tensor/test_math_ops.py create mode 100644 test/distributed/_tensor/test_matrix_ops.py create mode 100644 test/distributed/_tensor/test_pointwise_ops.py create mode 100644 test/distributed/_tensor/test_tp_sharding_ops.py create mode 100644 test/distributed/_tensor/test_view_ops.py diff --git a/test/distributed/_tensor/test_common_rules.py b/test/distributed/_tensor/test_common_rules.py new file mode 100644 index 000000000000..ab9743c1d5e9 --- /dev/null +++ b/test/distributed/_tensor/test_common_rules.py @@ -0,0 +1,476 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torchgen.model import FunctionSchema +from torch.distributed._tensor.dispatch import OpSchema + +from torch.distributed._tensor.ops.common_rules import ( + einop_rule, + reduction_rule, + pointwise_rule, +) +from torch.distributed._tensor.placement_types import DTensorSpec +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import DeviceMesh + + +class CommonRulesTest(DTensorTestBase): + def parse_schema(self, schema_str): + return FunctionSchema.parse(schema_str) + + @property + def world_size(self) -> int: + # hard code world size to 4 as we need to test + # at least with 2d mesh + return 4 + + @with_comms + def test_einop_basic_propagation(self): + # plain einsum, mm + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = self.parse_schema( + "aten::mm(Tensor self, Tensor mat2) -> Tensor" + ) + # propagate col-wise sharding + mat1, mat2 = [-1, -1], [-1, 0] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([4, 8]) + ) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # propagate row-wise sharding + mat1, mat2 = [0, -1], [-1, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([4, 8]) + ) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # generate partial + mat1, mat2 = [-1, 0], [0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([4, 8]) + ) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertTrue(output_spec.placements[0].is_partial()) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + @with_comms + def test_einop_pointwise_propagation(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = self.parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + # addition + mat1 = [0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 8]) + ) + output_sharding = einop_rule( + "ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat1_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + # broadcast addition + mat1 = [-1, 0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4, 2]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, [-1], [], shape=torch.Size([2]) + ) + output_sharding = einop_rule( + "ijk,k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 4, 2])) + + # broadcast to a common shape + mat1_spec = DTensorSpec.from_dim_map( + mesh, [0, -1, -1], [], shape=torch.Size([8, 8, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, [-1, -1], [], shape=torch.Size([1, 8]) + ) + output_sharding = einop_rule( + "ijk,1k->ijk", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, -1, -1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8, 8])) + + @with_comms + def test_einop_merge_sharding(self): + # 2d mesh einop merge sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = self.parse_schema( + "aten::mm(Tensor self, Tensor mat2) -> Tensor" + ) + + mat1, mat2 = [0, -1], [-1, 1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([4, 8]) + ) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [0, 1]) + self.assertEqual(output_spec.shape, torch.Size([8, 8])) + + @with_comms + def test_einop_linearity(self): + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + mm_func_schema = self.parse_schema( + "aten::mm(Tensor self, Tensor mat2) -> Tensor" + ) + + mat1, mat2 = [0, -1], [-1, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [1], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([4, 8]) + ) + # if not turn on linearity, partial sum is not eligible to propagate, we return + # suggestion to reshard inputs with no partial sum (i.e. all_reduce one input) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(mm_func_schema, (mat1_spec, mat2_spec), {}) + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + suggested_spec = suggestions[0].args_schema[0] + self.assertFalse(suggested_spec.placements[1].is_partial()) + + # einop prop with linearity on mm, should give back suggestion + # on converting placements to partial + output_sharding = einop_rule( + "mk,kn->mn", + OpSchema(mm_func_schema, (mat1_spec, mat2_spec), {}), + linearity=True, + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + mat2_spec = suggestions[0].args_schema[1] + # mat2 mesh dim 1 should become partial now! + self.assertTrue(mat2_spec.placements[1].is_partial()) + + # einop prop with linearity on point-wise, should give back suggestion + # on converting placements to partial + add_func_schema = self.parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + mat1, mat2 = [0, -1], [0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [1], shape=torch.Size([8, 6]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([8, 6]) + ) + + output_sharding = einop_rule( + "ij,ij->ij", + OpSchema(add_func_schema, (mat1_spec, mat2_spec), {}), + linearity=True, + ) + self.assertIsNone(output_sharding.output_spec) + suggestions = output_sharding.schema_suggestions + self.assertIsNotNone(suggestions) + mat2_spec = suggestions[0].args_schema[1] + # mat2 mesh dim 1 should become partial now! + self.assertTrue(mat2_spec.placements[1].is_partial()) + + @with_comms + def test_einop_multi_sharding_on_mesh_dim(self): + # einop prop with multi sharding on same mesh dim + mesh_shape = torch.arange(self.world_size) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = self.parse_schema( + "aten::mm(Tensor self, Tensor mat2) -> Tensor" + ) + mat1, mat2 = [0, -1], [0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 12]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([12, 4]) + ) + output_sharding = einop_rule( + "mk,kn->mn", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the second + # arg by all_gather its tensor dim sharding + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(schema_suggestion.args_schema[0].dim_map, [0, -1]) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, [-1, -1]) + + @with_comms + def test_einop_errors(self): + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = self.parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + mat1, mat2 = [0, -1], [1, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([8, 4]) + ) + + with self.assertRaisesRegex( + RuntimeError, "sharded two different ways:" + ): + einop_rule( + "ij,ij->ij", OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + + @with_comms + def test_pointwise_rules_broadcasting(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = self.parse_schema( + "where.self(Tensor condition, Tensor self, Tensor other) -> Tensor" + ) + inp1, inp2, inp3 = [0], [], [-1, -1] + condition = DTensorSpec.from_dim_map( + mesh, inp1, [], shape=torch.Size([8]) + ) + self_tensor = DTensorSpec.from_dim_map( + mesh, inp2, [], shape=torch.Size([]) + ) + other_tensor = DTensorSpec.from_dim_map( + mesh, inp3, [], shape=torch.Size([1, 1]) + ) + # propagate point-wise sharding with broadcasting + output_sharding = pointwise_rule( + OpSchema(func_schema, (condition, self_tensor, other_tensor), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + self.assertEqual(output_spec.shape, [1, 8]) + + @with_comms + def test_pointwise_rules_suggestion(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = self.parse_schema( + "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor" + ) + # propagate point-wise sharding + inp1, inp2 = [-1, -1], [-1, 0] + mat1_spec = DTensorSpec.from_dim_map( + mesh, inp1, [], shape=torch.Size([8, 4]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, inp2, [], shape=torch.Size([8, 4]) + ) + # adding a positional argument -1 to arg schema + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec, -1), {}) + ) + self.assertIsNone(output_sharding.output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion from pointwise rules still have + # the positional args that are not DTensorSpec + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(len(schema_suggestion.args_schema), 3) + self.assertEqual(schema_suggestion.args_schema[2], -1) + + @with_comms + def test_pointwise_multi_sharding_on_mesh_dim(self): + # 2d mesh pointwise sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = self.parse_schema( + "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + ) + + # basic case to test implicit broadcasting shape alignment + mat1, mat2 = [-1, 0], [0] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([20, 6]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([6]) + ) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNotNone(output_spec) + self.assertEqual(output_spec.dim_map, [-1, 0]) + + # more advanced case that needs reshard one input to align sharding + mat1, mat2 = [0, -1, -1, 1], [0, -1, 1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([12, 1, 1, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([12, 4, 8]) + ) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the first + # arg by all_gather first tensor dim sharding + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual( + schema_suggestion.args_schema[0].dim_map, [-1, -1, -1, 1] + ) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat2) + + @with_comms + def test_pointwise_enforce_sharding_multi_sharding_on_mesh_dim(self): + # 2d mesh pointwise sharding + mesh_shape = torch.arange(self.world_size).reshape( + self.world_size // 2, self.world_size // 2 + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + + func_schema = self.parse_schema( + "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)" + ) + + # more advanced case that needs reshard one input to align sharding + mat1, mat2 = [0, -1, 1], [-1, -1, 0] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([12, 4, 8]) + ) + mat2_spec = DTensorSpec.from_dim_map( + mesh, mat2, [], shape=torch.Size([12, 1, 8]) + ) + output_sharding = pointwise_rule( + OpSchema(func_schema, (mat1_spec, mat2_spec), {}) + ) + output_spec = output_sharding.output_spec + self.assertIsNone(output_spec) + self.assertIsNotNone(output_sharding.schema_suggestions) + + # ensure that the suggestion is to reshard the second + # arg as we should enforce the sharding of the first arg + schema_suggestion = output_sharding.schema_suggestions[0] + self.assertEqual(schema_suggestion.args_schema[0].dim_map, mat1) + self.assertEqual(schema_suggestion.args_schema[1].dim_map, mat1) + + @with_comms + def test_reduction_rule(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + + func_schema = self.parse_schema( + "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor" + ) + # reduction on a 2d mat + mat1 = [0, -1] + mat1_spec = DTensorSpec.from_dim_map( + mesh, mat1, [], shape=torch.Size([8, 4]) + ) + # reduction on dim 0 + output_sharding_0 = reduction_rule( + OpSchema(func_schema, (mat1_spec, 0), {}), + dims=[0], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_0.output_spec) + self.assertEqual(output_sharding_0.output_spec.dim_map, [-1]) + # pending sum on dim 0 + self.assertEqual(output_sharding_0.output_spec.sums, [0]) + self.assertEqual(output_sharding_0.output_spec.shape, torch.Size([4])) + + # reduction on dim 1 + output_sharding_1 = reduction_rule( + OpSchema(func_schema, (mat1_spec, 1), {}), + dims=[1], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_1.output_spec) + self.assertEqual(output_sharding_1.output_spec.dim_map, [0]) + self.assertEqual(output_sharding_1.output_spec.sums, []) + self.assertEqual(output_sharding_1.output_spec.shape, torch.Size([8])) + + # full reduction if not specify dim + output_sharding_all_dim = reduction_rule( + OpSchema(func_schema, (mat1_spec,), {}), + dims=[0, 1], + reduction_linear=True, + ) + self.assertIsNotNone(output_sharding_all_dim.output_spec) + self.assertEqual(output_sharding_all_dim.output_spec.dim_map, []) + # pending sum on mesh + self.assertEqual(output_sharding_all_dim.output_spec.sums, [0]) + self.assertEqual( + output_sharding_all_dim.output_spec.shape, torch.Size([]) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py new file mode 100644 index 000000000000..403f22d8325e --- /dev/null +++ b/test/distributed/_tensor/test_math_ops.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests + +from torch.distributed._tensor import distribute_tensor +from torch.distributed._tensor.placement_types import Shard, Replicate +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + skip_unless_torch_gpu, +) +import itertools + + +class DistMathOpsTest(DTensorTestBase): + @with_comms + def test_sum(self): + device_mesh = self.build_device_mesh() + + shard_spec = [Shard(0)] + + tensor_to_sum = torch.randn(12, 8, 8) + + mat1 = distribute_tensor(tensor_to_sum, device_mesh, shard_spec) + + keep_dim_or_not = [True, False, None] + for dim in range(tensor_to_sum.ndim): + for keep_dim in keep_dim_or_not: + sum_args = (dim, keep_dim) if keep_dim is not None else (dim,) + dim_sumed_tensor = tensor_to_sum.sum(*sum_args) + dt_dim_sumed_tensor = mat1.sum(*sum_args).redistribute( + device_mesh, [Replicate()] * device_mesh.ndim + ) + self.assertEqual( + dt_dim_sumed_tensor.to_local(), dim_sumed_tensor + ) + + full_sumed_tensor = tensor_to_sum.sum() + dt_sum = mat1.sum().redistribute( + device_mesh, [Replicate()] * device_mesh.ndim + ) + self.assertEqual(dt_sum.to_local(), full_sumed_tensor) + + # TODO: forward test can be removed once test_softmax_with_bwd passes on CPU + @with_comms + def test_softmax_fwd(self): + device_mesh = self.build_device_mesh() + + x = torch.rand(8, 12, 16, device=self.device_type) + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for softmax_dim, shard_dim in test_list: + local_y = torch.nn.functional.softmax( + x, dim=softmax_dim, dtype=torch.float32 + ) + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + if dims[shard_dim] == dims[softmax_dim]: + with self.assertRaisesRegex( + Exception, "Cannot run .* on sharding dimension!$" + ): + dist_y = torch.nn.functional.softmax( + dist_x, dim=softmax_dim, dtype=torch.float32 + ) + else: + dist_y = torch.nn.functional.softmax( + dist_x, dim=softmax_dim, dtype=torch.float32 + ) + self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim)) + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + + # TODO: get test_softmax_with_bwd pass on CPU + # DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension. + # fail_on_cpu_list = [(0, -1), (1, -1)] + @with_comms + @skip_unless_torch_gpu + def test_softmax_with_bwd(self): + device_mesh = self.build_device_mesh() + + dims = range(3) # used to convert -1 to the actual dim + softmax_dims = [-1, 0, 1, 2] + shard_dims = [-1, 0, 1, 2] + test_list = list(itertools.product(softmax_dims, shard_dims)) + + for params in test_list: + softmax_dim, shard_dim = params + x = torch.rand( + 8, 12, 16, device=self.device_type, requires_grad=True + ) + self.assertTrue(x.requires_grad) + local_y = torch.nn.functional.softmax( + x, dim=softmax_dim, dtype=torch.float32 + ).sum() + local_y.backward() + + dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)]) + self.assertTrue(dist_x.requires_grad) + if dims[softmax_dim] == dims[shard_dim]: + with self.assertRaisesRegex( + Exception, "Cannot run .* on sharding dimension!$" + ): + dist_softmax = dist_x.softmax(dim=softmax_dim) + else: + dist_softmax = dist_x.softmax(dim=softmax_dim) + self.assertTrue( + dist_softmax.placements[0].is_shard(dim=shard_dim) + ) + dist_y = dist_softmax.sum() + dist_y = dist_y.redistribute(device_mesh, [Replicate()]) + self.assertEqual(dist_y.to_local(), local_y) + self.assertIsNone(dist_x.grad) + dist_y.backward() + self.assertIsNotNone(dist_x.grad) + dist_x_grad = dist_x.grad.redistribute( + device_mesh, [Replicate()] + ) + self.assertEqual(dist_x_grad.to_local(), x.grad) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_matrix_ops.py b/test/distributed/_tensor/test_matrix_ops.py new file mode 100644 index 000000000000..ed2af130ac88 --- /dev/null +++ b/test/distributed/_tensor/test_matrix_ops.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.distributed._tensor.api import DTensor +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + skip_unless_torch_gpu, +) +from torch.distributed._tensor import distribute_tensor, DeviceMesh +from torch.distributed._tensor.placement_types import Placement, Shard, Replicate, _Partial +from typing import List, Optional, cast +import itertools + + +class DistMatrixOpsTest(DTensorTestBase): + @with_comms + def test_addmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + replica_spec = [Replicate()] + + tensor_to_shard = torch.randn(12, 8) + mat1 = distribute_tensor(tensor_to_shard, device_mesh, shard_spec) + tensor_to_replicate = torch.randn(8, 4) + mat2 = distribute_tensor(tensor_to_replicate, device_mesh, replica_spec) + input_tensor = torch.randn(4) + input = distribute_tensor(input_tensor, device_mesh, replica_spec) + + dist_res = torch.addmm(input, mat1, mat2) + local_res = torch.addmm( + input_tensor, tensor_to_shard, tensor_to_replicate + ) + self.assertEqual( + dist_res.redistribute(device_mesh, replica_spec).to_local(), + local_res, + ) + + @with_comms + def test_addmm_auto_redistribute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = [Shard(0)] + shard1_spec = [Shard(1)] + replica_spec = [Replicate()] + + tensor_to_shard1 = torch.randn(12, 8, requires_grad=True) + mat1 = distribute_tensor(tensor_to_shard1, device_mesh, shard1_spec) + tensor_to_shard0 = torch.randn(8, 4, requires_grad=True) + mat2 = distribute_tensor(tensor_to_shard0, device_mesh, shard0_spec) + input_tensor = torch.randn(4, requires_grad=True) + input = distribute_tensor(input_tensor, device_mesh, replica_spec) + + local_res = torch.addmm( + input_tensor, tensor_to_shard1, tensor_to_shard0 + ) + dist_res = torch.addmm(input, mat1, mat2) + + # test if addmm output is a partial + self.assertIsInstance(dist_res, DTensor) + self.assertIsInstance(dist_res.placements[0], _Partial) + + # test if result is the same as tensor + replica_res = dist_res.redistribute(device_mesh, replica_spec) + dist_local_res = replica_res.to_local() + self.assertEqual(local_res, dist_local_res) + + # backward checks + dist_local_res.sum().backward() + local_res.sum().backward() + self.assertIsNotNone(mat2.grad) + mat2_grad = mat2.grad.redistribute(device_mesh, replica_spec) + self.assertEqual(mat2_grad.to_local(), tensor_to_shard0.grad) + + @with_comms + def test_mm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard0_spec = Shard(0) + shard1_spec = Shard(1) + replica_spec = Replicate() + + t1 = torch.randn(12, 8, requires_grad=True) + t2 = torch.randn(8, 16, requires_grad=True) + local_res = torch.mm(t1, t2) + + def test_placement_comb( + placements1: List[Placement], placements2: List[Placement] + ) -> None: + dt1 = distribute_tensor(t1, device_mesh, placements1) + dt2 = distribute_tensor(t2, device_mesh, placements2) + dist_res: DTensor = cast(DTensor, torch.mm(dt1, dt2)).redistribute( + device_mesh, [replica_spec] + ) + self.assertEqual(dist_res.to_local(), local_res) + # backward + grad_dist_res = torch.ones_like(dist_res) + dist_res.backward(grad_dist_res) + self.assertIsNotNone(dt1.grad) + + placement_specs = [shard0_spec, shard1_spec, replica_spec] + shard_specs_comb = list( + itertools.product(placement_specs, placement_specs) + ) + for spec in shard_specs_comb: + test_placement_comb([spec[0]], [spec[1]]) + + @with_comms + def test_t(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + tensor_to_transpose = torch.randn(12, 8, requires_grad=True) + mat = distribute_tensor(tensor_to_transpose, device_mesh, shard_spec) + tranposed_mat = mat.t() + self.assertEqual(tranposed_mat.size(), torch.Size([8, 12])) + self.assertEqual(tranposed_mat.placements, [Shard(1)]) + tranposed_mat2 = tranposed_mat.t() + self.assertEqual(tranposed_mat2.size(), torch.Size([12, 8])) + self.assertEqual(tranposed_mat2.placements, shard_spec) + + @with_comms + def test_t_partial(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + a = torch.randn(12, 8) + b = torch.randn(8, 4) + c = torch.mm(a, b).t() + + da = distribute_tensor(a, device_mesh, [Shard(1)]) + db = distribute_tensor(b, device_mesh, [Shard(0)]) + + # mm(da, db) should return a _Partial tensor. + # transposing it should keep it _Partial + dc = torch.mm(da, db).t() + + self.assertTrue(isinstance(dc.placements[0], _Partial)) + + # check that the local and distributed op results match + self.assertEqual( + c, + dc.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_baddbmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + tensor = torch.rand( + 4, 4, 8, device=self.device_type, requires_grad=True + ) + batch_1 = torch.rand( + 4, 4, 8, device=self.device_type, requires_grad=True + ) + batch_2 = torch.rand( + 4, 8, 8, device=self.device_type, requires_grad=True + ) + + def test_placement_comb( + tensor_placements: List[Placement], + batch_1_placements: List[Placement], + batch_2_placements: List[Placement], + beta: int, + alpha: int, + batch_1_grad: Optional[torch.Tensor], + ) -> None: + tensor_dt = distribute_tensor( + tensor, device_mesh, tensor_placements + ) + batch_1_dt = distribute_tensor( + batch_1, device_mesh, batch_1_placements + ) + batch_2_dt = distribute_tensor( + batch_2, device_mesh, batch_2_placements + ) + dist_res = cast( + DTensor, + torch.baddbmm( + tensor_dt, batch_1_dt, batch_2_dt, beta=beta, alpha=alpha + ), + ).redistribute(device_mesh, [Replicate()]) + dist_local_res = dist_res.to_local() + assert not torch.isnan(local_result).any() + assert not torch.isnan(dist_local_res).any() + self.assertEqual(dist_local_res.detach(), local_result.detach()) + + # TODO: add test backward + # grad_dist_res = torch.ones_like(dist_res) + # dist_res.backward(grad_dist_res) + # self.assertIsNotNone(batch_1_dt.grad) + # batch_1_grad_local = batch_1_dt.grad.redistribute( + # device_mesh, [Replicate()] + # ).to_local() + # self.assertEqual(batch_1_grad_local, batch_1_grad) + + shard0_spec = Shard(0) + shard1_spec = Shard(1) + shard2_spec = Shard(2) + replica_spec = Replicate() + shard_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] + shard_specs_comb = list( + itertools.product(shard_specs, shard_specs, shard_specs) + ) + passlist = [ + (shard0_spec, shard0_spec, shard0_spec), + (shard0_spec, shard0_spec, replica_spec), + (shard0_spec, shard1_spec, shard0_spec), + (shard0_spec, shard2_spec, shard0_spec), + (shard1_spec, shard1_spec, replica_spec), + (shard0_spec, replica_spec, shard0_spec), + (shard2_spec, replica_spec, shard2_spec), + (shard2_spec, shard0_spec, shard2_spec), + (shard2_spec, shard1_spec, shard2_spec), + (shard2_spec, shard2_spec, shard2_spec), + (replica_spec, shard0_spec, shard0_spec), + (replica_spec, shard1_spec, replica_spec), + (replica_spec, shard2_spec, shard1_spec), + (replica_spec, replica_spec, shard2_spec), + (replica_spec, replica_spec, replica_spec), + ] + # If beta is 0, input tensor will be ignored + numeric_params_comb = [ + (0.0, 0.5), # zero-beta + (0.8, 0.5), # non-zero-beta + ] + + for beta, alpha in numeric_params_comb: + local_result = torch.baddbmm( + tensor, batch_1, batch_2, beta=beta, alpha=alpha + ) + grad_local_res = torch.ones_like(local_result) + local_result.backward(grad_local_res) + # tests that currently pass + for spec in passlist: + test_placement_comb( + [spec[0]], [spec[1]], [spec[2]], beta, alpha, batch_1.grad + ) + # TODO: support these tests + shard_specs_comb = [ + spec for spec in shard_specs_comb if spec not in passlist + ] + for spec in shard_specs_comb: + with self.assertRaises(Exception): + test_placement_comb( + [spec[0]], + [spec[1]], + [spec[2]], + beta, + alpha, + batch_1.grad, + ) + + @with_comms + def test_bmm(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + mat1 = torch.rand(4, 8, 4, device=self.device_type, requires_grad=True) + mat2 = torch.rand(4, 4, 8, device=self.device_type, requires_grad=True) + local_result = torch.bmm(mat1, mat2) + grad_local_res = torch.ones_like(local_result) + local_result.backward(grad_local_res) + + def test_placement_comb( + placements1: List[Placement], + placements2: List[Placement], + ) -> None: + mat1_dt = distribute_tensor(mat1, device_mesh, placements1) + mat2_dt = distribute_tensor(mat2, device_mesh, placements2) + dist_res = cast(DTensor, torch.bmm(mat1_dt, mat2_dt)).redistribute( + device_mesh, [Replicate()] + ) + dist_local_res = dist_res.to_local() + self.assertEqual(dist_local_res, local_result) + + # test backward + # TODO: figure out (replicate, shard1) fail on backward + # it generates a different grad shape + grad_dist_res = torch.ones_like(dist_res) + dist_res.backward(grad_dist_res) + self.assertIsNotNone(mat1_dt.grad) + mat1_dt_grad = cast(DTensor, mat1_dt.grad) + mat1_grad_local = mat1_dt_grad.redistribute( + device_mesh, [Replicate()] + ).to_local() + self.assertEqual(mat1_grad_local, mat1.grad) + + shard0_spec = Shard(0) + shard1_spec = Shard(1) + shard2_spec = Shard(2) + replica_spec = Replicate() + placement_specs = [shard0_spec, shard1_spec, shard2_spec, replica_spec] + shard_specs_comb = list( + itertools.product(placement_specs, placement_specs) + ) + + # tests that currently pass + for spec in shard_specs_comb: + test_placement_comb([spec[0]], [spec[1]]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py new file mode 100644 index 000000000000..5069166dee27 --- /dev/null +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +from typing import Sequence, Any, Dict, Callable, Optional +from unittest import skip + +import torch +from torch import Tensor +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + skip_unless_torch_gpu, +) + +from torch.distributed._tensor import DeviceMesh, DTensor, distribute_tensor +from torch.distributed._tensor.placement_types import ( + Shard, + Replicate, + _Partial, + Placement, +) +from torch.distributed.distributed_c10d import ReduceOp + +import torch.utils._pytree as pytree + + +def no_op(): + return None + + +def deepcopy_convert_to_dtensor( + val: Any, + device_mesh: DeviceMesh, + placements: Sequence[Placement], +) -> Any: + """ + Recursively convert (over Sequence and Dict types) Tensors into DTensors. + + :param device_mesh: the DeviceMesh to use. + :param placements: the Placement list to use. + :return: the transformed structure. + """ + + def f(x): + if isinstance(x, Tensor) and not isinstance(x, DTensor): + return distribute_tensor( + x, + device_mesh=device_mesh, + placements=placements, + ) + return x + + return pytree.tree_map(f, [val])[0] + + +def deepcopy_convert_from_dtensor(val: Any) -> Any: + """ + Recursive convert any DTensor to local Tensor. + + :param val: the structure to coerce. + :return: the coerced structure. + """ + + def f(x): + if isinstance(x, DTensor): + return x.redistribute( + device_mesh=x.device_mesh, + placements=[Replicate()] * x.device_mesh.ndim, + ).to_local() + return x + + return pytree.tree_map(f, [val])[0] + + +class DistElementwiseOpsTest(DTensorTestBase): + def _compare_pairwise_ops( + self, + *, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + op: Callable, + pre_op_fn: Optional[Callable] = None, + args: Sequence[Any] = tuple(), + kwargs: Optional[Dict[str, Any]] = None, + ): + if pre_op_fn is None: + pre_op_fn = no_op + + if not kwargs: + kwargs = {} + + dargs = deepcopy_convert_to_dtensor( + args, + device_mesh=device_mesh, + placements=placements, + ) + dkwargs = deepcopy_convert_to_dtensor( + kwargs, + device_mesh=device_mesh, + placements=placements, + ) + + pre_op_fn() + + # run the reference first, in case the call is broken; + # it's better to debug an incorrect call at this point. + reference_result = op(*args, **kwargs) + + pre_op_fn() + + dist_result = op(*dargs, **dkwargs) + + collected_result = deepcopy_convert_from_dtensor(dist_result) + + self.assertEqual(reference_result, collected_result) + + # TODO: We need to add CPU tests for ops in the future. + def _run_sharded_elementwise_ops( + self, + *, + device_mesh: DeviceMesh, + placements: Sequence[Placement], + pre_op_fn: Optional[Callable] = None, + input_size: Sequence[int], + op: Callable, + **kwargs, + ): + if pre_op_fn is None: + pre_op_fn = no_op + + input_tensor = torch.randn( + *input_size, + device=self.device_type, + requires_grad=True, + ) + + self._compare_pairwise_ops( + device_mesh=device_mesh, + placements=placements, + pre_op_fn=pre_op_fn, + op=op, + args=(input_tensor,), + kwargs=kwargs, + ) + + @with_comms + def test_activations(self): + device_mesh = self.build_device_mesh() + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.nn.functional.gelu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.nn.functional.gelu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(1)], + input_size=(3, 12), + op=torch.nn.functional.relu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.nn.functional.relu, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.sigmoid, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Replicate()], + input_size=(8, 5), + op=torch.sigmoid, + ) + + @with_comms + @skip( + "testing RNG based ops is broken: https://github.com/pytorch/tau/issues/494" + ) + def test_dropout(self): + device_mesh = self.build_device_mesh() + + def _reset_random_seed(): + torch.manual_seed(self.rank + 4) + + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(0)], + input_size=(8, 5), + op=torch.nn.functional.dropout, + pre_op_fn=_reset_random_seed, + p=0.4, + training=False, + ) + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[Shard(1)], + input_size=(3, 14), + op=torch.nn.functional.dropout, + pre_op_fn=_reset_random_seed, + p=0.5, + training=True, + ) + + @with_comms + @skip_unless_torch_gpu + def test_dropout_backward(self): + device_mesh = self.build_device_mesh() + placements = [Shard(0)] + + input_size = (8, 5) + + grad_output = torch.rand( + input_size, + device=self.device_type, + requires_grad=True, + ) + mask = ( + torch.rand( + input_size, + device=self.device_type, + requires_grad=False, + ) + < 0.8 + ) + + self._compare_pairwise_ops( + device_mesh=device_mesh, + placements=placements, + op=torch.ops.aten.native_dropout_backward, + kwargs=dict( + grad_output=grad_output, + mask=mask, + scale=0.3, + ), + ) + + @with_comms + def test_dropout_errors(self): + device_mesh = self.build_device_mesh() + with self.assertRaisesRegex(RuntimeError, "supported"): + self._run_sharded_elementwise_ops( + device_mesh=device_mesh, + placements=[_Partial(ReduceOp.SUM)], + input_size=(8, 5), + op=torch.nn.functional.dropout, + ) + + @with_comms + def test_mul_out(self): + device_mesh = self.build_device_mesh() + torch.manual_seed(self.rank) + shard_spec = [Shard(0)] + input_size = (8, 4) + input_tensor = torch.randn(*input_size, device=self.device_type) + dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + + other_tensor = torch.randn(*input_size, device=self.device_type) + other_dtensor = DTensor.from_local( + other_tensor, device_mesh, shard_spec + ) + + output_tensor = torch.randn(*input_size, device=self.device_type) + output_dtensor = DTensor.from_local( + output_tensor, device_mesh, shard_spec + ) + dt = torch.mul(dtensor, other_dtensor, out=output_dtensor) + expected = torch.mul(input_tensor, other_tensor, out=output_tensor) + self.assertEqual(input_tensor, dtensor.to_local()) + self.assertEqual(expected, dt.to_local()) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_tp_sharding_ops.py b/test/distributed/_tensor/test_tp_sharding_ops.py new file mode 100644 index 000000000000..acd28fe3a306 --- /dev/null +++ b/test/distributed/_tensor/test_tp_sharding_ops.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import DeviceMesh, DTensor, Shard, Replicate, distribute_tensor + + +class TPShardingOpsTest(DTensorTestBase): + @property + def world_size(self) -> int: + return 4 + + @with_comms + def test_sharded_view(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(0) + tensor = torch.rand(16, 35, 26) + sharding = [Shard(0)] + st = distribute_tensor(tensor, device_mesh, sharding).view(8, 4, 35, 13) + st_new = distribute_tensor( + tensor.view(8, 4, 35, 13), device_mesh, sharding + ) + self.assertEqual(st.to_local(), st_new.to_local()) + self.assertEqual(st.placements[0], st_new.placements[0]) + + @with_comms + def test_sharded_transpose(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.transpose(0, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=2)) + self.assertEqual(new_dt.to_local(), tensor.transpose(0, 2)) + new_dt = dist_tensor.transpose(1, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=0)) + self.assertEqual(new_dt.to_local(), tensor.transpose(1, 2)) + + @with_comms + def test_sharded_permute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.permute(1, 0, 2) + self.assertTrue(new_dt.placements[0].is_shard(dim=1)) + self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2)) + + @with_comms + def test_replicated_permute(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(0) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Replicate()] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + new_dt = dist_tensor.permute(1, 0, 2) + self.assertTrue(new_dt.placements[0].is_replicate()) + self.assertEqual(new_dt.to_local(), tensor.permute(1, 0, 2)) + self.assertEqual(new_dt.stride(), tensor.permute(1, 0, 2).stride()) + + @with_comms + def test_sharded_cat(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor_1 = torch.rand(3, 5, 6) + tensor_2 = torch.rand(3, 5, 6) + tensor_3 = torch.rand(3, 5, 6) + sharding = [Shard(0)] + dt_1 = DTensor.from_local(tensor_1, device_mesh, sharding) + dt_2 = DTensor.from_local(tensor_2, device_mesh, sharding) + dt_3 = DTensor.from_local(tensor_3, device_mesh, sharding) + new_dt = torch.cat([dt_1, dt_2, dt_3]) + cat_dt = DTensor.from_local( + torch.cat([tensor_1, tensor_2, tensor_3]), device_mesh, sharding + ) + self.assertEqual(new_dt.to_local(), cat_dt.to_local()) + self.assertEqual(new_dt.size(), cat_dt.size()) + + @with_comms + def test_sharded_split(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(2)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + dt_list = dist_tensor.split(dist_tensor.size(-1) // 2, dim=-1) + local_tensors = tensor.split(3, dim=-1) + for idx, dt in enumerate(dt_list): + self.assertTrue(dt.placements[0].is_shard(dim=2)) + self.assertEqual(dt.to_local(), local_tensors[idx]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_view_ops.py b/test/distributed/_tensor/test_view_ops.py new file mode 100644 index 000000000000..c1c5a03b9113 --- /dev/null +++ b/test/distributed/_tensor/test_view_ops.py @@ -0,0 +1,480 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +from typing import List, cast +from torch.distributed._tensor.placement_types import Placement +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + redistribute_profiler, + with_comms, +) +from torch.distributed._tensor import DeviceMesh, Shard, Replicate, distribute_tensor +from torch.distributed._tensor.ops.view_ops import ( + ops, + Singleton, + Broadcast, + Flatten, + Repeat, + Split, + InputDim, + view_groups, +) +from torch import Tensor, rand, randn +from torch.testing._internal.common_utils import run_tests +from torch.utils._pytree import tree_flatten + +import itertools +import torch +import torch.distributed as dist + + +class TestViewOps(DTensorTestBase): + def test_view_groups(self): + self.assertEquals( + view_groups([2, 3], [3, 2]), + ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ), + ) + self.assertEquals( + view_groups([3, 4, 5], [12, 5]), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + self.assertEquals( + view_groups([2, 3, 4, 5, 7], [12, 70]), + ( + Split( + Flatten( + ( + InputDim(0), + InputDim(1), + InputDim(2), + InputDim(3), + InputDim(4), + ) + ), + (12, 70), + 0, + ), + Split( + Flatten( + ( + InputDim(0), + InputDim(1), + InputDim(2), + InputDim(3), + InputDim(4), + ) + ), + (12, 70), + 1, + ), + ), + ) + self.assertEquals( + view_groups([2, 3, 4, 5, 7], [3, 8, 7, 5]), + ( + Split( + Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 0 + ), + Split( + Flatten((InputDim(0), InputDim(1), InputDim(2))), (3, 8), 1 + ), + Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 0), + Split(Flatten((InputDim(3), InputDim(4))), (7, 5), 1), + ), + ) + self.assertEquals( + view_groups([3, 4, 8, 3], [12, 4, 2, 3]), + ( + Flatten((InputDim(0), InputDim(1))), + Split(InputDim(2), (4, 2), 0), + Split(InputDim(2), (4, 2), 1), + InputDim(3), + ), + ) + self.assertEquals( + view_groups([3, 24], [1, 3, 2, 4, 1, 3, 1]), + ( + Singleton(), + InputDim(0), + Split(InputDim(1), (2, 4, 3), 0), + Split(InputDim(1), (2, 4, 3), 1), + Singleton(), + Split(InputDim(1), (2, 4, 3), 2), + Singleton(), + ), + ) + self.assertEquals( + view_groups([1, 1, 3, 2, 1, 1], [6, 1, 1, 1]), + ( + Flatten((InputDim(2), InputDim(3))), + Singleton(), + Singleton(), + Singleton(), + ), + ) + self.assertEquals( + view_groups([1, 1, 12, 1, 1, 1, 2, 5, 1], [3, 4, 1, 10]), + ( + Split(InputDim(2), (3, 4), 0), + Split(InputDim(2), (3, 4), 1), + Singleton(), + Flatten((InputDim(6), InputDim(7))), + ), + ) + self.assertEquals( + view_groups([2, 3, 4], [2, -1, 4]), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + @property + def world_size(self) -> int: + return 6 + + def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh): + spec = ops[op] + rules = spec.dim_map(*args, **kwargs) + outputs = op(*args, **kwargs) + flat_args, _ = tree_flatten(args) + in_shape = flat_args[0].shape + + no_shard_dims = set() + for rule in rules: + if isinstance(rule, Repeat): + if isinstance(rule.input_dim, InputDim): + no_shard_dims.add(rule.input_dim.input_dim) + elif isinstance(rule, Flatten): + for dim in rule.input_dims[1:]: + if isinstance(dim, InputDim): + no_shard_dims.add(dim.input_dim) + elif isinstance(rule, Split): + if isinstance(rule.input_dim, Flatten): + for dim in rule.input_dim.input_dims[1:]: + if isinstance(dim, InputDim): + no_shard_dims.add(dim.input_dim) + + if op == torch.unbind: + no_shard_dims.add(kwargs.get("dim", 0)) + + sharding_choices = cast(List[Placement], [Replicate()]) + [ + Shard(i) + for i, s in enumerate(in_shape) + if s > 1 and i not in no_shard_dims + ] + + all_sharding_choices = itertools.product( + *(device_mesh.ndim * [sharding_choices]) + ) + + for in_shard in all_sharding_choices: + # print(f' |--- {in_shard}') + in_dt = distribute_tensor(args[0], device_mesh, in_shard) + + with redistribute_profiler() as profiler: + out_dt = op(in_dt, *args[1:], **kwargs) + + self.assertEqual( + profiler.num_calls, 0, "Expected no redistribution." + ) + + full_out = out_dt.redistribute( + device_mesh, device_mesh.ndim * [Replicate()] + ).to_local() + + if dist.get_rank() == 0: + self.assertEqual(outputs, full_out) + + def dimmap_test(self, op, args, expected_rule_output): + rules = ops[op].dim_map(*args) + self.assertEquals(rules, expected_rule_output) + self.call_dt_test(op, args, {}, self.device_mesh) + + @with_comms + def test_view_ops(self): + self.device_mesh = DeviceMesh( + self.device_type, torch.arange(dist.get_world_size()).view(-1, 2) + ) + self.dimmap_test(torch.atleast_1d, (randn(()),), (Singleton(),)) + self.dimmap_test(torch.atleast_1d, (randn(24),), (InputDim(0),)) + self.dimmap_test( + torch.atleast_1d, (randn(24, 36),), (InputDim(0), InputDim(1)) + ) + + self.dimmap_test( + torch.atleast_2d, (randn(()),), (Singleton(), Singleton()) + ) + self.dimmap_test( + torch.atleast_2d, (randn(24),), (Singleton(), InputDim(0)) + ) + self.dimmap_test( + torch.atleast_2d, (randn(24, 36),), (InputDim(0), InputDim(1)) + ) + self.dimmap_test( + torch.atleast_2d, + (randn(24, 36, 48),), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + torch.atleast_3d, + (randn(()),), + (Singleton(), Singleton(), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24),), + (Singleton(), InputDim(0), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36),), + (InputDim(0), InputDim(1), Singleton()), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36, 42),), + (InputDim(0), InputDim(1), InputDim(2)), + ) + self.dimmap_test( + torch.atleast_3d, + (randn(24, 36, 42, 24),), + (InputDim(0), InputDim(1), InputDim(2), InputDim(3)), + ) + + with self.assertRaises(AssertionError): + ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4)) + + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (1, 24, 36)), + (Singleton(), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (42, 24, 36)), + (Broadcast(Singleton(), 42), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 1, 36), (12, 24, 24, 36)), + ( + Broadcast(Singleton(), 12), + InputDim(0), + Broadcast(InputDim(1), 24), + InputDim(2), + ), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 36), (-1, 36)), + (InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.broadcast_to, + (rand(24, 1, 36), (-1, 1, 36)), + (InputDim(0), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + torch.broadcast_to, + (randn(36, 1, 24), (12, 36, 42, 24)), + ( + Broadcast(Singleton(), 12), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + ), + ) + + self.dimmap_test( + Tensor.expand, + (randn(24, 1, 36, 1), 36, 24, 42, -1, 24), + ( + Broadcast(Singleton(), 36), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + Broadcast(InputDim(3), 24), + ), + ) + + self.dimmap_test( + Tensor.expand, + (randn(24, 1, 36, 1), (36, 24, 42, -1, 24)), + ( + Broadcast(Singleton(), 36), + InputDim(0), + Broadcast(InputDim(1), 42), + InputDim(2), + Broadcast(InputDim(3), 24), + ), + ) + + self.dimmap_test( + torch.flatten, + (randn(24, 36),), + (Flatten((InputDim(0), InputDim(1))),), + ) + self.dimmap_test(torch.flatten, (randn(42),), (InputDim(0),)) + self.dimmap_test(torch.flatten, (randn(()),), (Singleton(),)) + + self.dimmap_test( + torch.movedim, + (randn(12, 24, 48, 96), 1, 2), + (InputDim(0), InputDim(2), InputDim(1), InputDim(3)), + ) + self.dimmap_test( + torch.movedim, + (randn(6, 12, 24), 1, 0), + (InputDim(1), InputDim(0), InputDim(2)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 12, 6), (1, 2), (0, 1)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 6, 12), (0, 2, 1), (2, 1, 0)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(24, 12), (1, 0), (0, 1)), + (InputDim(1), InputDim(0)), + ) + + self.dimmap_test( + torch.movedim, + (randn(36, 24, 12), (1, 2), (0, 1)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + self.dimmap_test( + torch.movedim, + (randn(36, 24, 12), (1, 2), (-3, -2)), + (InputDim(1), InputDim(2), InputDim(0)), + ) + + self.dimmap_test( + torch.permute, + (randn(24, 36, 42), (2, 0, 1)), + (InputDim(2), InputDim(0), InputDim(1)), + ) + self.dimmap_test( + torch.permute, + (randn(24, 36, 42), (-1, -3, -2)), + (InputDim(2), InputDim(0), InputDim(1)), + ) + + self.dimmap_test( + torch.ravel, + (randn(24, 36),), + (Flatten((InputDim(0), InputDim(1))),), + ) + self.dimmap_test(torch.ravel, (randn(42),), (InputDim(0),)) + self.dimmap_test(torch.ravel, (randn(()),), (Singleton(),)) + + self.dimmap_test( + Tensor.repeat, + (randn(24, 36), 1, 2, 1, 1, 2), + ( + Singleton(), + Broadcast(Singleton(), 2), + Singleton(), + InputDim(0), + Repeat(InputDim(1), 2), + ), + ) + + self.dimmap_test( + torch.reshape, + (randn(6, 12, 24), (72, 24)), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + + self.dimmap_test( + torch.tile, + (randn(24, 36), (1, 2, 1, 1, 2)), + ( + Singleton(), + Broadcast(Singleton(), 2), + Singleton(), + InputDim(0), + Repeat(InputDim(1), 2), + ), + ) + self.dimmap_test( + torch.tile, + (randn(42, 24, 36), (1, 3)), + (InputDim(0), InputDim(1), Repeat(InputDim(2), 3)), + ) + + self.dimmap_test( + torch.transpose, + (randn(24, 60, 42, 60), 2, 0), + (InputDim(2), InputDim(1), InputDim(0), InputDim(3)), + ) + self.dimmap_test( + torch.transpose, + (randn(24, 60, 42, 60), -1, 0), + (InputDim(3), InputDim(1), InputDim(2), InputDim(0)), + ) + + self.dimmap_test( + torch.unsqueeze, + (randn(42, 24, 36), 1), + (InputDim(0), Singleton(), InputDim(1), InputDim(2)), + ) + + self.dimmap_test( + Tensor.view, + (randn(6, 12, 24), 72, 24), + (Flatten((InputDim(0), InputDim(1))), InputDim(2)), + ) + + self.dimmap_test(Tensor.view, (randn(1, 1, 12), -1), (InputDim(2),)) + + self.dimmap_test( + Tensor.view, + (randn(1, 1, 42, 24), -1), + (Flatten((InputDim(2), InputDim(3))),), + ) + + self.dimmap_test( + Tensor.view, + (randn(1, 1, 42, 1, 24, 1), -1), + (Flatten((InputDim(2), InputDim(4))),), + ) + + self.dimmap_test( + Tensor.view, + (randn(48, 35, 26), (24, 4, 35, 13)), + ( + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=0, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=1, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=2, + ), + Split( + Flatten(input_dims=(InputDim(0), InputDim(1), InputDim(2))), + group_shape=(24, 4, 35, 13), + split_id=3, + ), + ), + ) + + +if __name__ == "__main__": + run_tests() From 0230e52b541358cec075b9b9f3e6286d3964848f Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 22:51:33 +0000 Subject: [PATCH 221/453] [dtensor] PART 7: move remaining DTensor tests to core distributed (#88179) This PR moves remaining tests, i.e. tensor_ops, op db tests to core distributed part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88179 Approved by: https://github.com/aazzolini --- test/distributed/_tensor/test_dtensor_ops.py | 704 ++++++++++++++++++ test/distributed/_tensor/test_tensor_ops.py | 365 +++++++++ .../_tensor/dtensor_lagging_op_db.py | 661 ++++++++++++++++ .../_tensor/gen_dtensor_lagging_op_db.py | 67 ++ 4 files changed, 1797 insertions(+) create mode 100644 test/distributed/_tensor/test_dtensor_ops.py create mode 100644 test/distributed/_tensor/test_tensor_ops.py create mode 100644 torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py create mode 100644 torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py new file mode 100644 index 000000000000..22ae5807d5f3 --- /dev/null +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -0,0 +1,704 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +import sys +import unittest +import warnings + +from torch.overrides import resolve_name +from torch.utils._pytree import tree_flatten, tree_map +from torch.testing._internal.common_utils import ( + suppress_warnings, + TEST_WITH_ASAN, + run_tests, +) +import torch.distributed as dist +from torch.testing._internal.common_device_type import ( + ops, + instantiate_device_type_tests, +) +import torch.testing._internal.common_methods_invocations as common_ops +from torch.testing._internal.common_methods_invocations import DecorateInfo + +from torch.distributed._tensor import DTensor, DeviceMesh, Replicate +from torch.testing._internal.distributed._tensor.dtensor_lagging_op_db import dtensor_lagging_op_db +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + TEST_SKIPS, + DTensorConverter, + DEVICE_TYPE, +) + +# rewrite common size variables to sth can be sharded evenly +# we can enable uneven shards later, but need to adjust more on +# sample inputs (i.e. view/reshape need to adjust shape size as well) +common_ops.L = 24 +common_ops.M = 12 +common_ops.S = 4 +common_ops.XS = 2 + + +def assert_ref_dtensor_equal(test_case, dtensor_rs, rs): + flat_dtensor_rs, _ = tree_flatten(dtensor_rs) + flat_rs, _ = tree_flatten(rs) + test_case.assertEqual(len(flat_dtensor_rs), len(flat_rs)) + for dtensor_r, r in zip(flat_dtensor_rs, flat_rs): + + if not isinstance(r, torch.Tensor): + continue + + test_case.assertIsInstance(dtensor_r, torch.Tensor) + test_case.assertEqual( + dtensor_r.shape, + r.shape, + f"Shape mismatch! original shape:{r.shape}, dtensor shape: {dtensor_r.shape}", + ) + test_case.assertEqual( + dtensor_r.requires_grad, + r.requires_grad, + "op result requires_grad mismatch!" + f"original requires_grad: {r.requires_grad}, " + f"dtensor requires_grad: {dtensor_r.requires_grad}", + ) + + test_case.assertEqual(dtensor_r.to_local(), r) + + +# Copied from functorch +def xfail(op_name, variant_name="", *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, True) + + +def skip(op_name, variant_name="", *, device_type=None, dtypes=None): + return (op_name, variant_name, device_type, dtypes, False) + + +def skipOps(test_case_name, base_test_name, to_skip): + all_opinfos = dtensor_lagging_op_db + for xfail in to_skip: + op_name, variant_name, device_type, dtypes, expected_failure = xfail + matching_opinfos = [ + o + for o in all_opinfos + if o.name == op_name and o.variant_test_name == variant_name + ] + assert len(matching_opinfos) >= 1, f"Couldn't find OpInfo for {xfail}" + for opinfo in matching_opinfos: + decorators = list(opinfo.decorators) + if expected_failure: + decorator = DecorateInfo( + unittest.expectedFailure, + test_case_name, + base_test_name, + device_type=device_type, + dtypes=dtypes, + ) + decorators.append(decorator) + else: + decorator = DecorateInfo( + unittest.skip("Skipped!"), + test_case_name, + base_test_name, + device_type=device_type, + dtypes=dtypes, + ) + decorators.append(decorator) + opinfo.decorators = tuple(decorators) + + # This decorator doesn't modify fn in any way + def wrapped(fn): + return fn + + return wrapped + + +# Re-generate this failed list, turn on dry_run of the below func +# check_dtensor_func(self, test, op, dry_run=True), then run sth +# like python test/spmd/tensor/test_dtensor_ops.py > failed.expect +dtensor_fails = { + # these sometimes pass and sometimes fail + # we need to remove many of them from list once op + # get full support with varying sharding specs + xfail("__getitem__"), + xfail("__rsub__"), + xfail("masked.amax"), + xfail("masked.amin"), + xfail("masked.argmax"), + xfail("masked.argmin"), + xfail("masked.cumprod"), + xfail("masked.cumsum"), + xfail("masked.log_softmax"), + xfail("masked.logaddexp"), + xfail("masked.logsumexp"), + xfail("masked.median"), + xfail("masked.norm"), + xfail("masked.prod"), + xfail("masked.softmin"), + xfail("masked.softmax"), + xfail("masked.sum"), + xfail("addbmm"), + xfail("addmv"), + xfail("addr"), + xfail("all"), + xfail("allclose"), + xfail("amax"), + xfail("amin"), + xfail("aminmax"), + xfail("any"), + xfail("arange"), + xfail("argmax"), + xfail("argmin"), + xfail("argsort"), + xfail("as_strided"), + xfail("as_strided_scatter"), + xfail("baddbmm"), + xfail("bernoulli"), + xfail("block_diag"), + xfail("broadcast_shapes"), + xfail("cat"), + xfail("cartesian_prod"), + xfail("cdist"), + xfail("cholesky"), + xfail("cholesky_inverse"), + xfail("cholesky_solve"), + xfail("chunk"), + xfail("clamp"), + xfail("clamp_max"), + xfail("clamp_min"), + xfail("column_stack"), + xfail("combinations"), + xfail("complex"), + xfail("constant_pad_nd"), + xfail("copysign"), + xfail("corrcoef"), + xfail("count_nonzero"), + xfail("cov"), + xfail("cross"), + xfail("cummax"), + xfail("cummin"), + xfail("cumsum"), + xfail("cumulative_trapezoid"), + xfail("diag"), + xfail("diag_embed"), + xfail("diagflat"), + xfail("diagonal"), + xfail("diagonal_copy"), + xfail("diagonal_scatter"), + xfail("diff"), + xfail("dist"), + xfail("dot"), + xfail("dstack"), + xfail("einsum"), + xfail("empty"), + xfail("empty_like"), + xfail("eq"), + xfail("eye"), + xfail("fft.fft2"), + xfail("fft.fft"), + xfail("fft.fftn"), + xfail("fft.fftshift"), + xfail("fft.ifft2"), + xfail("fft.ifft"), + xfail("fft.ifftshift"), + xfail("fft.ihfft2"), + xfail("fft.ihfft"), + xfail("fft.ihfftn"), + xfail("fft.irfft2"), + xfail("fft.irfftn"), + xfail("fft.rfft2"), + xfail("fft.rfft"), + xfail("fft.rfftn"), + xfail("flip"), + xfail("fliplr"), + xfail("flipud"), + xfail("floor_divide"), + xfail("fmax"), + xfail("fmin"), + xfail("frexp"), + xfail("full"), + xfail("gather"), + xfail("geqrf"), + xfail("gradient"), + xfail("heaviside"), + xfail("histc"), + xfail("histogram"), + xfail("histogramdd"), + xfail("hstack"), + xfail("index_add"), + xfail("index_copy"), + xfail("index_fill"), + xfail("index_put"), + xfail("index_reduce"), + xfail("index_select"), + xfail("isfinite"), + xfail("isin"), + xfail("isinf"), + xfail("isnan"), + xfail("isneginf"), + xfail("isposinf"), + xfail("kthvalue"), + xfail("linalg.cholesky"), + xfail("linalg.cholesky_ex"), + xfail("linalg.cond"), + xfail("linalg.cross"), + xfail("linalg.det"), + xfail("linalg.det", "singular"), + xfail("linalg.eig"), + xfail("linalg.eigh"), + xfail("linalg.eigvals"), + xfail("linalg.eigvalsh"), + xfail("linalg.householder_product"), + xfail("linalg.inv"), + xfail("linalg.inv_ex"), + xfail("linalg.ldl_factor"), + xfail("linalg.ldl_factor_ex"), + xfail("linalg.ldl_solve"), + xfail("linalg.lstsq"), + xfail("linalg.lstsq", "grad_oriented"), + xfail("linalg.lu"), + xfail("linalg.lu_factor"), + xfail("linalg.lu_factor_ex"), + xfail("linalg.lu_solve"), + xfail("linalg.matrix_norm"), + xfail("linalg.matrix_power"), + xfail("linalg.matrix_rank"), + xfail("linalg.matrix_rank", "hermitian"), + xfail("linalg.multi_dot"), + xfail("linalg.norm"), + xfail("linalg.norm", "subgradients_at_zero"), + xfail("linalg.pinv"), + xfail("linalg.pinv", "hermitian"), + xfail("linalg.qr"), + xfail("linalg.slogdet"), + xfail("linalg.solve"), + xfail("linalg.solve_ex"), + xfail("linalg.solve_triangular"), + xfail("linalg.svd"), + xfail("linalg.svdvals"), + xfail("linalg.tensorinv"), + xfail("linalg.tensorsolve"), + xfail("linalg.vander"), + xfail("linalg.vecdot"), + xfail("linalg.vector_norm"), + xfail("linspace"), + xfail("log_softmax"), + xfail("log_softmax", "with_dtype"), + xfail("logcumsumexp"), + xfail("logdet"), + xfail("logical_not"), + xfail("logspace"), + xfail("logsumexp"), + xfail("lt"), + xfail("lu"), + xfail("lu_solve"), + xfail("lu_unpack"), + xfail("masked_fill"), + xfail("masked_scatter"), + xfail("masked_select"), + xfail("matrix_exp"), + xfail("max", "binary"), + xfail("max", "reduction_no_dim"), + xfail("max", "reduction_with_dim"), + xfail("maximum"), + xfail("median"), + xfail("min", "binary"), + xfail("min", "reduction_no_dim"), + xfail("min", "reduction_with_dim"), + xfail("minimum"), + xfail("mode"), + xfail("msort"), + xfail("multinomial"), + xfail("mv"), + xfail("max_pool2d_with_indices_backward", ""), + xfail("nanmean"), + xfail("nanmedian"), + xfail("nanquantile"), + xfail("nansum"), + xfail("native_batch_norm"), + xfail("native_layer_norm"), + xfail("narrow_copy"), + xfail("ne"), + xfail("new_empty"), + xfail("new_empty_strided"), + xfail("transpose"), + xfail("nn.functional.adaptive_avg_pool1d"), + xfail("nn.functional.adaptive_avg_pool2d"), + xfail("nn.functional.adaptive_avg_pool3d"), + xfail("nn.functional.adaptive_max_pool1d"), + xfail("nn.functional.adaptive_max_pool2d"), + xfail("nn.functional.adaptive_max_pool3d"), + xfail("nn.functional.alpha_dropout"), + xfail("nn.functional.avg_pool1d"), + xfail("nn.functional.avg_pool2d"), + xfail("nn.functional.avg_pool3d"), + xfail("nn.functional.batch_norm"), + xfail("nn.functional.batch_norm", "without_cudnn"), + xfail("nn.functional.bilinear"), + xfail("nn.functional.binary_cross_entropy"), + xfail("nn.functional.binary_cross_entropy_with_logits"), + xfail("nn.functional.celu"), + xfail("nn.functional.conv1d"), + xfail("nn.functional.conv2d"), + xfail("nn.functional.conv_transpose1d"), + xfail("nn.functional.conv_transpose2d"), + xfail("nn.functional.conv_transpose3d"), + xfail("nn.functional.cosine_similarity"), + xfail("nn.functional.cross_entropy"), + xfail("nn.functional.ctc_loss"), + xfail("nn.functional.dropout"), + xfail("nn.functional.dropout2d"), + xfail("nn.functional.dropout3d"), + xfail("nn.functional.elu"), + xfail("nn.functional.fractional_max_pool2d"), + xfail("nn.functional.fractional_max_pool3d"), + xfail("nn.functional.gaussian_nll_loss"), + xfail("nn.functional.glu"), + xfail("nn.functional.grid_sample"), + xfail("nn.functional.group_norm"), + xfail("nn.functional.hardshrink"), + xfail("nn.functional.hardsigmoid"), + xfail("nn.functional.hardswish"), + xfail("nn.functional.hardtanh"), + xfail("nn.functional.huber_loss"), + xfail("nn.functional.instance_norm"), + xfail("nn.functional.interpolate", "area"), + xfail("nn.functional.interpolate", "bicubic"), + xfail("nn.functional.interpolate", "bilinear"), + xfail("nn.functional.interpolate", "linear"), + xfail("nn.functional.interpolate", "nearest"), + xfail("nn.functional.interpolate", "trilinear"), + xfail("nn.functional.layer_norm"), + xfail("nn.functional.leaky_relu"), + xfail("nn.functional.linear"), + xfail("nn.functional.local_response_norm"), + xfail("nn.functional.logsigmoid"), + xfail("nn.functional.margin_ranking_loss"), + xfail("nn.functional.max_pool1d"), + xfail("nn.functional.max_pool2d"), + xfail("nn.functional.max_pool3d"), + xfail("nn.functional.max_unpool1d"), + xfail("nn.functional.max_unpool1d", "grad"), + xfail("nn.functional.max_unpool2d"), + xfail("nn.functional.max_unpool2d", "grad"), + xfail("nn.functional.max_unpool3d"), + xfail("nn.functional.max_unpool3d", "grad"), + xfail("nn.functional.mish"), + xfail("nn.functional.mse_loss"), + xfail("nn.functional.multi_margin_loss"), + xfail("nn.functional.multilabel_margin_loss"), + xfail("nn.functional.multilabel_soft_margin_loss"), + xfail("nn.functional.nll_loss"), + xfail("nn.functional.normalize"), + xfail("nn.functional.pad", "circular"), + xfail("nn.functional.pad", "constant"), + xfail("nn.functional.pad", "reflect"), + xfail("nn.functional.pad", "replicate"), + xfail("nn.functional.pairwise_distance"), + xfail("nn.functional.pdist"), + xfail("nn.functional.pixel_shuffle"), + xfail("nn.functional.pixel_unshuffle"), + xfail("nn.functional.poisson_nll_loss"), + xfail("nn.functional.prelu"), + xfail("nn.functional.relu6"), + xfail("nn.functional.rrelu"), + xfail("nn.functional.selu"), + xfail("nn.functional.silu"), + xfail("nn.functional.smooth_l1_loss"), + xfail("nn.functional.soft_margin_loss"), + xfail("nn.functional.softplus"), + xfail("nn.functional.softshrink"), + xfail("nn.functional.threshold"), + xfail("nn.functional.triplet_margin_loss"), + xfail("nn.functional.triplet_margin_with_distance_loss"), + xfail("nn.functional.unfold"), + xfail("nn.functional.upsample_bilinear"), + xfail("nn.functional.upsample_nearest"), + xfail("nonzero"), + xfail("norm"), + xfail("norm", "fro"), + xfail("norm", "inf"), + xfail("norm", "nuc"), + xfail("normal"), + xfail("normal", "number_mean"), + xfail("ormqr"), + xfail("ones"), + xfail("pca_lowrank"), + xfail("pinverse"), + xfail("polar"), + xfail("put"), + xfail("qr"), + xfail("quantile"), + xfail("rad2deg"), + xfail("rand_like"), + xfail("randint_like"), + xfail("randint"), + xfail("randn"), + xfail("randn_like"), + xfail("renorm"), + xfail("repeat_interleave"), + xfail("resize_"), + xfail("resize_as_"), + xfail("roll"), + xfail("rot90"), + xfail("rsub"), + xfail("scalar_tensor"), + xfail("scatter_add"), + xfail("scatter"), + xfail("scatter_reduce", "amax"), + xfail("scatter_reduce", "amin"), + xfail("scatter_reduce", "mean"), + xfail("scatter_reduce", "prod"), + xfail("scatter_reduce", "sum"), + xfail("searchsorted"), + xfail("select"), + xfail("select_scatter"), + xfail("signbit"), + xfail("sort"), + xfail("sparse.sampled_addmm"), + xfail("special.airy_ai"), + xfail("special.bessel_j0"), + xfail("special.bessel_j1"), + xfail("special.bessel_y0"), + xfail("special.bessel_y1"), + xfail("special.chebyshev_polynomial_t"), + xfail("special.chebyshev_polynomial_u"), + xfail("special.entr"), + xfail("special.erfcx"), + xfail("special.hermite_polynomial_h"), + xfail("special.hermite_polynomial_he"), + xfail("special.i0e"), + xfail("special.i1"), + xfail("special.i1e"), + xfail("special.laguerre_polynomial_l"), + xfail("special.log_ndtr"), + xfail("special.modified_bessel_i0"), + xfail("special.modified_bessel_i1"), + xfail("special.modified_bessel_k0"), + xfail("special.modified_bessel_k1"), + xfail("special.ndtri"), + xfail("special.scaled_modified_bessel_k0"), + xfail("special.scaled_modified_bessel_k1"), + xfail("special.spherical_bessel_j0"), + xfail("special.xlog1py"), + xfail("special.zeta"), + xfail("split"), + xfail("split", "list_args"), + xfail("split_with_sizes"), + xfail("signal.windows.cosine"), + xfail("signal.windows.exponential"), + xfail("signal.windows.gaussian"), + xfail("signal.windows.kaiser"), + xfail("squeeze"), + xfail("stack"), + xfail("std"), + xfail("std_mean"), + xfail("stft"), + xfail("svd"), + xfail("svd_lowrank"), + xfail("symeig"), + xfail("t"), + xfail("take_along_dim"), + xfail("take"), + xfail("tensor_split"), + xfail("to_sparse"), + xfail("topk"), + xfail("trace"), + xfail("trapezoid"), + xfail("trapz"), + xfail("triangular_solve"), + xfail("tril"), + xfail("triu"), + xfail("unbind"), + xfail("unfold"), + xfail("unfold_copy"), + xfail("uniform"), + xfail("unflatten"), + xfail("unique_consecutive"), + xfail("unique"), + xfail("var_mean"), + xfail("vdot"), + xfail("view_as_complex"), + xfail("vstack"), + xfail("zeros"), + # ops inside this might even fail without dtensor + # tests, as we rescale op db common test size factor (i.e. L, M, S) + # which triggered the orignal function run failures with input + # generation becomes wrong, we skip them for now but should enable later. + # TODO: need to clean this list and remove all cases + skip("argwhere"), + skip("cumprod"), + skip("__rmatmul__"), + skip("meshgrid", "list_of_tensors"), + skip("meshgrid", "variadic_tensors"), + skip("nn.functional._scaled_dot_product_attention"), + skip("nn.functional.softmin"), + skip("nn.functional.embedding"), + skip("nn.functional.embedding_bag"), + skip("nn.functional.feature_alpha_dropout", "with_train"), + skip("nn.functional.feature_alpha_dropout", "without_train"), + skip("nn.functional.hinge_embedding_loss"), + skip("nn.functional.cosine_embedding_loss"), + skip("fft.hfft"), + skip("fft.hfft2"), + skip("fft.hfft2"), + skip("fft.hfftn"), + skip("fft.ifftn"), + skip("fft.irfft"), + skip("istft"), + skip("isclose"), + skip("isreal"), + skip("matmul"), + skip("masked.mean"), + skip("masked.var"), + skip("masked.std"), + skip("masked.normalize"), + skip("prod"), + skip("segment_reduce", "lengths"), + skip("segment_reduce", "offsets"), +} + + +# Add a list of ops that are currently failing BW pass +skip_bw = [ + None, # corresponds to the transpose ops 'H' and 'T' + "torch.bucketize", + "torch.conj_physical", + "torch.eq", + "torch.isfinite", + "torch.isnan", +] + + +def run_dtensor_crossref(test_case, func, args, kwargs): + to_dtensor = DTensorConverter(test_case.mesh, args, kwargs) + + # TODO: also handle cases where func raise an exception + rs = func(*args, **kwargs) + + def to_replicate(e: object) -> object: + return ( + e.redistribute(test_case.mesh, test_case.mesh.ndim * [Replicate()]) + if isinstance(e, DTensor) + else e + ) + + try: + # Suppress warnings, this doesn't matter for test_meta.py + # but it does matter if you want to use this decorator + # for cross-ref testing, as some tests may be looking at + # errors + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # for every comb of sharding choices, we test if it works + for dtensor_args, dtensor_kwargs in to_dtensor: + # Only attempt if we managed to convert all tensors to DTensor + # (if any of them failed, we're in a mixed tensor situation and + # this is not allowed in DTensor) + if to_dtensor.successful(): + # Handle special cases first if there's any + # Suppress warnings, this doesn't matter for test_meta.py + # but it does matter if you want to use this decorator + # for cross-ref testing, as some tests may be looking at + # errors + dtensor_rs = func(*dtensor_args, **dtensor_kwargs) + + # we need to skip tests containing tensors of zero elmeents for now. + # see issue: https://github.com/pytorch/tau/issues/470 + # TODO remove this once issue above fixed. + flat_args, _ = tree_flatten(dtensor_rs) + if any( + isinstance(e, torch.Tensor) and e.numel() == 0 + for e in flat_args + ): + continue + + # redistribute/all_gather the results to compare with normal output + dtensor_rs = tree_map(to_replicate, dtensor_rs) + try: + if resolve_name(func) not in skip_bw: + if isinstance(dtensor_rs, DTensor): + dtensor_rs.to_local().sum().backward() + elif isinstance(dtensor_rs, tuple): + dtensor_rs[0].to_local().sum().backward() + + except Exception as e: + # TODO(anj): Remove this guard exception after gaining more confidence. + if torch.distributed.get_rank() == 0: + print( + f"failed to run BW: {resolve_name(func)}, {func}, {str(e)})" + ) + assert_ref_dtensor_equal(test_case, dtensor_rs, rs) + else: + raise RuntimeError( + f"failed to convert args to DTensor; " + f"originally (*{args}, **{kwargs})" + ) + except Exception as e: + raise RuntimeError( + f"failed to run: {resolve_name(func)}, with (*{args}, **{kwargs})" + ) from e + + return rs + + +def check_dtensor_func(test_case, test_func, opinfo, dry_run=False): + try: + test_func() + except Exception: + test_case.destroy_pg() + if not dry_run: + raise + if dist.get_rank() == 0: + if opinfo.variant_test_name: + print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") + else: + print(f"xfail('{opinfo.name}'),") + else: + test_case.destroy_pg() + + +class TestDTensorOps(DTensorTestBase): + @property + def world_size(self) -> int: + return 4 + + # only allow float dytpe for now, we can relax this constraint + # when feel necessary later (i.e when adding quantization support). + @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") + @suppress_warnings + @ops(dtensor_lagging_op_db, allowed_dtypes=(torch.float,)) + @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails) + def test_dtensor_op_db(self, dtype, op): + pg_backend = "nccl" if DEVICE_TYPE == "cuda" else "gloo" + if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size: + sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) + + self.init_pg(backend=pg_backend) + self.mesh = DeviceMesh(DEVICE_TYPE, torch.arange(self.world_size)) + + # test each op with dist tensor inputs and normal inputs + def test(): + samples = op.sample_inputs(DEVICE_TYPE, dtype, requires_grad=True) + for sample_input in samples: + args = [sample_input.input] + list(sample_input.args) + kwargs = sample_input.kwargs + + run_dtensor_crossref(self, op.op, args, kwargs) + # we need to figure out a way to test the out variant, out variant testing + # is tricky, as we need to pre allocate the dtensor out, some of them rely + # on sharding placements to be pre-known (i.e. mm.out) + # if isinstance(expected, torch.Tensor) and op.supports_out: + # func(*args, **kwargs, out=expected) + + check_dtensor_func(self, test, op) + + +# only instantiate tests for DEVICE_TYPE alone (i.e. either CPU or GPU) +instantiate_device_type_tests( + TestDTensorOps, globals(), only_for=(DEVICE_TYPE,) +) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py new file mode 100644 index 000000000000..1ba3f6d5f95b --- /dev/null +++ b/test/distributed/_tensor/test_tensor_ops.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorConverter, + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor +from torch.distributed._tensor.placement_types import Shard, Replicate, _Partial + + +class DistTensorOpsTest(DTensorTestBase): + @with_comms + def test_aten_contiguous(self): + # this op not covered by dtensor_ops + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + self._test_op( + mesh, + lambda x: torch.ops.aten.contiguous(x), + torch.randn(16, 32), + ) + + @with_comms + def test_detach(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + tensor_to_detach = torch.randn(12, 8, requires_grad=True) + mat = distribute_tensor(tensor_to_detach, device_mesh, shard_spec) + detached_mat = mat.detach() + self.assertFalse(detached_mat is mat) + + @with_comms + def test_clone(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + specs = [[Replicate()], [Shard(0)]] + tensor_to_clone = torch.randn(12, 8, requires_grad=True) + for spec in specs: + mat = distribute_tensor(tensor_to_clone, device_mesh, spec) + cloned_mat = mat.clone() + self.assertFalse(cloned_mat is mat) + self.assertEqual(cloned_mat.to_local(), mat.to_local()) + + @with_comms + def test_contiguous(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + tensor = torch.rand(3, 5, 6, requires_grad=True) + sharding = [Shard(0)] + dist_tensor = DTensor.from_local(tensor, device_mesh, sharding) + self.assertTrue(dist_tensor.is_contiguous()) + # shard on dim 0 should not change stride (30, 6, 1) + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + new_dt = dist_tensor.transpose(0, 2) + self.assertFalse(new_dt.is_contiguous()) + self.assertFalse(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(new_dt.stride(), (1, 6, 30)) + + new_dt = new_dt.contiguous() + self.assertTrue(new_dt.is_contiguous()) + self.assertTrue(new_dt.to_local().is_contiguous()) + # check stride + self.assertEqual(dist_tensor.stride(), tensor.stride()) + + # check backward + new_dt.to_local().sum().backward() + self.assertEqual(tensor.grad, torch.ones(3, 5, 6)) + + @with_comms + def test_inplace_op(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + dt_to_add = distribute_tensor(input_tensor, mesh, [Shard(0)]) + dt_to_mul = dt_to_add.clone() + expected_add_dt = dt_to_add.clone() + 3 + add_res = dt_to_add.add_(3) + expected_mul_dt = dt_to_mul.clone() * 3 + mul_res = dt_to_mul.mul_(3) + # inplace op should be the same instance before and after + self.assertTrue(add_res is dt_to_add) + self.assertEqual(add_res.to_local(), expected_add_dt.to_local()) + + self.assertTrue(mul_res is dt_to_mul) + self.assertEqual(mul_res.to_local(), expected_mul_dt.to_local()) + + # test inplace op self and other dtensor with other specs + # and make sure out spec not change + shard_spec = [Shard(0)] + partial_spec = [_Partial()] + dt_to_inplace_add = distribute_tensor(input_tensor, mesh, shard_spec) + partial_grad = DTensor.from_local( + torch.randn(12, 3), mesh, partial_spec + ) + res = dt_to_inplace_add.add_(partial_grad) + self.assertTrue(res is dt_to_inplace_add) + self.assertTrue(res.placements == shard_spec) + + @with_comms + def test_op_out_variant(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + input_tensor = torch.randn((12, 3), device=self.device_type) + sharded_dt_input = distribute_tensor(input_tensor, mesh, [Shard(0)]) + expected_dt = sharded_dt_input.clone() + 3 + sharded_dt_out = sharded_dt_input.clone() + res = torch.add(sharded_dt_input, 3, out=sharded_dt_out) + # op out variant should be the same instance before and after + self.assertTrue(res is sharded_dt_out) + self.assertEqual(sharded_dt_out.to_local(), expected_dt.to_local()) + + # test op out variant with other spec and make sure out spec not change + replica_spec = [Replicate()] + replicate_out = distribute_tensor(input_tensor, mesh, replica_spec) + expected_dt = replicate_out.clone() + 3 + res = torch.add(sharded_dt_input, 3, out=replicate_out) + self.assertTrue(res is replicate_out) + self.assertTrue(res.placements == replica_spec) + self.assertEqual(replicate_out.to_local(), expected_dt.to_local()) + + @with_comms + def test_empty_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + empty_like_dt = torch.empty_like(dist_tensor) + # empty is not deterministic, so we only check that the shard propagation worked + self.assertEqual((4, 8), empty_like_dt.to_local().shape) + + @with_comms + def test_fill_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.fill_(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + self.assertEqual(full_expected, dist_tensor.to_local()) + + @with_comms + def test_full_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + full_like_dt = torch.full_like(dist_tensor, 42.0) + full_expected = torch.full((4, 8), 42.0) + self.assertEqual(full_expected, full_like_dt.to_local()) + + @with_comms + def test_ones_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(4, 8) + self.assertEqual(ones_expected, ones_like_dt.to_local()) + + @with_comms + def test_ones_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + ones_like_dt = torch.ones_like(dist_tensor) + ones_expected = torch.ones(dist_tensor.shape) + self.assertEqual( + ones_expected, + ones_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_fill_inplace_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + torch.fill_(dist_tensor, 42) + fill_expected = torch.full( + dist_tensor.shape, 42, dtype=input_tensor.dtype + ) + self.assertEqual( + fill_expected, + dist_tensor.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_zeros_like_partial_sum(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [_Partial()] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + assert dist_tensor.shape == (4, 8) + + zeros_like_dt = torch.zeros_like(dist_tensor) + zeros_expected = torch.zeros(dist_tensor.shape) + self.assertEqual( + zeros_expected, + zeros_like_dt.redistribute(device_mesh, [Replicate()]).to_local(), + ) + + @with_comms + def test_zero_inplace(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zero_(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + self.assertEqual(zeros_expected, dist_tensor.to_local()) + + @with_comms + def test_zeros_like(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + shard_spec = [Shard(0)] + + input_tensor = torch.randn(4, 8, requires_grad=True) + dist_tensor = DTensor.from_local(input_tensor, device_mesh, shard_spec) + zeros_like_dt = torch.zeros_like(dist_tensor) + zeros_expected = torch.zeros(4, 8) + self.assertEqual(zeros_expected, zeros_like_dt.to_local()) + + def _test_op(self, mesh, op_call, *args, **kwargs): + out = op_call(*args, **kwargs) + dtc = DTensorConverter(mesh, args, kwargs) + for d_args, d_kwargs in dtc: + self.assertTrue(dtc.successful()) + d_out = op_call(*d_args, **d_kwargs) + self.assertEqual( + d_out.redistribute(mesh, [Replicate()] * mesh.ndim).to_local(), + out, + ) + + @with_comms + def test_index(self): + meshes = [ + DeviceMesh( + self.device_type, list(range(self.world_size)) + ), # 1D mesh + # TODO(@azzolini): un-comment when DTensorConverter supports N-D mesh + # DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, -1)), # 2D mesh + ] + for mesh in meshes: + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(1, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x.index_select(0, y), + torch.randn(16, 32, 16), + torch.randint(5, (4,)), + ) + self._test_op( + mesh, + lambda x, y: x[y], + torch.randn(16, 32, 16), + torch.randint(5, (12,)), + ) + self._test_op( + mesh, + lambda x, y: x[:, y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 12)), + ) + self._test_op( + mesh, + lambda x, y: x[..., y], + torch.randn(16, 32, 16), + torch.randint(5, (4, 8, 16)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y], + torch.randn(16, 32, 16), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 8, 12)), + ) + # broadcast in inner dimensions + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (12, 1, 12)), + ) + # implicit (left-padded) broadcast + self._test_op( + mesh, + lambda x, y, z: x[:, z, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(5, (12, 8, 12)), + torch.randint(2, (8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, y, :, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, y, :], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 12)), + torch.randint(5, (12, 8, 12)), + ) + self._test_op( + mesh, + lambda x, y, z: x[z, :, :, y], + torch.randn(16, 32, 16, 12), + torch.randint(2, (8, 1)), + torch.randint(5, (12, 8, 12)), + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py new file mode 100644 index 000000000000..abd0ccfe0a09 --- /dev/null +++ b/torch/testing/_internal/distributed/_tensor/dtensor_lagging_op_db.py @@ -0,0 +1,661 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List +from torch.testing._internal.common_methods_invocations import op_db, OpInfo + +# Generated from test/gen_dtensor_op_db.py via +# python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py +# +# This approach is copied from functorch: +# People add new OpInfos to PyTorch all the time. +# We want them to be able to add OpInfos without breaking our CI. +# To achieve this, we keep our OpInfo library behind that of Pytorch's and +# we periodically update our OpInfo library by regenerating this file +_dtensor_lagging_meta = { + ("H", ""), + ("T", ""), + ("__getitem__", ""), + ("__radd__", ""), + ("__rand__", ""), + ("__rdiv__", ""), + ("__rmatmul__", ""), + ("__rmod__", ""), + ("__rmul__", ""), + ("__ror__", ""), + ("__rpow__", ""), + ("__rsub__", ""), + ("__rxor__", ""), + ("abs", ""), + ("acos", ""), + ("acosh", ""), + ("add", ""), + ("addbmm", ""), + ("addcdiv", ""), + ("addcmul", ""), + ("addmm", ""), + ("addmm", "decomposed"), + ("addmv", ""), + ("addr", ""), + ("all", ""), + ("allclose", ""), + ("amax", ""), + ("amin", ""), + ("aminmax", ""), + ("angle", ""), + ("any", ""), + ("arange", ""), + ("argmax", ""), + ("argmin", ""), + ("argsort", ""), + ("argwhere", ""), + ("as_strided", ""), + ("as_strided_scatter", ""), + ("asin", ""), + ("asinh", ""), + ("atan", ""), + ("atan2", ""), + ("atanh", ""), + ("atleast_1d", ""), + ("atleast_2d", ""), + ("atleast_3d", ""), + ("baddbmm", ""), + ("bernoulli", ""), + ("bfloat16", ""), + ("bincount", ""), + ("bitwise_and", ""), + ("bitwise_left_shift", ""), + ("bitwise_not", ""), + ("bitwise_or", ""), + ("bitwise_right_shift", ""), + ("bitwise_xor", ""), + ("block_diag", ""), + ("bmm", ""), + ("bool", ""), + ("broadcast_shapes", ""), + ("broadcast_tensors", ""), + ("broadcast_to", ""), + ("bucketize", ""), + ("byte", ""), + ("cartesian_prod", ""), + ("cat", ""), + ("cdist", ""), + ("cdouble", ""), + ("ceil", ""), + ("cfloat", ""), + ("chalf", ""), + ("char", ""), + ("cholesky", ""), + ("cholesky_inverse", ""), + ("cholesky_solve", ""), + ("chunk", ""), + ("clamp", ""), + ("clamp_max", ""), + ("clamp_min", ""), + ("clone", ""), + ("column_stack", ""), + ("combinations", ""), + ("complex", ""), + ("conj", ""), + ("conj_physical", ""), + ("constant_pad_nd", ""), + ("contiguous", ""), + ("copysign", ""), + ("corrcoef", ""), + ("cos", ""), + ("cosh", ""), + ("count_nonzero", ""), + ("cov", ""), + ("cross", ""), + ("cummax", ""), + ("cummin", ""), + ("cumprod", ""), + ("cumsum", ""), + ("cumulative_trapezoid", ""), + ("deg2rad", ""), + ("diag", ""), + ("diag_embed", ""), + ("diagflat", ""), + ("diagonal", ""), + ("diagonal_copy", ""), + ("diagonal_scatter", ""), + ("diff", ""), + ("digamma", ""), + ("dist", ""), + ("div", "floor_rounding"), + ("div", "no_rounding_mode"), + ("div", "trunc_rounding"), + ("dot", ""), + ("double", ""), + ("dsplit", ""), + ("dstack", ""), + ("einsum", ""), + ("empty", ""), + ("empty_like", ""), + ("eq", ""), + ("equal", ""), + ("erf", ""), + ("erfc", ""), + ("erfinv", ""), + ("exp", ""), + ("exp2", ""), + ("expand", ""), + ("expand_as", ""), + ("expm1", ""), + ("eye", ""), + ("fft.fft", ""), + ("fft.fft2", ""), + ("fft.fftn", ""), + ("fft.fftshift", ""), + ("fft.hfft", ""), + ("fft.hfft2", ""), + ("fft.hfftn", ""), + ("fft.ifft", ""), + ("fft.ifft2", ""), + ("fft.ifftn", ""), + ("fft.ifftshift", ""), + ("fft.ihfft", ""), + ("fft.ihfft2", ""), + ("fft.ihfftn", ""), + ("fft.irfft", ""), + ("fft.irfft2", ""), + ("fft.irfftn", ""), + ("fft.rfft", ""), + ("fft.rfft2", ""), + ("fft.rfftn", ""), + ("fill", ""), + ("flatten", ""), + ("flip", ""), + ("fliplr", ""), + ("flipud", ""), + ("float", ""), + ("float_power", ""), + ("floor", ""), + ("floor_divide", ""), + ("fmax", ""), + ("fmin", ""), + ("fmod", ""), + ("frac", ""), + ("frexp", ""), + ("full", ""), + ("full_like", ""), + ("gather", ""), + ("gcd", ""), + ("ge", ""), + ("geqrf", ""), + ("gradient", ""), + ("gt", ""), + ("half", ""), + ("heaviside", ""), + ("histc", ""), + ("histogram", ""), + ("histogramdd", ""), + ("hsplit", ""), + ("hstack", ""), + ("hypot", ""), + ("i0", ""), + ("igamma", ""), + ("igammac", ""), + ("imag", ""), + ("index_add", ""), + ("index_copy", ""), + ("index_fill", ""), + ("index_put", ""), + ("index_reduce", ""), + ("index_select", ""), + ("inner", ""), + ("int", ""), + ("isclose", ""), + ("isfinite", ""), + ("isin", ""), + ("isinf", ""), + ("isnan", ""), + ("isneginf", ""), + ("isposinf", ""), + ("isreal", ""), + ("istft", ""), + ("jiterator_2inputs_2outputs", ""), + ("jiterator_4inputs_with_extra_args", ""), + ("jiterator_binary", ""), + ("jiterator_binary_return_by_ref", ""), + ("jiterator_unary", ""), + ("kron", ""), + ("kthvalue", ""), + ("lcm", ""), + ("ldexp", ""), + ("le", ""), + ("lerp", ""), + ("lgamma", ""), + ("linalg.cholesky", ""), + ("linalg.cholesky_ex", ""), + ("linalg.cond", ""), + ("linalg.cross", ""), + ("linalg.det", ""), + ("linalg.det", "singular"), + ("linalg.eig", ""), + ("linalg.eigh", ""), + ("linalg.eigvals", ""), + ("linalg.eigvalsh", ""), + ("linalg.householder_product", ""), + ("linalg.inv", ""), + ("linalg.inv_ex", ""), + ("linalg.ldl_factor", ""), + ("linalg.ldl_factor_ex", ""), + ("linalg.ldl_solve", ""), + ("linalg.lstsq", ""), + ("linalg.lstsq", "grad_oriented"), + ("linalg.lu", ""), + ("linalg.lu_factor", ""), + ("linalg.lu_factor_ex", ""), + ("linalg.lu_solve", ""), + ("linalg.matrix_norm", ""), + ("linalg.matrix_power", ""), + ("linalg.matrix_rank", ""), + ("linalg.matrix_rank", "hermitian"), + ("linalg.multi_dot", ""), + ("linalg.norm", ""), + ("linalg.norm", "subgradients_at_zero"), + ("linalg.pinv", ""), + ("linalg.pinv", "hermitian"), + ("linalg.pinv", "singular"), + ("linalg.qr", ""), + ("linalg.slogdet", ""), + ("linalg.solve", ""), + ("linalg.solve_ex", ""), + ("linalg.solve_triangular", ""), + ("linalg.svd", ""), + ("linalg.svdvals", ""), + ("linalg.tensorinv", ""), + ("linalg.tensorsolve", ""), + ("linalg.vander", ""), + ("linalg.vecdot", ""), + ("linalg.vector_norm", ""), + ("linspace", ""), + ("log", ""), + ("log10", ""), + ("log1p", ""), + ("log2", ""), + ("log_softmax", ""), + ("log_softmax", "with_dtype"), + ("logaddexp", ""), + ("logaddexp2", ""), + ("logcumsumexp", ""), + ("logdet", ""), + ("logical_and", ""), + ("logical_not", ""), + ("logical_or", ""), + ("logical_xor", ""), + ("logit", ""), + ("logspace", ""), + ("logsumexp", ""), + ("long", ""), + ("lt", ""), + ("lu", ""), + ("lu_solve", ""), + ("lu_unpack", ""), + ("mH", ""), + ("mT", ""), + ("masked.amax", ""), + ("masked.amin", ""), + ("masked.argmax", ""), + ("masked.argmin", ""), + ("masked.cumprod", ""), + ("masked.cumsum", ""), + ("masked.log_softmax", ""), + ("masked.logaddexp", ""), + ("masked.logsumexp", ""), + ("masked.mean", ""), + ("masked.median", ""), + ("masked.norm", ""), + ("masked.normalize", ""), + ("masked.prod", ""), + ("masked.softmax", ""), + ("masked.softmin", ""), + ("masked.std", ""), + ("masked.sum", ""), + ("masked.var", ""), + ("masked_fill", ""), + ("masked_scatter", ""), + ("masked_select", ""), + ("matmul", ""), + ("matrix_exp", ""), + ("max", "binary"), + ("max", "reduction_no_dim"), + ("max", "reduction_with_dim"), + ("max_pool2d_with_indices_backward", ""), + ("maximum", ""), + ("mean", ""), + ("median", ""), + ("meshgrid", "list_of_tensors"), + ("meshgrid", "variadic_tensors"), + ("min", "binary"), + ("min", "reduction_no_dim"), + ("min", "reduction_with_dim"), + ("minimum", ""), + ("mm", ""), + ("mode", ""), + ("movedim", ""), + ("msort", ""), + ("mul", ""), + ("multinomial", ""), + ("mv", ""), + ("mvlgamma", "mvlgamma_p_1"), + ("mvlgamma", "mvlgamma_p_3"), + ("mvlgamma", "mvlgamma_p_5"), + ("nan_to_num", ""), + ("nanmean", ""), + ("nanmedian", ""), + ("nanquantile", ""), + ("nansum", ""), + ("narrow", ""), + ("narrow_copy", ""), + ("native_batch_norm", ""), + ("native_layer_norm", ""), + ("ne", ""), + ("neg", ""), + ("new_empty", ""), + ("new_empty_strided", ""), + ("new_full", ""), + ("new_ones", ""), + ("new_zeros", ""), + ("nextafter", ""), + ("nn.functional._scaled_dot_product_attention", ""), + ("nn.functional.adaptive_avg_pool1d", ""), + ("nn.functional.adaptive_avg_pool2d", ""), + ("nn.functional.adaptive_avg_pool3d", ""), + ("nn.functional.adaptive_max_pool1d", ""), + ("nn.functional.adaptive_max_pool2d", ""), + ("nn.functional.adaptive_max_pool3d", ""), + ("nn.functional.alpha_dropout", ""), + ("nn.functional.avg_pool1d", ""), + ("nn.functional.avg_pool2d", ""), + ("nn.functional.avg_pool3d", ""), + ("nn.functional.batch_norm", ""), + ("nn.functional.batch_norm", "without_cudnn"), + ("nn.functional.bilinear", ""), + ("nn.functional.binary_cross_entropy", ""), + ("nn.functional.binary_cross_entropy_with_logits", ""), + ("nn.functional.celu", ""), + ("nn.functional.conv1d", ""), + ("nn.functional.conv2d", ""), + ("nn.functional.conv_transpose1d", ""), + ("nn.functional.conv_transpose2d", ""), + ("nn.functional.conv_transpose3d", ""), + ("nn.functional.cosine_embedding_loss", ""), + ("nn.functional.cosine_similarity", ""), + ("nn.functional.cross_entropy", ""), + ("nn.functional.ctc_loss", ""), + ("nn.functional.dropout", ""), + ("nn.functional.dropout2d", ""), + ("nn.functional.dropout3d", ""), + ("nn.functional.elu", ""), + ("nn.functional.embedding", ""), + ("nn.functional.embedding_bag", ""), + ("nn.functional.feature_alpha_dropout", "with_train"), + ("nn.functional.feature_alpha_dropout", "without_train"), + ("nn.functional.fractional_max_pool2d", ""), + ("nn.functional.fractional_max_pool3d", ""), + ("nn.functional.gaussian_nll_loss", ""), + ("nn.functional.gelu", ""), + ("nn.functional.glu", ""), + ("nn.functional.grid_sample", ""), + ("nn.functional.group_norm", ""), + ("nn.functional.hardshrink", ""), + ("nn.functional.hardsigmoid", ""), + ("nn.functional.hardswish", ""), + ("nn.functional.hardtanh", ""), + ("nn.functional.hinge_embedding_loss", ""), + ("nn.functional.huber_loss", ""), + ("nn.functional.instance_norm", ""), + ("nn.functional.interpolate", "area"), + ("nn.functional.interpolate", "bicubic"), + ("nn.functional.interpolate", "bilinear"), + ("nn.functional.interpolate", "linear"), + ("nn.functional.interpolate", "nearest"), + ("nn.functional.interpolate", "trilinear"), + ("nn.functional.kl_div", ""), + ("nn.functional.l1_loss", ""), + ("nn.functional.layer_norm", ""), + ("nn.functional.leaky_relu", ""), + ("nn.functional.linear", ""), + ("nn.functional.local_response_norm", ""), + ("nn.functional.logsigmoid", ""), + ("nn.functional.margin_ranking_loss", ""), + ("nn.functional.max_pool1d", ""), + ("nn.functional.max_pool2d", ""), + ("nn.functional.max_pool3d", ""), + ("nn.functional.max_unpool1d", ""), + ("nn.functional.max_unpool1d", "grad"), + ("nn.functional.max_unpool2d", ""), + ("nn.functional.max_unpool2d", "grad"), + ("nn.functional.max_unpool3d", ""), + ("nn.functional.max_unpool3d", "grad"), + ("nn.functional.mish", ""), + ("nn.functional.mse_loss", ""), + ("nn.functional.multi_margin_loss", ""), + ("nn.functional.multilabel_margin_loss", ""), + ("nn.functional.multilabel_soft_margin_loss", ""), + ("nn.functional.nll_loss", ""), + ("nn.functional.normalize", ""), + ("nn.functional.one_hot", ""), + ("nn.functional.pad", "circular"), + ("nn.functional.pad", "constant"), + ("nn.functional.pad", "reflect"), + ("nn.functional.pad", "replicate"), + ("nn.functional.pairwise_distance", ""), + ("nn.functional.pdist", ""), + ("nn.functional.pixel_shuffle", ""), + ("nn.functional.pixel_unshuffle", ""), + ("nn.functional.poisson_nll_loss", ""), + ("nn.functional.prelu", ""), + ("nn.functional.relu", ""), + ("nn.functional.relu6", ""), + ("nn.functional.rrelu", ""), + ("nn.functional.selu", ""), + ("nn.functional.silu", ""), + ("nn.functional.silu", "complex"), + ("nn.functional.smooth_l1_loss", ""), + ("nn.functional.soft_margin_loss", ""), + ("nn.functional.softmin", ""), + ("nn.functional.softmin", "with_dtype"), + ("nn.functional.softplus", ""), + ("nn.functional.softshrink", ""), + ("nn.functional.softsign", ""), + ("nn.functional.tanhshrink", ""), + ("nn.functional.threshold", ""), + ("nn.functional.triplet_margin_loss", ""), + ("nn.functional.triplet_margin_with_distance_loss", ""), + ("nn.functional.unfold", ""), + ("nn.functional.upsample_bilinear", ""), + ("nn.functional.upsample_nearest", ""), + ("nonzero", ""), + ("norm", ""), + ("norm", "fro"), + ("norm", "inf"), + ("norm", "nuc"), + ("normal", ""), + ("normal", "number_mean"), + ("ones", ""), + ("ones_like", ""), + ("ormqr", ""), + ("outer", ""), + ("pca_lowrank", ""), + ("permute", ""), + ("pinverse", ""), + ("polar", ""), + ("polygamma", "polygamma_n_0"), + ("polygamma", "polygamma_n_1"), + ("polygamma", "polygamma_n_2"), + ("polygamma", "polygamma_n_3"), + ("polygamma", "polygamma_n_4"), + ("positive", ""), + ("pow", ""), + ("prod", ""), + ("put", ""), + ("qr", ""), + ("quantile", ""), + ("rad2deg", ""), + ("rand_like", ""), + ("randint", ""), + ("randint_like", ""), + ("randn", ""), + ("randn_like", ""), + ("ravel", ""), + ("real", ""), + ("reciprocal", ""), + ("remainder", ""), + ("renorm", ""), + ("repeat", ""), + ("repeat_interleave", ""), + ("reshape", ""), + ("reshape_as", ""), + ("resize_", ""), + ("resize_as_", ""), + ("resolve_conj", ""), + ("resolve_neg", ""), + ("roll", ""), + ("rot90", ""), + ("round", ""), + ("round", "decimals_0"), + ("round", "decimals_3"), + ("round", "decimals_neg_3"), + ("rsqrt", ""), + ("rsub", ""), + ("scalar_tensor", ""), + ("scatter", ""), + ("scatter_add", ""), + ("scatter_reduce", "amax"), + ("scatter_reduce", "amin"), + ("scatter_reduce", "mean"), + ("scatter_reduce", "prod"), + ("scatter_reduce", "sum"), + ("searchsorted", ""), + ("segment_reduce", "lengths"), + ("segment_reduce", "offsets"), + ("select", ""), + ("select_scatter", ""), + ("sgn", ""), + ("short", ""), + ("sigmoid", ""), + ("sign", ""), + ("signal.windows.cosine", ""), + ("signal.windows.exponential", ""), + ("signal.windows.gaussian", ""), + ("signal.windows.kaiser", ""), + ("signbit", ""), + ("sin", ""), + ("sinc", ""), + ("sinh", ""), + ("slice", ""), + ("slice_scatter", ""), + ("softmax", ""), + ("softmax", "with_dtype"), + ("sort", ""), + ("sparse.sampled_addmm", ""), + ("special.airy_ai", ""), + ("special.bessel_j0", ""), + ("special.bessel_j1", ""), + ("special.bessel_y0", ""), + ("special.bessel_y1", ""), + ("special.chebyshev_polynomial_t", ""), + ("special.chebyshev_polynomial_u", ""), + ("special.chebyshev_polynomial_v", ""), + ("special.chebyshev_polynomial_w", ""), + ("special.entr", ""), + ("special.erfcx", ""), + ("special.hermite_polynomial_h", ""), + ("special.hermite_polynomial_he", ""), + ("special.i0e", ""), + ("special.i1", ""), + ("special.i1e", ""), + ("special.laguerre_polynomial_l", ""), + ("special.legendre_polynomial_p", ""), + ("special.log_ndtr", ""), + ("special.modified_bessel_i0", ""), + ("special.modified_bessel_i1", ""), + ("special.modified_bessel_k0", ""), + ("special.modified_bessel_k1", ""), + ("special.ndtr", ""), + ("special.ndtri", ""), + ("special.polygamma", "special_polygamma_n_0"), + ("special.scaled_modified_bessel_k0", ""), + ("special.scaled_modified_bessel_k1", ""), + ("special.shifted_chebyshev_polynomial_t", ""), + ("special.shifted_chebyshev_polynomial_u", ""), + ("special.shifted_chebyshev_polynomial_v", ""), + ("special.shifted_chebyshev_polynomial_w", ""), + ("special.spherical_bessel_j0", ""), + ("special.xlog1py", ""), + ("special.zeta", ""), + ("split", ""), + ("split", "list_args"), + ("split_with_sizes", ""), + ("sqrt", ""), + ("square", ""), + ("squeeze", ""), + ("stack", ""), + ("std", ""), + ("std_mean", ""), + ("stft", ""), + ("sub", ""), + ("sum", ""), + ("sum_to_size", ""), + ("svd", ""), + ("svd_lowrank", ""), + ("symeig", ""), + ("t", ""), + ("take", ""), + ("take_along_dim", ""), + ("tan", ""), + ("tanh", ""), + ("tensor_split", ""), + ("tensordot", ""), + ("tile", ""), + ("to", ""), + ("to_sparse", ""), + ("topk", ""), + ("trace", ""), + ("transpose", ""), + ("trapezoid", ""), + ("trapz", ""), + ("triangular_solve", ""), + ("tril", ""), + ("tril_indices", ""), + ("triu", ""), + ("triu_indices", ""), + ("true_divide", ""), + ("trunc", ""), + ("unbind", ""), + ("unflatten", ""), + ("unfold", ""), + ("unfold_copy", ""), + ("uniform", ""), + ("unique", ""), + ("unique_consecutive", ""), + ("unsqueeze", ""), + ("var", ""), + ("var_mean", ""), + ("vdot", ""), + ("view", ""), + ("view_as", ""), + ("view_as_complex", ""), + ("view_as_real", ""), + ("vsplit", ""), + ("vstack", ""), + ("where", ""), + ("xlogy", ""), + ("zero_", ""), + ("zeros", ""), + ("zeros_like", ""), +} + + +def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool: + return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta + + +dtensor_lagging_op_db: List[OpInfo] = [ + opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo) +] diff --git a/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py new file mode 100644 index 000000000000..f684f77ed2c4 --- /dev/null +++ b/torch/testing/_internal/distributed/_tensor/gen_dtensor_lagging_op_db.py @@ -0,0 +1,67 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple +from torch.testing._internal.common_methods_invocations import op_db + + +def num_leading_spaces(line: str) -> int: + result = len(line) - len(line.lstrip()) + # Empty space handling + if result == 0: + return 999999 + return result + + +def deindent(code: str) -> str: + lines = code.split("\n") + min_leading_spaces = min(map(num_leading_spaces, lines)) + lines = [line[min_leading_spaces:] for line in lines] + return "\n".join(lines) + + +if __name__ == "__main__": + supported: List[Tuple[str, str]] = [ + (opinfo.name, opinfo.variant_test_name) for opinfo in op_db + ] + supported = sorted(supported) + print( + deindent( + """\ + # Copyright (c) Facebook, Inc. and its affiliates. + # All rights reserved. + # + # This source code is licensed under the BSD-style license found in the + # LICENSE file in the root directory of this source tree. + from typing import List + from torch.testing._internal.common_methods_invocations import op_db, OpInfo + # Generated from test/gen_dtensor_op_db.py via + # python spmd/testing/gen_dtensor_lagging_op_db.py > spmd/testing/dtensor_lagging_op_db.py + # + # This approach is copied from functorch: + # People add new OpInfos to PyTorch all the time. + # We want them to be able to add OpInfos without breaking our CI. + # To achieve this, we keep our OpInfo library behind that of Pytorch's and + # we periodically update our OpInfo library by regenerating this file""" + ) + ) + + print("_dtensor_lagging_meta = {") + for name, variant in supported: + print(f" {(name, variant)},") + print("}") + + print( + deindent( + """\ + def in_dtensor_lagging_op_db(opinfo: OpInfo) -> bool: + return (opinfo.name, opinfo.variant_test_name) in _dtensor_lagging_meta + + dtensor_lagging_op_db: List[OpInfo] = [ + opinfo for opinfo in op_db if in_dtensor_lagging_op_db(opinfo) + ]""" + ) + ) From f20b3f2e5734b23a9e0a898196ddf77aa90323b8 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Tue, 15 Nov 2022 22:51:33 +0000 Subject: [PATCH 222/453] [dtensor] PART 8: move tensor parallel api and tests to core distributed (#88180) This PR moves tensor/parallel folder and tests to torch.distributed. part of https://github.com/pytorch/pytorch/issues/88838 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88180 Approved by: https://github.com/aazzolini --- test/distributed/_tensor/parallel/__init__.py | 0 .../_tensor/parallel/test_2d_parallel.py | 223 ++++++++ .../_tensor/parallel/test_tp_examples.py | 516 ++++++++++++++++++ .../parallel/test_view_sharding_dim_change.py | 30 + .../distributed/_tensor/parallel/__init__.py | 10 + .../_tensor/parallel/_view_with_dim_change.py | 108 ++++ torch/distributed/_tensor/parallel/api.py | 86 +++ torch/distributed/_tensor/parallel/fsdp.py | 357 ++++++++++++ .../parallel/multihead_attention_tp.py | 273 +++++++++ 9 files changed, 1603 insertions(+) create mode 100644 test/distributed/_tensor/parallel/__init__.py create mode 100644 test/distributed/_tensor/parallel/test_2d_parallel.py create mode 100644 test/distributed/_tensor/parallel/test_tp_examples.py create mode 100644 test/distributed/_tensor/parallel/test_view_sharding_dim_change.py create mode 100644 torch/distributed/_tensor/parallel/__init__.py create mode 100644 torch/distributed/_tensor/parallel/_view_with_dim_change.py create mode 100644 torch/distributed/_tensor/parallel/api.py create mode 100644 torch/distributed/_tensor/parallel/fsdp.py create mode 100644 torch/distributed/_tensor/parallel/multihead_attention_tp.py diff --git a/test/distributed/_tensor/parallel/__init__.py b/test/distributed/_tensor/parallel/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/distributed/_tensor/parallel/test_2d_parallel.py b/test/distributed/_tensor/parallel/test_2d_parallel.py new file mode 100644 index 000000000000..7a3779c296c3 --- /dev/null +++ b/test/distributed/_tensor/parallel/test_2d_parallel.py @@ -0,0 +1,223 @@ +# Owner(s): ["oncall: distributed"] + +from typing import Any + + +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +from torch.distributed._tensor import ( + distribute_tensor, + DeviceMesh, + DTensor as DT, + Shard, + Replicate, +) + +import torch.distributed.distributed_c10d as distributed_c10d + +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.distributed._tensor.parallel.fsdp import is_available + +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) + +# Tensor-Parallel degree +TP_DEGREE = 2 +LR = 3e-5 + +OPS_NOT_SHARD = [ + "net3.weight", + "net3.bias", +] + +SHARD_PARAMS = [ + "net1.weight", + "net1.bias", + "net2.weight", +] + + +class SimpleModel(torch.nn.Module): + def __init__(self): + super(SimpleModel, self).__init__() + self.net1 = torch.nn.Linear(5, 8) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(8, 4) + self.net3 = torch.nn.Linear(4, 12) + + def forward(self, x): + x = F.relu(self.net1(x)) + x = F.relu(self.net2(x)) + x = F.relu(self.net3(x)) + return x + + +def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module: + def hook_func(_module, _input, output): + if isinstance(output, DT): + replica_placement = [Replicate()] + return output.redistribute( + output.device_mesh, replica_placement + ).to_local() + + module.register_forward_hook(hook_func) + return module + + +def _replicate_input_tensor( + module: torch.nn.Module, device_mesh, replica_placement +) -> torch.nn.Module: + def hook_func(_, input): + if not isinstance(input[0], DT): + return DT.from_local( + input[0], device_mesh, replica_placement, run_check=False + ) + + module.register_forward_pre_hook(hook_func) + return module + + +def shard_module(m, pg): + start_idx = distributed_c10d.get_global_rank(pg, 0) + device_mesh = DeviceMesh( + "cuda", list(range(start_idx, start_idx + pg.size())), dim_groups=[pg] + ) + col_wise_sharding = [Shard(0)] + row_wise_sharding = [Shard(1)] + replicate = [Replicate()] + m.net1.weight = torch.nn.Parameter( + distribute_tensor(m.net1.weight, device_mesh, col_wise_sharding), + ) + m.net2.weight = torch.nn.Parameter( + distribute_tensor(m.net2.weight, device_mesh, row_wise_sharding) + ) + m.net1.bias = torch.nn.Parameter( + distribute_tensor(m.net1.bias, device_mesh, col_wise_sharding) + ) + m.net2.bias = torch.nn.Parameter( + distribute_tensor(m.net2.bias, device_mesh, replicate) + ) + m = _replicate_input_tensor(m, device_mesh, replicate) + m.net2 = _aggregate_local_tensor(m.net2) + + +def _shard_wrap_module(module, module_shard, fsdp_wrap, tp_pg, fsdp_pg): + if module_shard: + # Fetch the module sharding planner. + shard_module(module, tp_pg) + + if fsdp_wrap and module_shard: + return FSDP(module, process_group=fsdp_pg) + if fsdp_wrap: + return FSDP(module, process_group=distributed_c10d._get_default_group()) + return module + + +def init_model(model_parallel_size=TP_DEGREE): + rank = dist.get_rank() + torch.cuda.set_device(rank) + world_size = dist.get_world_size() + + model = SimpleModel().cuda(rank) + + # 2-D mesh is [dp, tp] + twod_mesh = DeviceMesh( + device_type="cuda", + mesh=torch.arange(0, world_size).view(model_parallel_size, -1), + ) + + fsdp_pg = twod_mesh.get_dim_groups()[0] + tp_pg = twod_mesh.get_dim_groups()[1] + + # Create Input + model = _shard_wrap_module(model, True, True, tp_pg, fsdp_pg) + return model, tp_pg, fsdp_pg + + +def is_nested_tensor(val: Any) -> bool: + if isinstance(val, ShardedTensor): + if len(val.local_shards()) == 0: + return False + if isinstance(val.local_shards()[0].tensor, ShardedTensor): + return True + if isinstance(val.local_shards()[0].tensor, DT): + raise ValueError("Cannot handle DT nested insided ST") + # Safety valve for when this eventually happen + elif isinstance(val, DT) and isinstance( + val._local_tensor, (DT, ShardedTensor) + ): + raise ValueError("Cannot handle nested DT") + return False + + +class Test2dParallelIntegration(DTensorTestBase): + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_functionality(self) -> None: + if not is_available(): + self.skipTest("FSDP 2d parallel integration not available") + + model_tp = init_model()[0] + + with FSDP.state_dict_type(model_tp, StateDictType.SHARDED_STATE_DICT): + state_dict = model_tp.state_dict() + # TODO once 2D is out, validate the nesting + self.assertTrue(is_nested_tensor(state_dict["net1.weight"])) + self.assertFalse(is_nested_tensor(state_dict["net3.bias"])) + + optim = torch.optim.Adam(model_tp.parameters(), lr=0.0001) + + # Create Input + input_seed = self.rank + torch.manual_seed(input_seed + 1) + input = torch.rand(4, 5).cuda(self.rank) + + model_tp(input).sum().backward() + optim.step() + + optim_state = FSDP.sharded_optim_state_dict(model_tp, optim) + # TODO once 2D is out, validate the nesting + self.assertTrue( + is_nested_tensor(optim_state["state"]["net1.weight"]["exp_avg"]) + ) + self.assertFalse( + is_nested_tensor(optim_state["state"]["net3.bias"]["exp_avg"]) + ) + + @with_comms + @skip_if_lt_x_gpu(4) + def test_2d_fsdp_integration_correctness(self) -> None: + if not is_available(): + self.skipTest("FSDP 2d parallel integration not available") + torch.manual_seed(0) + model = SimpleModel().cuda(self.rank) + model = FSDP(model) + torch.manual_seed(0) + model_2d, _, dp_pg = init_model() + + optim = torch.optim.Adam(model.parameters(), lr=0.0001) + optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.0001) + + for i in range(5): + # Ensure all input across TP ranks are same. + torch.manual_seed(i + dist.get_rank(dp_pg)) + input = torch.rand(4, 5).cuda(self.rank) + output = model(input) + output_2d = model_2d(input) + self.assertEqual(output, output_2d) + output.sum().backward() + output_2d.sum().backward() + optim.step() + optim_2d.step() + self.assertEqual(model(input), model_2d(input)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/parallel/test_tp_examples.py b/test/distributed/_tensor/parallel/test_tp_examples.py new file mode 100644 index 000000000000..582108ea7599 --- /dev/null +++ b/test/distributed/_tensor/parallel/test_tp_examples.py @@ -0,0 +1,516 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, + NUM_DEVICES, + skip_unless_torch_gpu, +) +from torch.distributed._tensor import ( + distribute_tensor, + distribute_module, + DeviceMesh, + DTensor, + Shard, + Replicate, +) +from torch.distributed._tensor.parallel import ( + TensorParallelMultiheadAttention, + tp_shard_self_attn, + replicate_input, + replicate_output, +) + + +class MLPModule(torch.nn.Module): + def __init__(self, device): + super(MLPModule, self).__init__() + torch.manual_seed(5) + self.net1 = torch.nn.Linear(10, 16, device=device) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(16, 12, device=device) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module: + def hook_func(_module, _input, output): + if isinstance(output, DTensor): + replica_placement = [Replicate()] * device_mesh.ndim + return ( + output.redistribute(output.device_mesh, replica_placement) + .contiguous() + .to_local() + ) + + module.register_forward_hook(hook_func) + return module + + +def shard_mlp(m, device_type, tp_size): + start_idx = 0 + device_mesh = DeviceMesh( + device_type, + list(range(start_idx, start_idx + tp_size)), + ) + col_wise_sharding = [Shard(0)] + row_wise_sharding = [Shard(1)] + replicate = [Replicate()] * device_mesh.ndim + + def shard_params(name, module, device_mesh): + if isinstance(module, nn.Linear): + if name == "net1": + sharded_weight = nn.Parameter( + distribute_tensor( + module.weight, device_mesh, col_wise_sharding + ) + ) + sharded_bias = nn.Parameter( + distribute_tensor( + module.bias, device_mesh, col_wise_sharding + ) + ) + module.register_parameter("weight", sharded_weight) + module.register_parameter("bias", sharded_bias) + elif name == "net2": + sharded_weight = nn.Parameter( + distribute_tensor( + module.weight, device_mesh, row_wise_sharding + ) + ) + replicated_bias = nn.Parameter( + distribute_tensor(module.bias, device_mesh, replicate) + ) + module.register_parameter("weight", sharded_weight) + module.register_parameter("bias", replicated_bias) + + def aggregate_output(outputs, device_mesh): + assert isinstance(outputs, DTensor) + return ( + outputs.redistribute(device_mesh, replicate).contiguous().to_local() + ) + + dist_mod = distribute_module( + m, + device_mesh, + partition_fn=shard_params, + input_fn=replicate_input, + output_fn=aggregate_output, + ) + return dist_mod + + +class MultiheadAttnWrap(nn.Module): + def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None): + super().__init__() + self.attn = nn.MultiheadAttention( + embed_dim, num_heads, add_bias_kv=add_bias_kv, device=device + ) + + def forward(self, query, key, value): + return self.attn(query, key, value) + + +class DistTensorParallelExampleTest(DTensorTestBase): + @with_comms + def test_mlp_megatron_e2e(self): + inp_size = [5, 10] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + model = MLPModule(self.device_type) + model_tp = MLPModule(self.device_type) + + # Ensure model are initialized the same way. + self.assertEqual(model.net1.weight, model_tp.net1.weight) + self.assertEqual(model.net1.bias, model_tp.net1.bias) + self.assertEqual(model.net2.weight, model_tp.net2.weight) + self.assertEqual(model.net2.bias, model_tp.net2.bias) + + # Shard module and initialize optimizer. + LR = 0.25 + shard_mlp(model_tp, self.device_type, NUM_DEVICES) + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp) + output_tp = model_tp(inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.net1.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + + # Ensure gradients are same. + self.assertEqual( + model.net1.weight.grad, + model_tp.net1.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net1.bias.grad, + model_tp.net1.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.weight.grad, + model_tp.net2.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.bias.grad, + model_tp.net2.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.net1.weight, + model_tp.net1.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net1.bias, + model_tp.net1.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.net2.weight, + model_tp.net2.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + # Due to the trick we use for Partial aggregation, we only check the weight when local_rank = 0. + if self.rank == 0: + self.assertEqual( + model.net2.bias, + model_tp.net2.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp) + output_tp = model_tp(inp) + self.assertEqual(output, output_tp) + + # TensorParallelMultiheadAttention == dist_module(TensorParallelMultiheadAttention) + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_self_attn_megatron_e2e(self): + inp_size = [8, 12, 16] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + + # Initialize model using same seed. + torch.manual_seed(5) + model = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + torch.manual_seed(5) + model_tp = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + + # Ensure model are initialized the same way. + self.assertEqual(model.qkv.weight, model_tp.qkv.weight) + self.assertEqual(model.qkv.bias, model_tp.qkv.bias) + self.assertEqual(model.proj.weight, model_tp.proj.weight) + self.assertEqual(model.proj.bias, model_tp.proj.bias) + + # Shard module and initialize optimizer. + device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) + distribute_module( + model_tp, + device_mesh, + partition_fn=tp_shard_self_attn, + input_fn=replicate_input, + output_fn=replicate_output, + ) + + device_mesh = model_tp.qkv.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + # Ensure model are initialized the same way. + self.assertEqual( + model.qkv.weight, + model_tp.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + LR = 0.25 + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.qkv.weight.device_mesh + # Ensure gradients are same. + self.assertEqual( + model.qkv.weight.grad, + model_tp.qkv.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias.grad, + model_tp.qkv.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight.grad, + model_tp.proj.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias.grad, + model_tp.proj.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.qkv.weight, + model_tp.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + # TensorParallelMultiheadAttention == dist_module(torch.nn.MultiheadAttention) + # baddbmm introduces nan occasionally on CPU: https://github.com/pytorch/pytorch/issues/80588 + @with_comms + @skip_unless_torch_gpu + def test_self_attn_replacement_megatron_e2e(self): + inp_size = [8, 12, 16] + # Ensure all tp ranks have same input. + torch.manual_seed(0) + inp = torch.rand(*inp_size, device=self.device_type) + + # TODO: our sharding function cannot shard the root node + torch.manual_seed(5) + model = TensorParallelMultiheadAttention( + 16, + 8, + tp_size=NUM_DEVICES, + add_bias_kv=True, + device=self.device_type, + ) + model_tp = MultiheadAttnWrap( + 16, 8, add_bias_kv=True, device=self.device_type + ) + + # TODO: somehow using torch.nn.MultiheadAttention's initial params does not work + # Use TensorParallelMultiheadAttention parameters instead + x = model.qkv.weight.clone().detach().requires_grad_() + model_tp.attn.register_parameter( + "in_proj_weight", torch.nn.Parameter(x) + ) + + x = model.qkv.bias.clone().detach().requires_grad_() + model_tp.attn.register_parameter("in_proj_bias", torch.nn.Parameter(x)) + + x = model.proj.weight.clone().detach().requires_grad_() + model_tp.attn.out_proj.register_parameter( + "weight", torch.nn.Parameter(x) + ) + + x = model.proj.bias.clone().detach().requires_grad_() + model_tp.attn.out_proj.register_parameter("bias", torch.nn.Parameter(x)) + + # check if parameters are same + self.assertEqual(model.qkv.weight, model_tp.attn.in_proj_weight) + self.assertEqual(model.qkv.bias, model_tp.attn.in_proj_bias) + self.assertEqual(model.proj.weight, model_tp.attn.out_proj.weight) + self.assertEqual(model.proj.bias, model_tp.attn.out_proj.bias) + + # Shard module and initialize optimizer. + device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) + distribute_module( + model_tp, + device_mesh, + partition_fn=tp_shard_self_attn, + input_fn=replicate_input, + output_fn=replicate_output, + ) + + device_mesh = model_tp.attn.qkv.weight.device_mesh + replicate = [Replicate()] * device_mesh.ndim + # Ensure model are initialized the same way. + self.assertEqual( + model.qkv.weight, + model_tp.attn.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.attn.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.attn.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.attn.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + LR = 0.25 + optim = torch.optim.SGD(model.parameters(), lr=LR) + optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) + + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + output.sum().backward() + output_tp.sum().backward() + + device_mesh = model_tp.attn.qkv.weight.device_mesh + # Ensure gradients are same. + self.assertEqual( + model.qkv.weight.grad, + model_tp.attn.qkv.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias.grad, + model_tp.attn.qkv.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight.grad, + model_tp.attn.proj.weight.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias.grad, + model_tp.attn.proj.bias.grad.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + optim.step() + optim_tp.step() + + # Ensure model weights are still same after update. + self.assertEqual( + model.qkv.weight, + model_tp.attn.qkv.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.qkv.bias, + model_tp.attn.qkv.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.weight, + model_tp.attn.proj.weight.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + self.assertEqual( + model.proj.bias, + model_tp.attn.proj.bias.redistribute( + device_mesh=device_mesh, placements=replicate + ).to_local(), + ) + + inp = torch.rand(*inp_size, device=self.device_type) + output = model(inp, inp, inp) + output_tp = model_tp(inp, inp, inp) + self.assertEqual(output, output_tp) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/parallel/test_view_sharding_dim_change.py b/test/distributed/_tensor/parallel/test_view_sharding_dim_change.py new file mode 100644 index 000000000000..4648d930b9eb --- /dev/null +++ b/test/distributed/_tensor/parallel/test_view_sharding_dim_change.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torch.distributed._tensor import DeviceMesh, DTensor, Shard +from torch.distributed._tensor.parallel._view_with_dim_change import ( + _view_with_sharding_dim_change, +) + + +class TPViewShardingDimChangeTest(DTensorTestBase): + @with_comms + def test_view_with_sharding_dim_change(self): + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + torch.manual_seed(self.rank) + tensor = torch.rand(3, 5, 6, device=self.device_type) + sharding = [Shard(2)] + dt = DTensor.from_local(tensor, device_mesh, sharding) + dt = _view_with_sharding_dim_change(dt, 1, (3, -1, 6)) + self.assertTrue(dt.placements[0].is_shard(dim=1)) + self.assertEqual(dt.to_local(), tensor.view(3, -1, 6)) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tensor/parallel/__init__.py b/torch/distributed/_tensor/parallel/__init__.py new file mode 100644 index 000000000000..5725c5077d4b --- /dev/null +++ b/torch/distributed/_tensor/parallel/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from torch.distributed._tensor.parallel.multihead_attention_tp import ( + TensorParallelMultiheadAttention, +) + +from torch.distributed._tensor.parallel.api import ( + tp_shard_self_attn, + replicate_input, + replicate_output, +) diff --git a/torch/distributed/_tensor/parallel/_view_with_dim_change.py b/torch/distributed/_tensor/parallel/_view_with_dim_change.py new file mode 100644 index 000000000000..7988129318b7 --- /dev/null +++ b/torch/distributed/_tensor/parallel/_view_with_dim_change.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from typing import Tuple, Union + +import torch +from torch.distributed._tensor import DTensor as DT +from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.ops.utils import prod + + +def _view_with_sharding_dim_change( + tensor: Union[torch.Tensor, DT], sharding_dim: int, shape: Tuple[int, ...] +) -> Union[torch.Tensor, DT]: + """ + We change the implicit sharding dim for a distributed tensor without comms. + Because if we don't change sharding dim, we will ended up having more comms that are not necessary. + Note that this op will produce invalid DTensor, you will need to call this op in pair to recover + it back to a valid DTensor. + + This should only be used when implicitly changing sharding dim doesn't have semantic issue. + """ + if isinstance(tensor, DT): + # pyre-fixme[16]: Undefined attribute. + return _ViewAndRedistribute.apply(tensor, sharding_dim, shape) + else: + return tensor.view(shape) + + +class _ViewAndRedistribute(torch.autograd.Function): + @staticmethod + # pyre-fixme[14]: Inconsistent override. + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + self: DT, + sharding_dim: int, + shape: Tuple[int, ...], + ) -> DT: + ctx.previous_placement = self.placements + ctx.previous_device_mesh = self.device_mesh + ctx.previous_local_shape = self.to_local().size() + ctx.previous_global_shape = self.size() + assert ( + self.device_mesh.ndim == 1 + ), "Only support 1D Device Mesh for _ViewAndRedistribute." + if ( + self.placements[0].is_shard(dim=sharding_dim) + or self.placements[0].is_replicate() + or self.placements[0].is_partial() + ): + # pyre-fixme[7]: Incompatible return type. + return self.view(shape) # type: ignore[return-value] + else: + if sharding_dim < 0: + sharding_dim += self.dim() + + device_mesh = self.device_mesh + world_size = device_mesh.size(dim=0) + new_sharding_placement = [Shard(sharding_dim)] + + # Fix shape + try: + infer_idx = shape.index(-1) + except ValueError: + infer_idx = None # type: ignore[assignment] + + # Infer the dim which is specified with -1. + if infer_idx is not None: + st_size = prod(self.size()) # type: ignore[attr-defined] + shape_size = -1 * prod(shape) # type: ignore[attr-defined] + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + shape = ( + *shape[:infer_idx], + st_size // shape_size, + *shape[infer_idx + 1 :], + ) + + # pyre-fixme[60]: Concatenation not yet support for multiple variadic + new_local_tensor_size = ( + *shape[:sharding_dim], + shape[sharding_dim] // world_size, + *shape[sharding_dim + 1 :], + ) + new_local_tensor = self.to_local().view(*new_local_tensor_size) + + return DT( + new_local_tensor, + device_mesh, + new_sharding_placement, + size=torch.Size(shape), + requires_grad=new_local_tensor.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: DT) -> Tuple[DT, None, None]: # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + previous_local_tensor_size = ctx.previous_local_shape + previous_global_shape = ctx.previous_global_shape + return ( + DT( + grad_output.to_local().view(*previous_local_tensor_size), + previous_device_mesh, + previous_placement, + size=previous_global_shape, + requires_grad=grad_output.requires_grad, + ), + None, + None, + ) diff --git a/torch/distributed/_tensor/parallel/api.py b/torch/distributed/_tensor/parallel/api.py new file mode 100644 index 000000000000..7ab3ad2199f2 --- /dev/null +++ b/torch/distributed/_tensor/parallel/api.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +import torch.nn as nn +from typing import Sequence, Tuple +from torch.distributed._tensor import ( + distribute_tensor, + DTensor, + Shard, + Replicate, + DeviceMesh, + Placement, +) +from torch.distributed._tensor.parallel import TensorParallelMultiheadAttention + + +def replicate_input( + inputs: Sequence[torch.Tensor], device_mesh: DeviceMesh +) -> Tuple[DTensor, ...]: + replicate = [Replicate()] * device_mesh.ndim + return tuple( + DTensor.from_local(tensor, device_mesh, replicate) for tensor in inputs + ) + + +def replicate_output(output: DTensor, device_mesh: DeviceMesh) -> torch.Tensor: + if isinstance(output, DTensor): + replicate = [Replicate()] * output.device_mesh.ndim + # TODO: can the output be left incontiguous? + return ( + output.redistribute(output.device_mesh, replicate) + .to_local() + .contiguous() + ) + + +def tp_shard_self_attn( + name: str, module: nn.Module, device_mesh: DeviceMesh +) -> None: + col_wise_sharding: Sequence[Placement] = [Shard(0)] + row_wise_sharding: Sequence[Placement] = [Shard(1)] + replicate: Sequence[Placement] = [Replicate()] * device_mesh.ndim + + def _shard_self_attn_params(name: str, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + if name == "qkv": + sharded_weight = nn.Parameter( + distribute_tensor( + module.weight, device_mesh, col_wise_sharding + ) + ) + module.register_parameter("weight", sharded_weight) + if module.bias is not None: + sharded_bias = nn.Parameter( + distribute_tensor( + module.bias, device_mesh, col_wise_sharding + ) + ) + module.register_parameter("bias", sharded_bias) + elif name == "proj": + sharded_weight = nn.Parameter( + distribute_tensor( + module.weight, device_mesh, row_wise_sharding + ) + ) + module.register_parameter("weight", sharded_weight) + if module.bias is not None: + replicated_bias = nn.Parameter( + distribute_tensor(module.bias, device_mesh, replicate) + ) + module.register_parameter("bias", replicated_bias) + + if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA + for n, m in module.named_children(): + _shard_self_attn_params(n, m) + else: + for n, m in module.named_children(): # replace with TPMA + if isinstance(m, nn.MultiheadAttention): + tp_multi_head_attention = TensorParallelMultiheadAttention( + m.embed_dim, + m.num_heads, + device=torch.device(device_mesh.device_type), + tp_size=device_mesh.size(0), # group size on dim 0 + add_bias_kv=m.bias_k is not None, + ) + tp_multi_head_attention.copy(m) + module.register_module(n, tp_multi_head_attention) diff --git a/torch/distributed/_tensor/parallel/fsdp.py b/torch/distributed/_tensor/parallel/fsdp.py new file mode 100644 index 000000000000..1f1123c51775 --- /dev/null +++ b/torch/distributed/_tensor/parallel/fsdp.py @@ -0,0 +1,357 @@ +import warnings +import copy +from typing import List, NamedTuple, Optional, Tuple, cast + +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as c10d + +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor + +import torch.distributed._shard.sharding_spec as shard_spec +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ( + ChunkShardingSpec, +) + +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) + +from torch.distributed._shard.sharding_spec import ( + ShardMetadata, +) + +from torch.distributed.remote_device import _remote_device + +from torch.distributed._tensor import ( + DTensor as DistributedTensor, + DeviceMesh, + Shard as DShard, +) +from torch.distributed._tensor.placement_types import Placement + +__all__ = ["is_available"] + + +class _STShardingInfo(NamedTuple): + """:class:`ShardedTensor` sharding information.""" + + sharding_spec: Optional[shard_spec.ShardingSpec] + global_size: Optional[torch.Size] + process_group: Optional[c10d.ProcessGroup] + device_mesh: Optional[DeviceMesh] + placements: Optional[List[Placement]] + + +def _get_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + placement = tensor.placements[0] + offsets = [0] * len(tensor.size()) + num_chunks = device_mesh.size(dim=0) + + if tensor.placements[0].is_shard(): + shard_dim = cast(DShard, placement).dim + chunk_size = tensor.size(shard_dim) // num_chunks + offsets[shard_dim] = chunk_size + + return (torch.Size(offsets), tensor._local_tensor.size()) + + +def _get_box_for( + tensor: DistributedTensor, idx: int +) -> Tuple[torch.Size, torch.Size]: + offsets, size = _get_box(tensor) + return (torch.Size([val * idx for val in offsets]), size) + + +def _get_local_box(tensor: DistributedTensor) -> Tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + dim_0_coord = device_mesh.get_coordinate_on_dim(0) + assert dim_0_coord is not None + return _get_box_for(tensor, dim_0_coord) + + +def _create_shard_md_from_dt( + dt: DistributedTensor, current_rank: int +) -> ShardMetadata: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + offsets, sizes = _get_local_box(dt) + return ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=f"rank:{current_rank}/{dt._local_tensor.device}", + ) + + +def _create_sharded_tensor_md_from_dt( + dt: DistributedTensor, dt_pg: c10d.ProcessGroup +) -> ShardedTensorMetadata: + # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage + # and yet has only one valid shard for the current rank. + + shards_md = [] + my_rank = dist.get_rank(dt_pg) + scapegoat_rank = 0 if my_rank > 0 else 1 + + if dt.placements[0].is_shard(): + shard_count = dt_pg.size() + else: + shard_count = 1 + + for i in range(shard_count): + offsets, sizes = _get_box_for(dt, i) + shards_md.append( + ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}", + ) + ) + + return ShardedTensorMetadata( + shards_metadata=shards_md, + size=dt.size(), + tensor_properties=TensorProperties( + dtype=dt.dtype, + layout=dt.layout, + requires_grad=dt.requires_grad, + # ignore memory_format and pin_memory as those are not supported by DT + ), + ) + + +def _get_dt_pg(dt: DistributedTensor) -> c10d.ProcessGroup: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + return mesh.get_dim_groups()[0] + + +def _rewrite_spec_if_needed( + spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int +) -> shard_spec.ShardingSpec: + """ + Rewrite ``spec`` to match the device of ``tensor``. + + FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec + produces CUDA metadata, ST construction bombs. + """ + if not isinstance(spec, ChunkShardingSpec): + return spec + + # let's see if we need + rewrite = False + for p in spec.placements: + p = cast(_remote_device, p) + if p.rank() == rank and p.device() != tensor.device: + rewrite = True + break + if rewrite: + spec = copy.deepcopy(spec) + for i, placement in enumerate(spec.placements): + placement = cast(_remote_device, placement) + if placement.rank() == rank and placement.device() != tensor.device: + spec.placements[i] = _remote_device( + f"rank:{rank}/{tensor.device}" + ) + + return spec + + +def _flatten_tensor( + tensor: torch.Tensor, +) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]: + if type(tensor) is ShardedTensor: + return tensor.local_tensor(), _STShardingInfo( + tensor.sharding_spec(), + tensor.size(), + tensor._process_group, + None, + None, + ) + elif type(tensor) is DistributedTensor: + tensor._local_tensor.requires_grad_() + return tensor._local_tensor, _STShardingInfo( + None, + None, + None, + tensor.device_mesh, + list(tensor.placements), + ) + return tensor, None + + +def _unflatten_tensor( + tensor: torch.Tensor, sharding_info: _STShardingInfo +) -> torch.Tensor: + result: torch.Tensor + + if sharding_info.sharding_spec is not None: + assert sharding_info.global_size is not None + result = ShardedTensor._init_from_local_tensor( + tensor, + _rewrite_spec_if_needed( + sharding_info.sharding_spec, + tensor, + dist.get_rank(sharding_info.process_group), + ), + sharding_info.global_size, + process_group=cast(dist.ProcessGroup, sharding_info.process_group), + ) + else: + result = DistributedTensor.from_local( + tensor, + device_mesh=sharding_info.device_mesh, + placements=sharding_info.placements, + run_check=False, + ) + + _set_fsdp_flattened(result) + return result + + +def _chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, +) -> torch.Tensor: + if type(tensor) is ShardedTensor: + assert len(tensor.local_shards()) == 1 + + inner_param = tensor.local_tensor() + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + num_devices_per_node, + pg, + ) + + outer_local_shard = tensor.local_shards()[0] + shards: List[Shard] = [ + Shard(inner_st, copy.deepcopy(outer_local_shard.metadata)) + ] + st_meta = copy.deepcopy(tensor.metadata()) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=tensor._process_group, + init_rrefs=False, + ) + return st_outer + elif type(tensor) is DistributedTensor: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + inner_param = tensor._local_tensor + + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + torch.cuda.device_count(), + pg, + ) + + dt_pg = _get_dt_pg(tensor) + # We do this differently here, we create a ST with no local shards then patch it + shards = [ + Shard( + inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg)) + ) + ] + + st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=dt_pg, + init_rrefs=False, + ) + + return st_outer + else: + return _create_chunk_sharded_tensor( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _pre_load_state_dict( + tensor: torch.Tensor, +) -> Tuple[torch.Tensor, List[Shard]]: + shards = cast(ShardedTensor, tensor).local_shards() + if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor: + inner_tensor = shards[0].tensor + shards = inner_tensor.local_shards() # pyre-ignore[16] + tensor = inner_tensor + + return (tensor, shards if len(shards) > 0 else []) + + +try: + from torch.distributed.fsdp._fsdp_extensions import ( + _set_fsdp_extensions, + FSDPExtensions, + ) + from torch.distributed.fsdp._common_utils import _set_fsdp_flattened + + class DTensorExtensions(FSDPExtensions): + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[_STShardingInfo]]: + return _flatten_tensor(tensor) + + def post_unflatten_transform( + self, tensor: torch.Tensor, param_extension: _STShardingInfo + ) -> torch.Tensor: + return _unflatten_tensor(tensor, param_extension) + + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + ) -> torch.Tensor: + return _chunk_tensor( + tensor, rank, world_size, num_devices_per_node, pg + ) + + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> Tuple[torch.Tensor, List[Shard]]: + return _pre_load_state_dict(tensor) + + _set_fsdp_extensions(DTensorExtensions()) + + def is_available() -> bool: + return True + +except BaseException as e: + warnings.warn( + "PyTorch doesn't have TensorFlattener extension point available" + "2D parallelism won't work with FSDP" + f"exception: {e}" + ) + + def is_available() -> bool: + return False diff --git a/torch/distributed/_tensor/parallel/multihead_attention_tp.py b/torch/distributed/_tensor/parallel/multihead_attention_tp.py new file mode 100644 index 000000000000..3071f42632fd --- /dev/null +++ b/torch/distributed/_tensor/parallel/multihead_attention_tp.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# pyre-ignore-all-errors[6] + +import math + +import torch +from torch.distributed._tensor import DTensor as DT +from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.parallel._view_with_dim_change import ( + _view_with_sharding_dim_change, +) + +from typing import Optional, Union + + +# TODO: Add a test to test equivalence between our Multihead Attention +# with other mainstream ones (Megatron-LM or PyTorch). +def _stride_same_as_shard( + tensor: torch.Tensor, tp_size: int, chunk_dim: int, cat_dim: int +) -> torch.Tensor: + """ + Adjust local tensor's stride same as the sharded situation. + So that view result will keeps the same. + """ + if isinstance(tensor, DT): + return tensor + view_size = list(tensor.size()) + view_size[chunk_dim] //= tp_size + return torch.cat( + [t.view(*view_size) for t in tensor.chunk(tp_size, dim=chunk_dim)], + dim=cat_dim, + ).contiguous() + + +class TensorParallelMultiheadAttention(torch.nn.Module): + """ + Multi-head Attention block from Transformer models. + Since we need some customizations for the attention layer, + we are writing a customized but mathematically equivalent + attention module as defined in torch.nn. + + Note that: + We now only support the case when it's self attention with + limited input args and we also assume that the input tensor + has a dimension of three. Although we do implement the logic + for multihead attention, it was not fully tested. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + batch_first: bool = False, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + tp_size: int = 1, + self_attention: bool = True, + ) -> None: + super(TensorParallelMultiheadAttention, self).__init__() + self.device: torch.device = ( + torch.device("cuda" if torch.cuda.is_available() else "cpu") + if device is None + else device + ) + self.num_heads = num_heads + self.hidden_size = embed_dim + self.hidden_size_per_attention_head: int = self.hidden_size // num_heads + self.scale: float = self.hidden_size_per_attention_head**-0.5 + if self_attention: + self.qkv: torch.nn.Module = torch.nn.Linear( + embed_dim, embed_dim * 3, bias=add_bias_kv, device=self.device + ) + torch.nn.init.xavier_uniform_(self.qkv.weight) + if add_bias_kv: + torch.nn.init.zeros_(self.qkv.bias) + else: + self.query: torch.nn.Module = torch.nn.Linear( + embed_dim, embed_dim, bias=add_bias_kv, device=self.device + ) + self.key: torch.nn.Module = torch.nn.Linear( + embed_dim, embed_dim, bias=add_bias_kv, device=self.device + ) + self.value: torch.nn.Module = torch.nn.Linear( + embed_dim, embed_dim, bias=add_bias_kv, device=self.device + ) + torch.nn.init.xavier_uniform_(self.query.weight) + torch.nn.init.xavier_uniform_(self.key.weight) + torch.nn.init.xavier_uniform_(self.value.weight) + if add_bias_kv: + torch.nn.init.zeros_(self.query.bias) + torch.nn.init.zeros_(self.key.bias) + torch.nn.init.zeros_(self.value.bias) + self.proj: torch.nn.Module = torch.nn.Linear( + embed_dim, embed_dim, bias=bias, device=self.device + ) + torch.nn.init.kaiming_uniform_(self.proj.weight, a=math.sqrt(5)) + if bias: + torch.nn.init.zeros_(self.proj.bias) + self.tp_size = tp_size + self.hidden_size = embed_dim + self.norm_factor: float = math.sqrt(self.hidden_size_per_attention_head) + self.self_attention = self_attention + + def forward( + self, + query: Union[torch.Tensor, DT], + key: Union[torch.Tensor, DT], + value: Union[torch.Tensor, DT], + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True, + ) -> Union[torch.Tensor, DT]: + b, sq, h = query.shape + sk = key.size(1) + nh = self.num_heads + hn = self.hidden_size_per_attention_head + + # x: [b, sq/sk/sv, h] + # =================== + # Permute. [sq/sk/sv, b, h] + # =================== + if not self.self_attention: + # ===================== + # Query, Key, and Value + # ===================== + query = query.permute(1, 0, 2).contiguous() + key = key.permute(1, 0, 2).contiguous() + value = value.permute(1, 0, 2).contiguous() + + # Attention heads [sq/sk/sv, b, h] --> [sq/sk/sv * b, (nh * hn)] + query = query.view(-1, h) + key = key.view(-1, h) + value = value.view(-1, h) + + query_layer = _view_with_sharding_dim_change( + self.query(query), 1, (sq, b * nh, hn) + ) + key_layer = _view_with_sharding_dim_change( + self.key(key), 1, (sk, b * nh, hn) + ) + value_layer = _view_with_sharding_dim_change( + self.value(value), 1, (sk, b * nh, hn) + ) + else: + assert torch.equal(query, key) and torch.equal( + query, value + ), "inputs are different for self-attention." + # ===================== + # Query + # ===================== + query = query.permute(1, 0, 2).contiguous() + + # Attention heads [sq, b, h] --> [sq * b, (nh * 3 * hn)] + query = query.view(-1, h) + mixed_x_layer = self.qkv(query) + + # [sq * b, 3 * h] --> [sq, b, nh, 3 * hn] + mixed_x_layer = _view_with_sharding_dim_change( + mixed_x_layer, 2, (sq, b, nh, 3 * hn) + ) + + # [sq, b, nh, 3 * hn] --> 3 [sq, b, nh, hn] + last_dim = mixed_x_layer.dim() - 1 + last_dim_size = mixed_x_layer.size(last_dim) // 3 + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + last_dim_size, dim=last_dim + ) + + query_layer = _stride_same_as_shard(query_layer, self.tp_size, 2, 1) + key_layer = _stride_same_as_shard(key_layer, self.tp_size, 2, 1) + value_layer = _stride_same_as_shard(value_layer, self.tp_size, 2, 1) + # [sq, b, nh, hn] -> [sq, b * nh, hn] + query_layer = _view_with_sharding_dim_change( + query_layer, 1, (sq, b * nh, -1) + ) + key_layer = _view_with_sharding_dim_change( + key_layer, 1, (sq, b * nh, -1) + ) + value_layer = _view_with_sharding_dim_change( + value_layer, 1, (sq, b * nh, -1) + ) + + # =================================== + # Raw attention scores. [b, nh, s, s] + # =================================== + + factor = self.tp_size if isinstance(query_layer, DT) else 1 + # preallocting result tensor: [b * nh, sq, sk] + matmul_result = torch.empty( + b * nh // factor, + sq, + sk, + dtype=query_layer.dtype, + device=self.device, + ) + if isinstance(query_layer, DT): + matmul_result = DT.from_local( + matmul_result, + query_layer.device_mesh, + [Shard(0)], + run_check=False, + ) + + # Raw attention scores. [b * nh, sq, sk] + attn = torch.baddbmm( + matmul_result, + query_layer.transpose(0, 1), # [b * nh, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * nh, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # =============== + # Attention probs + # =============== + attn = attn.softmax(dim=-1) + + # ========================= + # Context layer. [sq * b, hidden] + # ========================= + + # bmm: [b * nh, sq, hn] + context_layer = torch.bmm(attn, value_layer.transpose(0, 1)) + + # change view [nh, b, sq, hn] + context_layer = context_layer.view(nh, b, sq, hn) + + # [nh, b, sq, hn] --> [sq, b, nh, hn] + context_layer = context_layer.permute(2, 1, 0, 3).contiguous() + + # [sq, b, nh, hn] --> [sq * b, hidden] + context_layer = _view_with_sharding_dim_change( + context_layer.contiguous(), 1, (-1, self.hidden_size) + ) + + # ================= + # Projection. [sq, b, h] + # ================= + output = self.proj(context_layer).view(sq, b, h) + + # =================== + # Permute. [b, sq, h] + # =================== + output = output.permute(1, 0, 2) + + return output + + def copy(self, that: torch.nn.MultiheadAttention) -> None: + # TODO: current implementation assume `self` is a self attention module + assert ( + self.hidden_size == that.embed_dim + ), "embed_dim must be equal in TensorParallelMultiheadAttention.copy()!" + + if that.in_proj_weight is not None: + self.qkv.register_parameter("weight", that.in_proj_weight) + if that.in_proj_bias is not None: + self.qkv.register_parameter("bias", that.in_proj_bias) + if that.out_proj.weight is not None: + # TODO: The use of Parameter is to avoid `mypy` issue caused + # by the `tensor` type annotation on Linear.weight to which + # a Parameter object is actually assigned + self.proj.register_parameter( + "weight", torch.nn.Parameter(that.out_proj.weight) + ) + if that.out_proj.bias is not None: + self.proj.register_parameter("bias", that.out_proj.bias) From 9d2f5a278414aeaa6f3277c5b15aee4938601fa6 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 16 Nov 2022 08:51:30 +0000 Subject: [PATCH 223/453] [dynamo] Support if cond on NNModuleVariable (#89095) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89095 Approved by: https://github.com/yanboliang, https://github.com/mlazos --- test/dynamo/test_misc.py | 28 ++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 5 +++++ 2 files changed, 33 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e27f7bc5198d..8f79f2476aee 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2885,6 +2885,34 @@ def func(x, y): self.assertTrue(same(ref, res)) self.assertTrue(same(x, x1)) + def test_if_cond_nn_mod(self): + class MockModule(torch.nn.Module): + def __init__(self, output_relu=True): + super(MockModule, self).__init__() + self.relu = torch.nn.ReLU() if output_relu else None + + def forward(self, x): + x = torch.sin(x) + if self.relu: + x = self.relu(x) + return x + + model = MockModule() + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + + model = MockModule(output_relu=False) + opt_model = torch._dynamo.optimize("eager", nopython=True)(model) + + x = torch.rand(4) + ref = model(x) + res = opt_model(x) + self.assertTrue(same(ref, res)) + class CustomFunc(torch.autograd.Function): @staticmethod diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d5c05f76efb0..d2bc5332719c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -252,6 +252,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): + if_next + if_jump ) + elif isinstance(value, NNModuleVariable): + # Equivant of "self.nn_module is not None" + if truth_fn(value): + push and self.push(value) + self.jump(inst) elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( self ): From 9d28775c1d28ab7c1dd93479a58bdafb9b626341 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 09:45:49 +0000 Subject: [PATCH 224/453] Revert "Rewrite assert statement with torch._assert under config (#88246)" This reverts commit 62ba15e10e875ce088dff26e872605ee70c8c04a. Reverted https://github.com/pytorch/pytorch/pull/88246 on behalf of https://github.com/DanilBaibak due to breaking internal builds --- test/dynamo/test_repros.py | 92 ------------------------------ torch/_dynamo/config.py | 3 - torch/_dynamo/symbolic_convert.py | 94 ------------------------------- 3 files changed, 189 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e30a1275ed13..503231b4cb12 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1938,98 +1938,6 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_with_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3, "First dim need to be 3" - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - cnt = torch._dynamo.testing.CompileCounter() - - opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) - self.assertTrue(same(f(*args), opt_f(*args))) - self.assertEqual(cnt.op_count, 6) - self.assertEqual(cnt.frame_count, 1) - - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - with self.assertRaisesRegex(AssertionError, ""): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_not_rewrite_assert_for_other_errors(self): - def f(x): - b = x.sin() - if not x.sum() <= 3: - raise ValueError("input sum needs to be 3") - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - opt_fn = torch._dynamo.optimize("eager")(f) - with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): - opt_fn(*args) - - # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement. - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_with_fstring_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3, f"First dim need to be {x[0]}" - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_without_msg(self): - def f(x): - b = x.sin() - assert x[0] == 3 - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - with self.assertRaisesRegex(AssertionError, ""): - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) - def test_rewrite_assert_noop(self): - def f(x): - b = x.sin() - assert True - assert x.dtype == torch.float32 - return x.cos() + b - - args = (torch.Tensor([3, 4, 5]),) - exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - cnt = torch._dynamo.testing.CompileCounter() - opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) - self.assertTrue(same(f(*args), opt_f(*args))) - # torch._assert shouldn't be in the graph - self.assertEqual(cnt.op_count, 3) - self.assertEqual(cnt.frame_count, 1) - - exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) - self.assertTrue(same(exported(*args), f(*args))) - - @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False) - def test_not_rewrite_assert(self): - def f(x): - b = x.sin() - assert x[0] == 3 - return x.cos() + b - - with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): - torch._dynamo.export(f, torch.Tensor([3, 4, 5])) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 39a1a6433419..12088383e741 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -87,9 +87,6 @@ # if an exception is encountered replay_record_enabled = False -# Rewrite assert statement in python with torch._assert -rewrite_assert_with_torch_assert = True - # Show a warning on every graph break print_graph_breaks = False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d2bc5332719c..e64804cb68b2 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -53,7 +53,6 @@ fake_tensors_available, graph_break_dup_warning_checker, istype, - proxy_args_kwargs, ) from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy @@ -122,103 +121,10 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction): return impl -def _detect_and_normalize_assert_statement( - self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool -): - # Detect if this jump instruction is assert and normalize the assert - # by pushing dummy error message when nothing is given. - # - # Python 3.9 assertion is in following format: - # 18 POP_JUMP_IF_TRUE 28 - # 20 LOAD_ASSERTION_ERROR - # 22 LOAD_CONST 3 ('Assert message') -> optional instruction - # 24 CALL_FUNCTION 1 -> optional instruction - # 26 RAISE_VARARGS - # - # Python 3.8 assertion is in following format: - # 18 POP_JUMP_IF_TRUE 28 - # 20 LOAD_GLOBAL 0 (Assertion type) - # 22 LOAD_CONST 3 ('Assert message') -> optional instruction - # 24 CALL_FUNCTION 1 -> optional instruction - # 26 RAISE_VARARGS 1 - - if (truth_fn is not operator.truth) or push: - return False - - current_instruction_pointer = self.instruction_pointer - inst = self.instructions[current_instruction_pointer] - # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 - if sys.version_info < (3, 9): - if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": - return False - else: - if inst.opname != "LOAD_ASSERTION_ERROR": - return False - - current_instruction_pointer += 1 - - if current_instruction_pointer >= len(self.instructions): - return False - - inst = self.instructions[current_instruction_pointer] - has_error_msg = False - # DETECT RAISE_VARARGS or LOAD CONST - if inst.opname == "LOAD_CONST": - if not isinstance(inst.argval, str): - return False - self.LOAD_CONST(inst) - has_error_msg = True - - # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION - current_instruction_pointer += 1 - if current_instruction_pointer >= len(self.instructions): - return False - inst = self.instructions[current_instruction_pointer] - if inst.opname != "CALL_FUNCTION": - return False - - # CALL_FUNCTION should be followed by RAISE_VARARGS - current_instruction_pointer += 1 - if current_instruction_pointer >= len(self.instructions): - return False - inst = self.instructions[current_instruction_pointer] - - if inst.opname != "RAISE_VARARGS": - return False - - if not has_error_msg: - # Push dummy value instead of error message - self.push(ConstantVariable("assertion error")) - - return True - - def generic_jump(truth_fn: typing.Callable, push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) - if ( - config.rewrite_assert_with_torch_assert - and _detect_and_normalize_assert_statement(self, truth_fn, push) - ): - error_msg: VariableTracker = self.pop() - self.output.guards.update(error_msg.guards) - # Skip over things like `assert True` - if value.is_python_constant() and bool(value.as_python_constant()): - self.jump(inst) - return - - # Manually insert torch._assert instead of python assert and jump over - # assert related instructions as we don't need them anymore. - self.output.create_proxy( - "call_function", - torch._assert, - *proxy_args_kwargs((value, error_msg), {}), - current_tx=self, - ) - self.jump(inst) - return - if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value) From 52701227737489392e59fe57ded40226bf0811f6 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Wed, 16 Nov 2022 10:37:26 +0000 Subject: [PATCH 225/453] [Inductor] Build FX Linear + Permute Vertical Fusion in Inductor (#89118) Summary: Build fx-based linear/matmul/bmm + permute/transpose vertical fusion in Inductor For an internal Ads model: **1.15x -> 1.36x speedup** Test Plan: CI Reviewed By: bertmaher, jansel, jianyuh Differential Revision: D41071665 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89118 Approved by: https://github.com/jianyuh --- test/inductor/test_torchinductor.py | 109 +++++++++++++++ torch/_inductor/config.py | 4 + torch/_inductor/overrides.py | 206 +++++++++++++++++++++++++++- 3 files changed, 316 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index dcb01b9ec78c..1265ca3e7872 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10,6 +10,7 @@ import typing import unittest import weakref +from typing import Any, Callable from unittest.mock import patch import torch @@ -18,6 +19,7 @@ from torch._dynamo.debug_utils import same_two_models from torch._dynamo.testing import rand_strided, same from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, @@ -39,6 +41,14 @@ from torch._inductor import codecache, config, metrics from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing + from torch._inductor.overrides import ( + linear_permute_fusion, + linear_transpose, + permute_linear_fusion, + permute_matmul_fusion, + transpose_linear, + transpose_matmul, + ) from torch._inductor.sizevars import SizeVarAllocator from torch._inductor.utils import has_torchvision_roi_align, timed @@ -113,6 +123,29 @@ def maybe_test(*args, **kwargs): return wrap_test +PassFunc = Callable[[torch.fx.GraphModule, Any], torch.fx.GraphModule] + + +def chain_passes(*passes: PassFunc) -> PassFunc: + def parent_pass(module: torch.fx.GraphModule, input: Any) -> torch.fx.GraphModule: + for pass_ in passes: + if isinstance(module, torch.fx.GraphModule): + ShapeProp(module).propagate(*input) + module = pass_(module) + return module + + return parent_pass + + +def count_call_function(module: torch.fx.GraphModule, target_op: Any) -> int: + return sum( + [ + 1 if (n.op == "call_function" and n.target == target_op) else 0 + for n in module.graph.nodes + ] + ) + + class TestCase(TorchTestCase): @classmethod def setUpClass(cls): @@ -1586,6 +1619,82 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) + @unittest.skipIf(HAS_CPU, "Support GPU so far") + def test_linear_permute_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + a0 = torch.nn.functional.linear(input, self.weight, self.bias) + b0 = a0.permute(0, 2, 1) + return b0 + + m, k, n = 16, 8, 4 + trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, m, k) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_linear_transpose = count_call_function(traced, linear_transpose) + self.assertEqual(num_linear, 0) + self.assertEqual(num_linear_transpose, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + @unittest.skipIf(HAS_CPU, "Support GPU so far") + def test_permute_linear_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.nn.functional.linear(input1, self.weight, self.bias) + return output + + m, k, n = 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, k, m) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_transpose_linear = count_call_function(traced, transpose_linear) + self.assertEqual(num_linear, 0) + self.assertEqual(num_transpose_linear, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + @unittest.skipIf(HAS_CPU, "Support GPU so far") + def test_permute_bmm_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, batch: int, k: int, n: int): + super().__init__() + self.other = torch.randn(batch, k, n) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.bmm(input1, self.other) + return output + + batch, m, k, n = 6, 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) + module = TestModule(batch, k, n).eval() + input = torch.randn(batch, k, m) + traced = trace_func(module, [input]) + num_bmm = count_call_function(traced, torch.bmm) + num_transpose_matmul = count_call_function(traced, transpose_matmul) + self.assertEqual(num_bmm, 0) + self.assertEqual(num_transpose_matmul, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + def test_slice1(self): def fn(a): return ( diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index d376fe3e8bf7..c552101c1cae 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -75,6 +75,10 @@ shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1" alignment_size = 4 +# Fx-based linear/matmul/bmm + permute/transpose vertical fusion +permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1" + + # config specific to codegen/cpp.pp class cpp: # set to torch.get_num_threads() diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 3a95aa7ce880..9a8bc6266ac0 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -19,6 +19,8 @@ from torch.nn.utils.fusion import fuse_conv_bn_eval from torch.overrides import TorchFunctionMode +from . import config + log = logging.getLogger(__name__) @@ -425,14 +427,23 @@ def check_node_is_add_inplace(node): def fuse_fx(gm: torch.fx.GraphModule, example_inputs): + is_cpu = all( + example_input.device == torch.device("cpu") for example_input in example_inputs + ) + + if config.permute_fusion and not is_cpu: + # For linear permute fusion, we need to check input info to identify + # and perform proper permutation/transpose + ShapeProp(gm).propagate(*example_inputs) + gm = linear_permute_fusion(gm) + gm = permute_linear_fusion(gm) + gm = permute_matmul_fusion(gm) + # make sure the autograd is disabled. if torch.is_grad_enabled(): return gm if not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()): return gm - is_cpu = all( - example_input.device == torch.device("cpu") for example_input in example_inputs - ) if not is_cpu: return gm gm = fuse_conv_bn(gm) @@ -528,6 +539,195 @@ def _philox_rand_like(input, seed, offset): return torch.rand_like(input) +class NormalizedLinearNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.nn.functional.linear] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_weight(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["weight"] + + def get_bias(self) -> torch.fx.Node: + if len(self.node.args) > 2: + return self.node.args[2] + else: + return self.node.kwargs["bias"] + + +class NormalizedMatmulNode: + def __init__(self, node: torch.fx.Node) -> None: + assert node.op == "call_function" + assert node.target in [torch.bmm, torch.matmul] + self.node: torch.fx.Node = node + + def get_input(self) -> torch.fx.Node: + if len(self.node.args) > 0: + return self.node.args[0] + else: + return self.node.kwargs["input"] + + def get_other(self) -> torch.fx.Node: + if len(self.node.args) > 1: + return self.node.args[1] + else: + return self.node.kwargs["other"] + + +def check_permute(node: torch.fx.Node): + ranks = len(node.meta["tensor_meta"].shape) + if len(node.args) > 3: + permutation = [node.args[i] % ranks for i in range(1, ranks + 1)] + elif ( + "permutation" in node.kwargs + and node.kwargs["permutation"] is not None + and len(node.kwargs["permutation"]) > 2 + ): + permutation = [i % ranks for i in node.kwargs["permutation"]] + else: + return False + allowed_permutation = list(range(ranks)) + allowed_permutation[-1] = ranks - 2 + allowed_permutation[-2] = ranks - 1 + return permutation == allowed_permutation + + +def linear_permute_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if ( + node.op == "call_method" + and node.target == "permute" + and check_permute(node) + ): + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_function" + and input_node.target == torch.nn.functional.linear + ): + normalized = NormalizedLinearNode(input_node) + input = normalized.get_input() + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + linear_transpose, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# Y1 = X * W^T + bias +# Y2 = Y1.permute(0, 2, 1) +# ----> +# Y2 = (W * X^T + bias.unsqueeze(-1))^T +def linear_transpose( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(weight, input.transpose(-1, -2)) + bias.unsqueeze(-1) + + +def permute_linear_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and node.target == torch.nn.functional.linear: + if len(node.args) > 0: + input_node = node.args[0] + else: + input_node = node.kwargs["input"] + if ( + input_node.op == "call_method" + and input_node.target == "permute" + and check_permute(input_node) + ): + normalized = NormalizedLinearNode(node) + if len(input_node.args) > 0: + input = input_node.args[0] + else: + input = input_node.kwargs["input"] + weight = normalized.get_weight() + bias = normalized.get_bias() + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_linear, args=(input, weight, bias) + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +def permute_matmul_fusion(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in module.graph.nodes: + if node.op == "call_function" and ( + node.target == torch.bmm or node.target == torch.matmul + ): + normalized = NormalizedMatmulNode(node) + A = normalized.get_input() + B = normalized.get_other() + Atrans = Btrans = False + if A.op == "call_method" and A.target == "permute" and check_permute(A): + Atrans = True + if len(A.args) > 0: + A = A.args[0] + else: + A = A.kwargs["input"] + + if B.op == "call_method" and B.target == "permute" and check_permute(B): + Btrans = True + if len(B.args) > 0: + B = B.args[0] + else: + B = B.kwargs["input"] + + if Atrans or Btrans: + with module.graph.inserting_before(node): + fused_node = module.graph.call_function( + transpose_matmul, + args=(A, B, Atrans, Btrans), + ) + node.replace_all_uses_with(fused_node) + + module.graph.lint() + module.graph.eliminate_dead_code() + module.recompile() + return module + + +# X1 = X.permute(0, 2, 1) +# Y1 = X1 * W1^T + bias1 +# ----> +# Y2 = X1.transpose(-1, -2) * W1^T + bias1 +def transpose_linear( + input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor +) -> torch.Tensor: + return torch.matmul(input.transpose(-1, -2), weight.t()) + bias + + +def transpose_matmul(A: torch.Tensor, B: torch.Tensor, Atrans: bool, Btrans: bool): + if Atrans: + A = A.transpose(-1, -2) + if Btrans: + B = B.transpose(-1, -2) + return torch.matmul(A, B) + + def replace_and_fuse_for_binary( computation_node, node, fuse_func, attr, modules, index_node, index_pointwise ): From dc40d3f93f849e467b2b56595a01f28e84ac7fa2 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 15 Nov 2022 19:24:31 +0000 Subject: [PATCH 226/453] Add meta impl for grid_sampler_2d_backward (#88745) TODO: add an OpInfo Pull Request resolved: https://github.com/pytorch/pytorch/pull/88745 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 6 +--- test/functorch/test_ops.py | 2 ++ test/inductor/test_torchinductor_opinfo.py | 1 + test/test_proxy_tensor.py | 8 ++--- torch/_meta_registrations.py | 27 ++++++++++++++++ .../_internal/common_methods_invocations.py | 31 +++++++++++++++++++ 6 files changed, 66 insertions(+), 9 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 752b03ac9984..1dc5476158f9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1128,7 +1128,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('nn.functional.embedding_bag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.fractional_max_pool2d', ''), # rand() received an invalid combination of arguments - g... xfail('nn.functional.fractional_max_pool3d', ''), # rand() received an invalid combination of arguments - g... - xfail('nn.functional.grid_sample', ''), # prims::arange() Expected a value of type 'number' for argument... + xfail('nn.functional.grid_sample', ''), # RuntimeError: aten.grid_sampler_3d.default - couldn't find sym ... xfail('nn.functional.group_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('nn.functional.hinge_embedding_loss', ''), # aten.zeros_like.default - couldn't find symbolic meta... xfail('nn.functional.interpolate', 'area'), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1182,10 +1182,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('repeat_interleave', ''), # aten.repeat_interleave.Te... xfail('reshape_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('roll', ''), # narrow() received an invalid combination of arguments - got (FakeTensor, int, torch._C... - xfail('round', ''), # aten.round.default - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decompos... xfail('segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio... xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio... xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 643ff0ec862a..91ea2443777b 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1204,6 +1204,7 @@ def get_vjp(cotangents, *primals): xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp xfail('nn.functional.embedding_bag', ''), # NYI: forward-AD for _embedding_bag xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d + xfail('grid_sampler_2d', ''), # NYI: forward AD for grid_sampler_2d xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer @@ -1343,6 +1344,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.fractional_max_pool3d'), # calls random op xfail('nn.functional.gaussian_nll_loss'), # data depenedant flow xfail('nn.functional.grid_sample'), # Forward AD not implemented and no decomposition + xfail('grid_sampler_2d'), # Forward AD not implemented and no decomposition xfail('nn.functional.hardsigmoid'), # Forward AD not implemented and no decomposition xfail('nn.functional.hinge_embedding_loss'), # vmap: inplace into a regular tensor xfail('nn.functional.huber_loss'), # Forward AD not implemented and no decomposition diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 83d8d40e21ec..7db9d13733b4 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -333,6 +333,7 @@ def process(device_type): "nn.functional.adaptive_avg_pool2d": {f16}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.grid_sample": {f16}, + "grid_sampler_2d": {f16}, "nn.functional.gaussian_nll_loss": {f16, f32, f64}, "nn.functional.one_hot": {i64}, "nn.functional.rrelu": {f16, f32, f64}, diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 59b08eea8dce..8dc42be7fdfb 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1308,10 +1308,6 @@ def f(a, b, c, d, e): xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('roll', ''), # Tensors of type TensorImpl do not have numel - xfail('round', ''), # aten.round.default - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_0'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition - xfail('round', 'decimals_neg_3'), # aten.round.decimals - couldn't find symbolic meta function/decomposition xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... xfail('segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition @@ -1441,6 +1437,10 @@ def f(a, b, c, d, e): xfail('uniform', ''), # aten.uniform_.default - couldn't find symbolic meta function/decomposition xfail('unique', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition xfail('xlogy', ''), # aten.xlogy_.Tensor - couldn't find symbolic meta function/decomposition + xfail('round', ''), # aten.round_.default - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_0'), # aten.round_.decimals - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_3'), # aten.round_.decimals - couldn't find symbolic meta function/decomposition + xfail('round', 'decimals_neg_3') # aten.round_.decimals - couldn't find symbolic meta function/decomposition } # Copies inputs to inplace operations to avoid inplace modifications diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index be7370e344f0..4fa3ab09d275 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6,6 +6,7 @@ from torch import Tensor from torch._decomp import _add_op_to_registry, global_decomposition_table, meta_table from torch._ops import OpOverload +from torch._prims import _elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND from torch._prims_common import ( check, corresponding_complex_dtype, @@ -1166,6 +1167,13 @@ def meta_binop_inplace_alpha(self, other, alpha=1): return self +@register_meta([aten.round.default, aten.round.decimals]) +def meta_round(self, **kwargs): + return _elementwise_meta( + self, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT + ) + + @register_meta(aten.zero.default) def meta_zero(self): return self.new_empty(self.shape) @@ -1474,6 +1482,25 @@ def meta_max_pool2d_with_indices( ) +@register_meta(aten.grid_sampler_2d_backward.default) +def grid_sampler_2d_backward_meta( + grad_output, + input, + grid, + interpolation_mode, + padding_mode, + align_corners, + output_mask, +): + input_requires_grad = output_mask[0] + if input_requires_grad: + grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format) + else: + grad_input = None + grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format) + return (grad_input, grad_grid) + + @register_meta([aten.full.default]) def full(size, fill_value, *args, **kwargs): return torch.empty(size, *args, **kwargs) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5e60eff2865e..e498e4f28509 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6969,6 +6969,28 @@ def sample_inputs_grid_sample(op_info, device, dtype, requires_grad, **kwargs): align_corners=align_corners, ) +def sample_inputs_grid_sampler_2d(op_info, device, dtype, requires_grad, **kwargs): + # We get better tests if we change the range of the values to something like [-2,2] + # because for grid (second tensor argument) the "useful" range is [-1,1] and this way + # you get a better combination of out-of-range and in-range test cases + _make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, + low=-2, high=2) + + batch_size = 2 + num_channels = 3 + modes = (0, 1, 2) + align_cornerss = (False, True) + padding_modes = (0, 1, 2) + + for mode, padding_mode, align_corners in itertools.product(modes, padding_modes, align_cornerss): + yield SampleInput( + _make_tensor((batch_size, num_channels, S, L)), + _make_tensor((batch_size, num_channels, M, 2)), + mode, + padding_mode, + align_corners, + ) + def sample_inputs_cosine_embedding_loss(op_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -16190,6 +16212,15 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_grid_sample, supports_gradgrad=False, gradcheck_nondet_tol=1e-15), + # TODO: delete this OpInfo once we add meta support for grid_sampler_3d + OpInfo( + "grid_sampler_2d", + dtypes=floating_types(), + dtypesIfCUDA=floating_types_and(torch.float16), + supports_out=False, + sample_inputs_func=sample_inputs_grid_sampler_2d, + supports_gradgrad=False, + gradcheck_nondet_tol=1e-15), OpInfo( "argwhere", ref=np.argwhere, From 57af0c82454c199ab7a734c3d12df93c93f50812 Mon Sep 17 00:00:00 2001 From: Nikita Karetnikov Date: Wed, 16 Nov 2022 11:25:35 +0100 Subject: [PATCH 227/453] Bug fix: make sure `copy_impl` doesn't read out of bounds (#88544) Fixes #88543. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88544 Approved by: https://github.com/lezcano --- aten/src/ATen/native/Copy.cpp | 7 +++++- test/test_torch.py | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index c6b82426d3bf..dc30db8e1100 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -124,12 +124,17 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) // 1. Memory Format for source and destination tensors is contiguous. // 2. Device for both the source and destination tensor is CPU. // 3. dtype conversion between FP32->FP16 and FP16->FP32. + // This checks that self.sizes() == src.sizes() because this code path doesn't + // support broadcasting. This also guards against out of bounds memory access + // when copying, see fbgemm::Float16ToFloat_ref. + // https://github.com/pytorch/pytorch/issues/88543 #ifdef USE_FBGEMM if (((self.dtype() == at::kFloat && src.dtype() == at::kHalf) || (self.dtype() == at::kHalf && src.dtype() == at::kFloat)) && (self.device().is_cpu() && src.device().is_cpu()) && ((self.is_contiguous() && src.is_contiguous()) || - (self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) { + (self.is_non_overlapping_and_dense() && self.strides() == src.strides())) && + (self.sizes() == src.sizes())) { if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) { auto* output_ptr = reinterpret_cast(self.data_ptr()); diff --git a/test/test_torch.py b/test/test_torch.py index 3ebc92676fe0..31759213ecef 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -7661,6 +7661,51 @@ def test_copy_many_to_one(self): # storage to a single storage would cause RuntimeError to be thrown self.assertRaises(RuntimeError, lambda: torch.zeros(1, 6).expand(5, 6).copy_(torch.zeros(5, 6))) + def test_copy_float16(self): + # Check that fbgemm code no longer reads memory out of bounds, see + # copy_impl and fbgemm::Float16ToFloat_ref. + # https://github.com/pytorch/pytorch/issues/88543 + + # Types to test different code paths in copy_impl. + dtypes = ( + # out_dtype, src_dtype + (torch.float32, torch.float16), # fbgemm + (torch.float16, torch.float32), # fbgemm + (torch.float32, torch.float32), # TensorIterator + ) + + cases = ( + # out_shape, src_shape, is_ok + # These cases used to crash with fbgemm, make sure these also raise + # exceptions with TensorIterator. + ((1, 2, 3), (0, 2, 3), False), # same strides, not allowed by TI + ((1, 5, 6), (4, 5, 6), False), # same strides, not allowed by TI + (1, (0, 2, 3), False), # different strides + ((4, 5, 6), (0, 2, 3), False), # different strides + ((4, 5, 6), (1, 2, 3), False), # different strides + ((4, 5, 6), (6, 5, 4), False), # same numel + + # These cases should pass with fbgemm and TensorIterator. + ((4, 5, 6), (1, 5, 6), True), # same strides + ((4, 5, 6), (4, 5, 6), True), # same strides + ((0, 2, 3), 1, True), # different strides, allowed by TI + ((4, 5, 6), (4, 5, 1), True), # different strides, allowed by TI + ) + + for (out_shape, src_shape, is_ok), (out_dtype, src_dtype) in itertools.product(cases, dtypes): + out = torch.zeros(out_shape, dtype=out_dtype, device=torch.device('cpu')) + src = torch.ones(src_shape, dtype=src_dtype, device=torch.device('cpu')) + if is_ok: + if torch.cuda.is_available(): + out_cuda = out.cuda() + src_cuda = src.cuda() + res = out.copy_(src) + if torch.cuda.is_available(): + res_cuda = out_cuda.copy_(src_cuda) + self.assertEqual(res, res_cuda) + else: + self.assertRaises(RuntimeError, lambda: out.copy_(src)) + # FIXME: Port to a more appropriate test suite def _test_to_with_layout(self, layout): def test_copy_behavior(t, non_blocking=False): From 5e0c01330c76c003e55aec29bfb3e83926ee933a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Nov 2022 10:10:27 -0800 Subject: [PATCH 228/453] SymIntArrayRef type caster (#89074) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89074 Approved by: https://github.com/SherlockNoMad --- torch/csrc/utils.cpp | 42 +++++++++++++++++++++++++++++++++++++++ torch/csrc/utils/pybind.h | 16 +++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index b2eac4b54fa1..5fc91d68dd18 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -348,5 +349,46 @@ handle type_caster::cast( return handle(THPUtils_packInt64Array(src.size(), src.data())); } +bool type_caster::load(handle src, bool) { + PyObject* source = src.ptr(); + + auto tuple = PyTuple_Check(source); + if (tuple || PyList_Check(source)) { + // NOLINTNEXTLINE(bugprone-branch-clone) + const auto size = + tuple ? PyTuple_GET_SIZE(source) : PyList_GET_SIZE(source); + v_value.resize(size); + for (const auto idx : c10::irange(size)) { + PyObject* obj = + tuple ? PyTuple_GET_ITEM(source, idx) : PyList_GET_ITEM(source, idx); + + if (THPVariable_Check(obj)) { + // TODO: this is for consistency with IntArrayRef but arguably + // we shouldn't really allow this on pybind11 casters + v_value[idx] = THPVariable_Unpack(obj).item(); + } else if (torch::is_symint(py::handle(obj))) { + v_value[idx] = py::handle(obj).cast(); + } else if (PyLong_Check(obj)) { + v_value[idx] = c10::SymInt(THPUtils_unpackIndex(obj)); + } else { + return false; + } + } + value = v_value; + return true; + } + return false; +} +handle type_caster::cast( + at::SymIntArrayRef src, + return_value_policy /* policy */, + handle /* parent */) { + py::list t(src.size()); + for (const auto i : c10::irange(src.size())) { + t[i] = py::cast(src[i]); + } + return t.release(); +} + } // namespace detail } // namespace pybind11 diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index c43cf5e73283..85532a42cee2 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -109,6 +109,22 @@ struct TORCH_PYTHON_API type_caster { std::vector v_value; }; +template <> +struct TORCH_PYTHON_API type_caster { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("at::SymIntArrayRef")); + + bool load(handle src, bool); + static handle cast( + at::SymIntArrayRef src, + return_value_policy /* policy */, + handle /* parent */); + + private: + std::vector v_value; +}; + template <> struct TORCH_PYTHON_API type_caster { public: From 09ed8b67e24cfe29f3fa7b5dd28eaa7749229f12 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 15 Nov 2022 10:10:28 -0800 Subject: [PATCH 229/453] SymIntify convolution backend calculation (#89069) We will need this to implement a convolution meta function that is SymInt aware. I use templates so that regular convolution code is not affected by the change. No tests for symbolic ints directly; that will come in a subsequent PR which also needs to refactor fake tensors. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89069 Approved by: https://github.com/SherlockNoMad --- aten/src/ATen/native/ConvUtils.h | 79 ++++-- aten/src/ATen/native/Convolution.cpp | 319 +++++++++++++----------- aten/src/ATen/native/utils/ParamUtils.h | 7 +- c10/core/SymInt.h | 13 + torch/csrc/Module.cpp | 12 +- 5 files changed, 256 insertions(+), 174 deletions(-) diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index b8e2b0842a00..880ce0c2af54 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -110,8 +110,8 @@ enum class ConvBackend { // This overload is exposed to python for testing, etc. TORCH_API ConvBackend select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups, const at::OptionalIntArrayRef bias_sizes_opt); + IntArrayRef stride, SymIntArrayRef padding, IntArrayRef dilation, + bool transposed, SymIntArrayRef output_padding, int64_t groups, const at::OptionalSymIntArrayRef bias_sizes_opt); TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input, const Tensor& weight, @@ -200,15 +200,16 @@ static void convolution_shape_check( // as conv_output_size loses information; this is why conv_input_size // takes an extra output_padding argument to resolve the ambiguity. -static inline std::vector conv_output_size( - IntArrayRef input_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +template +static inline std::vector _conv_output_size( + ArrayRef input_size, ArrayRef weight_size, + ArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { // ASSERT(input_size.size() > 2) // ASSERT(input_size.size() == weight_size.size()) bool has_dilation = dilation.size() > 0; auto dim = input_size.size(); - std::vector output_size(dim); + std::vector output_size(dim); output_size[0] = input_size[input_batch_size_dim]; output_size[1] = weight_size[weight_output_channels_dim]; for (const auto d : c10::irange(2, dim)) { @@ -219,40 +220,84 @@ static inline std::vector conv_output_size( return output_size; } -static inline std::vector conv_input_size( - IntArrayRef output_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +static inline std::vector conv_output_size( + IntArrayRef input_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +static inline std::vector conv_output_size( + SymIntArrayRef input_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +template +std::vector _conv_input_size( + ArrayRef output_size, ArrayRef weight_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { // ASSERT(output_size.size() > 2) // ASSERT(output_size.size() == weight_size.size()) auto dim = output_size.size(); - std::vector input_size(dim); + std::vector input_size(dim); input_size[0] = output_size[output_batch_size_dim]; input_size[1] = weight_size[weight_input_channels_dim] * groups; for (const auto d : c10::irange(2, dim)) { - int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; - input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + + auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + kernel + output_padding[d - 2]; } return input_size; } -static inline std::vector conv_weight_size( - IntArrayRef input_size, IntArrayRef output_size, +static inline std::vector conv_input_size( + SymIntArrayRef output_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_input_size( + IntArrayRef output_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +template +std::vector _conv_weight_size( + ArrayRef input_size, ArrayRef output_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { auto dim = input_size.size(); - std::vector weight_size(dim); + std::vector weight_size(dim); weight_size[0] = output_size[1]; weight_size[1] = input_size[1] / groups; for (const auto d : c10::irange(2, dim)) { - int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] - + 2 * padding[d - 2] - output_padding[d - 2]; + auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + padding[d - 2] * 2 - output_padding[d - 2]; weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; } return weight_size; } +static inline std::vector conv_weight_size( + SymIntArrayRef input_size, SymIntArrayRef output_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_weight_size( + IntArrayRef input_size, IntArrayRef output_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 29b2ce804c80..bf7017f20a4f 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -83,10 +83,11 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { // Check workload to activate fast depthwise FP16 cudnn conv kernels +template bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { - int w = input.size(3); // same as h - int ch = input.size(1); - int bs = input.size(0); + auto w = at::symint::size(input, 3); // same as h + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); if (stride==1) { if (w >= 7) { // All batch sizes and nb_channels @@ -205,27 +206,28 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { } // simplified version for cudnn 8.2 and above +template bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) { // 1D conv - if(input.size(2) == 1 && stride == 1){ + if(at::symint::size(input, 2) == 1 && stride == 1){ return true; } // 2d conv // only square filters - if (weight.size(2) != weight.size(3)) return false; - int filter = weight.size(3); + if (at::symint::size(weight, 2) != at::symint::size(weight, 3)) return false; + auto filter = at::symint::size(weight, 3); // only 1/3/5 filter if (filter != 1 && filter != 3 && filter != 5) return false; // we don't enforce square input but only check width to reduce heuristic space - if (input.size(3) < 7) return false; // min width 7 - int w = input.size(3); + if (at::symint::size(input, 3) < 7) return false; // min width 7 + auto w = at::symint::size(input, 3); // only 1/2 stride, use cudnn for all stride 1 if (stride == 1) return true; if (stride != 2) return false; - int ch = input.size(1); - int bs = input.size(0); + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); // special case since bs1 show good perf in lots of cases if (bs == 1) { if (filter == 1 && w <= 28) return true; @@ -240,13 +242,42 @@ bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int str } +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed); +} + +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalSymIntArrayRef bias_sizes_opt, + const SymIntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + // Never use xnnpack for symbolic tracing + return false; +} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +// This struct is templated so that we can run backend selection in a dynamic +// shapes context; all of the real kernel selection in eager mode runs with +// int64_t +template struct ConvParams { std::vector stride; - std::vector padding; + std::vector padding; std::vector dilation; bool transposed; - std::vector output_padding; + std::vector output_padding; int groups; bool benchmark; bool deterministic; @@ -322,12 +353,12 @@ struct ConvParams { #if defined(__ARM_NEON__) // Currently only 3x3 depthwise convolutions on tensors of float are supported. return (input.ndimension() == 4) && - (input.size(1) == groups) && + (at::symint::size(input, 1) == groups) && (weight.ndimension() == 4 ) && - (weight.size(0) % input.size(1) == 0) && - (weight.size(1) == 1) && - (weight.size(2) == 3) && - (weight.size(3) == 3) && + (at::symint::size(weight, 0) % at::symint::size(input, 1) == 0) && + (at::symint::size(weight, 1) == 1) && + (at::symint::size(weight, 2) == 3) && + (at::symint::size(weight, 3) == 3) && (input.device().is_cpu()) && (input.scalar_type() == at::kFloat) && input.is_contiguous() && @@ -345,23 +376,23 @@ struct ConvParams { bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { constexpr int64_t int_max = std::numeric_limits::max(); - int64_t numel_input = input.numel(); + auto numel_input = at::symint::numel(input); // empty input if (numel_input == 0) { return false; } // input size can not be reduced to the range of int by splitting the batch dim - int64_t n = input.size(0); + auto n = at::symint::size(input, 0); if (numel_input / n > int_max) { return true; } // output size can not be reduced to the range of int by splitting the batch dim - int64_t outsize = 1; + T outsize = 1; if (transposed) { - std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); + auto o = conv_input_size(at::symint::sizes(input), at::symint::sizes(weight), padding, output_padding, stride, dilation, groups); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } else { - std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); + auto o = conv_output_size(at::symint::sizes(input), at::symint::sizes(weight), padding, stride, dilation); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } return outsize > int_max; @@ -417,10 +448,10 @@ struct ConvParams { is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks !is_dilated() && // no dilation supported - (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d - input.size(1) >= 32); // min 32 channels supported) + (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d + at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); } } // keep (7600 <= cudnn < 8200) code unchanged @@ -430,14 +461,14 @@ struct ConvParams { weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - weight.size(2) == weight.size(3) && // only square kernels - input.size(2) >= 7 && // min width/height 7 + at::symint::size(weight, 2) == at::symint::size(weight, 3) && // only square kernels + at::symint::size(input, 2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported stride[0] == stride[1] && // equal strides - ((weight.size(3) == 3) || (weight.size(3) == 1)) && - input.size(1) >= 32); // min 32 channels supported) + ((at::symint::size(weight, 3) == 3) || (at::symint::size(weight, 3) == 1)) && + at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); + return check_cudnn_depthwise_workload(input, stride[0]); } else { return false; } @@ -473,12 +504,12 @@ struct ConvParams { !transposed && // or transposed tensors // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, // but THNN is faster when single-threaded. - (is_strided() || is_dilated() || input.size(0) >= 16 || - weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && + (is_strided() || is_dilated() || at::symint::size(input, 0) >= 16 || + at::symint::size(weight, -1) != 1 || at::symint::size(weight, -2) != 1 || at::get_num_threads() > 1) && (groups > 1 - || (weight.size(-1) > 3 && weight.size(-2) > 3) - || input.size(0) > 1 - || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster + || (at::symint::size(weight, -1) > 3 && at::symint::size(weight, -2) > 3) + || at::symint::size(input, 0) > 1 + || at::symint::size(input, 0)*at::symint::size(input, 1)*at::symint::size(input, 2)*at::symint::size(input, 3) > 20480) // for some case, native is faster ); #endif @@ -493,20 +524,23 @@ struct ConvParams { !transposed && // or transposed tensors input.ndimension() == 4 && // must be in NCHW format weight.ndimension() == 4 && - (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 + (at::symint::size(weight, 2) < 17) && (at::symint::size(weight, 3) < 17) // NNPACK only supports kernels up to 16x16 #if !defined(C10_MOBILE) - && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable + && at::symint::size(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable #endif ; #endif return false; } bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const { + const at::OptionalArrayRef bias_sizes_opt) const { #if defined(C10_MOBILE) if (!transposed) { - return (input.size(1) == groups) && - xnnpack::use_convolution2d( + // NB: for the call here, it MATTERS that we are templated. If you + // untemplate this to always use SymInt, the function + // xnnpack_use_convolution2d will always return false + return (at::symint::size(input, 1) == groups) && + xnnpack_use_convolution2d( input, weight, bias_sizes_opt, @@ -543,33 +577,12 @@ struct ConvParams { return input.is_cuda() && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && - input.size(1) == groups && + at::symint::size(input, 1) == groups && groups > 1 && // no point if there is only a single group - weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels + at::symint::size(weight, 0) % at::symint::size(input, 1) == 0; // output channels must be a multiple of input channels } }; -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -ConvBackend _select_conv_backend( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias_opt, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -// For BC reasons, have a copy that does not require bias_opt -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - DEFINE_DISPATCH(conv_depthwise2d_backward_stub); DEFINE_DISPATCH(conv_depthwise3d_backward_stub); DEFINE_DISPATCH(cudnn_convolution_backward_stub); @@ -591,13 +604,14 @@ REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); -std::ostream& operator<<(std::ostream & out, const ConvParams& params) { +template +std::ostream& operator<<(std::ostream & out, const ConvParams& params) { out << "ConvParams {" << " stride = " << IntArrayRef{params.stride} - << " padding = " << IntArrayRef{params.padding} + << " padding = " << ArrayRef{params.padding} << " dilation = " << IntArrayRef{params.dilation} << " transposed = " << params.transposed - << " output_padding = " << IntArrayRef{params.output_padding} + << " output_padding = " << ArrayRef{params.output_padding} << " groups = " << params.groups << " benchmark = " << params.benchmark << " deterministic = " << params.deterministic @@ -607,9 +621,10 @@ std::ostream& operator<<(std::ostream & out, const ConvParams& params) { return out; } +template static void check_shape_forward(const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, const at::Tensor& bias, - const ConvParams& params) { + const c10::ArrayRef& weight_sizes, const at::Tensor& bias, + const ConvParams& params) { int64_t k = input.ndimension(); int64_t weight_dim = weight_sizes.size(); int64_t groups = params.groups; @@ -624,7 +639,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ", - input.sizes(), " instead"); + at::symint::sizes(input), " instead"); TORCH_CHECK(weight_sizes[0] >= groups, "Given groups=", groups, ", expected weight to be at least ", groups, " at dimension 0, but got weight of size ", weight_sizes, " instead"); @@ -634,23 +649,23 @@ static void check_shape_forward(const at::Tensor& input, "] instead"); if (!transposed) { - std::vector input_shape; - std::vector kernel_shape; + std::vector input_shape; + std::vector kernel_shape; bool kernel_size_correct = true; - TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups), + TORCH_CHECK(at::symint::size(input, 1) == (weight_sizes[1] * groups), "Given groups=", groups, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", - (weight_sizes[1] * groups), " channels, but got ", input.size(1), + (weight_sizes[1] * groups), " channels, but got ", at::symint::size(input, 1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]), + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[0]), "Given weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements", - ", but got bias of size ", bias.sizes(), " instead"); + ", but got bias of size ", at::symint::sizes(bias), " instead"); for (const auto i : c10::irange(2, k)) { - input_shape.push_back(input.size(i) + 2 * padding[i-2]); + input_shape.push_back(at::symint::size(input, i) + 2 * padding[i-2]); // log new kernel size considering dilation kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); if (input_shape.back() < kernel_shape.back()) { @@ -676,22 +691,23 @@ static void check_shape_forward(const at::Tensor& input, "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); } } else { // transposed - TORCH_CHECK(input.size(1) == weight_sizes[0], + TORCH_CHECK(at::symint::size(input, 1) == weight_sizes[0], "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", weight_sizes[0], - " channels, but got ", input.size(1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[1] * groups), + " channels, but got ", at::symint::size(input, 1), " channels instead"); + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[1] * groups), "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements", ", but got bias of size ", bias.sizes(), " instead"); } } +template static void check_shape_backward( const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, - const ConvParams& params) { - check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); + const c10::ArrayRef& weight_sizes, + const ConvParams& params) { + check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); } // Given an input tensor and an expected number of spatial dimensions, checks that the @@ -1149,71 +1165,25 @@ at::Tensor convolution_overrideable( TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); } -// Selects a backend for convolution based on the inputs and params. -ConvBackend select_conv_backend( - const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, - IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, - bool transposed_, IntArrayRef output_padding_, int64_t groups_, const at::OptionalIntArrayRef bias_sizes_opt) { - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - auto& ctx = at::globalContext(); - auto k = weight_r.ndimension(); - int64_t dim = k - 2; - ConvParams params; - params.stride = expand_param_if_needed(stride_, "stride", dim); - params.padding = expand_param_if_needed(padding_, "padding", dim); - params.dilation = expand_param_if_needed(dilation_, "dilation", dim); - params.transposed = transposed_; - params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); - params.groups = groups_; - params.benchmark = ctx.benchmarkCuDNN(); - params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); - params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); - - auto input = input_r; - auto weight = weight_r; - check_shape_forward(input, weight.sizes(), bias, params); - - // Expand 1d -> 2d. - // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { - // avoid accidentally going through NHWC for permuted 3d input. - input = input.contiguous(); - params.view1d_as_2d(); - input = view4d(input); - weight = view4d(weight); - } - - auto bias_sizes = bias.defined() ? c10::optional(bias.sizes()) : bias_sizes_opt; - bool need_backward = GradMode::is_enabled() && - (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); -} - -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params) { - return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); -} - +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +template ConvBackend _select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias, - const at::OptionalIntArrayRef bias_sizes_opt, + const at::OptionalArrayRef bias_sizes_opt, const bool need_backward, - const ConvParams& params) { + const ConvParams& params) { // don't send empty inputs through backends - if (input.size(0) == 0 || input.size(1) == 0) { + if (at::symint::size(input, 0) == 0 || at::symint::size(input, 1) == 0) { return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty; - } else if (input.numel() == 0) { - TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", input.sizes()); + } else if (at::symint::numel(input) == 0) { + TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes(input)); } if (params.is_depthwise(input, weight)) { @@ -1305,12 +1275,65 @@ ConvBackend _select_conv_backend( AT_ERROR("unsupported ConvNd parameters"); } +// Selects a backend for convolution based on the inputs and params. +ConvBackend select_conv_backend( + const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, + IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) { + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + auto& ctx = at::globalContext(); + auto k = weight_r.ndimension(); + int64_t dim = k - 2; + ConvParams params; + params.stride = expand_param_if_needed(stride_, "stride", dim); + params.padding = expand_param_if_needed(padding_, "padding", dim); + params.dilation = expand_param_if_needed(dilation_, "dilation", dim); + params.transposed = transposed_; + params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); + params.groups = groups_; + params.benchmark = ctx.benchmarkCuDNN(); + params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); + params.cudnn_enabled = ctx.userEnabledCuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN(); + + auto input = input_r; + auto weight = weight_r; + check_shape_forward(input, weight.sym_sizes(), bias, params); + + // Expand 1d -> 2d. + // This is only done for backends that don't natively support 1d spatial input. + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { + // avoid accidentally going through NHWC for permuted 3d input. + input = input.contiguous(); + params.view1d_as_2d(); + input = view4d(input); + weight = view4d(weight); + } + + auto bias_sizes = bias.defined() ? c10::optional(bias.sym_sizes()) : bias_sizes_opt; + bool need_backward = GradMode::is_enabled() && + (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); + return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); +} + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params) { + return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); +} + at::Tensor _convolution_nogroup_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::NnpackSpatial: @@ -1341,7 +1364,7 @@ at::Tensor _convolution_nogroup_backend( static inline std::vector calc_output_size( const Tensor& input, const Tensor& weight, - const ConvParams& params) { + const ConvParams& params) { std::vector output_size = params.transposed ? conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding, params.stride, params.dilation, params.groups) : @@ -1422,7 +1445,7 @@ at::Tensor _convolution( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); TORCH_CHECK(groups_ > 0, "non-positive groups is not supported"); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1450,7 +1473,7 @@ at::Tensor _convolution( auto bias_sizes_opt = bias.defined() ? c10::optional(bias.sizes()) : c10::nullopt; bool need_backward = GradMode::is_enabled() && (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - ConvBackend backend = _select_conv_backend(input, weight, bias, bias_sizes_opt, need_backward, params); + ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params); at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); // Call the backend. @@ -1663,7 +1686,7 @@ std::tuple _convolution_double_backward( const c10::option auto weight = weight_r; int64_t dim = weight.ndimension() - 2; - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1726,7 +1749,7 @@ std::tuple _convolution_double_backward( const c10::option if (ggI.defined()) { // Modified params with correct padding - ConvParams gw_conv_params(params); + ConvParams gw_conv_params(params); // Disable groups as they are handled separately auto groups = gw_conv_params.groups; @@ -1795,7 +1818,7 @@ std::tuple _convolution_double_backward( const c10::option Tensor gI; if (input.numel() != 0) { if (ggW.defined()) { - ConvParams gi_conv_params(params); + ConvParams gi_conv_params(params); gi_conv_params.transposed = !params.transposed; if (params.transposed) { @@ -1851,7 +1874,7 @@ std::tuple _convolution_backward_nogroup_bac const Tensor& weight, const std::array output_mask, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::Slow2d: @@ -1916,7 +1939,7 @@ std::tuple convolution_backward( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); auto& ctx = at::globalContext(); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride, "stride", dim); params.padding = expand_param_if_needed(padding, "padding", dim); params.dilation = expand_param_if_needed(dilation, "dilation", dim); diff --git a/aten/src/ATen/native/utils/ParamUtils.h b/aten/src/ATen/native/utils/ParamUtils.h index 376467ff79cf..7c89a3316cb4 100644 --- a/aten/src/ATen/native/utils/ParamUtils.h +++ b/aten/src/ATen/native/utils/ParamUtils.h @@ -6,12 +6,13 @@ namespace at { namespace native { -inline std::vector expand_param_if_needed( - IntArrayRef list_param, +template +inline std::vector expand_param_if_needed( + ArrayRef list_param, const char* param_name, int64_t expected_dim) { if (list_param.size() == 1) { - return std::vector(expected_dim, list_param[0]); + return std::vector(expected_dim, list_param[0]); } else if ((int64_t)list_param.size() != expected_dim) { std::ostringstream ss; ss << "expected " << param_name << " to be a single integer value or a " diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 9ab72a077680..6355f1339505 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -235,6 +235,19 @@ inline c10::SymInt multiply_integers(const C& container) { [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } +template < + typename Iter, + typename = std::enable_if_t::value_type, + c10::SymInt>::value>> +inline c10::SymInt multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, + end, + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + inline SymInt operator+(int64_t a, const SymInt& b) { return c10::SymInt(a) + b; } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b8693a484ed9..607373625724 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1408,10 +1408,10 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias_opt, at::IntArrayRef stride_, - at::IntArrayRef padding_, + at::SymIntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::IntArrayRef output_padding_, + at::SymIntArrayRef output_padding_, int64_t groups_) { return at::native::select_conv_backend( input, @@ -1442,13 +1442,13 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias, at::IntArrayRef stride_, - at::IntArrayRef padding_, + at::SymIntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::IntArrayRef output_padding_, + at::SymIntArrayRef output_padding_, int64_t groups_, - c10::optional> bias_sizes_opt) { - c10::OptionalArrayRef ref = c10::nullopt; + c10::optional> bias_sizes_opt) { + c10::OptionalArrayRef ref = c10::nullopt; if (bias_sizes_opt) { ref = (*bias_sizes_opt); } From 37d54239c7ea88fd9c98dcac3fcc9b98a6f9e9d1 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 16 Nov 2022 05:58:02 -0800 Subject: [PATCH 230/453] Towards unifying symbolic and non symbolic fake tensor (#89038) Fake tensor behaves pretty differently depending on if you have symbolic shapes or not. This leads to bugs; for example, we weren't getting correct convolution_backward strides because we bypassed the correct stride logic in fake tensor on symbolic shapes. This PR attempts to unify the two codepaths. I don't manage to unify everything, but I get most of it. The algorithm is delicate and I'm still hosing down test failures. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89038 Approved by: https://github.com/anjali411 --- aten/src/ATen/native/TensorFactories.cpp | 6 --- test/functorch/test_aotdispatch.py | 1 - test/test_proxy_tensor.py | 21 +++------ torch/_meta_registrations.py | 44 +++++++++++++++--- torch/_ops.py | 1 + torch/_prims/__init__.py | 5 +- torch/_prims_common/__init__.py | 3 ++ torch/_subclasses/fake_tensor.py | 58 +++++++++--------------- 8 files changed, 71 insertions(+), 68 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 9d1c6d8a3633..7245cb77b1c5 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -325,12 +325,6 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = self.options() .merge_in(options_) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 1dc5476158f9..ae216f9be4a4 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1011,7 +1011,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8dc42be7fdfb..0a24807af55f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1151,9 +1151,7 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1235,8 +1233,6 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... @@ -1281,7 +1277,6 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... @@ -1298,7 +1293,6 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition @@ -1347,11 +1341,15 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) +outplace_symbolic_tensor_failures = { + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition +} + inplace_symbolic_tensor_failures = { - xfail('abs', ''), # aten.abs_.default - couldn't find symbolic meta function/decomposition xfail('acos', ''), # aten.acos_.default - couldn't find symbolic meta function/decomposition xfail('acosh', ''), # aten.acosh_.default - couldn't find symbolic meta function/decomposition - xfail('addbmm', ''), # aten.addbmm_.default - couldn't find symbolic meta function/decomposition xfail('addcdiv', ''), # aten.addcdiv_.default - couldn't find symbolic meta function/decomposition xfail('addcmul', ''), # aten.addcmul_.default - couldn't find symbolic meta function/decomposition xfail('addmm', ''), # aten.addmm_.default - couldn't find symbolic meta function/decomposition @@ -1365,7 +1363,6 @@ def f(a, b, c, d, e): xfail('clamp', ''), # aten.clamp_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_max', ''), # aten.clamp_max_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_min', ''), # aten.clamp_min_.Tensor - couldn't find symbolic meta function/decomposition - xfail('conj_physical', ''), # aten.conj_physical_.default - couldn't find symbolic meta function/decomposition xfail('copysign', ''), # aten.copysign_.Tensor - couldn't find symbolic meta function/decomposition xfail('cos', ''), # aten.cos_.default - couldn't find symbolic meta function/decomposition xfail('cosh', ''), # aten.cosh_.default - couldn't find symbolic meta function/decomposition @@ -1382,7 +1379,6 @@ def f(a, b, c, d, e): xfail('expm1', ''), # aten.expm1_.default - couldn't find symbolic meta function/decomposition xfail('float_power', ''), # the base given to float_power_ has dtype Float but the operation's result requires dtype Double xfail('floor', ''), # aten.floor_.default - couldn't find symbolic meta function/decomposition - xfail('floor_divide', ''), # aten.floor_divide_.Tensor - couldn't find symbolic meta function/decomposition xfail('fmod', ''), # aten.fmod_.Tensor - couldn't find symbolic meta function/decomposition xfail('frac', ''), # aten.frac_.default - couldn't find symbolic meta function/decomposition xfail('ge', ''), # aten.ge_.Tensor - couldn't find symbolic meta function/decomposition @@ -1398,7 +1394,6 @@ def f(a, b, c, d, e): xfail('log1p', ''), # aten.log1p_.default - couldn't find symbolic meta function/decomposition xfail('log2', ''), # aten.log2_.default - couldn't find symbolic meta function/decomposition xfail('log', ''), # aten.log_.default - couldn't find symbolic meta function/decomposition - xfail('logit', ''), # aten.logit_.default - couldn't find symbolic meta function/decomposition xfail('lt', ''), # aten.lt_.Tensor - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_1'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_3'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition @@ -1408,7 +1403,6 @@ def f(a, b, c, d, e): xfail('neg', ''), # aten.neg_.default - couldn't find symbolic meta function/decomposition xfail('nextafter', ''), # aten.nextafter_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.celu', ''), # aten.celu_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.dropout3d', ''), # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition xfail('nn.functional.elu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.hardsigmoid', ''), # aten.hardsigmoid_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.mish', ''), # aten.mish_.default - couldn't find symbolic meta function/decomposition @@ -1426,7 +1420,6 @@ def f(a, b, c, d, e): xfail('sinh', ''), # aten.sinh_.default - couldn't find symbolic meta function/decomposition xfail('sqrt', ''), # aten.sqrt_.default - couldn't find symbolic meta function/decomposition xfail('square', ''), # aten.pow_.Scalar - couldn't find symbolic meta function/decomposition - xfail('squeeze', ''), # aten.squeeze_.default - couldn't find symbolic meta function/decomposition xfail('t', ''), # aten.t_.default - couldn't find symbolic meta function/decomposition xfail('tan', ''), # aten.tan_.default - couldn't find symbolic meta function/decomposition xfail('tanh', ''), # aten.tanh_.default - couldn't find symbolic meta function/decomposition @@ -1516,7 +1509,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 4fa3ab09d275..abcd1ead8b43 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1513,7 +1513,6 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, - aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1521,6 +1520,44 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + if layout == torch.sparse_coo: + check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + return aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): @@ -1894,11 +1931,6 @@ def activate_meta(): # Instead, we should be letting those decompositions run, and writing meta kernels # only for the base operators. pass - elif op_overload.is_view: - # Attempting to register a python meta kernel for a view operator. - # We shouldn't do this, because the output will report as not having aliased storages. - # All view ops have meta kernels in C++ today, so we should use those instead. - pass elif op_overload.name() in { "aten::empty_strided", # causing infinite recursion, test_meta.py "aten::clone", # causing infinite recursion diff --git a/torch/_ops.py b/torch/_ops.py index 9163932144d0..b20398a7f3ab 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -365,6 +365,7 @@ def handler(*args, **kwargs): return handler final_key = resolve_key(self, key) + # print(self, key, final_key) r = self.py_kernels.get(final_key, final_key) self._dispatch_cache[key] = r return r diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index da8d9af723ac..a4bac68f0ff1 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1150,9 +1150,6 @@ def _minimum_aten( # # View operations -# -# TODO: model view relationships -# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1170,7 +1167,7 @@ def _as_strided_meta( a._typed_storage(), size, stride, storage_offset ) - return TensorMeta(a, shape=size, strides=stride) + return torch.as_strided(a, size, stride, storage_offset) def _as_strided_aten( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 128796dfa3d0..041448e8102a 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -291,6 +291,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ + if a.is_sparse: + return False + # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 5d3d3a0e32fe..9a0ac050e6b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,7 +1,6 @@ import contextlib import functools import itertools -import sys import weakref from dataclasses import dataclass from functools import partial @@ -297,8 +296,9 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - # Not in_kernel_invocation_manager as no fake tensor inputs - with no_dispatch(): + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -821,40 +821,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - from torch._decomp import decomposition_table - - with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table - # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise - if has_symbolic_sizes and not self.cpp_meta_supports_symint(func): - from torch._decomp import meta_table as meta_table + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table - if func == aten.size.default: - sys.stderr.write( - "Trying to call aten.size on a tensor with symbolic shapes. " - "It's likely that this is from calling tensor.shape in C++" + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) ) - # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` - return None - - with self: - if func in meta_table: - r = meta_table[func](*args, **kwargs) - return r - if func in decomposition_table: + ): + with self: return decomposition_table[func](*args, **kwargs) - if ( - func in decomposition_table - and torch_decomp_decompositions(func) - and all(not e.is_sparse for e in flat_arg_fake_tensors) - ): with self: - return decomposition_table[func](*args, **kwargs) + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them @@ -865,12 +855,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self: return func.prim_meta_impl(*args, **kwargs) - if has_symbolic_sizes: - if not self.cpp_meta_supports_symint(func): - raise RuntimeError( - f"{func} - couldn't find symbolic meta function/decomposition" - ) - # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors # and then afterwards wrapping them to a FakeTensor From 9fe36a02146c57ed8165bb8914708437043899ab Mon Sep 17 00:00:00 2001 From: mindest Date: Wed, 16 Nov 2022 15:08:41 +0000 Subject: [PATCH 231/453] [ONNX] Extra support for bernoulli export (#88655) * add opset 15 support for `bernoulli`. * add extra export options for different `bernoulli` cases: `x.bernoulli(p)` where `p` is a tensor or float. Fixes #88299 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88655 Approved by: https://github.com/BowenBao --- test/onnx/test_pytorch_onnx_onnxruntime.py | 17 +++++++++++++++++ torch/onnx/symbolic_opset15.py | 16 ++++++++++++++++ torch/onnx/symbolic_opset9.py | 9 +++++---- 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index e4fc3f83b288..7ae9d8edaccc 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2643,6 +2643,23 @@ def forward(self, x): x = torch.empty(2, 3, 3, dtype=torch.double).uniform_(0, 1) self.run_test(Bernoulli(), x) + def test_bernoulli_p(self): + class Bernoulli_float(torch.nn.Module): + def forward(self, x): + return torch.mul(x, torch.bernoulli(x, 0.2).size(0)) + + class Bernoulli_tensor(torch.nn.Module): + def forward(self, x): + return torch.mul(x, torch.rand_like(x).bernoulli_(x).size(0)) + + x = torch.rand(3, 3) + self.run_test(Bernoulli_float(), x) + self.run_test(Bernoulli_tensor(), x) + + x = torch.rand(2, 3, 3, dtype=torch.double) + self.run_test(Bernoulli_float(), x) + self.run_test(Bernoulli_tensor(), x) + @unittest.skip("Bug in ORT, skip test until rel-1.11.") @skipIfUnsupportedMinOpsetVersion(14) def test_reshape_allowzero(self): diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index efb96c717fde..4f316a77f62e 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -54,6 +54,22 @@ def aten__isnot_(g: jit_utils.GraphContext, self, other): return aten__is_(g, self, other) +@_onnx_symbolic("aten::bernoulli") +@_beartype.beartype +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): + symbolic_helper._unimplemented( + "Bernoulli", "out parameter is not supported for bernoulli", input + ) + if generator is not None and not symbolic_helper._is_none(generator): + symbolic_helper._unimplemented( + "Bernoulli", "generator is not supported for bernoulli", input + ) + if p is None or symbolic_helper._is_none(p): + return g.op("Bernoulli", input) + return opset9.bernoulli(g, input, p, generator, out) + + @_onnx_symbolic("prim::unchecked_cast") @_beartype.beartype def prim_unchecked_cast(g: jit_utils.GraphContext, self): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index c02fb0f20090..9984f602425c 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -4942,8 +4942,8 @@ def rrelu(g: jit_utils.GraphContext, input, lower, upper, training, generator): @_onnx_symbolic("aten::bernoulli") @_beartype.beartype -def bernoulli(g: jit_utils.GraphContext, input, generator=None, out=None): - if out is not None: +def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None): + if out is not None and not symbolic_helper._is_none(out): symbolic_helper._unimplemented( "Bernoulli", "out parameter is not supported for bernoulli", input ) @@ -4960,14 +4960,15 @@ def bernoulli(g: jit_utils.GraphContext, input, generator=None, out=None): "Bernoulli", "input dtype not accessible", input ) - p = g.op( + rands = g.op( "RandomUniformLike", input, high_f=1.0, low_f=0.0, dtype_i=dtype.onnx_type(), ) - output = g.op("Less", p, input) + prob = p if p is not None and not symbolic_helper._is_none(p) else input + output = g.op("Less", rands, prob) return g.op("Cast", output, to_i=dtype.onnx_type()) From abe41aee776e7ab39c34f28a88f03a03dc6f1479 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 16 Nov 2022 06:30:03 +0000 Subject: [PATCH 232/453] [ONNX] Support custom Op with onnx-script local function (#86906) Extend `register_custom_op` to support onnx-script local function. The FunctionProto from onnx-script is represented by custom op and inserted into ModelProto for op execution. NOTE: I did experiments on >2GB case of a simple model with large initializers: ```python import torch class Net(torch.nn.Module): def __init__(self, B, C): super().__init__() self.layer_norm = torch.nn.LayerNorm((B, C), eps=1e-3) def forward(self, x): return self.layer_norm(x) N, B, C = 3, 25000, 25000 model = Net(B, C) x = torch.randn(N, B, C) torch.onnx.export(model, x, "large_model.onnx", opset_version=12) ``` And it turns out we won't get model_bytes > 2GB after `_export_onnx` pybind cpp function, as we split initializer in external files in that function, and have serialization before return the model bytes, which protobuf is not allowed to be larger than 2GB at any circumstances. The test cases can be found in the next PR #86907 . Pull Request resolved: https://github.com/pytorch/pytorch/pull/86906 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- .jenkins/caffe2/test.sh | 2 +- torch/onnx/_internal/jit_utils.py | 64 ++++++++++- torch/onnx/utils.py | 183 +++++++++++++++++++++++------- 3 files changed, 205 insertions(+), 44 deletions(-) diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 2b6f7ec6b246..42111ea22bdd 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -176,7 +176,7 @@ fi ############## if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" - pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.12.1 beartype==0.10.4 + pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.12.1 beartype==0.10.4 onnx==1.12.0 # numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21. # We don't actually need it for our tests, but it's imported if it's present, so uninstall. pip uninstall -q --yes numba diff --git a/torch/onnx/_internal/jit_utils.py b/torch/onnx/_internal/jit_utils.py index 6354cea73fc0..a8740a4a2ff6 100644 --- a/torch/onnx/_internal/jit_utils.py +++ b/torch/onnx/_internal/jit_utils.py @@ -12,7 +12,8 @@ from torch import _C from torch._C import _onnx as _C_onnx from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype +from torch.onnx._internal import _beartype, registration + _ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") _SKIP_NODE_ATTRIBUTES = {"inplace", "aten"} @@ -98,6 +99,49 @@ def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): **kwargs, ) + @_beartype.beartype + def onnxscript_op( + self, + onnx_fn, # TODO(titaiwang): annotate this when onnx-script becomes dependency + *raw_args: Union[torch.Tensor, _C.Value], + outputs: int = 1, + **kwargs, + ): + """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes. + + onnx-script repository: https://github.com/microsoft/onnx-script + + Args: + onnx_fn: ONNXFunction from onnx-script; An example can be found at + https://github.com/microsoft/onnx-script#example + raw_args: The inputs to the operator; usually provided + as arguments to the `symbolic` definition. + outputs: The number of outputs this operator returns. + By default an operator is assumed to return a single output. + If `outputs` is greater than one, this functions returns a tuple + of output `Value`, representing each output of the ONNX operator + in order. + kwargs: The attributes of the ONNX operator, whose keys are named + according to the following convention: `alpha_f` indicates + the `alpha` attribute with type `f`. The valid type specifiers are + `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute + specified with type float accepts either a single float, or a + list of floats (e.g., you would say `dims_i` for a `dims` attribute + that takes a list of integers). + + Returns: + The value representing the single output of this operator (see the `outputs` + keyword argument for multi-return nodes). + """ + # NOTE(titaiwang): This is using class attributes, and it needs to be updated + # if onnx-script makes any change on these. + symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.opname}" + opset_version = onnx_fn.opset.version + + registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn) + + return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs) + @_beartype.beartype def add_op_with_blocks( @@ -332,3 +376,21 @@ def parse_node_kind(kind: str) -> Tuple[str, str]: if "::" in opname: raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") return domain, opname + + +@_beartype.beartype +def is_aten(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "aten" + + +@_beartype.beartype +def is_prim(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "prim" + + +@_beartype.beartype +def is_onnx(domain: str) -> bool: + """Check if the domain is official.""" + return domain == "onnx" diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index b30b71812aae..9d6ec0b32523 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1597,49 +1597,14 @@ def _export( model_file_location, node_attr_to_name, ) + # insert function_proto into model_proto. + proto = _add_onnxscript_fn( + proto, + custom_opsets, + ) if verbose: torch.onnx.log("Exported graph: ", graph) - if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: - assert len(export_map) == 0 - with torch.serialization._open_file_like(f, "wb") as opened_file: - opened_file.write(proto) - elif export_type in [ - _exporter_states.ExportTypes.ZIP_ARCHIVE, - _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, - ]: - compression = ( - zipfile.ZIP_DEFLATED - if export_type - == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE - else zipfile.ZIP_STORED - ) - with zipfile.ZipFile(f, "w", compression=compression) as z: - z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, proto) - for k, v in export_map.items(): - z.writestr(k, v) - elif export_type == _exporter_states.ExportTypes.DIRECTORY: - if os.path.exists(f): - assert os.path.isdir(f) - else: - os.makedirs(f) - - model_proto_file = os.path.join( - f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME - ) - with torch.serialization._open_file_like( - model_proto_file, "wb" - ) as opened_file: - opened_file.write(proto) - - for k, v in export_map.items(): - weight_proto_file = os.path.join(f, k) - with torch.serialization._open_file_like( - weight_proto_file, "wb" - ) as opened_file: - opened_file.write(v) - else: - raise RuntimeError("Unknown export type") - + _export_file(proto, f, export_type, export_map) # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX, # we can skip this check. # If large model format export is enabled, proto will only contain data location instead of @@ -1660,6 +1625,138 @@ def _export( return torch_out +@_beartype.beartype +def _export_file( + model_bytes: bytes, + f: Union[io.BytesIO, str], + export_type: str, + export_map: Mapping[str, bytes], +) -> None: + """export/write model bytes into directory/protobuf/zip""" + # TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f, + # but beartype raises beartype.roar.BeartypeDecorHintNonpepException, + # as os.PathLike[str] uncheckable at runtime + if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) + elif export_type in [ + _exporter_states.ExportTypes.ZIP_ARCHIVE, + _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, + ]: + compression = ( + zipfile.ZIP_DEFLATED + if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE + else zipfile.ZIP_STORED + ) + with zipfile.ZipFile(f, "w", compression=compression) as z: + z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) + for k, v in export_map.items(): + z.writestr(k, v) + elif export_type == _exporter_states.ExportTypes.DIRECTORY: + if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] + raise ValueError( + f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" + ) + if not os.path.exists(f): # type: ignore[arg-type] + os.makedirs(f) # type: ignore[arg-type] + + model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] + with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: + opened_file.write(model_bytes) + + for k, v in export_map.items(): + weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] + with torch.serialization._open_file_like( + weight_proto_file, "wb" + ) as opened_file: + opened_file.write(v) + else: + raise RuntimeError("Unknown export type") + + +@_beartype.beartype +def _add_onnxscript_fn( + model_bytes: bytes, + custom_opsets: Mapping[str, int], +) -> bytes: + """Insert model-included custom onnx-script function into ModelProto""" + + # TODO(titaiwang): remove this when onnx becomes dependency + try: + import onnx + except ImportError: + raise errors.OnnxExporterError("Module onnx is not installed!") + + # For > 2GB model, onnx.load_fromstring would fail. However, because + # in _export_onnx, the tensors should be saved separately if the proto + # size > 2GB, and if it for some reason did not, the model would fail on + # serialization anyway in terms of the protobuf limitation. So we don't + # need to worry about > 2GB model getting here. + model_proto = onnx.load_from_string(model_bytes) + + # Iterate graph nodes to insert only the included custom + # function_proto into model_proto + # TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction + # calling other ONNXFunction scenario, neither does it here + onnx_function_list = list() # type: ignore[var-annotated] + included_node_func = set() # type: Set[str] + # onnx_function_list and included_node_func are expanded in-place + _find_onnxscript_op( + model_proto.graph, included_node_func, custom_opsets, onnx_function_list + ) + + if onnx_function_list: + model_proto.functions.extend(onnx_function_list) + model_bytes = model_proto.SerializeToString() + return model_bytes + + +@_beartype.beartype +def _find_onnxscript_op( + graph_proto, + included_node_func: Set[str], + custom_opsets: Mapping[str, int], + onnx_function_list: List, +): + """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" + for node in graph_proto.node: + node_kind = node.domain + "::" + node.op_type + # Recursive is needed for control flow nodes: IF/Loop which has inner graph_proto + for attr in node.attribute: + if attr.g is not None: + _find_onnxscript_op( + attr.g, included_node_func, custom_opsets, onnx_function_list + ) + # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry + onnx_function_group = registration.registry.get_function_group(node_kind) + # Ruled out corner cases: onnx/prim in registry + if ( + node.domain + and not jit_utils.is_aten(node.domain) + and not jit_utils.is_prim(node.domain) + and not jit_utils.is_onnx(node.domain) + and onnx_function_group is not None + and node_kind not in included_node_func + ): + specified_version = custom_opsets.get(node.domain, 1) + onnx_fn = onnx_function_group.get(specified_version) + if onnx_fn is not None: + # TODO(titaiwang): to_function_proto is onnx-script API and can be annotated + # after onnx-script is dependency + onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined] + included_node_func.add(node_kind) + continue + raise errors.UnsupportedOperatorError( + node_kind, + specified_version, + onnx_function_group.get_min_supported() + if onnx_function_group + else None, + ) + return onnx_function_list, included_node_func + + @_beartype.beartype def _apply_friendly_debug_names(graph, params): for n in graph.nodes(): @@ -1959,7 +2056,9 @@ def _verify_custom_op_name(symbolic_name: str): @_beartype.beartype def register_custom_op_symbolic( - symbolic_name: str, symbolic_fn: Callable, opset_version: int + symbolic_name: str, + symbolic_fn: Callable, + opset_version: int, ): """Registers a symbolic function for a custom operator. From cf4b4b1b060fd48d4103acb4d0422e88c7e3b69e Mon Sep 17 00:00:00 2001 From: Angel Avila Date: Wed, 16 Nov 2022 16:30:56 +0000 Subject: [PATCH 233/453] Fix python types in pybind function signatures (#89115) Fixes #88958 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89115 Approved by: https://github.com/ezyang --- torch/csrc/utils/pybind.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index 85532a42cee2..c582dee1d2f6 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -36,7 +36,7 @@ template <> struct TORCH_PYTHON_API type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::Tensor, _("at::Tensor")); + PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor")); bool load(handle src, bool); @@ -51,7 +51,7 @@ template <> struct type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::Storage, _("at::Storage")); + PYBIND11_TYPE_CASTER(at::Storage, _("torch.storage._StorageBase")); bool load(handle src, bool) { PyObject* obj = src.ptr(); @@ -74,7 +74,7 @@ template <> struct type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::Generator, _("at::Generator")); + PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator")); bool load(handle src, bool) { PyObject* obj = src.ptr(); @@ -97,7 +97,7 @@ template <> struct TORCH_PYTHON_API type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::IntArrayRef, _("at::IntArrayRef")); + PYBIND11_TYPE_CASTER(at::IntArrayRef, _("typing.Tuple[int, ...]")); bool load(handle src, bool); static handle cast( @@ -129,7 +129,7 @@ template <> struct TORCH_PYTHON_API type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::MemoryFormat, _("at::MemoryFormat")); + PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format")); bool load(handle src, bool) { PyObject* obj = src.ptr(); @@ -151,7 +151,7 @@ template <> struct type_caster { public: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - PYBIND11_TYPE_CASTER(at::Device, _("at::Device")); + PYBIND11_TYPE_CASTER(at::Device, _("torch.device")); // PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device // cannot be default-initialized, we provide this constructor to explicitly @@ -206,7 +206,7 @@ struct type_caster template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(c10::SymInt, _("SymInt")); + PYBIND11_TYPE_CASTER(c10::SymInt, _("torch._prims_common.IntLike")); bool load(py::handle src, bool); static py::handle cast( @@ -218,7 +218,7 @@ struct type_caster { template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(c10::SymFloat, _("SymFloat")); + PYBIND11_TYPE_CASTER(c10::SymFloat, _("torch._prims_common.FloatLike")); bool load(py::handle src, bool); static py::handle cast( From 90db86be108184a6c86c73e1b01012352c72e66b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 16:36:27 +0000 Subject: [PATCH 234/453] Revert "SymIntify convolution backend calculation (#89069)" This reverts commit 09ed8b67e24cfe29f3fa7b5dd28eaa7749229f12. Reverted https://github.com/pytorch/pytorch/pull/89069 on behalf of https://github.com/DanilBaibak due to breaking internal builds --- aten/src/ATen/native/ConvUtils.h | 79 ++---- aten/src/ATen/native/Convolution.cpp | 319 +++++++++++------------- aten/src/ATen/native/utils/ParamUtils.h | 7 +- c10/core/SymInt.h | 13 - torch/csrc/Module.cpp | 12 +- 5 files changed, 174 insertions(+), 256 deletions(-) diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 880ce0c2af54..b8e2b0842a00 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -110,8 +110,8 @@ enum class ConvBackend { // This overload is exposed to python for testing, etc. TORCH_API ConvBackend select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, - IntArrayRef stride, SymIntArrayRef padding, IntArrayRef dilation, - bool transposed, SymIntArrayRef output_padding, int64_t groups, const at::OptionalSymIntArrayRef bias_sizes_opt); + IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, + bool transposed, IntArrayRef output_padding, int64_t groups, const at::OptionalIntArrayRef bias_sizes_opt); TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input, const Tensor& weight, @@ -200,16 +200,15 @@ static void convolution_shape_check( // as conv_output_size loses information; this is why conv_input_size // takes an extra output_padding argument to resolve the ambiguity. -template -static inline std::vector _conv_output_size( - ArrayRef input_size, ArrayRef weight_size, - ArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +static inline std::vector conv_output_size( + IntArrayRef input_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { // ASSERT(input_size.size() > 2) // ASSERT(input_size.size() == weight_size.size()) bool has_dilation = dilation.size() > 0; auto dim = input_size.size(); - std::vector output_size(dim); + std::vector output_size(dim); output_size[0] = input_size[input_batch_size_dim]; output_size[1] = weight_size[weight_output_channels_dim]; for (const auto d : c10::irange(2, dim)) { @@ -220,84 +219,40 @@ static inline std::vector _conv_output_size( return output_size; } -static inline std::vector conv_output_size( - IntArrayRef input_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() -) { - return _conv_output_size(input_size, weight_size, padding, stride, dilation); -} - -static inline std::vector conv_output_size( - SymIntArrayRef input_size, SymIntArrayRef weight_size, - SymIntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() -) { - return _conv_output_size(input_size, weight_size, padding, stride, dilation); -} - -template -std::vector _conv_input_size( - ArrayRef output_size, ArrayRef weight_size, - ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +static inline std::vector conv_input_size( + IntArrayRef output_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { // ASSERT(output_size.size() > 2) // ASSERT(output_size.size() == weight_size.size()) auto dim = output_size.size(); - std::vector input_size(dim); + std::vector input_size(dim); input_size[0] = output_size[output_batch_size_dim]; input_size[1] = weight_size[weight_input_channels_dim] * groups; for (const auto d : c10::irange(2, dim)) { - auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1; - input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + + int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + kernel + output_padding[d - 2]; } return input_size; } -static inline std::vector conv_input_size( - SymIntArrayRef output_size, SymIntArrayRef weight_size, - SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups -) { - return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); -} - -static inline std::vector conv_input_size( - IntArrayRef output_size, IntArrayRef weight_size, +static inline std::vector conv_weight_size( + IntArrayRef input_size, IntArrayRef output_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups -) { - return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); -} - -template -std::vector _conv_weight_size( - ArrayRef input_size, ArrayRef output_size, - ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { auto dim = input_size.size(); - std::vector weight_size(dim); + std::vector weight_size(dim); weight_size[0] = output_size[1]; weight_size[1] = input_size[1] / groups; for (const auto d : c10::irange(2, dim)) { - auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] - + padding[d - 2] * 2 - output_padding[d - 2]; + int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + 2 * padding[d - 2] - output_padding[d - 2]; weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; } return weight_size; } -static inline std::vector conv_weight_size( - SymIntArrayRef input_size, SymIntArrayRef output_size, - SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups -) { - return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); -} - -static inline std::vector conv_weight_size( - IntArrayRef input_size, IntArrayRef output_size, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups -) { - return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); -} - static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index bf7017f20a4f..29b2ce804c80 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -83,11 +83,10 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { // Check workload to activate fast depthwise FP16 cudnn conv kernels -template bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { - auto w = at::symint::size(input, 3); // same as h - auto ch = at::symint::size(input, 1); - auto bs = at::symint::size(input, 0); + int w = input.size(3); // same as h + int ch = input.size(1); + int bs = input.size(0); if (stride==1) { if (w >= 7) { // All batch sizes and nb_channels @@ -206,28 +205,27 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { } // simplified version for cudnn 8.2 and above -template bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) { // 1D conv - if(at::symint::size(input, 2) == 1 && stride == 1){ + if(input.size(2) == 1 && stride == 1){ return true; } // 2d conv // only square filters - if (at::symint::size(weight, 2) != at::symint::size(weight, 3)) return false; - auto filter = at::symint::size(weight, 3); + if (weight.size(2) != weight.size(3)) return false; + int filter = weight.size(3); // only 1/3/5 filter if (filter != 1 && filter != 3 && filter != 5) return false; // we don't enforce square input but only check width to reduce heuristic space - if (at::symint::size(input, 3) < 7) return false; // min width 7 - auto w = at::symint::size(input, 3); + if (input.size(3) < 7) return false; // min width 7 + int w = input.size(3); // only 1/2 stride, use cudnn for all stride 1 if (stride == 1) return true; if (stride != 2) return false; - auto ch = at::symint::size(input, 1); - auto bs = at::symint::size(input, 0); + int ch = input.size(1); + int bs = input.size(0); // special case since bs1 show good perf in lots of cases if (bs == 1) { if (filter == 1 && w <= 28) return true; @@ -242,42 +240,13 @@ bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int str } -bool xnnpack_use_convolution2d( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const IntArrayRef padding, - const IntArrayRef stride, - const IntArrayRef dilation, - const int64_t groups, - const bool transposed) { - return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed); -} - -bool xnnpack_use_convolution2d( - const Tensor& input, - const Tensor& weight, - const at::OptionalSymIntArrayRef bias_sizes_opt, - const SymIntArrayRef padding, - const IntArrayRef stride, - const IntArrayRef dilation, - const int64_t groups, - const bool transposed) { - // Never use xnnpack for symbolic tracing - return false; -} - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) -// This struct is templated so that we can run backend selection in a dynamic -// shapes context; all of the real kernel selection in eager mode runs with -// int64_t -template struct ConvParams { std::vector stride; - std::vector padding; + std::vector padding; std::vector dilation; bool transposed; - std::vector output_padding; + std::vector output_padding; int groups; bool benchmark; bool deterministic; @@ -353,12 +322,12 @@ struct ConvParams { #if defined(__ARM_NEON__) // Currently only 3x3 depthwise convolutions on tensors of float are supported. return (input.ndimension() == 4) && - (at::symint::size(input, 1) == groups) && + (input.size(1) == groups) && (weight.ndimension() == 4 ) && - (at::symint::size(weight, 0) % at::symint::size(input, 1) == 0) && - (at::symint::size(weight, 1) == 1) && - (at::symint::size(weight, 2) == 3) && - (at::symint::size(weight, 3) == 3) && + (weight.size(0) % input.size(1) == 0) && + (weight.size(1) == 1) && + (weight.size(2) == 3) && + (weight.size(3) == 3) && (input.device().is_cpu()) && (input.scalar_type() == at::kFloat) && input.is_contiguous() && @@ -376,23 +345,23 @@ struct ConvParams { bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { constexpr int64_t int_max = std::numeric_limits::max(); - auto numel_input = at::symint::numel(input); + int64_t numel_input = input.numel(); // empty input if (numel_input == 0) { return false; } // input size can not be reduced to the range of int by splitting the batch dim - auto n = at::symint::size(input, 0); + int64_t n = input.size(0); if (numel_input / n > int_max) { return true; } // output size can not be reduced to the range of int by splitting the batch dim - T outsize = 1; + int64_t outsize = 1; if (transposed) { - auto o = conv_input_size(at::symint::sizes(input), at::symint::sizes(weight), padding, output_padding, stride, dilation, groups); + std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } else { - auto o = conv_output_size(at::symint::sizes(input), at::symint::sizes(weight), padding, stride, dilation); + std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } return outsize > int_max; @@ -448,10 +417,10 @@ struct ConvParams { is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks !is_dilated() && // no dilation supported - (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d - at::symint::size(input, 1) >= 32); // min 32 channels supported) + (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d + input.size(1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); } } // keep (7600 <= cudnn < 8200) code unchanged @@ -461,14 +430,14 @@ struct ConvParams { weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - at::symint::size(weight, 2) == at::symint::size(weight, 3) && // only square kernels - at::symint::size(input, 2) >= 7 && // min width/height 7 + weight.size(2) == weight.size(3) && // only square kernels + input.size(2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported stride[0] == stride[1] && // equal strides - ((at::symint::size(weight, 3) == 3) || (at::symint::size(weight, 3) == 1)) && - at::symint::size(input, 1) >= 32); // min 32 channels supported) + ((weight.size(3) == 3) || (weight.size(3) == 1)) && + input.size(1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); + return check_cudnn_depthwise_workload(input, stride[0]); } else { return false; } @@ -504,12 +473,12 @@ struct ConvParams { !transposed && // or transposed tensors // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, // but THNN is faster when single-threaded. - (is_strided() || is_dilated() || at::symint::size(input, 0) >= 16 || - at::symint::size(weight, -1) != 1 || at::symint::size(weight, -2) != 1 || at::get_num_threads() > 1) && + (is_strided() || is_dilated() || input.size(0) >= 16 || + weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && (groups > 1 - || (at::symint::size(weight, -1) > 3 && at::symint::size(weight, -2) > 3) - || at::symint::size(input, 0) > 1 - || at::symint::size(input, 0)*at::symint::size(input, 1)*at::symint::size(input, 2)*at::symint::size(input, 3) > 20480) // for some case, native is faster + || (weight.size(-1) > 3 && weight.size(-2) > 3) + || input.size(0) > 1 + || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster ); #endif @@ -524,23 +493,20 @@ struct ConvParams { !transposed && // or transposed tensors input.ndimension() == 4 && // must be in NCHW format weight.ndimension() == 4 && - (at::symint::size(weight, 2) < 17) && (at::symint::size(weight, 3) < 17) // NNPACK only supports kernels up to 16x16 + (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 #if !defined(C10_MOBILE) - && at::symint::size(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable + && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable #endif ; #endif return false; } bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalArrayRef bias_sizes_opt) const { + const at::OptionalIntArrayRef bias_sizes_opt) const { #if defined(C10_MOBILE) if (!transposed) { - // NB: for the call here, it MATTERS that we are templated. If you - // untemplate this to always use SymInt, the function - // xnnpack_use_convolution2d will always return false - return (at::symint::size(input, 1) == groups) && - xnnpack_use_convolution2d( + return (input.size(1) == groups) && + xnnpack::use_convolution2d( input, weight, bias_sizes_opt, @@ -577,12 +543,33 @@ struct ConvParams { return input.is_cuda() && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && - at::symint::size(input, 1) == groups && + input.size(1) == groups && groups > 1 && // no point if there is only a single group - at::symint::size(weight, 0) % at::symint::size(input, 1) == 0; // output channels must be a multiple of input channels + weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels } }; +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +ConvBackend _select_conv_backend( + const Tensor& input, + const Tensor& weight, + const c10::optional& bias_opt, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params); + DEFINE_DISPATCH(conv_depthwise2d_backward_stub); DEFINE_DISPATCH(conv_depthwise3d_backward_stub); DEFINE_DISPATCH(cudnn_convolution_backward_stub); @@ -604,14 +591,13 @@ REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); -template -std::ostream& operator<<(std::ostream & out, const ConvParams& params) { +std::ostream& operator<<(std::ostream & out, const ConvParams& params) { out << "ConvParams {" << " stride = " << IntArrayRef{params.stride} - << " padding = " << ArrayRef{params.padding} + << " padding = " << IntArrayRef{params.padding} << " dilation = " << IntArrayRef{params.dilation} << " transposed = " << params.transposed - << " output_padding = " << ArrayRef{params.output_padding} + << " output_padding = " << IntArrayRef{params.output_padding} << " groups = " << params.groups << " benchmark = " << params.benchmark << " deterministic = " << params.deterministic @@ -621,10 +607,9 @@ std::ostream& operator<<(std::ostream & out, const ConvParams& params) { return out; } -template static void check_shape_forward(const at::Tensor& input, - const c10::ArrayRef& weight_sizes, const at::Tensor& bias, - const ConvParams& params) { + const c10::IntArrayRef& weight_sizes, const at::Tensor& bias, + const ConvParams& params) { int64_t k = input.ndimension(); int64_t weight_dim = weight_sizes.size(); int64_t groups = params.groups; @@ -639,7 +624,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ", - at::symint::sizes(input), " instead"); + input.sizes(), " instead"); TORCH_CHECK(weight_sizes[0] >= groups, "Given groups=", groups, ", expected weight to be at least ", groups, " at dimension 0, but got weight of size ", weight_sizes, " instead"); @@ -649,23 +634,23 @@ static void check_shape_forward(const at::Tensor& input, "] instead"); if (!transposed) { - std::vector input_shape; - std::vector kernel_shape; + std::vector input_shape; + std::vector kernel_shape; bool kernel_size_correct = true; - TORCH_CHECK(at::symint::size(input, 1) == (weight_sizes[1] * groups), + TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups), "Given groups=", groups, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", - (weight_sizes[1] * groups), " channels, but got ", at::symint::size(input, 1), + (weight_sizes[1] * groups), " channels, but got ", input.size(1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[0]), + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]), "Given weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements", - ", but got bias of size ", at::symint::sizes(bias), " instead"); + ", but got bias of size ", bias.sizes(), " instead"); for (const auto i : c10::irange(2, k)) { - input_shape.push_back(at::symint::size(input, i) + 2 * padding[i-2]); + input_shape.push_back(input.size(i) + 2 * padding[i-2]); // log new kernel size considering dilation kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); if (input_shape.back() < kernel_shape.back()) { @@ -691,23 +676,22 @@ static void check_shape_forward(const at::Tensor& input, "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); } } else { // transposed - TORCH_CHECK(at::symint::size(input, 1) == weight_sizes[0], + TORCH_CHECK(input.size(1) == weight_sizes[0], "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", weight_sizes[0], - " channels, but got ", at::symint::size(input, 1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[1] * groups), + " channels, but got ", input.size(1), " channels instead"); + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[1] * groups), "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements", ", but got bias of size ", bias.sizes(), " instead"); } } -template static void check_shape_backward( const at::Tensor& input, - const c10::ArrayRef& weight_sizes, - const ConvParams& params) { - check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); + const c10::IntArrayRef& weight_sizes, + const ConvParams& params) { + check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); } // Given an input tensor and an expected number of spatial dimensions, checks that the @@ -1165,25 +1149,71 @@ at::Tensor convolution_overrideable( TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); } -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -template +// Selects a backend for convolution based on the inputs and params. +ConvBackend select_conv_backend( + const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, + IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, IntArrayRef output_padding_, int64_t groups_, const at::OptionalIntArrayRef bias_sizes_opt) { + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + auto& ctx = at::globalContext(); + auto k = weight_r.ndimension(); + int64_t dim = k - 2; + ConvParams params; + params.stride = expand_param_if_needed(stride_, "stride", dim); + params.padding = expand_param_if_needed(padding_, "padding", dim); + params.dilation = expand_param_if_needed(dilation_, "dilation", dim); + params.transposed = transposed_; + params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); + params.groups = groups_; + params.benchmark = ctx.benchmarkCuDNN(); + params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); + params.cudnn_enabled = ctx.userEnabledCuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN(); + + auto input = input_r; + auto weight = weight_r; + check_shape_forward(input, weight.sizes(), bias, params); + + // Expand 1d -> 2d. + // This is only done for backends that don't natively support 1d spatial input. + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { + // avoid accidentally going through NHWC for permuted 3d input. + input = input.contiguous(); + params.view1d_as_2d(); + input = view4d(input); + weight = view4d(weight); + } + + auto bias_sizes = bias.defined() ? c10::optional(bias.sizes()) : bias_sizes_opt; + bool need_backward = GradMode::is_enabled() && + (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); + return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); +} + +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params) { + return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); +} + ConvBackend _select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias, - const at::OptionalArrayRef bias_sizes_opt, + const at::OptionalIntArrayRef bias_sizes_opt, const bool need_backward, - const ConvParams& params) { + const ConvParams& params) { // don't send empty inputs through backends - if (at::symint::size(input, 0) == 0 || at::symint::size(input, 1) == 0) { + if (input.size(0) == 0 || input.size(1) == 0) { return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty; - } else if (at::symint::numel(input) == 0) { - TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes(input)); + } else if (input.numel() == 0) { + TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", input.sizes()); } if (params.is_depthwise(input, weight)) { @@ -1275,65 +1305,12 @@ ConvBackend _select_conv_backend( AT_ERROR("unsupported ConvNd parameters"); } -// Selects a backend for convolution based on the inputs and params. -ConvBackend select_conv_backend( - const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, - IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_, - bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) { - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - auto& ctx = at::globalContext(); - auto k = weight_r.ndimension(); - int64_t dim = k - 2; - ConvParams params; - params.stride = expand_param_if_needed(stride_, "stride", dim); - params.padding = expand_param_if_needed(padding_, "padding", dim); - params.dilation = expand_param_if_needed(dilation_, "dilation", dim); - params.transposed = transposed_; - params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); - params.groups = groups_; - params.benchmark = ctx.benchmarkCuDNN(); - params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); - params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); - - auto input = input_r; - auto weight = weight_r; - check_shape_forward(input, weight.sym_sizes(), bias, params); - - // Expand 1d -> 2d. - // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { - // avoid accidentally going through NHWC for permuted 3d input. - input = input.contiguous(); - params.view1d_as_2d(); - input = view4d(input); - weight = view4d(weight); - } - - auto bias_sizes = bias.defined() ? c10::optional(bias.sym_sizes()) : bias_sizes_opt; - bool need_backward = GradMode::is_enabled() && - (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); -} - -// For BC reasons, have a copy that does not require bias_opt -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params) { - return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); -} - at::Tensor _convolution_nogroup_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::NnpackSpatial: @@ -1364,7 +1341,7 @@ at::Tensor _convolution_nogroup_backend( static inline std::vector calc_output_size( const Tensor& input, const Tensor& weight, - const ConvParams& params) { + const ConvParams& params) { std::vector output_size = params.transposed ? conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding, params.stride, params.dilation, params.groups) : @@ -1445,7 +1422,7 @@ at::Tensor _convolution( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); TORCH_CHECK(groups_ > 0, "non-positive groups is not supported"); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1473,7 +1450,7 @@ at::Tensor _convolution( auto bias_sizes_opt = bias.defined() ? c10::optional(bias.sizes()) : c10::nullopt; bool need_backward = GradMode::is_enabled() && (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params); + ConvBackend backend = _select_conv_backend(input, weight, bias, bias_sizes_opt, need_backward, params); at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); // Call the backend. @@ -1686,7 +1663,7 @@ std::tuple _convolution_double_backward( const c10::option auto weight = weight_r; int64_t dim = weight.ndimension() - 2; - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1749,7 +1726,7 @@ std::tuple _convolution_double_backward( const c10::option if (ggI.defined()) { // Modified params with correct padding - ConvParams gw_conv_params(params); + ConvParams gw_conv_params(params); // Disable groups as they are handled separately auto groups = gw_conv_params.groups; @@ -1818,7 +1795,7 @@ std::tuple _convolution_double_backward( const c10::option Tensor gI; if (input.numel() != 0) { if (ggW.defined()) { - ConvParams gi_conv_params(params); + ConvParams gi_conv_params(params); gi_conv_params.transposed = !params.transposed; if (params.transposed) { @@ -1874,7 +1851,7 @@ std::tuple _convolution_backward_nogroup_bac const Tensor& weight, const std::array output_mask, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::Slow2d: @@ -1939,7 +1916,7 @@ std::tuple convolution_backward( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); auto& ctx = at::globalContext(); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride, "stride", dim); params.padding = expand_param_if_needed(padding, "padding", dim); params.dilation = expand_param_if_needed(dilation, "dilation", dim); diff --git a/aten/src/ATen/native/utils/ParamUtils.h b/aten/src/ATen/native/utils/ParamUtils.h index 7c89a3316cb4..376467ff79cf 100644 --- a/aten/src/ATen/native/utils/ParamUtils.h +++ b/aten/src/ATen/native/utils/ParamUtils.h @@ -6,13 +6,12 @@ namespace at { namespace native { -template -inline std::vector expand_param_if_needed( - ArrayRef list_param, +inline std::vector expand_param_if_needed( + IntArrayRef list_param, const char* param_name, int64_t expected_dim) { if (list_param.size() == 1) { - return std::vector(expected_dim, list_param[0]); + return std::vector(expected_dim, list_param[0]); } else if ((int64_t)list_param.size() != expected_dim) { std::ostringstream ss; ss << "expected " << param_name << " to be a single integer value or a " diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 6355f1339505..9ab72a077680 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -235,19 +235,6 @@ inline c10::SymInt multiply_integers(const C& container) { [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } -template < - typename Iter, - typename = std::enable_if_t::value_type, - c10::SymInt>::value>> -inline c10::SymInt multiply_integers(Iter begin, Iter end) { - return std::accumulate( - begin, - end, - c10::SymInt(1), - [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); -} - inline SymInt operator+(int64_t a, const SymInt& b) { return c10::SymInt(a) + b; } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 607373625724..b8693a484ed9 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1408,10 +1408,10 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias_opt, at::IntArrayRef stride_, - at::SymIntArrayRef padding_, + at::IntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::SymIntArrayRef output_padding_, + at::IntArrayRef output_padding_, int64_t groups_) { return at::native::select_conv_backend( input, @@ -1442,13 +1442,13 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias, at::IntArrayRef stride_, - at::SymIntArrayRef padding_, + at::IntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::SymIntArrayRef output_padding_, + at::IntArrayRef output_padding_, int64_t groups_, - c10::optional> bias_sizes_opt) { - c10::OptionalArrayRef ref = c10::nullopt; + c10::optional> bias_sizes_opt) { + c10::OptionalArrayRef ref = c10::nullopt; if (bias_sizes_opt) { ref = (*bias_sizes_opt); } From fe276ea0f9b4cce9c7d32157f831897fbbd1c85a Mon Sep 17 00:00:00 2001 From: Kirtesh Patil Date: Wed, 16 Nov 2022 16:40:24 +0000 Subject: [PATCH 235/453] [UCC] Add pre & post processing for CPU collectives (#89030) Summary: The CPU block in `collective_post` was missing pre & post processing. The reduce-scatter implementaion expects use of pre-processing callback to flatten the input tensors, however, the missing invocation meant grabage values were being passed. Test Plan: Tested the reduce-scatter collective using PARAM Reviewed By: eastzone Differential Revision: D41291592 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89030 Approved by: https://github.com/kingchc, https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupUCC.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp index 5f286b7a716c..ad135062a702 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -789,7 +789,9 @@ c10::intrusive_ptr ProcessGroupUCC::collective_post( work->future_ = c10::make_intrusive( c10::ListType::create(c10::TensorType::get())); } + preproc(); comm->enqueue_collective(std::move(data), work, coll, team); + postproc(); return work; } #ifdef USE_CUDA From cf6003f0469ae1440d4a8585860c2c5f4c738707 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 16 Nov 2022 16:52:47 +0000 Subject: [PATCH 236/453] Revert "Towards unifying symbolic and non symbolic fake tensor (#89038)" This reverts commit 37d54239c7ea88fd9c98dcac3fcc9b98a6f9e9d1. Reverted https://github.com/pytorch/pytorch/pull/89038 on behalf of https://github.com/ezyang due to executorch segfaults --- aten/src/ATen/native/TensorFactories.cpp | 6 +++ test/functorch/test_aotdispatch.py | 1 + test/test_proxy_tensor.py | 21 ++++++--- torch/_meta_registrations.py | 44 +++--------------- torch/_ops.py | 1 - torch/_prims/__init__.py | 5 +- torch/_prims_common/__init__.py | 3 -- torch/_subclasses/fake_tensor.py | 58 +++++++++++++++--------- 8 files changed, 68 insertions(+), 71 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 7245cb77b1c5..9d1c6d8a3633 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -325,6 +325,12 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + + TORCH_CHECK( + !(options_.has_memory_format() && optional_memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + TensorOptions options = self.options() .merge_in(options_) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ae216f9be4a4..1dc5476158f9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1011,6 +1011,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 0a24807af55f..8dc42be7fdfb 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1151,7 +1151,9 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition + xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1233,6 +1235,8 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... @@ -1277,6 +1281,7 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... @@ -1293,6 +1298,7 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition @@ -1341,15 +1347,11 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) -outplace_symbolic_tensor_failures = { - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition -} - inplace_symbolic_tensor_failures = { + xfail('abs', ''), # aten.abs_.default - couldn't find symbolic meta function/decomposition xfail('acos', ''), # aten.acos_.default - couldn't find symbolic meta function/decomposition xfail('acosh', ''), # aten.acosh_.default - couldn't find symbolic meta function/decomposition + xfail('addbmm', ''), # aten.addbmm_.default - couldn't find symbolic meta function/decomposition xfail('addcdiv', ''), # aten.addcdiv_.default - couldn't find symbolic meta function/decomposition xfail('addcmul', ''), # aten.addcmul_.default - couldn't find symbolic meta function/decomposition xfail('addmm', ''), # aten.addmm_.default - couldn't find symbolic meta function/decomposition @@ -1363,6 +1365,7 @@ def f(a, b, c, d, e): xfail('clamp', ''), # aten.clamp_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_max', ''), # aten.clamp_max_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_min', ''), # aten.clamp_min_.Tensor - couldn't find symbolic meta function/decomposition + xfail('conj_physical', ''), # aten.conj_physical_.default - couldn't find symbolic meta function/decomposition xfail('copysign', ''), # aten.copysign_.Tensor - couldn't find symbolic meta function/decomposition xfail('cos', ''), # aten.cos_.default - couldn't find symbolic meta function/decomposition xfail('cosh', ''), # aten.cosh_.default - couldn't find symbolic meta function/decomposition @@ -1379,6 +1382,7 @@ def f(a, b, c, d, e): xfail('expm1', ''), # aten.expm1_.default - couldn't find symbolic meta function/decomposition xfail('float_power', ''), # the base given to float_power_ has dtype Float but the operation's result requires dtype Double xfail('floor', ''), # aten.floor_.default - couldn't find symbolic meta function/decomposition + xfail('floor_divide', ''), # aten.floor_divide_.Tensor - couldn't find symbolic meta function/decomposition xfail('fmod', ''), # aten.fmod_.Tensor - couldn't find symbolic meta function/decomposition xfail('frac', ''), # aten.frac_.default - couldn't find symbolic meta function/decomposition xfail('ge', ''), # aten.ge_.Tensor - couldn't find symbolic meta function/decomposition @@ -1394,6 +1398,7 @@ def f(a, b, c, d, e): xfail('log1p', ''), # aten.log1p_.default - couldn't find symbolic meta function/decomposition xfail('log2', ''), # aten.log2_.default - couldn't find symbolic meta function/decomposition xfail('log', ''), # aten.log_.default - couldn't find symbolic meta function/decomposition + xfail('logit', ''), # aten.logit_.default - couldn't find symbolic meta function/decomposition xfail('lt', ''), # aten.lt_.Tensor - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_1'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_3'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition @@ -1403,6 +1408,7 @@ def f(a, b, c, d, e): xfail('neg', ''), # aten.neg_.default - couldn't find symbolic meta function/decomposition xfail('nextafter', ''), # aten.nextafter_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.celu', ''), # aten.celu_.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.dropout3d', ''), # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition xfail('nn.functional.elu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.hardsigmoid', ''), # aten.hardsigmoid_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.mish', ''), # aten.mish_.default - couldn't find symbolic meta function/decomposition @@ -1420,6 +1426,7 @@ def f(a, b, c, d, e): xfail('sinh', ''), # aten.sinh_.default - couldn't find symbolic meta function/decomposition xfail('sqrt', ''), # aten.sqrt_.default - couldn't find symbolic meta function/decomposition xfail('square', ''), # aten.pow_.Scalar - couldn't find symbolic meta function/decomposition + xfail('squeeze', ''), # aten.squeeze_.default - couldn't find symbolic meta function/decomposition xfail('t', ''), # aten.t_.default - couldn't find symbolic meta function/decomposition xfail('tan', ''), # aten.tan_.default - couldn't find symbolic meta function/decomposition xfail('tanh', ''), # aten.tanh_.default - couldn't find symbolic meta function/decomposition @@ -1509,7 +1516,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index abcd1ead8b43..4fa3ab09d275 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1513,6 +1513,7 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, + aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1520,44 +1521,6 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) -# zeros_like is special cased to work for sparse -@register_meta(aten.zeros_like.default) -def zeros_like( - self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None -): - if layout == torch.sparse_coo: - check( - memory_format is None, - lambda: "memory format option is only supported by strided tensors", - ) - - res = torch.empty( - 0, - dtype=self.dtype if dtype is None else dtype, - layout=layout, - device=self.device if device is None else device, - pin_memory=pin_memory, - ) - - if self.is_sparse: - res.sparse_resize_and_clear_( - self.size(), self.sparse_dim(), self.dense_dim() - ) - else: - res.sparse_resize_and_clear_(self.size(), self.dim(), 0) - - res._coalesced_(True) - return res - return aten.empty_like.default( - self, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - memory_format=memory_format, - ) - - # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): @@ -1931,6 +1894,11 @@ def activate_meta(): # Instead, we should be letting those decompositions run, and writing meta kernels # only for the base operators. pass + elif op_overload.is_view: + # Attempting to register a python meta kernel for a view operator. + # We shouldn't do this, because the output will report as not having aliased storages. + # All view ops have meta kernels in C++ today, so we should use those instead. + pass elif op_overload.name() in { "aten::empty_strided", # causing infinite recursion, test_meta.py "aten::clone", # causing infinite recursion diff --git a/torch/_ops.py b/torch/_ops.py index b20398a7f3ab..9163932144d0 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -365,7 +365,6 @@ def handler(*args, **kwargs): return handler final_key = resolve_key(self, key) - # print(self, key, final_key) r = self.py_kernels.get(final_key, final_key) self._dispatch_cache[key] = r return r diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index a4bac68f0ff1..da8d9af723ac 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1150,6 +1150,9 @@ def _minimum_aten( # # View operations +# +# TODO: model view relationships +# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1167,7 +1170,7 @@ def _as_strided_meta( a._typed_storage(), size, stride, storage_offset ) - return torch.as_strided(a, size, stride, storage_offset) + return TensorMeta(a, shape=size, strides=stride) def _as_strided_aten( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 041448e8102a..128796dfa3d0 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -291,9 +291,6 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ - if a.is_sparse: - return False - # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9a0ac050e6b9..5d3d3a0e32fe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,6 +1,7 @@ import contextlib import functools import itertools +import sys import weakref from dataclasses import dataclass from functools import partial @@ -296,9 +297,8 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - # _like constructors have fake tensor inputs (maybe this causes the non-like - # to fail? hmmm) - with in_kernel_invocation_manager(fake_mode): + # Not in_kernel_invocation_manager as no fake tensor inputs + with no_dispatch(): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -821,30 +821,40 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - # If there's a Python meta, prefer that over the decomposition - from torch._decomp import meta_table as meta_table + from torch._decomp import decomposition_table + + with self: + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r - if func not in meta_table and not self.cpp_meta_supports_symint(func): - from torch._decomp import decomposition_table + # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise + if has_symbolic_sizes and not self.cpp_meta_supports_symint(func): + from torch._decomp import meta_table as meta_table - # Prefer Python decompositions over C++ ones - if func in decomposition_table and ( - has_symbolic_sizes - or ( - # TODO: Remove these exclusions, so that we can remove - # this leg entirely - torch_decomp_decompositions(func) - and all(not e.is_sparse for e in flat_arg_fake_tensors) + if func == aten.size.default: + sys.stderr.write( + "Trying to call aten.size on a tensor with symbolic shapes. " + "It's likely that this is from calling tensor.shape in C++" ) - ): - with self: - return decomposition_table[func](*args, **kwargs) + # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` + return None with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: + if func in meta_table: + r = meta_table[func](*args, **kwargs) return r + if func in decomposition_table: + return decomposition_table[func](*args, **kwargs) + + if ( + func in decomposition_table + and torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) + ): + with self: + return decomposition_table[func](*args, **kwargs) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them @@ -855,6 +865,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self: return func.prim_meta_impl(*args, **kwargs) + if has_symbolic_sizes: + if not self.cpp_meta_supports_symint(func): + raise RuntimeError( + f"{func} - couldn't find symbolic meta function/decomposition" + ) + # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors # and then afterwards wrapping them to a FakeTensor From 7f55db4fb0fb12ed593c7f23de01bfb9330b7dd5 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 16 Nov 2022 16:59:36 +0000 Subject: [PATCH 237/453] add quantize_decomposed_dynamic to op lib (#88855) Summary: Needed for dynamic quant reference pattern graphs. Test Plan: added unittest Differential Revision: D41205030 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88855 Approved by: https://github.com/jerryzh168 --- .../core/test_quantized_tensor.py | 25 +++++++++++++ torch/ao/quantization/fx/_decomposed.py | 36 +++++++++++++++---- torch/ao/quantization/utils.py | 1 - 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index a2043509f1f1..dab53de5b107 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import TestCase, DeterministicGuard import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_quantization import get_supported_device_types +from torch.ao.quantization import MinMaxObserver hu.assert_deadline_disabled() @@ -1498,6 +1499,30 @@ def test_decomposed_dequantize(self): self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) self.assertEqual(dequantized_X, dequantized_decomposed_X) + def test_decomposed_quantize_dynamic(self): + import torch.ao.quantization.fx._decomposed + X = torch.randn(5, 10) + dtype = torch.uint8 + qdtype = torch.quint8 + scale, zero_point = torch._choose_qparams_per_tensor(X, False) + quant_min, quant_max = 0, 255 + + quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype) + dequantized_X = torch.dequantize(quantized_X) + + quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor_dynamic( + X, quant_min, quant_max, dtype) + + # observer logic is what quantize_per_tensor_dynamic does internally + observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max) + observer(X) + scale_decomposed, zero_point_decomposed = observer.calculate_qparams() + dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor( + quantized_decomposed_X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype + ) + self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) + self.assertEqual(dequantized_X, dequantized_decomposed_X) + if __name__ == '__main__': raise RuntimeError("This test file is not meant to be run directly, use:\n\n" "\tpython test/test_quantization.py TESTNAME\n\n" diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 001fa16f8cd3..3f4d38872e17 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -1,16 +1,13 @@ import torch -from torch.library import Library, impl +from torch.library import impl, Library +from torch.ao.quantization import MinMaxObserver # Note: decomposed means decomposed quantized tensor, using decomposed so that the # name is not too long quantized_decomposed_lib = Library("quantized_decomposed", "DEF") -quantized_decomposed_lib.define( - "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> Tensor") - -@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") -def quantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype): - assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" +# Helper to check the passed in quant min and max are valid for the dtype +def _quant_min_max_bounds_check(quant_min, quant_max, dtype): quant_min_lower_bound = 0 quant_max_upper_bound = 0 if dtype == torch.uint8: @@ -30,6 +27,14 @@ def quantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype): "quant_max out of bound for dtype, " \ f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" +quantized_decomposed_lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype): + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + inv_scale = 1.0 / scale return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) @@ -50,3 +55,20 @@ def dequantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype) return (input.to(torch.float32) - zero_point) * scale else: raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + +quantized_decomposed_lib.define( + "quantize_per_tensor_dynamic(Tensor input, int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor_dynamic", "CompositeExplicitAutograd") +def quantize_per_tensor_dynamic(input, quant_min, quant_max, dtype): + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + + # Its weird to create an observer manually just to calculate qparams. I tried refactoring this functionality out of observer + # into a util and then use that util directly, but I kept running into jit typing errors related to torch.qscheme not + # being recognized as a type. TODO: properly refactor this out to avoid observer overhead + tensor_dtype_to_observer_dtype = {torch.uint8: torch.quint8, torch.int8: torch.qint8} + observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max, dtype=tensor_dtype_to_observer_dtype[dtype]) + observer(input) + scale, zero_point = observer.calculate_qparams() + return torch.ops.quantized_decomposed.quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index afa278a795dd..9f3dc712a9fe 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -546,7 +546,6 @@ def _patched_module_call(self, *args, **kwargs): torch.nn.Module.__call__ = orig_module_call return fqn_to_example_inputs - __all__ = [ "NodePattern", "Pattern", From d2d22d89d92bf7d6bb02417dab04027d7fcc80d3 Mon Sep 17 00:00:00 2001 From: bmedishe Date: Wed, 16 Nov 2022 17:42:26 +0000 Subject: [PATCH 238/453] test_unary_ufuncs few tests enabled on rocm which are passing (#89007) This PR is to enable tests which are skip on rocm from test package test_unary_ufuncs.py::TestUnaryUfuncsCUDA
test_file | test_name | test_class -- | -- | -- test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_float16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_float32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_int16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_int32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_2_cuda_int64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_float16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_float32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_int16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_int32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_polygamma_polygamma_n_4_cuda_int64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_large_tan_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_bfloat16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_float16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_float32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_int16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_int32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_int64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_int8 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_atan_cuda_uint8 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_float16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_float32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_int16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_int32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_int64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_int8 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_2_cuda_uint8 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_float16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_float32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_float64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_int16 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_int32 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_int64 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_int8 | (__main__.TestUnaryUfuncsCUDA) test_unary_ufuncs | test_reference_numerics_small_polygamma_polygamma_n_4_cuda_uint8 | (__main__.TestUnaryUfuncsCUDA)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89007 Approved by: https://github.com/mruberry --- .../_internal/common_methods_invocations.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index e498e4f28509..1b8920cbc867 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8740,7 +8740,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): decorators=(precisionOverride({torch.bfloat16: 1e-2}),), skips=( DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - active_if=TEST_WITH_ROCM, device_type='cuda'), + active_if=TEST_WITH_ROCM, device_type='cuda', dtypes=[torch.complex64, torch.complex128]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', active_if=TEST_WITH_ROCM, device_type='cuda', dtypes=[torch.complex128]), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', @@ -13320,9 +13320,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - device_type='cuda', dtypes=[torch.float64], - active_if=TEST_WITH_ROCM), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), ), @@ -13841,11 +13838,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), # Mismatch: https://github.com/pytorch/pytorch/issues/55357 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - active_if=TEST_WITH_ROCM), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - active_if=TEST_WITH_ROCM),), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),), sample_kwargs=lambda device, dtype, input: ({'n': 2}, {'n': 2}), # polygamma functions have multiple singularities at x <= 0 reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), @@ -13888,11 +13881,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators'), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon'), # Mismatch: https://github.com/pytorch/pytorch/issues/55357 - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - active_if=TEST_WITH_ROCM), - DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', - active_if=TEST_WITH_ROCM),), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal'),), sample_kwargs=lambda device, dtype, input: ({'n': 4}, {'n': 4}), # polygamma functions have multiple singularities at x <= 0 reference_numerics_filter=NumericsFilter(condition=lambda x: x < 0.1, safe_val=1)), From e1ecf53d8480899b5b41c295e52eafb7347f0141 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 16 Nov 2022 14:09:58 +0000 Subject: [PATCH 239/453] Simplify linspace decomp and increase its tolerance (#87203) This is an interesting one Since this is an operation that's intrinsically defined on the reals, we should perform the ops on that dtype always, and just cast to the desired dtype at the end. This simplifies the decomposition. Now, I started looking at this one when I started seeing failures on a test that's added in a later PR. What's going on here is that, by doing an upcast to a higher dtype and then cast down to integers, sometimes there's an off-by-one error. I think this is fine, as the decomposition is more accurate than the original function, which goes in line with the whole PrimTorch effort. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87203 Approved by: https://github.com/mruberry --- test/test_decomp.py | 11 ++- torch/_refs/__init__.py | 89 ++++++++----------- .../_internal/common_methods_invocations.py | 19 +--- 3 files changed, 50 insertions(+), 69 deletions(-) diff --git a/test/test_decomp.py b/test/test_decomp.py index a3658792c5e7..dc94b6714ccd 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -201,8 +201,17 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5), (torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 5e-4), (torch.float64, torch.ops.aten.upsample_bicubic2d.default) : (1e-5, 5e-4), + # The decomposition is TOO correct. It computes everything in int64, so sometimes + # there's an off-by-one error. See + # https://github.com/pytorch/pytorch/issues/81996 + # https://github.com/pytorch/pytorch/issues/82230 + (torch.int8, torch.ops.aten.linspace.default) : (0, 1), + (torch.uint8, torch.ops.aten.linspace.default) : (0, 1), + (torch.int16, torch.ops.aten.linspace.default) : (0, 1), + (torch.int32, torch.ops.aten.linspace.default) : (0, 1), + (torch.int64, torch.ops.aten.linspace.default) : (0, 1), } - if (test_dtype, op) in tol_table: + if (decomp.dtype, op) in tol_table: rtol, atol = tol_table[(decomp.dtype, op)] else: rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index a0916c3f8268..111c5c956f5d 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -4182,8 +4182,9 @@ def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]): ) assert isinstance(weight, Tensor) # mypy # We implement it this way for numerical stability. We assume (in the stability optimisation) - # that 0 <= weight <= 1. We take the abs to deal with comples numbers - # We want to do operations near zero, which is where floating points are most precise + # that 0 <= weight <= 1. We take the abs to deal with complex numbers + # We want to perform operations near zero, which is where floating points are most precise + # thus, we perform the following optimisation: # If weight.abs() >= 0.5: # return (1 - weight) * (start - end) + end mask = weight.abs() >= 0.5 @@ -4205,22 +4206,22 @@ def linspace( pin_memory: bool = False, requires_grad: bool = False, ) -> TensorLikeType: - if dtype is None: - dtype = torch.get_default_dtype() - - # NB: NumPy actually doesn't do this cast, but for this ref, I'd rather have this - # cast than not, because it allows us to always go into the precise path - # if dtype is integral and not worry about whether start/end are float - if prims.utils.is_integer_dtype(dtype): - if isinstance(start, FloatLike): - start = sym_int(start) - if isinstance(end, FloatLike): - end = sym_int(end) - if py_any(isinstance(arg, complex) for arg in (start, end, steps)): - raise NotImplementedError - assert not isinstance(start, complex) and not isinstance(end, complex) # for mypy + default_complex_dtype = utils.corresponding_complex_dtype( + torch.get_default_dtype() + ) + if dtype is None: + dtype = default_complex_dtype + else: + check( + utils.is_complex_dtype(dtype), + lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}", + ) + else: + dtype = dtype or torch.get_default_dtype() + assert isinstance(dtype, torch.dtype) + # steps does not participate in the computation of the dtype check( isinstance(steps, IntLike), lambda: "steps must be int, not float", @@ -4236,41 +4237,27 @@ def linspace( "requires_grad": requires_grad, } if steps == 0: - ret = torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - elif steps == 1: - ret = torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - elif start == end: - ret = torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[call-overload] - else: - if prims.utils.is_integer_dtype(dtype): - # We need to cast to int, so to avoid off-by-one issues - # do the entire computation with ints when we can - assert isinstance(start, IntLike) and isinstance(end, IntLike) - step_size_x_denom = end - start - eps = 1 if end > start else -1 - denom = steps - 1 - ret = prims.to_dtype( - torch.arange( - start * denom, - end * denom + eps, - step_size_x_denom, - dtype=torch.int64, - **factory_kwargs, # type: ignore[arg-type] - ) - / denom, - dtype, - ) - else: - step_size = (end - start) / (steps - 1) - eps = step_size / 2 - ret = prims.to_dtype( - torch.arange( # type: ignore[call-overload] - start, end + eps, step_size, dtype=torch.float64, **factory_kwargs - ), - dtype, - ) - - return ret + return torch.full((0,), 0, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if steps == 1: + return torch.full((1,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + if start == end: + return torch.full((steps,), start, dtype=dtype, **factory_kwargs) # type: ignore[arg-type] + + # arange returns values in the interval [start, end) so we add an an eps to make it [start, end] + # The eps is small enough as to always add just the element end + step_size = 1 / (steps - 1) + eps = step_size / 2 + # arange returns a tensor of size divup(end - start, step) and thus, for the arguemnts below + # ceil(div(1 + step_size/2, 1/(steps - 1)) = steps - 1 + ceil(1 / 2) = steps + # torch.arange is an scan algorithm, so we need a high-precision dtype + rg = torch.arange( + 0, 1 + eps, step_size, dtype=torch.float64, **factory_kwargs # type: ignore[arg-type] + ) + double_dtype = torch.complex128 if utils.is_complex_dtype(dtype) else torch.float64 + rg = _maybe_convert_to_dtype(rg, double_dtype) # type: ignore[assignment] + cast = partial(torch.full, (1,), dtype=double_dtype, **factory_kwargs) + out = torch.lerp(cast(start), cast(end), rg) + return _maybe_convert_to_dtype(out, dtype) # type: ignore[return-value] @register_decomposition(torch.ops.aten.logspace) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1b8920cbc867..f81aa4f5024c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9937,21 +9937,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - - # cpu implementation is wrong on some integral types - # https://github.com/pytorch/pytorch/issues/81996 - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', - dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', - dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), - # cuda implementation is off-by-one on some inputs due to precision issues - # https://github.com/pytorch/pytorch/issues/82230 - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_quick', - dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), - device_type="cuda"), - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', - dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), - device_type="cuda"), # UserWarning: CUDA caching allocator reports a memory leak not verified by the driver API # in __main__.TestJitCUDA.test_variant_consistency_jit_logspace_cuda_complex64! # Caching allocator allocated memory was 0 and is now reported as 307200 on device 0. @@ -16965,9 +16950,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # cpu implementation is wrong on some integral types # https://github.com/pytorch/pytorch/issues/81996 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', - dtypes=(torch.int16, torch.int32, torch.int64), device_type="cpu"), + dtypes=(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), device_type="cpu"), # cuda implementation is off-by-one on some inputs due to precision issues # https://github.com/pytorch/pytorch/issues/82230 From 33209153035ef60f84014983186f9eefde7dab72 Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 16 Nov 2022 14:09:59 +0000 Subject: [PATCH 240/453] Fix decomp for embedding_backward and simplify the decomposition of embedding_dense and embedding_dense_backward (#87204) See the title Pull Request resolved: https://github.com/pytorch/pytorch/pull/87204 Approved by: https://github.com/Chillee --- test/test_decomp.py | 2 -- torch/_decomp/decompositions.py | 54 ++++++++++++++++----------------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/test/test_decomp.py b/test/test_decomp.py index dc94b6714ccd..ad8cf27ae0f2 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -310,8 +310,6 @@ def normalize_op_input_output(f, sample, requires_grad=True): CROSS_REF_BACKWARD_EXCLUDE_SET = { # Decomposed backward formula is not as precise - ("cuda", torch.float16, "nn.functional.embedding"), - ("cuda", torch.bfloat16, "nn.functional.embedding"), ("cpu", torch.bfloat16, "nn.functional.hardswish"), ("cuda", torch.float16, "nn.functional.cross_entropy"), } diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 7c84cb7e2ca8..7e3d31bb9746 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -12,7 +12,12 @@ from torch import Tensor from torch._decomp import register_decomposition from torch._prims_common import IntLike, NumberType, TensorLike, TensorSequenceType -from torch._prims_common.wrappers import _maybe_resize_out, _safe_copy_out, out_wrapper +from torch._prims_common.wrappers import ( + _maybe_convert_to_dtype, + _maybe_resize_out, + _safe_copy_out, + out_wrapper, +) from torch.fx.experimental.symbolic_shapes import guard_int, sym_float, sym_int from torch.utils._pytree import tree_flatten, tree_map @@ -1039,22 +1044,19 @@ def embedding( sparse: bool = False, ) -> Tensor: assert weight.dim() == 2, "'weight' must be 2-D" - # TODO: Assert not ported over yet - # auto indices_arg = TensorArg(indices, "indices", 1); - # checkScalarTypes("embedding", indices_arg, {kLong, kInt}); - - if indices.dim() == 1: - return weight.index_select(0, indices) - - size = list(indices.shape) - for d in weight.shape[1:]: - size.append(d) - - return weight.index_select(0, indices.reshape(-1)).view(size) + # Nb. scale_grad_by_freq is not used in the forward + if indices.ndim <= 1: + # We need this one as weight[indices] calls item() in these cases + out = weight.index_select(0, indices) + if indices.ndim == 0: + out = out.squeeze(0) + return out + else: + return weight[indices] -# TODO: Correct the type promotion semantics @register_decomposition(aten.embedding_dense_backward) +@pw_cast_for_opmath def embedding_dense_backward( grad_output: Tensor, indices: Tensor, @@ -1062,22 +1064,20 @@ def embedding_dense_backward( padding_idx: int, scale_grad_by_freq: bool, ): - numel = indices.numel() - grad = grad_output.reshape(numel, grad_output.size(-1)) - grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1])) - indices_rank1 = indices.reshape(numel) + indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment] if scale_grad_by_freq: counts = indices.new_zeros((num_weights,)) - ones = indices.new_ones((numel,)) - counts = counts.index_put([indices_rank1], ones, accumulate=True) - grad_weights_scale = counts[indices_rank1] - grad = grad / grad_weights_scale.unsqueeze(1) - skip_padding = (indices_rank1 != padding_idx).unsqueeze(1) - skip_padding = skip_padding.expand_as(grad) - zero_grad = torch.full_like(grad, 0) - return grad_weight.index_put( - [indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True + ones = torch.ones_like(indices) + counts = counts.index_put([indices], ones, accumulate=True) + grad_weights_scale = counts[indices] + grad_output = grad_output / grad_weights_scale.unsqueeze(1) + + mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim) + grad = grad_output.masked_fill(mask, 0) + grad_weight = grad_output.new_zeros( + (num_weights,) + grad_output.shape[indices.ndim :] ) + return grad_weight.index_put([indices], grad, accumulate=True) def prod(x: List[int]): From 58ebf92cf06bd68ca7aba0e29526e9004d53f08d Mon Sep 17 00:00:00 2001 From: lezcano Date: Wed, 16 Nov 2022 14:09:59 +0000 Subject: [PATCH 241/453] Add bfloat16 support to torch.prod to align with torch.cumprod (#87205) As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/87205 Approved by: https://github.com/mruberry --- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 2 +- .../testing/_internal/common_methods_invocations.py | 5 +---- .../testing/_internal/opinfo/definitions/_masked.py | 12 ++---------- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index a4345c3fd5d8..a82f3ed3eaa1 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -184,7 +184,7 @@ static void prod_kernel_impl(TensorIterator& iter) { // NOLINTNEXTLINE(bugprone-argument-comment) /*identity=*/1); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "prod_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, iter.dtype(), "prod_out_cpu", [&] { binary_kernel_reduce_vec( iter, [=](scalar_t a, scalar_t b) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index f81aa4f5024c..af4539ee5fec 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9312,9 +9312,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): skips=( # cumprod does not handle correctly out= dtypes DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # RuntimeError: "prod_cpu" not implemented for 'BFloat16' - DecorateInfo(unittest.expectedFailure, 'TestDecomp', 'test_comprehensive', - dtypes=(torch.bfloat16,), device_type='cpu'), ), # gradgradcheck fails in fast_mode=True: #56275 sample_inputs_func=sample_inputs_cumprod, @@ -16441,7 +16438,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, promotes_int_to_int64=True, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, - dtypes=all_types_and_complex_and(torch.bool), + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_prod, ref=reference_reduction_numpy(np.prod), diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index 92231229bb5e..f4b590fe2520 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -446,8 +446,8 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar supports_sparse=True, supports_sparse_csr=True, promotes_int_to_int64=True, - # FIXME: "prod_cpu" not implemented for 'Half' or 'BFloat16' - dtypes=all_types_and_complex_and(torch.bool), + # FIXME: "prod_cpu" not implemented for 'Half' + dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and( torch.bool, torch.float16, torch.bfloat16 ), @@ -549,14 +549,6 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar DecorateInfo( unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" ), - # RuntimeError: "prod_cpu" not implemented for 'BFloat16' - DecorateInfo( - unittest.expectedFailure, - "TestDecomp", - "test_comprehensive", - dtypes=(torch.bfloat16,), - device_type="cpu", - ), DecorateInfo( toleranceOverride({torch.float32: tol(atol=1e-5, rtol=1e-5)}), "TestCompositeCompliance", From 7b0adc290a744de42e875822a1be4fa2b8d96147 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 16 Nov 2022 12:40:27 +0000 Subject: [PATCH 242/453] Run tests from test/inductor in inductor CI job (#88957) CUDA inductor tests are currently not run in CI because the only jobs that have triton installed don't actually run these test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88957 Approved by: https://github.com/ngimel, https://github.com/seemethere --- .jenkins/pytorch/test.sh | 1 + test/inductor/test_torchinductor_opinfo.py | 36 ++++++---------------- test/run_test.py | 1 + 3 files changed, 11 insertions(+), 27 deletions(-) diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 5fa54f538f35..135fb50762d6 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -250,6 +250,7 @@ test_inductor_distributed() { test_inductor() { python test/run_test.py --include test_modules test_ops --verbose + PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include inductor/test_torchinductor --include inductor/test_torchinductor_opinfo --verbose # TODO: investigate "RuntimeError: CUDA driver API confirmed a leak" # seen intest_ops_gradients.py # pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64" diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 7db9d13733b4..67b64c73a8ef 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -127,35 +127,14 @@ def process(device_type): } inductor_skips["cuda"] = { - # flaky - "__rdiv__": {b8, f16, f32, f64, i32, i64}, - "masked.prod": {f16, f32, f64}, - "linalg.vander": {f32, f64}, - "sparse.sampled_addmm": {f32, f64}, - "broadcast_tensors": {f16, f32, f64}, - "dsplit": {f16, f32, f64}, # Jiterator kernel is not expected to work with inductor "jiterator_2inputs_2outputs": {b8, f16, f32, f64, i32, i64}, "jiterator_4inputs_with_extra_args": {b8, f16, f32, f64, i32, i64}, "jiterator_binary": {b8, f16, f32, f64, i32, i64}, "jiterator_binary_return_by_ref": {b8, f16, f32, f64, i32, i64}, "jiterator_unary": {b8, f16, f32, f64, i32, i64}, - # Disabled on migration to core - "linalg.pinv.singular": {f32, f64}, - "linalg.householder_product": {f32}, - # These might be passing now? - "__getitem__": {b8, f16, f32, f64, i32, i64}, - "nn.functional.conv_transpose3d": {f16}, - "max.reduction_with_dim": {i32, i64}, - "min.reduction_with_dim": {i32, i64}, - "linalg.lu": {f32, f64}, - "lu_unpack": {f32, f64}, + # flaky "native_batch_norm": {f16, f32, f64}, - "native_layer_norm": {f16, f32, f64}, - # Issues on sm86 periodic job (complex numbers) - "cdouble": {b8, f16, f32, f64, i32, i64}, - "cfloat": {b8, f16, f32, f64, i32, i64}, - "randint": {b8, f16, f32, f64, i32, i64}, } inductor_expected_failures_single_sample = defaultdict(dict) @@ -280,6 +259,7 @@ def process(device_type): "mH": {b8, f16, f32, f64, i32, i64}, "mT": {b8, f16, f32, f64, i32, i64}, "__getitem__": {b8, f16, f32, f64, i32, i64}, + "__rdiv__": {b8, f16, f32, f64, i32, i64}, "allclose": {f16, f32, f64}, "angle": {f32, f64}, "argwhere": {b8, f16, f32, f64, i32, i64}, @@ -287,6 +267,8 @@ def process(device_type): "bernoulli": {f16, f32, f64}, "bincount": {i32, i64}, "bucketize": {b8, f16, f32, f64, i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -322,13 +304,13 @@ def process(device_type): "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "lu_unpack": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "masked.argmax": {f16, f32, f64, i32}, "masked.argmin": {f16, f32, f64, i32}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, - "max.reduction_with_dim": {b8, i32, i64}, - "min.reduction_with_dim": {b8, i32, i64}, + "max.reduction_with_dim": {b8}, + "min.reduction_with_dim": {b8}, "multinomial": {f16, f32, f64}, "nn.functional.adaptive_avg_pool2d": {f16}, "nn.functional.ctc_loss": {f32, f64}, @@ -346,12 +328,14 @@ def process(device_type): "pow": {i32, i64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "round.decimals_3": {f16}, "scatter_reduce.prod": {f16, f32, f64}, "segment_reduce.lengths": {f16, f32, f64}, "sgn": {f16, f32, f64}, + "sparse.sampled_addmm": {f32, f64}, "stft": {f32, f64}, "svd_lowrank": {f32, f64}, "tensor_split": {b8, f16, f32, f64, i32, i64}, @@ -375,8 +359,6 @@ def process(device_type): "linalg.vector_norm": {f64, f64}, "kron": {f16}, "nanquantile": {f32, f64}, - "native_batch_norm": {f16, f32, f64}, - "native_layer_norm": {f16, f32, f64}, "nn.functional._scaled_dot_product_attention": {f16}, "nn.functional.avg_pool2d": {f16, f32, f64}, "nn.functional.batch_norm.without_cudnn": {f16}, diff --git a/test/run_test.py b/test/run_test.py index 1273ab45c4fb..8a25a2e70785 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -792,6 +792,7 @@ def run_test_ops(test_module, test_directory, options): "distributed/rpc/test_share_memory": get_run_test_with_subprocess_fn(), "distributed/rpc/cuda/test_tensorpipe_agent": get_run_test_with_subprocess_fn(), "doctests": run_doctests, + "inductor/test_torchinductor_opinfo": run_test_ops, "test_ops": run_test_ops, "test_ops_gradients": run_test_ops, "test_ops_fwd_gradients": run_test_ops, From a6ef2c7634e2a77fe698d5335d29e10ca24cdf2b Mon Sep 17 00:00:00 2001 From: Huy Do Date: Wed, 16 Nov 2022 18:25:38 +0000 Subject: [PATCH 243/453] Support test-config filter logic for rocm (#89046) The logic used by `mem_leak_check` https://github.com/pytorch/pytorch/pull/88373 is currently not applied to rocm, i.e. https://hud.pytorch.org/pytorch/pytorch/commit/06486cd0087200e08ebb8a9518e064251c7c5309 because its workflows don't have the test-config filtering logic yet (linux, mac, and windows all have it already). In another work, rocm tests always run with mem leak check disabled at the moment. We want that but also to run the test with mem leak check enabled periodically one per day. This PR closes that gap Pull Request resolved: https://github.com/pytorch/pytorch/pull/89046 Approved by: https://github.com/clee2000 --- .github/workflows/_rocm-test.yml | 28 +++++++++++++++++++++++++--- .github/workflows/periodic.yml | 20 +++++++++++--------- .github/workflows/pull.yml | 5 +++++ .github/workflows/trunk.yml | 11 ++++++----- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index dd1a0830275c..be4a5c9dcc6c 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -39,12 +39,34 @@ env: GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} jobs: + # This needs to be run right before the test starts so that it can gather the + # latest labels from the PR + filter: + runs-on: [self-hosted, linux.large] + outputs: + test-matrix: ${{ steps.filter.outputs.test-matrix }} + is-test-matrix-empty: ${{ steps.filter.outputs.is-test-matrix-empty }} + steps: + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@master + with: + fetch-depth: 1 + submodules: false + + - name: Select all requested test configurations + id: filter + uses: ./.github/actions/filter-test-configs + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + test-matrix: ${{ inputs.test-matrix }} + test: - # Don't run on forked repos. - if: github.repository_owner == 'pytorch' + needs: filter + # Don't run on forked repos or empty test matrix + if: github.repository_owner == 'pytorch' && needs.filter.outputs.is-test-matrix-empty == 'False' timeout-minutes: 300 strategy: - matrix: ${{ fromJSON(inputs.test-matrix) }} + matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} fail-fast: false runs-on: ${{ matrix.runner }} steps: diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 61302e1a0d61..b5512b20eaae 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -41,6 +41,10 @@ jobs: with: build-environment: linux-focal-rocm5.2-py3.8 docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 + test-matrix: | + { include: [ + { config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, + ]} linux-focal-rocm5_2-py3_8-slow-test: name: linux-focal-rocm5.2-py3.8-slow @@ -49,10 +53,7 @@ jobs: with: build-environment: linux-focal-rocm5.2-py3.8 docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-slow-build.outputs.docker-image }} - test-matrix: | - { include: [ - { config: "slow", shard: 1, num_shards: 1, runner: "linux.rocm.gpu" }, - ]} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-slow-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} @@ -63,6 +64,11 @@ jobs: with: build-environment: linux-focal-rocm5.2-py3.8 docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 + test-matrix: | + { include: [ + { config: "distributed", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, + { config: "distributed", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, + ]} linux-focal-rocm5_2-py3_8-distributed-test: name: linux-focal-rocm5.2-py3.8-distributed @@ -71,11 +77,7 @@ jobs: with: build-environment: linux-focal-rocm5.2-py3.8 docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-distributed-build.outputs.docker-image }} - test-matrix: | - { include: [ - { config: "distributed", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, - { config: "distributed", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, - ]} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-distributed-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index c3d530e3e718..3208cb198bb4 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -308,3 +308,8 @@ jobs: build-environment: linux-focal-rocm5.2-py3.8 docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 sync-tag: rocm-build + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, + { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, + ]} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index cb5d1291833a..6779a362209c 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -298,6 +298,11 @@ jobs: build-environment: linux-focal-rocm5.2-py3.8 docker-image-name: pytorch-linux-focal-rocm5.2-py3.8 sync-tag: rocm-build + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, + { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, + ]} linux-focal-rocm5_2-py3_8-test: name: linux-focal-rocm5.2-py3.8 @@ -306,11 +311,7 @@ jobs: with: build-environment: linux-focal-rocm5.2-py3.8 docker-image: ${{ needs.linux-focal-rocm5_2-py3_8-build.outputs.docker-image }} - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "linux.rocm.gpu" }, - { config: "default", shard: 2, num_shards: 2, runner: "linux.rocm.gpu" }, - ]} + test-matrix: ${{ needs.linux-focal-rocm5_2-py3_8-build.outputs.test-matrix }} secrets: AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID: ${{ secrets.AWS_OSSCI_METRICS_V2_ACCESS_KEY_ID }} AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY: ${{ secrets.AWS_OSSCI_METRICS_V2_SECRET_ACCESS_KEY }} From 61801799a0a6a2fe0b577450c1fdd55af6063664 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 15 Nov 2022 13:27:57 -0800 Subject: [PATCH 244/453] [Quant][bc-breaking] Remove overwrite_output_observer (#88620) Summary: When the BackendConfig was first introduced, `overwrite_output_observer` and `overwrite_output_fake_quantize` were added to ensure fixed qparams ops like `torch.nn.Sigmoid` and `torch.nn.Tanh` used the correct observers and fake quantizes. However, this is hacky because the BackendConfig should not set the observer constructors themselves, but should instead specify only requirements on the observers. Later, https://github.com/pytorch/pytorch/pull/80184 added the correct observers to `get_default_qconfig_mapping` along with validation logic that throws an error if incorrect observers were specified. With this change, we no longer need to overwrite the observers from the BackendConfig, since we expect the user to pass in the correct observers for these ops. This commit removes these overwrite observer settings in the BackendConfig. Instead, we represent the observer constraints for fixed qparams ops through the existing DTypeWithConstraints mechanism. Note that, however, to be consistent with other DTypeWithConstraints checks, we no longer throw an error if an incorrect observer is specified, but simply ignore the offending QConfig and log a warning instead. This is the BC-breaking part of the change. BC-breaking notes: ``` from torch.ao.quantization.qconfig import default_qconfig from torch.ao.quantization.quantize_fx import prepare_fx model = ModelWithFixedQParamsOps() qconfig_mapping = QConfigMapping().set_global(default_qconfig) example_inputs = ... prepare_fx(model, qconfig_mapping, example_inputs) ``` Before this commit, running the above leads to an exception because the wrong observers are used for fixed qparams ops. After this commit, the above will only encounter a warning, and the fixed qparams ops will not be quantized. In both cases, switching to `get_default_qconfig_mapping` will cause the fixed qparams ops to be quantized. Test Plan: python test/test_quantization.py TestQuantizeFx python test/test_quantization.py TestQuantizeFxOps Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo Pull Request resolved: https://github.com/pytorch/pytorch/pull/88620 Approved by: https://github.com/jerryzh168 --- test/quantization/core/test_backend_config.py | 25 +----- test/quantization/fx/test_quantize_fx.py | 16 ++-- .../ao/quantization/backend_config/README.md | 4 + .../_common_operator_config_utils.py | 82 ++++++++++++++++--- .../backend_config/backend_config.py | 26 ++---- torch/ao/quantization/backend_config/utils.py | 8 -- .../quantization/fx/backend_config_utils.py | 33 +------- torch/ao/quantization/fx/prepare.py | 61 +------------- .../quantization/fx/quantization_patterns.py | 20 +---- torch/ao/quantization/fx/utils.py | 44 +++++++++- torch/ao/quantization/qconfig_mapping.py | 1 + 11 files changed, 137 insertions(+), 183 deletions(-) diff --git a/test/quantization/core/test_backend_config.py b/test/quantization/core/test_backend_config.py index aa9de64824bc..e641e58bb2aa 100644 --- a/test/quantization/core/test_backend_config.py +++ b/test/quantization/core/test_backend_config.py @@ -13,10 +13,8 @@ DTypeWithConstraints, ObservationType, ) -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize from torch.ao.quantization.fuser_method_mappings import _reverse_sequential_wrapper2 from torch.ao.quantization.fx.quantization_patterns import _default_root_node_getter -from torch.ao.quantization.observer import default_fixed_qparams_range_0to1_observer class TestBackendConfig(QuantizationTestCase): @@ -118,7 +116,6 @@ def test_dtype_config_to_dict(self): "input": 1, "weight": 2, } - _fake_quantize = FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer) def _extra_inputs_getter(self, p): return (torch.rand(3, 3),) @@ -141,9 +138,7 @@ def _get_backend_op_config2(self): ._set_extra_inputs_getter(self._extra_inputs_getter) \ ._set_num_tensor_args_to_observation_type(self._num_tensor_args_to_observation_type) \ ._set_input_type_to_index(self._input_type_to_index) \ - ._set_input_output_observed(False) \ - ._set_overwrite_output_fake_quantize(self._fake_quantize) \ - ._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) + ._set_input_output_observed(False) def _get_backend_pattern_config_dict1(self): return { @@ -167,8 +162,6 @@ def _get_backend_pattern_config_dict2(self): "num_tensor_args_to_observation_type": self._num_tensor_args_to_observation_type, "input_type_to_index": self._input_type_to_index, "input_output_observed": False, - "overwrite_output_fake_quantize": self._fake_quantize, - "overwrite_output_observer": default_fixed_qparams_range_0to1_observer } def test_backend_op_config_set_observation_type(self): @@ -246,18 +239,6 @@ def test_backend_op_config_set_input_output_observed(self): conf._set_input_output_observed(False) self.assertEqual(conf._input_output_observed, False) - def test_backend_op_config_set_overwrite_output_fake_quantize(self): - conf = BackendPatternConfig(torch.sigmoid) - self.assertTrue(conf._overwrite_output_fake_quantize is None) - conf._set_overwrite_output_fake_quantize(self._fake_quantize) - self.assertEqual(conf._overwrite_output_fake_quantize, self._fake_quantize) - - def test_backend_op_config_set_overwrite_output_observer(self): - conf = BackendPatternConfig(torch.sigmoid) - self.assertTrue(conf._overwrite_output_observer is None) - conf._set_overwrite_output_observer(default_fixed_qparams_range_0to1_observer) - self.assertEqual(conf._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) - def test_backend_op_config_from_dict(self): conf_dict1 = self._get_backend_pattern_config_dict1() conf1 = BackendPatternConfig.from_dict(conf_dict1) @@ -273,8 +254,6 @@ def test_backend_op_config_from_dict(self): self.assertEqual(len(conf1._num_tensor_args_to_observation_type), 0) self.assertEqual(len(conf1._input_type_to_index), 0) self.assertTrue(conf1._input_output_observed is None) - self.assertTrue(conf1._overwrite_output_fake_quantize is None) - self.assertTrue(conf1._overwrite_output_observer is None) # Test temporary/internal keys conf_dict2 = self._get_backend_pattern_config_dict2() conf2 = BackendPatternConfig.from_dict(conf_dict2) @@ -290,8 +269,6 @@ def test_backend_op_config_from_dict(self): self.assertEqual(conf2._num_tensor_args_to_observation_type, self._num_tensor_args_to_observation_type) self.assertEqual(conf2._input_type_to_index, self._input_type_to_index) self.assertEqual(conf2._input_output_observed, False) - self.assertEqual(conf2._overwrite_output_fake_quantize, self._fake_quantize) - self.assertEqual(conf2._overwrite_output_observer, default_fixed_qparams_range_0to1_observer) def test_backend_op_config_to_dict(self): conf1 = self._get_backend_op_config1() diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 6721e397180e..6c631a24abc6 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -6792,9 +6792,8 @@ def forward(self, x): M(), data, quant_type, custom_qconfig_dict=qconfig_mapping, expected_node_occurrence=node_occurrence, is_reference=True) - def test_fixed_qparams_ops_qconfig_error(self): - """ Test that a proper error message is shown when user don't specify the correct - qconfig for fixed qaprams ops + def test_fixed_qparams_ops_wrong_qconfig(self): + """ Test that wrong qconfigs for fixed qparams ops results in the ops not being quantized. """ class M(torch.nn.Module): def __init__(self): @@ -6814,8 +6813,15 @@ def forward(self, x): data = (torch.randn((2, 2, 2, 2), dtype=torch.float),) qconfig_mapping = QConfigMapping().set_global(default_qconfig) m = M().eval() - with self.assertRaisesRegex(ValueError, "get_default_qconfig_mapping"): - m = prepare_fx(m, qconfig_mapping, data) + node_occurrence = { + ns.call_function(torch.quantize_per_tensor): 0, + ns.call_method("dequantize"): 0, + } + self.checkGraphModeFxOp( + m, data, QuantType.STATIC, custom_qconfig_dict=qconfig_mapping, + expected_node_occurrence=node_occurrence, is_reference=True) + self.assertTrue(isinstance(m.sigmoid, torch.nn.Sigmoid)) + self.assertTrue(isinstance(m.tanh, torch.nn.Tanh)) @skipIfNoFBGEMM def test_general_shape_ops(self): diff --git a/torch/ao/quantization/backend_config/README.md b/torch/ao/quantization/backend_config/README.md index b8d8ceb3e38d..985765e6badc 100644 --- a/torch/ao/quantization/backend_config/README.md +++ b/torch/ao/quantization/backend_config/README.md @@ -152,3 +152,7 @@ The user's QConfig may specify `quant_min` and `quant_max`, which are min and ma #### Scale range Similarly, the user's QConfig may specify a minimum value for the quantization scale (currently exposed as `eps` but will change in the future to better reflect the semantics). Here we set the lower bound for the `scale_min` to represent the limits of the backend. If a QConfig's min scale value falls below this limit, the QConfig will be treated as violating this constraint. Note that `scale_max_upper_bound` is currently not used, because there is no corresponding mechanism to enforce this on the observer yet. + +#### Fixed quantization parameters + +For ops with fixed quantization parameters such as `torch.nn.Sigmoid` or `torch.nn.Tanh`, the BackendConfig can specify the specific scale and zero point values as constraints on the input and output activations. The user's QConfigs for these ops must use `FixedQParamsObserver` or `FixedQParamsFakeQuantize` for their activations with matching scale and zero point values, otherwise these QConfigs will be ignored. diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index c2f0f7227b10..47a0b3024208 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -1,3 +1,4 @@ +import copy import operator import torch import torch.nn.functional as F @@ -7,13 +8,13 @@ import torch.nn.qat as nnqat import torch.nn.quantized._reference as nnqr from collections import namedtuple -from typing import List +from typing import Callable, Dict, List, Union from .backend_config import ( BackendPatternConfig, DTypeConfig, + DTypeWithConstraints, ObservationType, ) -from ..fake_quantize import FixedQParamsFakeQuantize from ..fuser_method_mappings import ( _reverse_sequential_wrapper2, _reverse2, @@ -23,7 +24,6 @@ fuse_linear_bn, fuse_convtranspose_bn, ) -from ..qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER # TODO: rename to be more explict, e.g. qat_conv_relu _ConvMetadata = namedtuple( @@ -48,6 +48,38 @@ nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) +# Add constraints for fixed qparams ops like sigmoid and tanh to ensure values +# fall within the proper ranges, e.g. [0, 1] for sigmoid, [-1, 1] for tanh +_FIXED_QPARAM_OP_0TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=1.0 / 256.0, + zero_point_exact_match=0, +) +_FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS = DTypeWithConstraints( + dtype=torch.quint8, + quant_min_lower_bound=0, + quant_max_upper_bound=255, + scale_exact_match=2.0 / 256.0, + zero_point_exact_match=128, +) +_FIXED_QPARAMS_OP_TO_CONSTRAINTS: Dict[Union[Callable, str], DTypeWithConstraints] = { + torch.nn.Hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.functional.hardsigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "hardsigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.sigmoid: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + "sigmoid_": _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Softmax: _FIXED_QPARAM_OP_0TO1_CONSTRAINTS, + torch.nn.Tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + torch.tanh: _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, + "tanh_": _FIXED_QPARAM_OP_NEG1TO1_CONSTRAINTS, +} + def _get_binary_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: binary_op_configs: List[BackendPatternConfig] = [] num_tensor_args_to_observation_type_mapping = { @@ -393,21 +425,45 @@ def _get_default_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPat ) return configs +def _add_fixed_qparams_to_dtype_configs( + dtype_configs: List[DTypeConfig], + constraints: DTypeWithConstraints, +) -> List[DTypeConfig]: + """ + Return a copy of the list of DTypeConfigs where activations are subject to the specified + constraints required for fixed qparams ops. + + If the data type doesn't match the one in the constraints, simply leave the corresponding + DTypeConfig unchanged. + + If `scale_min_lower_bound` or `scale_max_upper_bound` is specified in the activations, + throw an exception since these settings are incompatible with fixed qparams ops. + """ + new_dtype_configs = [] + for dtype_config in dtype_configs: + dc = copy.deepcopy(dtype_config) + for orig_constraints in [dc.input_dtype_with_constraints, dc.output_dtype_with_constraints]: + if orig_constraints.dtype != constraints.dtype: + continue + if orig_constraints.scale_min_lower_bound is not None: + raise ValueError("scale_min_lower_bound is invalid for fixed qparams ops: %s" % dtype_config) + if orig_constraints.scale_max_upper_bound is not None: + raise ValueError("scale_max_upper_bound is invalid for fixed qparams ops: %s" % dtype_config) + orig_constraints.quant_min_lower_bound = constraints.quant_min_lower_bound + orig_constraints.quant_max_upper_bound = constraints.quant_max_upper_bound + orig_constraints.scale_exact_match = constraints.scale_exact_match + orig_constraints.zero_point_exact_match = constraints.zero_point_exact_match + new_dtype_configs.append(dc) + return new_dtype_configs + def _get_fixed_qparams_op_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: fixed_qparams_op_configs = [] - for fixed_qparam_op, output_observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items(): + for fixed_qparam_op, constraints in _FIXED_QPARAMS_OP_TO_CONSTRAINTS.items(): + new_dtype_configs = _add_fixed_qparams_to_dtype_configs(dtype_configs, constraints) fixed_qparams_op_configs.append( - # TODO: The _overwrite_output keys are temporary, since we don't want to put observer - # in the configs we expect that it's provided by user - # What we want to put here is the requirement on observers, in this case dtype, - # quant_min, quant_max etc., but we need to first move all configs to - # backend_config_dict to do that, we'll remove these keys after we fully migrated - # everything to use backend_config_dict BackendPatternConfig(fixed_qparam_op) .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 - .set_dtype_configs(dtype_configs) - ._set_overwrite_output_fake_quantize(FixedQParamsFakeQuantize.with_args(observer=output_observer)) - ._set_overwrite_output_observer(output_observer)) + .set_dtype_configs(new_dtype_configs)) return fixed_qparams_op_configs def _get_share_qparams_op_configs(dtype_configs): diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 1305c32a4ea8..e8af42ff4b6a 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -3,7 +3,6 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union import torch -from torch.ao.quantization.observer import _PartialWrapper from torch.ao.quantization.utils import Pattern from enum import Enum @@ -42,8 +41,6 @@ NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY = "num_tensor_args_to_observation_type" INPUT_TYPE_TO_INDEX_DICT_KEY = "input_type_to_index" INPUT_OUTPUT_OBSERVED_DICT_KEY = "input_output_observed" -OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY = "overwrite_output_fake_quantize" -OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer" # TODO: maybe rename this to something that's not related to observer @@ -69,14 +66,17 @@ class ObservationType(Enum): @dataclass class DTypeWithConstraints: """ - Config for specifying additional constraints for a given dtype, such as quantization value - ranges and scale value ranges, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + Config for specifying additional constraints for a given dtype, such as quantization + value ranges, scale value ranges, and fixed quantization params, to be used in + :class:`~torch.ao.quantization.backend_config.DTypeConfig`. """ dtype: Optional[torch.dtype] = None quant_min_lower_bound: Union[int, float, None] = None quant_max_upper_bound: Union[int, float, None] = None scale_min_lower_bound: Union[int, float, None] = None scale_max_upper_bound: Union[int, float, None] = None + scale_exact_match: Optional[float] = None + zero_point_exact_match: Optional[int] = None @dataclass @@ -336,8 +336,6 @@ def __init__(self, pattern: Pattern): self._num_tensor_args_to_observation_type: Dict[int, ObservationType] = {} self._input_type_to_index: Dict[str, int] = {} self._input_output_observed: Optional[bool] = None - self._overwrite_output_fake_quantize: Optional[_PartialWrapper] = None - self._overwrite_output_observer: Optional[_PartialWrapper] = None def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: """ @@ -433,14 +431,6 @@ def _set_input_output_observed(self, input_output_observed: bool) -> BackendPatt self._input_output_observed = input_output_observed return self - def _set_overwrite_output_fake_quantize(self, overwrite_output_fake_quantize: _PartialWrapper) -> BackendPatternConfig: - self._overwrite_output_fake_quantize = overwrite_output_fake_quantize - return self - - def _set_overwrite_output_observer(self, overwrite_output_observer: _PartialWrapper) -> BackendPatternConfig: - self._overwrite_output_observer = overwrite_output_observer - return self - @classmethod def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: """ @@ -487,8 +477,6 @@ def _get_dtype_config(obj: Any) -> DTypeConfig: backend_pattern_config_dict.get(NUM_TENSOR_ARGS_TO_OBSERVATION_TYPE_DICT_KEY, {})) conf._set_input_type_to_index(backend_pattern_config_dict.get(INPUT_TYPE_TO_INDEX_DICT_KEY, {})) conf._set_input_output_observed(backend_pattern_config_dict.get(INPUT_OUTPUT_OBSERVED_DICT_KEY, None)) - conf._set_overwrite_output_fake_quantize(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY, None)) - conf._set_overwrite_output_observer(backend_pattern_config_dict.get(OVERWRITE_OUTPUT_OBSERVER_DICT_KEY, None)) return conf def to_dict(self) -> Dict[str, Any]: @@ -521,8 +509,4 @@ def to_dict(self) -> Dict[str, Any]: backend_pattern_config_dict[INPUT_TYPE_TO_INDEX_DICT_KEY] = self._input_type_to_index if self._input_output_observed is not None: backend_pattern_config_dict[INPUT_OUTPUT_OBSERVED_DICT_KEY] = self._input_output_observed - if self._overwrite_output_fake_quantize is not None: - backend_pattern_config_dict[OVERWRITE_OUTPUT_FAKE_QUANTIZE_DICT_KEY] = self._overwrite_output_fake_quantize - if self._overwrite_output_observer is not None: - backend_pattern_config_dict[OVERWRITE_OUTPUT_OBSERVER_DICT_KEY] = self._overwrite_output_observer return backend_pattern_config_dict diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index cdc58327fbee..fc7e9aca9ff6 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -5,7 +5,6 @@ import torch.nn.functional as F from .backend_config import BackendConfig, DTypeConfig from ..utils import Pattern -from ..observer import _PartialWrapper __all__ = [ "get_pattern_to_dtype_configs", @@ -86,13 +85,6 @@ def get_root_node(node_pattern): root_node_getter_mapping[pattern] = config._root_node_getter return root_node_getter_mapping -def get_fixed_qparams_op_to_overwrite_output_observer(backend_config: BackendConfig) -> Dict[Union[Callable, str], _PartialWrapper]: - fixed_qparam_op_to_overwrite_output_observer: Dict[Union[Callable, str], _PartialWrapper] = {} - for pattern, config in backend_config.configs.items(): - if config._overwrite_output_observer is not None: - fixed_qparam_op_to_overwrite_output_observer[pattern] = config._overwrite_output_observer # type: ignore[index] - return fixed_qparam_op_to_overwrite_output_observer - def get_fusion_pattern_to_extra_inputs_getter(backend_config: BackendConfig) -> Dict[Pattern, Callable]: """ Get a map from fusion pattern to a function that returns extra input nodes from the fusion pattern, in the order required by the root node. This is optional, diff --git a/torch/ao/quantization/fx/backend_config_utils.py b/torch/ao/quantization/fx/backend_config_utils.py index eef4979a0a06..50c6b6a27ede 100644 --- a/torch/ao/quantization/fx/backend_config_utils.py +++ b/torch/ao/quantization/fx/backend_config_utils.py @@ -5,7 +5,6 @@ ObservationType, ) from torch.ao.quantization.utils import ( - activation_dtype, get_combined_dict, Pattern, NodePattern, @@ -16,14 +15,12 @@ from .quantization_patterns import QuantizeHandler from .fusion_patterns import DefaultFuseHandler -from typing import Dict, Any, Callable, Optional +from typing import Callable, Dict def get_quantize_handler_cls( observation_type, dtype_configs, num_tensor_args_to_observation_type, - overwrite_output_fake_quantizer, - overwrite_output_observer, input_output_observed): class ConfigurableQuantizeHandler(QuantizeHandler): @@ -41,35 +38,11 @@ def __init__( else: self.observation_type = observation_type self.dtype_configs = dtype_configs - self.overwrite_output_fake_quantizer = overwrite_output_fake_quantizer - self.overwrite_output_observer = overwrite_output_observer self.input_output_observed_ = input_output_observed def is_general_tensor_value_op(self) -> bool: return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT - # TODO: change this to output activation - def get_activation_ctr( - self, - qconfig: Any, - pattern: Pattern, - is_training: bool, - ) -> Optional[Callable]: - """ - Returns the constructor for the activation observer which should be - used for the pattern matched to this handler. Some handlers override - this to a different value than what is specified in the qconfig. - """ - act_dtype = activation_dtype(qconfig) - # TODO: change to is_qat - if is_training: - if act_dtype == torch.quint8 and self.overwrite_output_fake_quantizer is not None: - return self.overwrite_output_fake_quantizer - else: - if act_dtype == torch.quint8 and self.overwrite_output_observer is not None: - return self.overwrite_output_observer - return qconfig.activation - # This is temporary, and will be removed soon def input_output_observed(self): return self.input_output_observed_ @@ -89,8 +62,6 @@ def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Patt observation_type = config.observation_type dtype_configs = config.dtype_configs num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type - overwrite_fake_quantizer = config._overwrite_output_fake_quantize - overwrite_observer = config._overwrite_output_observer input_output_observed = config._input_output_observed if input_output_observed is None: input_output_observed = True @@ -99,8 +70,6 @@ def get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Patt observation_type, dtype_configs, num_tensor_args_to_observation_type, - overwrite_fake_quantizer, - overwrite_observer, input_output_observed) return pattern_to_quantize_handlers diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 281bd960ed7b..c908e3f3b764 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -18,9 +18,6 @@ ObserverBase, ) from ..qconfig import ( - _obs_or_fq_ctr_equals, - float16_dynamic_qconfig, - float16_static_qconfig, _is_reuse_input_qconfig, QConfigAny, ) @@ -45,8 +42,6 @@ NodePattern, ) -from torch.ao.quantization import FixedQParamsFakeQuantize - from ._equalize import ( is_equalization_observer, node_supports_equalization, @@ -91,14 +86,12 @@ get_qconfig_dtypes, get_swapped_custom_module_class, activation_is_statically_quantized, - activation_is_int8_quantized, ) from ..backend_config.utils import ( get_pattern_to_dtype_configs, get_module_to_qat_module, get_fusion_pattern_to_root_node_getter, - get_fixed_qparams_op_to_overwrite_output_observer, ) from ..backend_config import ( BackendConfig, @@ -826,13 +819,7 @@ def maybe_insert_output_observer_for_node( (not is_standalone_module) if should_insert_observer: - act_post_process_ctr = qconfig.activation - if activation_is_int8_quantized(qconfig): - act_post_process_ctr = qhandler.get_activation_ctr( - qconfig, - matched_pattern, - is_qat) - observer = act_post_process_ctr() + observer = qconfig.activation() return insert_observer(node, observer, model, modules, graph) else: return None @@ -1392,51 +1379,6 @@ def insert_observers_for_model( return results_node -def _validate_fixed_qparams_qconfigs( - model: GraphModule, - node_name_to_qconfig: Dict[str, QConfigAny], - backend_config: BackendConfig): - """ - Validate whether the correct observers are configured for fixed qparams ops in the model, if any. - """ - # TODO: handle fp16 qconfigs properly - allowed_observer_ctrs = [ - float16_dynamic_qconfig.activation, - float16_static_qconfig.activation, - ] - named_modules = dict(model.named_modules(remove_duplicate=False)) - fixed_qparams_op_to_overwrite_output_observer = \ - get_fixed_qparams_op_to_overwrite_output_observer(backend_config) - for node in model.graph.nodes: - if node.op == "call_function": - module_type_or_function_or_method = node.target - elif node.op == "call_module": - module_type_or_function_or_method = type(named_modules[node.target]) - else: - module_type_or_function_or_method = None - - if module_type_or_function_or_method in fixed_qparams_op_to_overwrite_output_observer: - bad_observer = True - qconfig = node_name_to_qconfig.get(node.name, None) - if qconfig is None: - bad_observer = False - else: - for observer_ctr in allowed_observer_ctrs + [ - fixed_qparams_op_to_overwrite_output_observer[module_type_or_function_or_method]]: - if _obs_or_fq_ctr_equals( - qconfig.activation, - FixedQParamsFakeQuantize.with_args(observer=observer_ctr)) or \ - _obs_or_fq_ctr_equals(qconfig.activation, observer_ctr): - bad_observer = False - if bad_observer: - raise ValueError("QConfigMapping must specify fixed qparams observer for fixed qparams op " - "'%s' type: '%s'. Please use torch.ao.quantization.get_default_qconfig_mapping or " - "torch.ao.quantization.get_default_qat_qconfig_mapping" - " instead. Example: \n" - " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\") \n" - " model = prepare_fx(model, qconfig_mapping, example_inputs)" - "" % (node.format_node(), module_type_or_function_or_method)) - def run_prepare_fx_on_standalone_modules( model: torch.nn.Module, is_qat: bool, @@ -1609,7 +1551,6 @@ def prepare( equalization_node_name_to_qconfig = generate_node_name_to_qconfig( model, modules, model.graph, _equalization_config, node_name_to_scope) node_name_to_qconfig = generate_node_name_to_qconfig(model, modules, model.graph, qconfig_mapping, node_name_to_scope) - _validate_fixed_qparams_qconfigs(model, node_name_to_qconfig, backend_config) # match the patterns that will get quantized standalone_module_names = list(prepare_custom_config.standalone_module_names.keys()) diff --git a/torch/ao/quantization/fx/quantization_patterns.py b/torch/ao/quantization/fx/quantization_patterns.py index c24adb9e11e9..f8d72de9c96a 100644 --- a/torch/ao/quantization/fx/quantization_patterns.py +++ b/torch/ao/quantization/fx/quantization_patterns.py @@ -6,13 +6,10 @@ from .utils import ( all_node_args_have_no_tensors, ) -from torch.ao.quantization.utils import ( - Pattern, - NodePattern, -) +from torch.ao.quantization.utils import NodePattern from abc import ABC -from typing import Any, Callable, Dict, Optional +from typing import Callable, Dict __all__ = [ "QuantizeHandler", @@ -98,19 +95,6 @@ def is_general_tensor_value_op(self) -> bool: """ return False - def get_activation_ctr( - self, - qconfig: Any, - pattern: Pattern, - is_training: bool, - ) -> Optional[Callable]: - """ - Returns the constructor for the activation observer which should be - used for the pattern matched to this handler. Some handlers override - this to a different value than what is specified in the qconfig. - """ - return qconfig.activation - def is_custom_module(self): return self.is_custom_module_ diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index a5a989ec2148..73fdb0700144 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -10,8 +10,19 @@ BackendConfig, DTypeWithConstraints, ) -from torch.ao.quantization.fake_quantize import FakeQuantizeBase -from torch.ao.quantization.observer import ObserverBase +from torch.ao.quantization.fake_quantize import ( + FakeQuantizeBase, + FixedQParamsFakeQuantize, +) +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + ObserverBase, +) +from torch.ao.quantization.qconfig import ( + float16_static_qconfig, + float16_dynamic_qconfig, + qconfig_equals, +) from torch.ao.quantization.stubs import DeQuantStub from torch.ao.quantization.utils import ( activation_is_statically_quantized, @@ -951,10 +962,13 @@ def _qconfig_satisfies_dtype_config_constraints( 1. QConfig specified a quantization range that falls within the backend's, if any 2. QConfig specified a min scale value that is >= the backend's, if any + 3. QConfig specified a FixedQParamsObserver or FixedQParamsFakeQuantize that has + scale and zero point that match the backend's, if any If `is_activation` is True, we check `qconfig.activation`, else we check `qconfig.weight`. If `qconfig` or `dtype_with_constraints.dtype` is None, or the dtypes do not match, return True. """ + # TODO: log warnings only when the user enabled a debug flag def _activation_post_process_satisfies_dtype_config_constraints( activation_post_process: Union[ObserverBase, FakeQuantizeBase], dtype_with_constraints: DTypeWithConstraints, @@ -968,6 +982,8 @@ def _activation_post_process_satisfies_dtype_config_constraints( backend_quant_min = dtype_with_constraints.quant_min_lower_bound backend_quant_max = dtype_with_constraints.quant_max_upper_bound backend_scale_min = dtype_with_constraints.scale_min_lower_bound + backend_scale_exact_match = dtype_with_constraints.scale_exact_match + backend_zero_point_exact_match = dtype_with_constraints.zero_point_exact_match # check quantization ranges if backend_quant_min is not None and backend_quant_max is not None: if app_quant_min is None or app_quant_max is None: @@ -990,6 +1006,30 @@ def _activation_post_process_satisfies_dtype_config_constraints( "the backend's min scale value (%s), ignoring %s") % (debug_string, app_scale_min, backend_scale_min, qconfig)) return False + # check fixed scale and zero point + if backend_scale_exact_match is not None and backend_zero_point_exact_match is not None: + # For tests only, accept the following qconfigs for now + # TODO: handle fp16 qconfigs properly + for accepted_qconfig in [float16_static_qconfig, float16_dynamic_qconfig]: + if qconfig_equals(qconfig, accepted_qconfig): + return True + suggestion_str = ( + "Please use torch.ao.quantization.get_default_qconfig_mapping or " + "torch.ao.quantization.get_default_qat_qconfig_mapping. Example:\n" + " qconfig_mapping = get_default_qconfig_mapping(\"fbgemm\")\n" + " model = prepare_fx(model, qconfig_mapping, example_inputs)" + ) + if not isinstance(activation_post_process, FixedQParamsObserver) and \ + not isinstance(activation_post_process, FixedQParamsFakeQuantize): + warnings.warn(("QConfig must specify a FixedQParamsObserver or a FixedQParamsFakeQuantize " + "for fixed qparams ops, ignoring %s.\n%s") % (qconfig, suggestion_str)) + return False + if observer.scale != backend_scale_exact_match or observer.zero_point != backend_zero_point_exact_match: + warnings.warn(("QConfig fixed scale (%s) and zero point (%s) do not match the backend's " + "(%s and %s), ignoring %s.\n%s") % + (observer.scale, observer.zero_point, backend_scale_exact_match, + backend_zero_point_exact_match, qconfig, suggestion_str)) + return False return True if qconfig is None or dtype_with_constraints.dtype is None: diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index e3410a52a9d8..65c85d033c5f 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -39,6 +39,7 @@ _MODULE_NAME_DICT_KEY = "module_name" _MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order" +# TODO: derive this map from the BackendConfig _FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = { torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer, torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer, From 848e7240a11c9fd82298bc5b5ae14534e1307627 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Nov 2022 19:08:49 +0000 Subject: [PATCH 245/453] [Dynamo] Add a dummy profiler to avoid activating real profiler (#88930) See context at https://github.com/pytorch/torchdynamo/issues/1721#issuecomment-1312396059 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88930 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 4 ++-- torch/_dynamo/variables/misc.py | 22 ++++++++++++++++------ torch/_dynamo/variables/torch.py | 4 ++-- torch/_dynamo/variables/user_defined.py | 10 +++++++--- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8f79f2476aee..aef364d76994 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1388,7 +1388,7 @@ def fn(): self.assertTrue(result[1] == fn.__code__.co_lnotab) def test_torch_profiler(self): - # wrap torch.profiler.* as ProfilerContextWrapperVariable and do nothing + # wrap torch.profiler.* as NullContextVariable and do nothing def fn(x): y = x**2 with torch.profiler.profile(): @@ -1408,7 +1408,7 @@ def fn(x): self.assertEqual(cnts.frame_count, 2) def test_autograd_profiler(self): - # wrap torch.autograd.profiler.* as ProfilerContextWrapperVariable and do nothing + # wrap torch.autograd.profiler.* as NullContextVariable and do nothing def fn(x): y = x**2 with torch.autograd.profiler.profile(): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 5d7336cefeae..298ddf24862b 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -116,6 +116,9 @@ def exit(self, tx, *args): self._call_func(tx, self.initial_values) return variables.ConstantVariable(None, **VariableTracker.propagate(self)) + def module_name(self): + return "torch" + def reconstruct(self, codegen, target_inst=None): """ Generate following Python Bytecode, with a `torch._C._set_grad_enable` call @@ -356,11 +359,15 @@ def exit_functional_autocast(mode): mode.__exit__(None, None, None) -class ProfilerContextWrapperVariable(ContextWrappingVariable): +class NullContextVariable(ContextWrappingVariable): + """ + This class represents Python contextlib.nullcontext. + It's used as a placeholder for other context managers that Dynamo doesn't + support yet, e.g, torch.autograd.profiler.record_function. + """ + def __init__(self, target_values=None, **kwargs): - super(ProfilerContextWrapperVariable, self).__init__( - target_values=target_values, **kwargs - ) + super(NullContextVariable, self).__init__(target_values=target_values, **kwargs) def enter(self, tx): return variables.ConstantVariable(None, **VariableTracker.propagate(self)) @@ -368,8 +375,11 @@ def enter(self, tx): def exit(self, tx, *args): return variables.ConstantVariable(None, **VariableTracker.propagate(self)) + def module_name(self): + return "contextlib" + def fn_name(self): - return "autograd.profiler.profile" + return "nullcontext" class WithExitFunctionVariable(VariableTracker): @@ -389,7 +399,7 @@ def reconstruct(self, codegen): # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. output = AttrSource( - codegen.tx.import_source("torch"), self.ctx.fn_name() + codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name() ).reconstruct(codegen) if codegen.tx.output.partial_convert: diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 3b9b552542ac..56e74503faca 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -26,7 +26,7 @@ ) from .base import VariableTracker from .lists import ListVariable, TupleVariable -from .misc import AutocastModeVariable, ProfilerContextWrapperVariable +from .misc import AutocastModeVariable, NullContextVariable from .nn_module import NNModuleVariable from .tensor import TensorWithTFOverrideVariable @@ -300,7 +300,7 @@ def call_function( torch.autograd.profiler.record_function, ): log.warning("Profiler will be ignored") - return ProfilerContextWrapperVariable(**options) + return NullContextVariable(**options) elif self.value is torch.autograd._profiler_enabled: unimplemented("torch.autograd._profiler_enabled not supported yet") elif self.value is torch.jit.annotate: diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 09d7893bef66..8cc9528ed67c 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -1,4 +1,5 @@ import collections +import contextlib import dataclasses import functools import importlib @@ -15,7 +16,7 @@ from ..source import AttrSource, ODictGetItemSource, RandomValueSource from ..utils import is_namedtuple_cls, namedtuple_fields from .base import MutableLocal, VariableTracker -from .misc import ProfilerContextWrapperVariable +from .misc import NullContextVariable class UserDefinedVariable(VariableTracker): @@ -77,8 +78,11 @@ def call_function( options = VariableTracker.propagate(self, args, kwargs.values()) - if self.value is torch.autograd.profiler.profile: - return ProfilerContextWrapperVariable() + if self.value in ( + contextlib.nullcontext, + torch.autograd.profiler.profile, + ): + return NullContextVariable(**options) elif is_namedtuple_cls(self.value): fields = namedtuple_fields(self.value) items = list(args) From 0581331963cb3dc18fa59a800661c800ebff92c2 Mon Sep 17 00:00:00 2001 From: BowenBao Date: Mon, 14 Nov 2022 13:31:23 -0800 Subject: [PATCH 246/453] [ONNX] Document ONNX diagnostics (#88371) Reference pages: - Landing page: https://docs-preview.pytorch.org/88371/onnx_diagnostics.html - Individual rule: https://docs-preview.pytorch.org/88371/generated/onnx_diagnostics_rules/POE0004%3Aoperator-supported-in-newer-opset-version.html An initial PR to setup the document generation for ONNX diagnostics. * Add document page for ONNX diagnostics. * Add document generation for diagnostics rules from `rules.yaml`. * Add dependency on `myst-parser` for markdown to rst parsing. More content to be added. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88371 Approved by: https://github.com/abock, https://github.com/justinchuby, https://github.com/malfet, https://github.com/kit1980 --- docs/Makefile | 7 ++-- docs/requirements.txt | 1 + docs/source/conf.py | 3 +- docs/source/index.rst | 1 + docs/source/onnx_diagnostics.rst | 35 ++++++++++++++++++ .../onnx/build_onnx_diagnostics_rules_md.py | 37 +++++++++++++++++++ 6 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 docs/source/onnx_diagnostics.rst create mode 100644 docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py diff --git a/docs/Makefile b/docs/Makefile index 122bda6231e3..c506845fa92b 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -17,8 +17,9 @@ figures: @$(PYCMD) source/scripts/build_activation_images.py @$(PYCMD) source/scripts/build_quantization_configs.py -onnx_supported_aten_ops: +onnx: @$(PYCMD) source/scripts/onnx/build_onnx_supported_aten_op_csv_table.py + @$(PYCMD) source/scripts/onnx/build_onnx_diagnostics_rules_md.py $(SOURCEDIR)/generated/onnx_diagnostics_rules docset: html doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url https://pytorch.org/docs/ --force $(BUILDDIR)/html/ @@ -34,11 +35,11 @@ html-stable: # See conf.py for more details. RELEASE=1 make html -.PHONY: help Makefile docset onnx_supported_aten_ops +.PHONY: help Makefile docset onnx # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile figures onnx_supported_aten_ops +%: Makefile figures onnx @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) clean: diff --git a/docs/requirements.txt b/docs/requirements.txt index 14c93adc22e9..fdbe10778bf9 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,3 +10,4 @@ tensorboard==2.10.0 python-etcd==0.4.5 sphinx-copybutton==0.5.0 sphinx-panels==0.4.1 +myst-parser==0.18.1 diff --git a/docs/source/conf.py b/docs/source/conf.py index 807f486ac0d6..f4d1d8b68eb9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -58,7 +58,8 @@ 'sphinxcontrib.katex', 'sphinx.ext.autosectionlabel', 'sphinx_copybutton', - 'sphinx_panels' + 'sphinx_panels', + 'myst_parser', ] # build the templated autosummary files diff --git a/docs/source/index.rst b/docs/source/index.rst index b9d097f55191..00f8e0967b73 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,6 +85,7 @@ Features described in this documentation are classified by release status: profiler nn.init onnx + onnx_diagnostics optim complex_numbers ddp_comm_hooks diff --git a/docs/source/onnx_diagnostics.rst b/docs/source/onnx_diagnostics.rst new file mode 100644 index 000000000000..ec2edd4cbdbe --- /dev/null +++ b/docs/source/onnx_diagnostics.rst @@ -0,0 +1,35 @@ +torch.onnx diagnostics +====================== + +.. contents:: :local: +.. automodule:: torch.onnx._internal.diagnostics +.. currentmodule:: torch.onnx._internal.diagnostics + +Overview +-------- + +NOTE: This feature is underdevelopment and is subject to change. + +The goal is to improve the diagnostics to help users debug and improve their model export to ONNX. + +- The diagnostics are emitted in machine parsable `Static Analysis Results Interchange Format (SARIF) `__. +- A new clearer, structured way to add new and keep track of diagnostic rules. +- Serve as foundation for more future improvements consuming the diagnostics. + + +Diagnostic Rules +---------------- + +.. toctree:: + :glob: + + generated/onnx_diagnostics_rules/* + +API Reference +------------- + +.. autoclass:: torch.onnx._internal.diagnostics.ExportDiagnostic + :members: + +.. autoclass:: torch.onnx._internal.diagnostics.infra.DiagnosticEngine + :members: diff --git a/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py b/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py new file mode 100644 index 000000000000..3c2895f6fe76 --- /dev/null +++ b/docs/source/scripts/onnx/build_onnx_diagnostics_rules_md.py @@ -0,0 +1,37 @@ +import argparse +import os +from dataclasses import fields + +from torch.onnx._internal import diagnostics +from torch.onnx._internal.diagnostics import infra + + +def gen_docs(out_dir: str): + os.makedirs(out_dir, exist_ok=True) + for field in fields(diagnostics.rules): + rule = getattr(diagnostics.rules, field.name) + if not isinstance(rule, infra.Rule): + continue + title = f"{rule.id}:{rule.name}" + full_description_markdown = rule.full_description_markdown + assert ( + full_description_markdown is not None + ), f"Expected {title} to have a full description in markdown" + with open(f"{out_dir}/{title}.md", "w") as f: + f.write(f"# {title}\n") + f.write(full_description_markdown) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Generate ONNX diagnostics rules doc in markdown." + ) + parser.add_argument( + "out_dir", metavar="OUT_DIR", help="path to output directory for docs" + ) + args = parser.parse_args() + gen_docs(args.out_dir) + + +if __name__ == "__main__": + main() From 6b521bbf3589d763f9ad348ee24e54be12c44356 Mon Sep 17 00:00:00 2001 From: soulitzer Date: Wed, 16 Nov 2022 11:22:58 -0500 Subject: [PATCH 247/453] Prevent module full_backward_hook from erroring in double backward (#88357) Also clarifies documentation to say "execute if and only if gradients wrt outputs are computed" (previously, "execute every time gradients wrt inputs are computed") See https://docs.google.com/document/d/1tFZKYdsSzRBJ7Di7SWt8X8fSg-E3eiUPwomMF10UyhM/edit for more details regarding the question: 'should module full_backward_hooks be called every time the gradients wrt module inputs are called, or should module full_backward_hooks only be called when the "backward for the module" have been computed?' Fixes https://github.com/pytorch/pytorch/issues/88312 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88357 Approved by: https://github.com/albanD --- test/test_autograd.py | 19 +++++++++++++++++++ torch/nn/modules/module.py | 12 ++++++++---- torch/utils/hooks.py | 11 ++++------- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index 33cf188af065..6e26f67f6dc3 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6638,6 +6638,25 @@ def forward(self, x): gc.collect() self.assertIsNone(ref_()) + def test_full_backward_hook_double_backward(self): + x = torch.rand(1, requires_grad=True) + y = torch.rand_like(x) + + func = torch.nn.MSELoss() + counter = [0] + + def hook(module, grad_input, grad_output): + counter[0] += 1 + + func.register_full_backward_hook(hook) + + f = func(x, y) + + (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) + self.assertEqual(counter[0], 1) + _ = torch.autograd.grad(gradx_f, x) + # We should not error, and counter should not be incremented + self.assertEqual(counter[0], 1) def test_input_buffer_accum(self): leaf = torch.rand(2, 2, requires_grad=True) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index fea0ca7b8de8..82389074f8a9 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -307,8 +307,10 @@ def register_module_full_backward_hook( This adds global state to the `nn.module` module and it is only intended for debugging/profiling purposes. - The hook will be called every time the gradients with respect to module - inputs are computed. The hook should have the following signature:: + The hook will be called every time the gradients with respect to a module + are computed, i.e. the hook will execute if and only if the gradients with + respect to module outputs are computed. The hook should have the following + signature:: hook(module, grad_input, grad_output) -> Tensor or None @@ -1197,8 +1199,10 @@ def register_full_backward_hook( ) -> RemovableHandle: r"""Registers a backward hook on the module. - The hook will be called every time the gradients with respect to module - inputs are computed. The hook should have the following signature:: + The hook will be called every time the gradients with respect to a module + are computed, i.e. the hook will execute if and only if the gradients with + respect to module outputs are computed. The hook should have the following + signature:: hook(module, grad_input, grad_output) -> tuple(Tensor) or None diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index 327b2143607c..133d2c0d2ceb 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -99,13 +99,10 @@ def _unpack_none(self, indices, values): def _set_user_hook(self, grad_fn): def hook(grad_input, _): if self.grad_outputs is None: - raise RuntimeError("Module backward hook for grad_input is called before " - "the grad_output one. This happens because the gradient " - "in your nn.Module flows to the Module's input without " - "passing through the Module's output. Make sure that the " - "output depends on the input and that the loss is computed " - "based on the output.") - + # This happens because the gradient in your nn.Module flows to + # the Module's input without " passing through the Module's + # output, e.g. when you're doing double backward. + return res = self._pack_with_none(self.input_tensors_index, grad_input, self.n_inputs) for hook in self.user_hooks: From aee96bbf5a34b7d9b12b8d03fa1904e595c6a329 Mon Sep 17 00:00:00 2001 From: Iris Date: Wed, 16 Nov 2022 21:06:35 +0000 Subject: [PATCH 248/453] [PT-D][Checkpointing] Move distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (#88698) Context in RFC: https://github.com/pytorch/pytorch/issues/86620 .rst file will be finalized in subsequent PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88698 Approved by: https://github.com/wanchaol --- docs/source/distributed.checkpoint.rst | 4 + docs/source/index.rst | 1 + .../checkpoint/test_checkpoint.py | 6 +- .../checkpoint/test_file_system_checkpoint.py | 2 +- .../test_file_system_checkpoint_cpu.py | 2 +- .../{_shard => }/checkpoint/test_planner.py | 6 +- .../{_shard => }/checkpoint/test_utils.py | 4 +- .../distributed/_shard/checkpoint/__init__.py | 29 ++--- torch/distributed/checkpoint/__init__.py | 21 ++++ .../{_shard => }/checkpoint/api.py | 10 +- .../checkpoint/default_planner.py | 82 ++++++++++---- .../{_shard => }/checkpoint/filesystem.py | 82 +++++++++----- .../{_shard => }/checkpoint/metadata.py | 23 +++- .../{_shard => }/checkpoint/planner.py | 41 ++++++- .../checkpoint/planner_helpers.py | 94 +++++++++------- .../{_shard => }/checkpoint/resharding.py | 7 +- .../checkpoint/state_dict_loader.py | 6 +- .../checkpoint/state_dict_saver.py | 13 +-- .../{_shard => }/checkpoint/storage.py | 14 ++- .../{_shard => }/checkpoint/utils.py | 101 +++++++++++++----- 20 files changed, 389 insertions(+), 159 deletions(-) create mode 100644 docs/source/distributed.checkpoint.rst rename test/distributed/{_shard => }/checkpoint/test_checkpoint.py (98%) rename test/distributed/{_shard => }/checkpoint/test_file_system_checkpoint.py (99%) rename test/distributed/{_shard => }/checkpoint/test_file_system_checkpoint_cpu.py (99%) rename test/distributed/{_shard => }/checkpoint/test_planner.py (97%) rename test/distributed/{_shard => }/checkpoint/test_utils.py (96%) create mode 100644 torch/distributed/checkpoint/__init__.py rename torch/distributed/{_shard => }/checkpoint/api.py (90%) rename torch/distributed/{_shard => }/checkpoint/default_planner.py (76%) rename torch/distributed/{_shard => }/checkpoint/filesystem.py (82%) rename torch/distributed/{_shard => }/checkpoint/metadata.py (87%) rename torch/distributed/{_shard => }/checkpoint/planner.py (95%) rename torch/distributed/{_shard => }/checkpoint/planner_helpers.py (74%) rename torch/distributed/{_shard => }/checkpoint/resharding.py (91%) rename torch/distributed/{_shard => }/checkpoint/state_dict_loader.py (98%) rename torch/distributed/{_shard => }/checkpoint/state_dict_saver.py (96%) rename torch/distributed/{_shard => }/checkpoint/storage.py (96%) rename torch/distributed/{_shard => }/checkpoint/utils.py (77%) diff --git a/docs/source/distributed.checkpoint.rst b/docs/source/distributed.checkpoint.rst new file mode 100644 index 000000000000..380ec0e6022a --- /dev/null +++ b/docs/source/distributed.checkpoint.rst @@ -0,0 +1,4 @@ +Distributed Checkpoint +======================== + +.. automodule:: torch.distributed.checkpoint diff --git a/docs/source/index.rst b/docs/source/index.rst index 00f8e0967b73..20214466328a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -70,6 +70,7 @@ Features described in this documentation are classified by release status: torch.distributed.elastic torch.distributed.fsdp torch.distributed.optim + torch.distributed.checkpoint torch.distributions torch.fft futures diff --git a/test/distributed/_shard/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py similarity index 98% rename from test/distributed/_shard/checkpoint/test_checkpoint.py rename to test/distributed/checkpoint/test_checkpoint.py index 1b3cf04eb2cc..167fdc5e7154 100644 --- a/test/distributed/_shard/checkpoint/test_checkpoint.py +++ b/test/distributed/checkpoint/test_checkpoint.py @@ -20,17 +20,17 @@ from torch.distributed._shard import sharded_tensor -from torch.distributed._shard.checkpoint.default_planner import ( +from torch.distributed.checkpoint.default_planner import ( _create_default_local_metadata, ) -from torch.distributed._shard.checkpoint.metadata import ( +from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, Metadata, TensorStorageMetadata, ) -from torch.distributed._shard.checkpoint.planner import ( +from torch.distributed.checkpoint.planner import ( SavePlan, SavePlanner, LoadPlan, diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py similarity index 99% rename from test/distributed/_shard/checkpoint/test_file_system_checkpoint.py rename to test/distributed/checkpoint/test_file_system_checkpoint.py index b5cc38767c96..7ef4e72e4fe0 100644 --- a/test/distributed/_shard/checkpoint/test_file_system_checkpoint.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint.py @@ -31,7 +31,7 @@ run_tests, ) -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, load_state_dict, diff --git a/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py similarity index 99% rename from test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py rename to test/distributed/checkpoint/test_file_system_checkpoint_cpu.py index 321dc2f54688..2ff2d9d12791 100644 --- a/test/distributed/_shard/checkpoint/test_file_system_checkpoint_cpu.py +++ b/test/distributed/checkpoint/test_file_system_checkpoint_cpu.py @@ -31,7 +31,7 @@ run_tests, ) -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, load_state_dict, diff --git a/test/distributed/_shard/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py similarity index 97% rename from test/distributed/_shard/checkpoint/test_planner.py rename to test/distributed/checkpoint/test_planner.py index 56373bd67c6d..334fba237a9b 100644 --- a/test/distributed/_shard/checkpoint/test_planner.py +++ b/test/distributed/checkpoint/test_planner.py @@ -3,7 +3,7 @@ import sys import torch -from torch.distributed._shard.checkpoint.planner import LoadItemType, WriteItemType +from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType from torch.distributed._shard.sharded_tensor import ( Shard, @@ -18,13 +18,13 @@ TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed._shard.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata +from torch.distributed.checkpoint.metadata import BytesStorageMetadata, MetadataIndex, TensorStorageMetadata from torch.testing._internal.distributed.distributed_utils import ( with_fake_comms, with_dist ) -from torch.distributed._shard.checkpoint.default_planner import ( +from torch.distributed.checkpoint.default_planner import ( create_default_global_save_plan, create_default_local_save_plan, create_default_local_load_plan, diff --git a/test/distributed/_shard/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py similarity index 96% rename from test/distributed/_shard/checkpoint/test_utils.py rename to test/distributed/checkpoint/test_utils.py index e99a9cf863e4..e2b4aac605bf 100644 --- a/test/distributed/_shard/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -17,8 +17,8 @@ TEST_WITH_DEV_DBG_ASAN, run_tests, ) -from torch.distributed._shard.checkpoint.utils import find_state_dict_object -from torch.distributed._shard.checkpoint.metadata import MetadataIndex +from torch.distributed.checkpoint.utils import find_state_dict_object +from torch.distributed.checkpoint.metadata import MetadataIndex from torch.testing._internal.distributed.distributed_utils import ( with_fake_comms ) diff --git a/torch/distributed/_shard/checkpoint/__init__.py b/torch/distributed/_shard/checkpoint/__init__.py index febc953f9b60..166c6f9254cf 100644 --- a/torch/distributed/_shard/checkpoint/__init__.py +++ b/torch/distributed/_shard/checkpoint/__init__.py @@ -1,21 +1,12 @@ -from .metadata import ( - TensorStorageMetadata, - BytesStorageMetadata, - ChunkStorageMetadata, - Metadata, -) -from .state_dict_loader import load_state_dict -from .state_dict_saver import save_state_dict -from .storage import StorageReader, StorageWriter -from .filesystem import FileSystemReader, FileSystemWriter -from .api import CheckpointException - +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed.checkpoint` package. +import sys +import torch +import warnings -from .planner import ( - SavePlanner, - LoadPlanner, - SavePlan, - LoadPlan, - ReadItem, - WriteItem, +from torch.distributed.checkpoint import * # noqa: F403 +warnings.warn( + "torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead", + DeprecationWarning ) +sys.modules['torch.distributed._shard.checkpoint'] = torch.distributed.checkpoint diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py new file mode 100644 index 000000000000..febc953f9b60 --- /dev/null +++ b/torch/distributed/checkpoint/__init__.py @@ -0,0 +1,21 @@ +from .metadata import ( + TensorStorageMetadata, + BytesStorageMetadata, + ChunkStorageMetadata, + Metadata, +) +from .state_dict_loader import load_state_dict +from .state_dict_saver import save_state_dict +from .storage import StorageReader, StorageWriter +from .filesystem import FileSystemReader, FileSystemWriter +from .api import CheckpointException + + +from .planner import ( + SavePlanner, + LoadPlanner, + SavePlan, + LoadPlan, + ReadItem, + WriteItem, +) diff --git a/torch/distributed/_shard/checkpoint/api.py b/torch/distributed/checkpoint/api.py similarity index 90% rename from torch/distributed/_shard/checkpoint/api.py rename to torch/distributed/checkpoint/api.py index e74b34d9f233..d7bfa18ecd79 100644 --- a/torch/distributed/_shard/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -3,20 +3,28 @@ WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary] +__all__ = ["CheckpointException"] + + def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: return (exc, tb.extract_tb(exc.__traceback__)) + def _is_wrapped_exception(obj: Any) -> bool: if not isinstance(obj, tuple): return False if len(obj) != 2: return False - return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary) + return isinstance(obj[0], BaseException) and isinstance( + obj[1], tb.StackSummary + ) + class CheckpointException(BaseException): """ Exception raised if failure was detected as part of a checkpoint load or save. """ + def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]): super().__init__(msg, failures) self._failures = failures diff --git a/torch/distributed/_shard/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py similarity index 76% rename from torch/distributed/_shard/checkpoint/default_planner.py rename to torch/distributed/checkpoint/default_planner.py index 8f6a0c2be7ed..aa531a62d235 100644 --- a/torch/distributed/_shard/checkpoint/default_planner.py +++ b/torch/distributed/checkpoint/default_planner.py @@ -24,18 +24,26 @@ MetadataIndex, Metadata, STATE_DICT_TYPE, - STORAGE_TYPES + STORAGE_TYPES, ) from .planner_helpers import ( _create_read_items, _create_write_items, - _create_default_metadata_only_plan + _create_default_metadata_only_plan, ) -from .utils import ( - find_state_dict_object -) +from .utils import find_state_dict_object + +__all__ = [ + "DefaultSavePlanner", + "DefaultLoadPlanner", + "create_default_local_load_plan", + "create_default_global_load_plan", + "create_default_local_save_plan", + "create_default_global_save_plan", +] + class DefaultSavePlanner(SavePlanner): def init(self, state_dict: Dict[str, Any], is_coordinator: bool) -> None: @@ -43,18 +51,26 @@ def init(self, state_dict: Dict[str, Any], is_coordinator: bool) -> None: self.is_coordinator = is_coordinator def create_local_plan(self) -> SavePlan: - self.plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + self.plan = create_default_local_save_plan( + self.state_dict, self.is_coordinator + ) return self.plan - def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: - self.global_plan, self.metadata = create_default_global_save_plan(all_plans) + def create_global_plan( + self, all_plans: List[SavePlan] + ) -> Tuple[List[SavePlan], Metadata]: + self.global_plan, self.metadata = create_default_global_save_plan( + all_plans + ) return self.global_plan, self.metadata def finish_plan(self, new_plan: SavePlan) -> SavePlan: self.plan = new_plan return new_plan - def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + def resolve_data( + self, write_item: WriteItem + ) -> Union[torch.Tensor, io.BytesIO]: object = self.lookup_object(write_item.index) return self.transform_object(write_item, object) @@ -76,7 +92,12 @@ def transform_object(self, write_item: WriteItem, object: Any): class DefaultLoadPlanner(LoadPlanner): - def init(self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool) -> None: + def init( + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata, + is_coordinator: bool, + ) -> None: self.state_dict = state_dict self.metadata = metadata self.is_coordinator = is_coordinator @@ -110,7 +131,9 @@ def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor): """ This is an extension from the planner interface to make it easy to extend the default planner """ - return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths) + return narrow_tensor_by_index( + tensor, read_item.dest_offsets, read_item.lengths + ) def create_default_local_load_plan( @@ -133,7 +156,10 @@ def create_default_local_load_plan( return LoadPlan(requests) -def create_default_global_load_plan(all_plans: List[LoadPlan]) -> List[LoadPlan]: + +def create_default_global_load_plan( + all_plans: List[LoadPlan], +) -> List[LoadPlan]: """ Create global load plan used by DefaultLoadPlanner. @@ -142,7 +168,10 @@ def create_default_global_load_plan(all_plans: List[LoadPlan]) -> List[LoadPlan] """ return all_plans -def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: + +def create_default_local_save_plan( + state_dict: Dict[str, Any], is_coordinator: bool +) -> SavePlan: """ Create the ``SavePlan`` used by DefaultSavePlanner. @@ -157,7 +186,10 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b requests += _create_write_items(fqn, obj) return SavePlan(requests) -def create_default_global_save_plan(all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + +def create_default_global_save_plan( + all_plans: List[SavePlan], +) -> Tuple[List[SavePlan], Metadata]: """ Create the global plan and metadata used by DefaultSavePlanner. @@ -180,21 +212,29 @@ def create_default_global_save_plan(all_plans: List[SavePlan]) -> Tuple[List[Sav assert item.tensor_data is not None tensor_md = cast( TensorStorageMetadata, - md.setdefault(item.index.fqn, TensorStorageMetadata( - properties=item.tensor_data.properties, - size=item.tensor_data.size, - chunks=[], - )) + md.setdefault( + item.index.fqn, + TensorStorageMetadata( + properties=item.tensor_data.properties, + size=item.tensor_data.size, + chunks=[], + ), + ), + ) + new_index = dataclasses.replace( + item.index, index=len(tensor_md.chunks) ) - new_index = dataclasses.replace(item.index, index=len(tensor_md.chunks)) new_item = dataclasses.replace(item, index=new_index) new_items.append(new_item) - assert item.tensor_data.chunk is not None, f"Cannot create MD for tensor without bounds. FQN: {item.index.fqn}" + assert ( + item.tensor_data.chunk is not None + ), f"Cannot create MD for tensor without bounds. FQN: {item.index.fqn}" tensor_md.chunks.append(item.tensor_data.chunk) new_plans.append(dataclasses.replace(plan, items=new_items)) return (new_plans, Metadata(md)) + def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata: """ Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``. diff --git a/torch/distributed/_shard/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py similarity index 82% rename from torch/distributed/_shard/checkpoint/filesystem.py rename to torch/distributed/checkpoint/filesystem.py index 9788853d9aa6..0e679c303921 100644 --- a/torch/distributed/_shard/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -34,32 +34,46 @@ from torch.distributed._shard._utils import narrow_tensor_by_index +__all__ = [ + "FileSystemWriter", + "SlicedBufferedReader", + "FileSystemReader", +] + + @dataclass class _StorageInfo: """ This is the per entry storage info """ + relative_path: str offset: int length: int + @dataclass class _StoragePrefix: prefix: str + DEFAULT_SUFIX = ".distcp" + def _trim(tensor: torch.Tensor) -> torch.Tensor: tensor = tensor.detach().cpu() if tensor._typed_storage()._size() != tensor.numel(): tensor = tensor.clone() return tensor -def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult: + +def _result_from_write_item( + item: WriteItem, size_in_bytes, storage_data +) -> WriteResult: return WriteResult( - index=item.index, - size_in_bytes=size_in_bytes, - storage_data=storage_data) + index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data + ) + def _write_item(stream, data, write_item, storage_key): offset = stream.tell() @@ -74,11 +88,10 @@ def _write_item(stream, data, write_item, storage_key): length = stream.tell() - offset return _result_from_write_item( - write_item, - length, - _StorageInfo(storage_key, offset, length) + write_item, length, _StorageInfo(storage_key, offset, length) ) + def _write_files_from_queue( file_queue: List, planner: SavePlanner, @@ -87,24 +100,33 @@ def _write_files_from_queue( write_results = [] for file_path, file_name, write_items in file_queue: - tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + tensor_w = [ + wi for wi in write_items if wi.type != WriteItemType.BYTE_IO + ] bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] with open(file_path, "wb") as stream: for write_item in bytes_w: data = planner.resolve_data(write_item) - write_results.append(_write_item(stream, data, write_item, file_name)) + write_results.append( + _write_item(stream, data, write_item, file_name) + ) for write_item in tensor_w: - tensor = _trim(cast(torch.Tensor, planner.resolve_data(write_item))) + tensor = _trim( + cast(torch.Tensor, planner.resolve_data(write_item)) + ) assert not tensor.is_cuda - write_results.append(_write_item(stream, tensor, write_item, file_name)) + write_results.append( + _write_item(stream, tensor, write_item, file_name) + ) if use_fsync: os.fsync(stream.fileno()) return write_results + class FileSystemWriter(StorageWriter): """ Basic implementation of StorageWriter using file IO. @@ -118,6 +140,7 @@ class FileSystemWriter(StorageWriter): a `.metadata` file with the serialized metadata. """ + def __init__( self, path: Union[str, os.PathLike], @@ -146,11 +169,14 @@ def prepare_local_plan(self, plan: SavePlan) -> SavePlan: # There's no storage input in the local plan return plan - def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: + def prepare_global_plan( + self, global_plan: List[SavePlan] + ) -> List[SavePlan]: self.path.mkdir(parents=True, exist_ok=True) new_plans = [ - dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) + for i, plan in enumerate(global_plan) ] return new_plans @@ -187,12 +213,12 @@ def gen_file(): fut.set_result(results) return fut - def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + def finish( + self, metadata: Metadata, results: List[List[WriteResult]] + ) -> None: storage_md = dict() for wr_list in results: - storage_md.update({ - wr.index: wr.storage_data for wr in wr_list - }) + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) metadata.storage_data = storage_md with (self.path / ".metadata.tmp").open("wb") as metadata_file: pickle.dump(metadata, metadata_file) @@ -220,6 +246,7 @@ def seek(self, __offset: int, __whence: int = os.SEEK_SET) -> int: def tell(self) -> int: return super().tell() - self.offset + class FileSystemReader(StorageReader): def __init__(self, path: Union[str, os.PathLike]) -> None: super().__init__() @@ -228,15 +255,10 @@ def __init__(self, path: Union[str, os.PathLike]) -> None: def _slice_file(self, file, sinfo: _StorageInfo): return SlicedBufferedReader( - io.FileIO(file.fileno(), closefd=False), - sinfo.offset, sinfo.length + io.FileIO(file.fileno(), closefd=False), sinfo.offset, sinfo.length ) - def read_data( - self, - plan: LoadPlan, - planner: LoadPlanner - ) -> Future[None]: + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: # group requests by file per_file: Dict[str, List[ReadItem]] = dict() for read_item in plan.items: @@ -255,8 +277,12 @@ def read_data( bytes.seek(0) planner.load_bytes(req, bytes) else: - tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) - tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + tensor = cast( + Tensor, torch.load(file_slice, map_location="cpu") + ) + tensor = narrow_tensor_by_index( + tensor, req.storage_offsets, req.lengths + ) target_tensor = planner.resolve_tensor(req).detach() assert ( @@ -281,5 +307,7 @@ def init(self, metadata: Metadata, is_coordinator: bool) -> None: def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: return plan - def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + def prepare_global_plan( + self, global_plan: List[LoadPlan] + ) -> List[LoadPlan]: return global_plan diff --git a/torch/distributed/_shard/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py similarity index 87% rename from torch/distributed/_shard/checkpoint/metadata.py rename to torch/distributed/checkpoint/metadata.py index 2321f0276623..1a03f16ff473 100644 --- a/torch/distributed/_shard/checkpoint/metadata.py +++ b/torch/distributed/checkpoint/metadata.py @@ -7,28 +7,42 @@ ShardedTensor, ) +__all__ = [ + "ChunkStorageMetadata", + "TensorStorageMetadata", + "BytesStorageMetadata", + "Metadata", + "MetadataIndex", +] + + @dataclass class ChunkStorageMetadata: """ Each chunk is expected to have the same properties of the TensorStorageMetadata that includes it. """ + offsets: torch.Size sizes: torch.Size + @dataclass class TensorStorageMetadata: properties: TensorProperties size: torch.Size chunks: List[ChunkStorageMetadata] + @dataclass class BytesStorageMetadata: pass + TENSOR_TYPE = Union[torch.Tensor, ShardedTensor] STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata] STATE_DICT_TYPE = Dict[str, Any] + @dataclass class Metadata: # Keys are the same from the `state_dict` used. @@ -36,11 +50,13 @@ class Metadata: planner_data: Any = None storage_data: Any = None + @dataclass(frozen=True) class MetadataIndex: """ This class represents a lookup key for items in a state dict or Metadata. """ + fqn: str """Fully Qualified Name of the object""" @@ -59,7 +75,12 @@ class MetadataIndex: the linear search and thus making it significantly faster. """ - def __init__(self, fqn: str, offset: Optional[Sequence[int]] = None, index: Optional[int] = None): + def __init__( + self, + fqn: str, + offset: Optional[Sequence[int]] = None, + index: Optional[int] = None, + ): # We must use object.__setattr__ due to frozen=True object.__setattr__(self, "fqn", fqn) object.__setattr__(self, "index", index) diff --git a/torch/distributed/_shard/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py similarity index 95% rename from torch/distributed/_shard/checkpoint/planner.py rename to torch/distributed/checkpoint/planner.py index f3692cc11395..cb94a40df732 100644 --- a/torch/distributed/_shard/checkpoint/planner.py +++ b/torch/distributed/checkpoint/planner.py @@ -12,24 +12,41 @@ ChunkStorageMetadata, MetadataIndex, Metadata, - STATE_DICT_TYPE + STATE_DICT_TYPE, ) + +__all__ = [ + "WriteItemType", + "LoadItemType", + "TensorWriteData", + "WriteItem", + "ReadItem", + "SavePlan", + "LoadPlan", + "SavePlanner", + "LoadPlanner", +] + + class WriteItemType(Enum): TENSOR = auto() SHARD = auto() BYTE_IO = auto() + class LoadItemType(Enum): TENSOR = auto() BYTE_IO = auto() + @dataclass(frozen=True) class TensorWriteData: chunk: ChunkStorageMetadata properties: TensorProperties size: torch.Size + @dataclass(frozen=True) class WriteItem: index: MetadataIndex @@ -38,6 +55,7 @@ class WriteItem: # Value present if it's a tensor write tensor_data: Optional[TensorWriteData] = None + @dataclass(frozen=True) class ReadItem: # Read Item @@ -56,18 +74,21 @@ class ReadItem: # Size of the hypercube to copy lengths: torch.Size + @dataclass(frozen=True) class SavePlan: items: List[WriteItem] storage_data: Any = None planner_data: Any = None + @dataclass class LoadPlan: items: List[ReadItem] storage_data: Any = None planner_data: Any = None + class SavePlanner(abc.ABC): """ Abstract class defining the protocol used by save_state_dict to plan the save process. @@ -156,6 +177,7 @@ class SavePlanner(abc.ABC): >>> metadata = replace(metadata, planner_data=merged_data) >>> return global_plan, metadata """ + @abc.abstractmethod def init(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: """ @@ -179,7 +201,9 @@ def create_local_plan(self) -> SavePlan: pass @abc.abstractmethod - def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + def create_global_plan( + self, all_plans: List[SavePlan] + ) -> Tuple[List[SavePlan], Metadata]: """ Compute the global checkpoint plan and return the local plan of each rank. @@ -197,7 +221,9 @@ def finish_plan(self, new_plan: SavePlan) -> SavePlan: pass @abc.abstractmethod - def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + def resolve_data( + self, write_item: WriteItem + ) -> Union[torch.Tensor, io.BytesIO]: """ Lookup the object associated with ``write_item``in `state_dict` and apply any transformation (such as serialization) prior to the storage layer consuming it. @@ -215,6 +241,7 @@ def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO] """ pass + class LoadPlanner: """ Abstract class defining the protocol used by load_state_dict to plan the load process. @@ -273,8 +300,14 @@ class LoadPlanner: >>> def commit_tensor(self, read_item, tensor): >>> self.state_dict[read_item.dest_index.fqn] = tensor """ + @abc.abstractmethod - def init(self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool) -> None: + def init( + self, + state_dict: STATE_DICT_TYPE, + metadata: Metadata, + is_coordinator: bool, + ) -> None: """ Initialize this instance to load data into ``state_dict`` diff --git a/torch/distributed/_shard/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py similarity index 74% rename from torch/distributed/_shard/checkpoint/planner_helpers.py rename to torch/distributed/checkpoint/planner_helpers.py index fce7699b953f..23fbcd0d7e78 100644 --- a/torch/distributed/_shard/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -26,12 +26,13 @@ TensorStorageMetadata, MetadataIndex, STATE_DICT_TYPE, - STORAGE_TYPES + STORAGE_TYPES, ) -from .resharding import ( - _shards_get_overlap_region_wrt_saved_tensor -) +from .resharding import _shards_get_overlap_region_wrt_saved_tensor + +__all__: List[str] = [] + def _create_shard_metadata(size: torch.Size) -> ShardMetadata: return ShardMetadata( @@ -39,26 +40,31 @@ def _create_shard_metadata(size: torch.Size) -> ShardMetadata: shard_sizes=list(size), ) + def _create_shard_from_tensor(tensor: torch.Tensor) -> Shard: - return Shard( - tensor=tensor, - metadata=_create_shard_metadata(tensor.size()) - ) + return Shard(tensor=tensor, metadata=_create_shard_metadata(tensor.size())) + def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: return ChunkStorageMetadata( offsets=torch.Size(shard_md.shard_offsets), - sizes=torch.Size(shard_md.shard_sizes) + sizes=torch.Size(shard_md.shard_sizes), ) -def _sharded_tensor_metadata(sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> TensorWriteData: + +def _sharded_tensor_metadata( + sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> TensorWriteData: return TensorWriteData( chunk=_chunk_for_shard(shard_md), properties=sharded_tensor.metadata().tensor_properties, size=sharded_tensor.metadata().size, ) -def _create_write_item_for_shard(fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> WriteItem: + +def _create_write_item_for_shard( + fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata +) -> WriteItem: offsets = torch.Size(shard_md.shard_offsets) return WriteItem( index=MetadataIndex(fqn, offsets), @@ -66,28 +72,30 @@ def _create_write_item_for_shard(fqn: str, sharded_tensor: ShardedTensor, shard_ tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), ) + def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: offsets = torch.Size([0] * len(tensor.size())) return WriteItem( index=MetadataIndex(fqn, offsets), type=WriteItemType.TENSOR, tensor_data=TensorWriteData( - chunk=ChunkStorageMetadata( - offsets=offsets, - sizes=tensor.size() - ), + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), properties=TensorProperties.create_from_tensor(tensor), size=tensor.size(), - ) + ), ) + def _create_write_item_for_bytesio(fqn: str, bytes: Any): return WriteItem( index=MetadataIndex(fqn), type=WriteItemType.BYTE_IO, ) -def _create_read_item_for_byteio(dest_index, dest_offset, storage_index, storage_offset, length): + +def _create_read_item_for_byteio( + dest_index, dest_offset, storage_index, storage_offset, length +): return ReadItem( type=LoadItemType.BYTE_IO, dest_index=dest_index, @@ -97,7 +105,10 @@ def _create_read_item_for_byteio(dest_index, dest_offset, storage_index, storage lengths=torch.Size((length,)), ) -def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storage_offsets, lengths): + +def _create_read_item_for_tensor( + dest_index, dest_offsets, storage_index, storage_offsets, lengths +): return ReadItem( type=LoadItemType.TENSOR, dest_index=dest_index, @@ -107,6 +118,7 @@ def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storag lengths=torch.Size(lengths), ) + def _create_sharded_read_items( fqn: str, checkpoint_md: TensorStorageMetadata, @@ -144,56 +156,66 @@ def _create_sharded_read_items( read_items.append( _create_read_item_for_tensor( - dest_index=MetadataIndex(fqn, shard.metadata.shard_offsets, idx), + dest_index=MetadataIndex( + fqn, shard.metadata.shard_offsets, idx + ), dest_offsets=dest_offsets, - storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_index=MetadataIndex( + fqn, storage_md.offsets, storage_idx + ), storage_offsets=storage_offsets, lengths=lengths, ) ) return read_items + def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: requests = [] for fqn, obj in state_dict.items(): if isinstance(obj, ShardedTensor): for shard_md in obj.metadata().shards_metadata: - requests.append(_create_write_item_for_shard(fqn, obj, shard_md)) + requests.append( + _create_write_item_for_shard(fqn, obj, shard_md) + ) elif isinstance(obj, torch.Tensor): requests.append(_create_write_item_for_tensor(fqn, obj)) else: requests.append(_create_write_item_for_bytesio(fqn, obj)) return SavePlan(requests) + def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: if isinstance(object, ShardedTensor): - return [_create_write_item_for_shard(fqn, object, shard.metadata) for shard in object.local_shards()] + return [ + _create_write_item_for_shard(fqn, object, shard.metadata) + for shard in object.local_shards() + ] elif isinstance(object, torch.Tensor): return [_create_write_item_for_tensor(fqn, object)] else: return [_create_write_item_for_bytesio(fqn, object)] + def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: if isinstance(md, BytesStorageMetadata): - return [_create_read_item_for_byteio( - dest_index=MetadataIndex(fqn), - dest_offset=0, - storage_index=MetadataIndex(fqn), - storage_offset=0, - length=0 - )] + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] elif isinstance(obj, ShardedTensor): local_shards = obj.local_shards() elif isinstance(obj, torch.Tensor): local_shards = [_create_shard_from_tensor(obj)] else: raise ValueError( - f"Invalid checkpoint metadata for {fqn}, " + - f"expected BytesStorageMetadata but found {type(md)}" + f"Invalid checkpoint metadata for {fqn}, " + + f"expected BytesStorageMetadata but found {type(md)}" ) - return _create_sharded_read_items( - fqn, - md, - local_shards - ) + return _create_sharded_read_items(fqn, md, local_shards) diff --git a/torch/distributed/_shard/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py similarity index 91% rename from torch/distributed/_shard/checkpoint/resharding.py rename to torch/distributed/checkpoint/resharding.py index f98248f5367b..c00def73b14d 100644 --- a/torch/distributed/_shard/checkpoint/resharding.py +++ b/torch/distributed/checkpoint/resharding.py @@ -4,6 +4,9 @@ ShardMetadata, ) +__all__: List[str] = [] + + def _shards_get_overlap_region_wrt_saved_tensor( saved_shard: ShardMetadata, current_shard: ShardMetadata ) -> List[Tuple[int, int, int, int]]: @@ -38,7 +41,9 @@ def _shards_get_overlap_region_wrt_saved_tensor( if saved_shard_offset > current_shard_offset: offset_for_saved_tensor = 0 - offset_for_current_tensor = saved_shard_offset - current_shard_offset + offset_for_current_tensor = ( + saved_shard_offset - current_shard_offset + ) else: offset_for_saved_tensor = current_shard_offset - saved_shard_offset offset_for_current_tensor = 0 diff --git a/torch/distributed/_shard/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py similarity index 98% rename from torch/distributed/_shard/checkpoint/state_dict_loader.py rename to torch/distributed/checkpoint/state_dict_loader.py index b9ea55c180c7..de94ffabf663 100644 --- a/torch/distributed/_shard/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -10,13 +10,16 @@ from .utils import _DistWrapper +__all__ = ["load_state_dict"] + + def load_state_dict( state_dict: Dict[str, Any], storage_reader: StorageReader, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, - planner: LoadPlanner = None + planner: LoadPlanner = None, ) -> None: """ Load a distributed state_dict in SPMD style. @@ -79,7 +82,6 @@ def load_state_dict( if planner is None: planner = DefaultLoadPlanner() - def local_step(): assert planner is not None metadata = storage_reader.read_metadata() diff --git a/torch/distributed/_shard/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py similarity index 96% rename from torch/distributed/_shard/checkpoint/state_dict_saver.py rename to torch/distributed/checkpoint/state_dict_saver.py index c4792e0c42ef..af18fd0c11dd 100644 --- a/torch/distributed/_shard/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -9,12 +9,11 @@ StorageWriter, ) -from .metadata import ( - Metadata, - STATE_DICT_TYPE -) +from .metadata import Metadata, STATE_DICT_TYPE from .utils import _DistWrapper +__all__ = ["save_state_dict"] + def save_state_dict( state_dict: STATE_DICT_TYPE, @@ -22,7 +21,7 @@ def save_state_dict( process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, - planner: SavePlanner = None + planner: SavePlanner = None, ) -> Metadata: """ Save a distributed model in SPMD style. @@ -92,7 +91,9 @@ def global_step(all_local_plans): nonlocal global_metatadata assert planner is not None - all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + all_local_plans, global_metatadata = planner.create_global_plan( + all_local_plans + ) all_local_plans = storage_writer.prepare_global_plan(all_local_plans) return all_local_plans diff --git a/torch/distributed/_shard/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py similarity index 96% rename from torch/distributed/_shard/checkpoint/storage.py rename to torch/distributed/checkpoint/storage.py index 56bd757765f2..dbc8fda59eac 100644 --- a/torch/distributed/_shard/checkpoint/storage.py +++ b/torch/distributed/checkpoint/storage.py @@ -16,6 +16,9 @@ LoadPlanner, ) +__all__ = ["WriteResult", "StorageWriter", "StorageReader"] + + @dataclass(frozen=True) class WriteResult: index: MetadataIndex @@ -23,6 +26,7 @@ class WriteResult: size_in_bytes: int storage_data: Any + class StorageWriter(abc.ABC): """ Interface used by ``save_state_dict`` to write to storage. @@ -87,9 +91,7 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: @abc.abstractmethod def write_data( - self, - plan: SavePlan, - planner: SavePlanner + self, plan: SavePlan, planner: SavePlanner ) -> Future[List[WriteResult]]: """ Write all items from ``plan`` using ``planner`` to resolve the data. @@ -113,7 +115,9 @@ def write_data( pass @abc.abstractmethod - def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + def finish( + self, metadata: Metadata, results: List[List[WriteResult]] + ) -> None: """ Writes the metadata and marks the current checkpoint as sucessful. @@ -130,6 +134,7 @@ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: """ pass + class StorageReader(abc.ABC): """ Interface used by ``load_state_dict`` to read from storage. @@ -146,6 +151,7 @@ class StorageReader(abc.ABC): 4) (coordinator) prepare_global_plan 5) (all ranks) read_data """ + @abc.abstractmethod def read_metadata(self) -> Metadata: """ diff --git a/torch/distributed/_shard/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py similarity index 77% rename from torch/distributed/_shard/checkpoint/utils.py rename to torch/distributed/checkpoint/utils.py index e82855672c22..a8d2a42d0fca 100644 --- a/torch/distributed/_shard/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -4,7 +4,7 @@ CheckpointException, _wrap_exception, _is_wrapped_exception, - WRAPPED_EXCEPTION + WRAPPED_EXCEPTION, ) import torch @@ -20,12 +20,20 @@ MetadataIndex, ) +__all__ = ["find_tensor_shard", "find_state_dict_object"] -T = TypeVar('T') -R = TypeVar('R') +T = TypeVar("T") +R = TypeVar("R") + + +def _get_failure_dict( + results: List[Union[T, WRAPPED_EXCEPTION]] +) -> Dict[int, WRAPPED_EXCEPTION]: + return cast( + Dict[int, WRAPPED_EXCEPTION], + {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}, + ) -def _get_failure_dict(results: List[Union[T, WRAPPED_EXCEPTION]]) -> Dict[int, WRAPPED_EXCEPTION]: - return cast(Dict[int, WRAPPED_EXCEPTION], {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)}) class _DistWrapper: """ @@ -36,7 +44,13 @@ class _DistWrapper: All variants that take functions are exception robust, meaning that if one or more ranks raise errors, all ranks will observe those. """ - def __init__(self, group: Optional[dist.ProcessGroup], use_dist: bool, coordinator_rank: int): + + def __init__( + self, + group: Optional[dist.ProcessGroup], + use_dist: bool, + coordinator_rank: int, + ): self.group = group self.use_dist = use_dist self.coordinator_rank = coordinator_rank @@ -64,7 +78,8 @@ def broadcast_object(self, object: Optional[T]) -> T: dist.broadcast_object_list( object_list=object_list, group=self.group, - src=self.coordinator_rank) + src=self.coordinator_rank, + ) return cast(T, object_list[0]) def gather_object(self, object: T) -> Optional[List[T]]: @@ -72,13 +87,17 @@ def gather_object(self, object: T) -> Optional[List[T]]: Same as c10d::gather_object but works without distributed enabled. """ if self.use_dist: - gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) if self.is_coordinator else None + gather_objs = ( + cast(List[T], [None] * dist.get_world_size(self.group)) + if self.is_coordinator + else None + ) dist.gather_object( obj=object, object_gather_list=gather_objs if self.is_coordinator else None, dst=self.coordinator_rank, - group=self.group + group=self.group, ) result = gather_objs else: @@ -90,12 +109,12 @@ def all_gather_object(self, object: T) -> List[T]: Same as c10d::all_gather_object but works without distributed enabled. """ if self.use_dist: - gather_objs = cast(List[T], [None] * dist.get_world_size(self.group)) + gather_objs = cast( + List[T], [None] * dist.get_world_size(self.group) + ) dist.all_gather_object( - object_list=gather_objs, - obj=object, - group=self.group + object_list=gather_objs, obj=object, group=self.group ) else: gather_objs = [object] @@ -109,9 +128,11 @@ def scatter_object(self, object_list: Optional[List[T]]) -> T: gather_result = cast(List[T], [None]) dist.scatter_object_list( scatter_object_output_list=gather_result, - scatter_object_input_list=object_list if self.is_coordinator else None, + scatter_object_input_list=object_list + if self.is_coordinator + else None, src=self.coordinator_rank, - group=self.group + group=self.group, ) local_reply = gather_result[0] @@ -124,7 +145,7 @@ def reduce_scatter( self, step: str, map_fun: Callable[[], T], - reduce_fun: Callable[[List[T]], List[R]] + reduce_fun: Callable[[List[T]], List[R]], ) -> R: """ Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter. @@ -150,12 +171,17 @@ def reduce_scatter( if len(node_failures) == 0: try: # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]? - all_results = cast(List[Union[R, CheckpointException]], reduce_fun(cast(List[T], all_data))) + all_results = cast( + List[Union[R, CheckpointException]], + reduce_fun(cast(List[T], all_data)), + ) except BaseException as e: node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: - all_results = [CheckpointException(step, node_failures)] * self.get_world_size() + all_results = [ + CheckpointException(step, node_failures) + ] * self.get_world_size() result = self.scatter_object(all_results) if isinstance(result, CheckpointException): @@ -166,7 +192,7 @@ def all_reduce( self, step: str, map_fun: Callable[[], T], - reduce_fun: Callable[[List[T]], R] + reduce_fun: Callable[[List[T]], R], ) -> R: """ Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast. @@ -244,43 +270,64 @@ def broadcast( try: result = map_fun() except BaseException as e: - result = CheckpointException(step, {self.rank: _wrap_exception(e)}) + result = CheckpointException( + step, {self.rank: _wrap_exception(e)} + ) final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): raise final_result return cast(T, final_result) + def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: if index.offset is None: - raise ValueError(f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided") + raise ValueError( + f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided" + ) shards = tensor.local_shards() # index fast path if index.index is not None: - if len(shards) > index.index and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset: + if ( + len(shards) > index.index + and torch.Size(shards[index.index].metadata.shard_offsets) + == index.offset + ): return shards[index.index] for shard in shards: if torch.Size(shard.metadata.shard_offsets) == index.offset: return shard - raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + -def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: +def find_tensor_shard( + tensor: torch.Tensor, index: MetadataIndex +) -> torch.Tensor: if isinstance(tensor, ShardedTensor): return _find_shard(tensor, index).tensor if index.offset is not None: # special case looking up a tensor by origin if index.offset == torch.Size([0] * len(tensor.size())): return tensor - raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'") + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) return tensor -def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + +def find_state_dict_object( + state_dict: STATE_DICT_TYPE, index: MetadataIndex +) -> Any: if index.fqn not in state_dict: raise ValueError(f"Could not find FQN: '{index.fqn}'") obj = state_dict[index.fqn] if isinstance(obj, torch.Tensor): return find_tensor_shard(obj, index) elif index.offset is not None: - raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'") + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'" + ) return obj From 45c62a337756ff9db97cd64d2d42d9e65dda0a85 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 16 Nov 2022 10:07:14 -0800 Subject: [PATCH 249/453] [ao] making _is_activation_post_process private (#87520) Summary: same function in observer and quantize, consolidated to a single function. Note the definitions were slightly different, I've changed the definition to be maximally inclusive so that the name of the function is more accurate Test Plan: python test/test_public_bindings.py python test/test_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D40709276](https://our.internmc.facebook.com/intern/diff/D40709276) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87520 Approved by: https://github.com/jcaip --- test/allowlist_for_publicAPI.json | 4 ++-- test/quantization/ao_migration/test_ao_migration.py | 2 +- test/quantization/ao_migration/test_quantization.py | 2 +- test/quantization/fx/test_quantize_fx.py | 6 +++--- torch/ao/ns/fx/graph_passes.py | 4 ++-- torch/ao/ns/fx/utils.py | 8 ++++---- torch/ao/quantization/__init__.py | 1 - torch/ao/quantization/fx/_model_report/detector.py | 4 ++-- torch/ao/quantization/fx/convert.py | 6 +++--- torch/ao/quantization/fx/prepare.py | 4 ++-- torch/ao/quantization/fx/qconfig_mapping_utils.py | 6 +++--- torch/ao/quantization/fx/utils.py | 6 +++--- torch/ao/quantization/observer.py | 2 +- torch/ao/quantization/quantize.py | 9 ++------- torch/quantization/quantize.py | 2 +- 15 files changed, 30 insertions(+), 36 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 94ff57700af6..2e1394a72e17 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -786,7 +786,7 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", - "is_activation_post_process", + "_is_activation_post_process", "load_observer_state_dict", "no_observer_set", "prepare", @@ -894,7 +894,7 @@ "convert", "get_observer_dict", "get_unique_devices_", - "is_activation_post_process", + "_is_activation_post_process", "prepare", "prepare_qat", "propagate_qconfig_", diff --git a/test/quantization/ao_migration/test_ao_migration.py b/test/quantization/ao_migration/test_ao_migration.py index accb13da0dcb..260ab32056f6 100644 --- a/test/quantization/ao_migration/test_ao_migration.py +++ b/test/quantization/ao_migration/test_ao_migration.py @@ -19,7 +19,7 @@ def test_function_import_quantize(self): 'convert', 'get_observer_dict', 'get_unique_devices_', - 'is_activation_post_process', + '_is_activation_post_process', 'prepare', 'prepare_qat', 'propagate_qconfig_', diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 9c246e1b7cd8..95c5c7bd6015 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -22,7 +22,7 @@ def test_function_import_quantize(self): 'convert', 'get_observer_dict', 'get_unique_devices_', - 'is_activation_post_process', + '_is_activation_post_process', 'prepare', 'prepare_qat', 'propagate_qconfig_', diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 6c631a24abc6..6cee5e95f21c 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -55,7 +55,6 @@ get_default_qat_qconfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping, - is_activation_post_process, fuse_modules, fuse_modules_qat, prepare, @@ -148,6 +147,7 @@ default_fixed_qparams_range_0to1_observer, default_fixed_qparams_range_neg1to1_observer, MinMaxObserver, + _is_activation_post_process, ) # test utils @@ -3249,7 +3249,7 @@ def _check_node_not_observed(model, arg_node, node): _check_node_not_observed(model, new_node, node) elif arg_node.op == "call_module": self.assertTrue( - not is_activation_post_process(getattr(model, arg_node.target)), + not _is_activation_post_process(getattr(model, arg_node.target)), "Arg: {0} of node: {1} is observed but is not a float tensor".format( arg_node, node ), @@ -4933,7 +4933,7 @@ def forward(self, x): qconfig_dict = func(backend) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) for name, mod in m.named_modules(): - if is_activation_post_process(mod) and mod.dtype == torch.quint8: + if _is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": lower_bnd = 0 upper_bnd = 127 diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index c78b19d2701b..3f4e15685902 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -24,7 +24,7 @@ from torch.ao.ns.fx.mappings import ( get_node_type_to_io_type_map, ) -from torch.ao.quantization.quantize import is_activation_post_process +from torch.ao.quantization.observer import _is_activation_post_process from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set @@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: if node.op == 'call_module': assert isinstance(node.target, str) module = getattr_from_fqn(gm, node.target) - if is_activation_post_process(module): + if _is_activation_post_process(module): node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0) fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] return fqn # type: ignore[return-value] diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index 2993764b8a12..90574dc20248 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -13,10 +13,10 @@ from torch.fx.graph import Node from torch.ao.quantization import ( ObserverBase, - FakeQuantizeBase, + FakeQuantizeBase ) +from torch.ao.quantization.observer import _is_activation_post_process from torch.ao.quantization.utils import getattr_from_fqn -from torch.ao.quantization.quantize import is_activation_post_process from .ns_types import NSNodeTargetType, NSResultsType @@ -256,14 +256,14 @@ def return_first_non_observer_node( """ if node.op == "call_module": node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] - if is_activation_post_process(node_obj): + if _is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] # code duplication intended, not worth refactoring assert isinstance(node.target, str) node_obj = getattr_from_fqn(gm, node.target) - if is_activation_post_process(node_obj): + if _is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 1ba2a60ed3d1..bc8403f32af8 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -114,7 +114,6 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", - "is_activation_post_process", "load_observer_state_dict", "no_observer_set", "per_channel_weight_observer_range_neg_127_to_127", diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index c92733bbc1c3..d398819ddcdd 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -23,7 +23,7 @@ default_equalization_qconfig, EqualizationQConfig, ) -from torch.ao.quantization.quantize import is_activation_post_process +from torch.ao.quantization.observer import _is_activation_post_process # Names for observer insert keys DETECTOR_TARGET_NODE_KEY = "target_node" @@ -1273,7 +1273,7 @@ def _supports_insertion(self, module: nn.Module) -> bool: # case for insertion of module # check if the module has any children and isn't observer num_children = len(list(module.children())) - return num_children == 0 and not is_activation_post_process(module) + return num_children == 0 and not _is_activation_post_process(module) def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: r""" Returns the DetectorQConfigInfo for each module_fqn relavent diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index b5e9cf3bbcb3..0c1249b4858d 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -61,7 +61,6 @@ ) from torch.ao.quantization.quantize import ( _remove_qconfig, - is_activation_post_process, ) from torch.ao.quantization.stubs import DeQuantStub from .custom_config import ( @@ -71,6 +70,7 @@ from .lower_to_fbgemm import lower_to_fbgemm # importing the lib so that the quantized_decomposed ops are registered from ._decomposed import quantized_decomposed_lib # noqa: F401 +from torch.ao.quantization.observer import _is_activation_post_process # TODO: revisit this list. Many helper methods shouldn't be public @@ -218,7 +218,7 @@ def maybe_get_observer_for_node( for maybe_obs_node, _ in node.users.items(): if maybe_obs_node.op == 'call_module': maybe_obs = modules[str(maybe_obs_node.target)] - if is_activation_post_process(maybe_obs): + if _is_activation_post_process(maybe_obs): return maybe_obs return None @@ -725,7 +725,7 @@ def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Gra elif node.op == "call_module": mod = _get_module(node, modules) assert mod is not None - if is_activation_post_process(mod): + if _is_activation_post_process(mod): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index c908e3f3b764..005a9cef45e3 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -16,6 +16,7 @@ ) from ..observer import ( ObserverBase, + _is_activation_post_process ) from ..qconfig import ( _is_reuse_input_qconfig, @@ -78,7 +79,6 @@ ) from torch.ao.quantization.quantize import ( - is_activation_post_process, convert ) @@ -148,7 +148,7 @@ def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool: return isinstance(node, torch.fx.Node) and node.op == "call_module" and \ - is_activation_post_process(modules[str(node.target)]) + _is_activation_post_process(modules[str(node.target)]) def is_input_arg_dtype_supported_by_backend( arg: Argument, diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 0b0407c0b106..26c7effd44db 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -3,8 +3,8 @@ from typing import Callable, Any, Dict, Tuple, Set, List from torch.ao.quantization import QConfig from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals -from torch.ao.quantization.quantize import ( - is_activation_post_process, +from torch.ao.quantization.observer import ( + _is_activation_post_process, ) from torch.ao.quantization.backend_config import ( DTypeConfig, @@ -158,7 +158,7 @@ def generate_node_name_to_qconfig( elif node.op == 'call_module': # if the node is an observer, just continue - don't add it to the qconfig_map - if is_activation_post_process(modules[node.target]): + if _is_activation_post_process(modules[node.target]): continue qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 73fdb0700144..b8bfa4c9d053 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -30,7 +30,7 @@ is_per_channel, to_underlying_dtype, ) -from torch.ao.quantization.quantize import is_activation_post_process +from torch.ao.quantization.observer import _is_activation_post_process from torch.fx import GraphModule, map_arg @@ -447,7 +447,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module result = False elif node.op == 'call_module': assert isinstance(node.target, str) - if is_activation_post_process(modules[node.target]): + if _is_activation_post_process(modules[node.target]): result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] elif node.op == 'call_module': result = False @@ -1040,7 +1040,7 @@ def _activation_post_process_satisfies_dtype_config_constraints( satisfies_constraints = True if activation_post_process_ctr is not None: activation_post_process = activation_post_process_ctr() - assert is_activation_post_process(activation_post_process) + assert _is_activation_post_process(activation_post_process) # If dtypes don't match, don't check the activation_post_process and return True early if activation_post_process.dtype != dtype_with_constraints.dtype: return True diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index e704444d0a6d..26a39c8c2e02 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1437,7 +1437,7 @@ def _is_observer_script_module(mod, obs_type_name): def _is_activation_post_process(module): return ( isinstance(module, torch.ao.quantization.ObserverBase) - or isinstance(module, torch.ao.quantization.FakeQuantize) + or isinstance(module, torch.ao.quantization.FakeQuantizeBase) or _is_observer_script_module(module, "quantization.observer") ) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 9f5537ec8561..b9ef24e35fdb 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -27,10 +27,10 @@ float_qparams_weight_only_qconfig_4bit, _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations +from torch.ao.quantization.observer import _is_activation_post_process __all__ = [ "get_default_custom_config_dict", - "is_activation_post_process", "propagate_qconfig_", "register_activation_post_process_hook", "add_observer_", @@ -62,11 +62,6 @@ def get_default_custom_config_dict(): """ return _DEFAULT_CUSTOM_CONFIG_DICT -def is_activation_post_process(module): - return (isinstance(module, torch.ao.quantization.ObserverBase) or - isinstance(module, torch.ao.quantization.FakeQuantizeBase)) - - def _propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix='', prepare_custom_config_dict=None): r"""This is a helper function for `propagate_qconfig_` @@ -322,7 +317,7 @@ def _remove_activation_post_process(module): # TODO: maybe we should change activation_post_process to _activation_post_process # to prevent it from being used by user if hasattr(module, 'activation_post_process') and \ - is_activation_post_process(module.activation_post_process): + _is_activation_post_process(module.activation_post_process): delattr(module, 'activation_post_process') # remove activation_post_proceess pre and post hooks diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index d9fcf1d04d8b..24d7049ec50e 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -17,7 +17,7 @@ from torch.ao.quantization.quantize import convert from torch.ao.quantization.quantize import get_observer_dict from torch.ao.quantization.quantize import get_unique_devices_ -from torch.ao.quantization.quantize import is_activation_post_process +from torch.ao.quantization.quantize import _is_activation_post_process from torch.ao.quantization.quantize import prepare from torch.ao.quantization.quantize import prepare_qat from torch.ao.quantization.quantize import propagate_qconfig_ From 4908a12542798a3e8641faae6b74f068fdfc6778 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 16 Nov 2022 11:59:40 -0500 Subject: [PATCH 250/453] Reland "SymIntify convolution backend calculation (#89069)"" (#89142) This reverts commit 90db86be108184a6c86c73e1b01012352c72e66b. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89142 Approved by: https://github.com/albanD, https://github.com/malfet --- aten/src/ATen/native/ConvUtils.h | 79 ++++-- aten/src/ATen/native/Convolution.cpp | 319 +++++++++++++----------- aten/src/ATen/native/utils/ParamUtils.h | 21 +- c10/core/SymInt.h | 13 + torch/csrc/Module.cpp | 12 +- 5 files changed, 270 insertions(+), 174 deletions(-) diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index b8e2b0842a00..880ce0c2af54 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -110,8 +110,8 @@ enum class ConvBackend { // This overload is exposed to python for testing, etc. TORCH_API ConvBackend select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias_opt, - IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, - bool transposed, IntArrayRef output_padding, int64_t groups, const at::OptionalIntArrayRef bias_sizes_opt); + IntArrayRef stride, SymIntArrayRef padding, IntArrayRef dilation, + bool transposed, SymIntArrayRef output_padding, int64_t groups, const at::OptionalSymIntArrayRef bias_sizes_opt); TORCH_API at::MemoryFormat _determine_backend_memory_format(const Tensor& input, const Tensor& weight, @@ -200,15 +200,16 @@ static void convolution_shape_check( // as conv_output_size loses information; this is why conv_input_size // takes an extra output_padding argument to resolve the ambiguity. -static inline std::vector conv_output_size( - IntArrayRef input_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +template +static inline std::vector _conv_output_size( + ArrayRef input_size, ArrayRef weight_size, + ArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() ) { // ASSERT(input_size.size() > 2) // ASSERT(input_size.size() == weight_size.size()) bool has_dilation = dilation.size() > 0; auto dim = input_size.size(); - std::vector output_size(dim); + std::vector output_size(dim); output_size[0] = input_size[input_batch_size_dim]; output_size[1] = weight_size[weight_output_channels_dim]; for (const auto d : c10::irange(2, dim)) { @@ -219,40 +220,84 @@ static inline std::vector conv_output_size( return output_size; } -static inline std::vector conv_input_size( - IntArrayRef output_size, IntArrayRef weight_size, - IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +static inline std::vector conv_output_size( + IntArrayRef input_size, IntArrayRef weight_size, + IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +static inline std::vector conv_output_size( + SymIntArrayRef input_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, IntArrayRef stride, IntArrayRef dilation = IntArrayRef() +) { + return _conv_output_size(input_size, weight_size, padding, stride, dilation); +} + +template +std::vector _conv_input_size( + ArrayRef output_size, ArrayRef weight_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { // ASSERT(output_size.size() > 2) // ASSERT(output_size.size() == weight_size.size()) auto dim = output_size.size(); - std::vector input_size(dim); + std::vector input_size(dim); input_size[0] = output_size[output_batch_size_dim]; input_size[1] = weight_size[weight_input_channels_dim] * groups; for (const auto d : c10::irange(2, dim)) { - int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1; - input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) + + auto kernel = (weight_size[d] - 1) * dilation[d - 2] + 1; + input_size[d] = (output_size[d] - 1) * stride[d - 2] - (padding[d - 2] * 2) + kernel + output_padding[d - 2]; } return input_size; } -static inline std::vector conv_weight_size( - IntArrayRef input_size, IntArrayRef output_size, +static inline std::vector conv_input_size( + SymIntArrayRef output_size, SymIntArrayRef weight_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_input_size( + IntArrayRef output_size, IntArrayRef weight_size, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_input_size(output_size, weight_size, padding, output_padding, stride, dilation, groups); +} + +template +std::vector _conv_weight_size( + ArrayRef input_size, ArrayRef output_size, + ArrayRef padding, ArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups ) { auto dim = input_size.size(); - std::vector weight_size(dim); + std::vector weight_size(dim); weight_size[0] = output_size[1]; weight_size[1] = input_size[1] / groups; for (const auto d : c10::irange(2, dim)) { - int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] - + 2 * padding[d - 2] - output_padding[d - 2]; + auto kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2] + + padding[d - 2] * 2 - output_padding[d - 2]; weight_size[d] = (kernel - 1) / dilation[d - 2] + 1; } return weight_size; } +static inline std::vector conv_weight_size( + SymIntArrayRef input_size, SymIntArrayRef output_size, + SymIntArrayRef padding, SymIntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + +static inline std::vector conv_weight_size( + IntArrayRef input_size, IntArrayRef output_size, + IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups +) { + return _conv_weight_size(input_size, output_size, padding, output_padding, stride, dilation, groups); +} + static inline Tensor reshape_bias(int64_t dim, const Tensor& bias) { std::vector shape(dim, 1); shape[1] = -1; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 29b2ce804c80..bf7017f20a4f 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -83,10 +83,11 @@ constexpr int MIOPEN_DIM_MAX = 5; namespace at { namespace native { // Check workload to activate fast depthwise FP16 cudnn conv kernels +template bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { - int w = input.size(3); // same as h - int ch = input.size(1); - int bs = input.size(0); + auto w = at::symint::size(input, 3); // same as h + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); if (stride==1) { if (w >= 7) { // All batch sizes and nb_channels @@ -205,27 +206,28 @@ bool check_cudnn_depthwise_workload(const at::Tensor& input, int stride) { } // simplified version for cudnn 8.2 and above +template bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int stride, const at::Tensor& weight) { // 1D conv - if(input.size(2) == 1 && stride == 1){ + if(at::symint::size(input, 2) == 1 && stride == 1){ return true; } // 2d conv // only square filters - if (weight.size(2) != weight.size(3)) return false; - int filter = weight.size(3); + if (at::symint::size(weight, 2) != at::symint::size(weight, 3)) return false; + auto filter = at::symint::size(weight, 3); // only 1/3/5 filter if (filter != 1 && filter != 3 && filter != 5) return false; // we don't enforce square input but only check width to reduce heuristic space - if (input.size(3) < 7) return false; // min width 7 - int w = input.size(3); + if (at::symint::size(input, 3) < 7) return false; // min width 7 + auto w = at::symint::size(input, 3); // only 1/2 stride, use cudnn for all stride 1 if (stride == 1) return true; if (stride != 2) return false; - int ch = input.size(1); - int bs = input.size(0); + auto ch = at::symint::size(input, 1); + auto bs = at::symint::size(input, 0); // special case since bs1 show good perf in lots of cases if (bs == 1) { if (filter == 1 && w <= 28) return true; @@ -240,13 +242,42 @@ bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, int str } +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const IntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + return xnnpack::use_convolution2d(input, weight, bias_sizes_opt, padding, stride, dilation, groups, transposed); +} + +bool xnnpack_use_convolution2d( + const Tensor& input, + const Tensor& weight, + const at::OptionalSymIntArrayRef bias_sizes_opt, + const SymIntArrayRef padding, + const IntArrayRef stride, + const IntArrayRef dilation, + const int64_t groups, + const bool transposed) { + // Never use xnnpack for symbolic tracing + return false; +} + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +// This struct is templated so that we can run backend selection in a dynamic +// shapes context; all of the real kernel selection in eager mode runs with +// int64_t +template struct ConvParams { std::vector stride; - std::vector padding; + std::vector padding; std::vector dilation; bool transposed; - std::vector output_padding; + std::vector output_padding; int groups; bool benchmark; bool deterministic; @@ -322,12 +353,12 @@ struct ConvParams { #if defined(__ARM_NEON__) // Currently only 3x3 depthwise convolutions on tensors of float are supported. return (input.ndimension() == 4) && - (input.size(1) == groups) && + (at::symint::size(input, 1) == groups) && (weight.ndimension() == 4 ) && - (weight.size(0) % input.size(1) == 0) && - (weight.size(1) == 1) && - (weight.size(2) == 3) && - (weight.size(3) == 3) && + (at::symint::size(weight, 0) % at::symint::size(input, 1) == 0) && + (at::symint::size(weight, 1) == 1) && + (at::symint::size(weight, 2) == 3) && + (at::symint::size(weight, 3) == 3) && (input.device().is_cpu()) && (input.scalar_type() == at::kFloat) && input.is_contiguous() && @@ -345,23 +376,23 @@ struct ConvParams { bool needs_64bit_indexing_no_split(const at::Tensor& input, const at::Tensor& weight) const { constexpr int64_t int_max = std::numeric_limits::max(); - int64_t numel_input = input.numel(); + auto numel_input = at::symint::numel(input); // empty input if (numel_input == 0) { return false; } // input size can not be reduced to the range of int by splitting the batch dim - int64_t n = input.size(0); + auto n = at::symint::size(input, 0); if (numel_input / n > int_max) { return true; } // output size can not be reduced to the range of int by splitting the batch dim - int64_t outsize = 1; + T outsize = 1; if (transposed) { - std::vector o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); + auto o = conv_input_size(at::symint::sizes(input), at::symint::sizes(weight), padding, output_padding, stride, dilation, groups); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } else { - std::vector o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation); + auto o = conv_output_size(at::symint::sizes(input), at::symint::sizes(weight), padding, stride, dilation); outsize = c10::multiply_integers(o.begin() + 1, o.end()); } return outsize > int_max; @@ -417,10 +448,10 @@ struct ConvParams { is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks !is_dilated() && // no dilation supported - (stride[0] == stride[1] || input.size(2) == 1) && // square or 1d - input.size(1) >= 32); // min 32 channels supported) + (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d + at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); + return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); } } // keep (7600 <= cudnn < 8200) code unchanged @@ -430,14 +461,14 @@ struct ConvParams { weight.scalar_type() == kHalf && is_depthwise(input, weight) && input.ndimension() == 4 && // TODO: 5-D contiguous depthwise is not supported yet, need benchmarks - weight.size(2) == weight.size(3) && // only square kernels - input.size(2) >= 7 && // min width/height 7 + at::symint::size(weight, 2) == at::symint::size(weight, 3) && // only square kernels + at::symint::size(input, 2) >= 7 && // min width/height 7 !is_dilated() && // no dilation supported stride[0] == stride[1] && // equal strides - ((weight.size(3) == 3) || (weight.size(3) == 1)) && - input.size(1) >= 32); // min 32 channels supported) + ((at::symint::size(weight, 3) == 3) || (at::symint::size(weight, 3) == 1)) && + at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { - return check_cudnn_depthwise_workload(input, stride[0]); + return check_cudnn_depthwise_workload(input, stride[0]); } else { return false; } @@ -473,12 +504,12 @@ struct ConvParams { !transposed && // or transposed tensors // For 1x1 filters, MKLDNN is faster than THNN when multi-threaded, // but THNN is faster when single-threaded. - (is_strided() || is_dilated() || input.size(0) >= 16 || - weight.size(-1) != 1 || weight.size(-2) != 1 || at::get_num_threads() > 1) && + (is_strided() || is_dilated() || at::symint::size(input, 0) >= 16 || + at::symint::size(weight, -1) != 1 || at::symint::size(weight, -2) != 1 || at::get_num_threads() > 1) && (groups > 1 - || (weight.size(-1) > 3 && weight.size(-2) > 3) - || input.size(0) > 1 - || input.size(0)*input.size(1)*input.size(2)*input.size(3) > 20480) // for some case, native is faster + || (at::symint::size(weight, -1) > 3 && at::symint::size(weight, -2) > 3) + || at::symint::size(input, 0) > 1 + || at::symint::size(input, 0)*at::symint::size(input, 1)*at::symint::size(input, 2)*at::symint::size(input, 3) > 20480) // for some case, native is faster ); #endif @@ -493,20 +524,23 @@ struct ConvParams { !transposed && // or transposed tensors input.ndimension() == 4 && // must be in NCHW format weight.ndimension() == 4 && - (weight.size(2) < 17) && (weight.size(3) < 17) // NNPACK only supports kernels up to 16x16 + (at::symint::size(weight, 2) < 17) && (at::symint::size(weight, 3) < 17) // NNPACK only supports kernels up to 16x16 #if !defined(C10_MOBILE) - && input.size(0) >= 16 // ensure large enough batch size to ensure perf, tuneable + && at::symint::size(input, 0) >= 16 // ensure large enough batch size to ensure perf, tuneable #endif ; #endif return false; } bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt) const { + const at::OptionalArrayRef bias_sizes_opt) const { #if defined(C10_MOBILE) if (!transposed) { - return (input.size(1) == groups) && - xnnpack::use_convolution2d( + // NB: for the call here, it MATTERS that we are templated. If you + // untemplate this to always use SymInt, the function + // xnnpack_use_convolution2d will always return false + return (at::symint::size(input, 1) == groups) && + xnnpack_use_convolution2d( input, weight, bias_sizes_opt, @@ -543,33 +577,12 @@ struct ConvParams { return input.is_cuda() && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && - input.size(1) == groups && + at::symint::size(input, 1) == groups && groups > 1 && // no point if there is only a single group - weight.size(0) % input.size(1) == 0; // output channels must be a multiple of input channels + at::symint::size(weight, 0) % at::symint::size(input, 1) == 0; // output channels must be a multiple of input channels } }; -// Function to select the convolution backend based on the inputs and params. -// This overload is used within the convolution internals but not exposed to python. -// NB: The forward pass provides a bias tensor while the backward pass provides -// a bool indicating whether the bias is defined. This is done to save memory by -// avoiding saving the full bias tensor for backward. -ConvBackend _select_conv_backend( - const Tensor& input, - const Tensor& weight, - const c10::optional& bias_opt, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - -// For BC reasons, have a copy that does not require bias_opt -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params); - DEFINE_DISPATCH(conv_depthwise2d_backward_stub); DEFINE_DISPATCH(conv_depthwise3d_backward_stub); DEFINE_DISPATCH(cudnn_convolution_backward_stub); @@ -591,13 +604,14 @@ REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub); REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub); -std::ostream& operator<<(std::ostream & out, const ConvParams& params) { +template +std::ostream& operator<<(std::ostream & out, const ConvParams& params) { out << "ConvParams {" << " stride = " << IntArrayRef{params.stride} - << " padding = " << IntArrayRef{params.padding} + << " padding = " << ArrayRef{params.padding} << " dilation = " << IntArrayRef{params.dilation} << " transposed = " << params.transposed - << " output_padding = " << IntArrayRef{params.output_padding} + << " output_padding = " << ArrayRef{params.output_padding} << " groups = " << params.groups << " benchmark = " << params.benchmark << " deterministic = " << params.deterministic @@ -607,9 +621,10 @@ std::ostream& operator<<(std::ostream & out, const ConvParams& params) { return out; } +template static void check_shape_forward(const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, const at::Tensor& bias, - const ConvParams& params) { + const c10::ArrayRef& weight_sizes, const at::Tensor& bias, + const ConvParams& params) { int64_t k = input.ndimension(); int64_t weight_dim = weight_sizes.size(); int64_t groups = params.groups; @@ -624,7 +639,7 @@ static void check_shape_forward(const at::Tensor& input, TORCH_CHECK(weight_dim == k, "Expected ", weight_dim, "-dimensional input for ", weight_dim, "-dimensional weight ", weight_sizes, ", but got ", k, "-dimensional input of size ", - input.sizes(), " instead"); + at::symint::sizes(input), " instead"); TORCH_CHECK(weight_sizes[0] >= groups, "Given groups=", groups, ", expected weight to be at least ", groups, " at dimension 0, but got weight of size ", weight_sizes, " instead"); @@ -634,23 +649,23 @@ static void check_shape_forward(const at::Tensor& input, "] instead"); if (!transposed) { - std::vector input_shape; - std::vector kernel_shape; + std::vector input_shape; + std::vector kernel_shape; bool kernel_size_correct = true; - TORCH_CHECK(input.size(1) == (weight_sizes[1] * groups), + TORCH_CHECK(at::symint::size(input, 1) == (weight_sizes[1] * groups), "Given groups=", groups, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", - (weight_sizes[1] * groups), " channels, but got ", input.size(1), + (weight_sizes[1] * groups), " channels, but got ", at::symint::size(input, 1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[0]), + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[0]), "Given weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[0], " elements", - ", but got bias of size ", bias.sizes(), " instead"); + ", but got bias of size ", at::symint::sizes(bias), " instead"); for (const auto i : c10::irange(2, k)) { - input_shape.push_back(input.size(i) + 2 * padding[i-2]); + input_shape.push_back(at::symint::size(input, i) + 2 * padding[i-2]); // log new kernel size considering dilation kernel_shape.push_back(dilation[i-2] * (weight_sizes[i]-1) + 1); if (input_shape.back() < kernel_shape.back()) { @@ -676,22 +691,23 @@ static void check_shape_forward(const at::Tensor& input, "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); } } else { // transposed - TORCH_CHECK(input.size(1) == weight_sizes[0], + TORCH_CHECK(at::symint::size(input, 1) == weight_sizes[0], "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected input", input.sizes(), " to have ", weight_sizes[0], - " channels, but got ", input.size(1), " channels instead"); - TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && bias.size(0) == weight_sizes[1] * groups), + " channels, but got ", at::symint::size(input, 1), " channels instead"); + TORCH_CHECK(!bias.defined() || (bias.ndimension() == 1 && at::symint::size(bias, 0) == weight_sizes[1] * groups), "Given transposed=", transposed, ", weight of size ", weight_sizes, ", expected bias to be 1-dimensional with ", weight_sizes[1] * groups, " elements", ", but got bias of size ", bias.sizes(), " instead"); } } +template static void check_shape_backward( const at::Tensor& input, - const c10::IntArrayRef& weight_sizes, - const ConvParams& params) { - check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); + const c10::ArrayRef& weight_sizes, + const ConvParams& params) { + check_shape_forward(input, weight_sizes, /*bias=*/ Tensor(), params); } // Given an input tensor and an expected number of spatial dimensions, checks that the @@ -1149,71 +1165,25 @@ at::Tensor convolution_overrideable( TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); } -// Selects a backend for convolution based on the inputs and params. -ConvBackend select_conv_backend( - const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, - IntArrayRef stride_, IntArrayRef padding_, IntArrayRef dilation_, - bool transposed_, IntArrayRef output_padding_, int64_t groups_, const at::OptionalIntArrayRef bias_sizes_opt) { - c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); - const Tensor& bias = *bias_maybe_owned; - - auto& ctx = at::globalContext(); - auto k = weight_r.ndimension(); - int64_t dim = k - 2; - ConvParams params; - params.stride = expand_param_if_needed(stride_, "stride", dim); - params.padding = expand_param_if_needed(padding_, "padding", dim); - params.dilation = expand_param_if_needed(dilation_, "dilation", dim); - params.transposed = transposed_; - params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); - params.groups = groups_; - params.benchmark = ctx.benchmarkCuDNN(); - params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); - params.cudnn_enabled = ctx.userEnabledCuDNN(); - params.allow_tf32 = ctx.allowTF32CuDNN(); - - auto input = input_r; - auto weight = weight_r; - check_shape_forward(input, weight.sizes(), bias, params); - - // Expand 1d -> 2d. - // This is only done for backends that don't natively support 1d spatial input. - if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { - // avoid accidentally going through NHWC for permuted 3d input. - input = input.contiguous(); - params.view1d_as_2d(); - input = view4d(input); - weight = view4d(weight); - } - - auto bias_sizes = bias.defined() ? c10::optional(bias.sizes()) : bias_sizes_opt; - bool need_backward = GradMode::is_enabled() && - (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); -} - -ConvBackend select_conv_backend( - const Tensor& input, - const Tensor& weight, - const at::OptionalIntArrayRef bias_sizes_opt, - const bool need_backward, - const ConvParams& params) { - return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); -} - +// Function to select the convolution backend based on the inputs and params. +// This overload is used within the convolution internals but not exposed to python. +// NB: The forward pass provides a bias tensor while the backward pass provides +// a bool indicating whether the bias is defined. This is done to save memory by +// avoiding saving the full bias tensor for backward. +template ConvBackend _select_conv_backend( const Tensor& input, const Tensor& weight, const c10::optional& bias, - const at::OptionalIntArrayRef bias_sizes_opt, + const at::OptionalArrayRef bias_sizes_opt, const bool need_backward, - const ConvParams& params) { + const ConvParams& params) { // don't send empty inputs through backends - if (input.size(0) == 0 || input.size(1) == 0) { + if (at::symint::size(input, 0) == 0 || at::symint::size(input, 1) == 0) { return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty; - } else if (input.numel() == 0) { - TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", input.sizes()); + } else if (at::symint::numel(input) == 0) { + TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes(input)); } if (params.is_depthwise(input, weight)) { @@ -1305,12 +1275,65 @@ ConvBackend _select_conv_backend( AT_ERROR("unsupported ConvNd parameters"); } +// Selects a backend for convolution based on the inputs and params. +ConvBackend select_conv_backend( + const Tensor& input_r, const Tensor& weight_r, const c10::optional& bias_opt, + IntArrayRef stride_, SymIntArrayRef padding_, IntArrayRef dilation_, + bool transposed_, SymIntArrayRef output_padding_, int64_t groups_, const at::OptionalSymIntArrayRef bias_sizes_opt) { + c10::MaybeOwned bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + auto& ctx = at::globalContext(); + auto k = weight_r.ndimension(); + int64_t dim = k - 2; + ConvParams params; + params.stride = expand_param_if_needed(stride_, "stride", dim); + params.padding = expand_param_if_needed(padding_, "padding", dim); + params.dilation = expand_param_if_needed(dilation_, "dilation", dim); + params.transposed = transposed_; + params.output_padding = expand_param_if_needed(output_padding_, "output_padding", dim); + params.groups = groups_; + params.benchmark = ctx.benchmarkCuDNN(); + params.deterministic = ctx.deterministicCuDNN() || ctx.deterministicAlgorithms(); + params.cudnn_enabled = ctx.userEnabledCuDNN(); + params.allow_tf32 = ctx.allowTF32CuDNN(); + + auto input = input_r; + auto weight = weight_r; + check_shape_forward(input, weight.sym_sizes(), bias, params); + + // Expand 1d -> 2d. + // This is only done for backends that don't natively support 1d spatial input. + if (k == 3 && !input.is_mkldnn() && !input.is_xpu()) { + // avoid accidentally going through NHWC for permuted 3d input. + input = input.contiguous(); + params.view1d_as_2d(); + input = view4d(input); + weight = view4d(weight); + } + + auto bias_sizes = bias.defined() ? c10::optional(bias.sym_sizes()) : bias_sizes_opt; + bool need_backward = GradMode::is_enabled() && + (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); + return _select_conv_backend(input, weight, bias, bias_sizes, need_backward, params); +} + +// For BC reasons, have a copy that does not require bias_opt +ConvBackend select_conv_backend( + const Tensor& input, + const Tensor& weight, + const at::OptionalIntArrayRef bias_sizes_opt, + const bool need_backward, + const ConvParams& params) { + return _select_conv_backend(input, weight, {}, bias_sizes_opt, need_backward, params); +} + at::Tensor _convolution_nogroup_backend( const Tensor& input, const Tensor& weight, const Tensor& bias, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::NnpackSpatial: @@ -1341,7 +1364,7 @@ at::Tensor _convolution_nogroup_backend( static inline std::vector calc_output_size( const Tensor& input, const Tensor& weight, - const ConvParams& params) { + const ConvParams& params) { std::vector output_size = params.transposed ? conv_input_size(input.sizes(), weight.sizes(), params.padding, params.output_padding, params.stride, params.dilation, params.groups) : @@ -1422,7 +1445,7 @@ at::Tensor _convolution( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); TORCH_CHECK(groups_ > 0, "non-positive groups is not supported"); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1450,7 +1473,7 @@ at::Tensor _convolution( auto bias_sizes_opt = bias.defined() ? c10::optional(bias.sizes()) : c10::nullopt; bool need_backward = GradMode::is_enabled() && (input.requires_grad() || weight.requires_grad() || (bias.defined() && bias.requires_grad())); - ConvBackend backend = _select_conv_backend(input, weight, bias, bias_sizes_opt, need_backward, params); + ConvBackend backend = _select_conv_backend(input, weight, bias, c10::OptionalIntArrayRef(bias_sizes_opt), need_backward, params); at::MemoryFormat backend_memory_format = determine_backend_memory_format(input, weight, backend); // Call the backend. @@ -1663,7 +1686,7 @@ std::tuple _convolution_double_backward( const c10::option auto weight = weight_r; int64_t dim = weight.ndimension() - 2; - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride_, "stride", dim); params.padding = expand_param_if_needed(padding_, "padding", dim); params.dilation = expand_param_if_needed(dilation_, "dilation", dim); @@ -1726,7 +1749,7 @@ std::tuple _convolution_double_backward( const c10::option if (ggI.defined()) { // Modified params with correct padding - ConvParams gw_conv_params(params); + ConvParams gw_conv_params(params); // Disable groups as they are handled separately auto groups = gw_conv_params.groups; @@ -1795,7 +1818,7 @@ std::tuple _convolution_double_backward( const c10::option Tensor gI; if (input.numel() != 0) { if (ggW.defined()) { - ConvParams gi_conv_params(params); + ConvParams gi_conv_params(params); gi_conv_params.transposed = !params.transposed; if (params.transposed) { @@ -1851,7 +1874,7 @@ std::tuple _convolution_backward_nogroup_bac const Tensor& weight, const std::array output_mask, const ConvBackend backend, - const ConvParams& params) { + const ConvParams& params) { auto kernel_size = weight.sizes().slice(2); switch(backend) { case ConvBackend::Slow2d: @@ -1916,7 +1939,7 @@ std::tuple convolution_backward( TORCH_CHECK(dim > 0, "weight should have at least three dimensions"); auto& ctx = at::globalContext(); - ConvParams params; + ConvParams params; params.stride = expand_param_if_needed(stride, "stride", dim); params.padding = expand_param_if_needed(padding, "padding", dim); params.dilation = expand_param_if_needed(dilation, "dilation", dim); diff --git a/aten/src/ATen/native/utils/ParamUtils.h b/aten/src/ATen/native/utils/ParamUtils.h index 376467ff79cf..adb5f1cfa49f 100644 --- a/aten/src/ATen/native/utils/ParamUtils.h +++ b/aten/src/ATen/native/utils/ParamUtils.h @@ -6,12 +6,13 @@ namespace at { namespace native { -inline std::vector expand_param_if_needed( - IntArrayRef list_param, +template +inline std::vector _expand_param_if_needed( + ArrayRef list_param, const char* param_name, int64_t expected_dim) { if (list_param.size() == 1) { - return std::vector(expected_dim, list_param[0]); + return std::vector(expected_dim, list_param[0]); } else if ((int64_t)list_param.size() != expected_dim) { std::ostringstream ss; ss << "expected " << param_name << " to be a single integer value or a " @@ -23,5 +24,19 @@ inline std::vector expand_param_if_needed( } } +inline std::vector expand_param_if_needed( + IntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + +inline std::vector expand_param_if_needed( + SymIntArrayRef list_param, + const char* param_name, + int64_t expected_dim) { + return _expand_param_if_needed(list_param, param_name, expected_dim); +} + } // namespace native } // namespace at diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index 9ab72a077680..6355f1339505 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -235,6 +235,19 @@ inline c10::SymInt multiply_integers(const C& container) { [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); } +template < + typename Iter, + typename = std::enable_if_t::value_type, + c10::SymInt>::value>> +inline c10::SymInt multiply_integers(Iter begin, Iter end) { + return std::accumulate( + begin, + end, + c10::SymInt(1), + [](const c10::SymInt& a, const c10::SymInt& b) { return a * b; }); +} + inline SymInt operator+(int64_t a, const SymInt& b) { return c10::SymInt(a) + b; } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index b8693a484ed9..607373625724 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1408,10 +1408,10 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias_opt, at::IntArrayRef stride_, - at::IntArrayRef padding_, + at::SymIntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::IntArrayRef output_padding_, + at::SymIntArrayRef output_padding_, int64_t groups_) { return at::native::select_conv_backend( input, @@ -1442,13 +1442,13 @@ Call this whenever a new thread is created in order to propagate values from const at::Tensor& weight, const c10::optional& bias, at::IntArrayRef stride_, - at::IntArrayRef padding_, + at::SymIntArrayRef padding_, at::IntArrayRef dilation_, bool transposed_, - at::IntArrayRef output_padding_, + at::SymIntArrayRef output_padding_, int64_t groups_, - c10::optional> bias_sizes_opt) { - c10::OptionalArrayRef ref = c10::nullopt; + c10::optional> bias_sizes_opt) { + c10::OptionalArrayRef ref = c10::nullopt; if (bias_sizes_opt) { ref = (*bias_sizes_opt); } From 305b9b1f0e5802437a7ed8169e0ff3fb5c06d4ec Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Wed, 16 Nov 2022 21:54:20 +0000 Subject: [PATCH 251/453] Fix XLASymNode.str() no str() attribute error (#89093) This fixes https://github.com/pytorch/xla/issues/4199 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89093 Approved by: https://github.com/ezyang --- torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/__init__.py b/torch/__init__.py index 6def80d1dc59..02765c4aeee8 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -235,7 +235,7 @@ def __sym_float__(self): raise AssertionError("type stub not overridden") def __repr__(self): - return self.node.str() + return str(self.node) # For BC; direct access of node is OK too def get_pyobj(self): From 640af8d70a3adc7727661c15260d42fe931e9de4 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 16 Nov 2022 21:54:24 +0000 Subject: [PATCH 252/453] More dynamo dashboard improvements (#89155) A number of dashboard improvements: - Add accuracy failures to warnings section - Add regression detection to all metrics (speedup, compile time, peak memory), not just accuracy - Add testing flag to update-dashboard to prevent image/comment uploads - Add section for comparing summary statistics (passrate, speedup) between 2 most recent reports - Show names of reports for summary stats diff and regression detection sections - Remove metric graphs from the comment (they can still be found in the generated text file) Sample comment: https://github.com/pytorch/torchdynamo/issues/1831#issuecomment-1317565972 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89155 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 352 ++++++++++++++++++++++++------------ 1 file changed, 233 insertions(+), 119 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 319ff677db4f..8012e82607cf 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -36,6 +36,7 @@ import re import shutil import subprocess +import sys import tempfile from collections import defaultdict from datetime import datetime @@ -133,10 +134,15 @@ def flag_compression_ratio(x): return x < 0.9 +def flag_accuracy(x): + return "pass" not in x + + FLAG_FNS = { "speedup": flag_speedup, "compilation_latency": flag_compilation_latency, "compression_ratio": flag_compression_ratio, + "accuracy": flag_accuracy, } @@ -216,6 +222,12 @@ def parse_args(): default=False, help="Updates to dashboard", ) + parser.add_argument( + "--update-dashboard-test", + action="store_true", + default=False, + help="Do not udpate lookup file or upload images/comments when --update-dashboard is specified", + ) parser.add_argument( "--dashboard-image-uploader", default=DASHBOARD_DEFAULTS["dashboard_image_uploader"], @@ -412,6 +424,20 @@ def archive(src_dir, dest_dir_prefix, archive_name, dtype): print(f"copied contents of {src_dir} to {dest}") +def get_metric_title(metric): + if metric == "speedup": + return "Performance speedup" + elif metric == "accuracy": + return "Accuracy" + elif metric == "compilation_latency": + return "Compilation latency (sec)" + elif metric == "compression_ratio": + return "Peak Memory Compression Ratio" + elif metric == "abs_latency": + return "Absolute latency (ms)" + raise RuntimeError("unknown metric") + + class Parser: def __init__( self, suites, devices, dtypes, compilers, flag_compilers, mode, output_dir @@ -693,28 +719,18 @@ def flag_bad_entries(self, suite, metric, flag_fn): df = df.assign(suite=suite) return df.reindex(columns=["suite", "name"] + self.flag_compilers) - def get_metric_title(self, metric): - if metric == "speedup": - return "Performance speedup" - elif metric == "accuracy": - return "Accuracy" - elif metric == "compilation_latency": - return "Compilation latency (sec)" - elif metric == "compression_ratio": - return "Peak Memory Compression Ratio" - elif metric == "abs_latency": - return "Absolute latency (ms)" - raise RuntimeError("unknown metric") - def generate_warnings(self): title = "## Warnings ##" body = ( "We flag models where:\n\n" - " - speedup < 0.95x\n" + " - accuracy fails\n" + " - speedup < 0.95x (NOTE: 0.0 speedup typically signifies a failure in the performance test)\n" " - compilation latency > 120 sec.\n" - " - compression ratio < 0.9\n\n" + " - compression ratio < 0.9\n" + "\n" ) for metric in [ + "accuracy", "speedup", "compilation_latency", "compression_ratio", @@ -728,7 +744,7 @@ def generate_warnings(self): tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") str_io = io.StringIO() str_io.write("\n") - str_io.write(self.get_metric_title(metric) + " warnings\n") + str_io.write(get_metric_title(metric) + " warnings\n") str_io.write("~~~\n") str_io.write(f"{tabform}\n") str_io.write("~~~\n") @@ -753,7 +769,7 @@ def prepare_message(self, suite): tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") str_io = io.StringIO() str_io.write("\n") - str_io.write(self.get_metric_title(metric) + "\n") + str_io.write(get_metric_title(metric) + "\n") str_io.write("~~~\n") str_io.write(f"{tabform}\n") str_io.write("~~~\n") @@ -779,18 +795,15 @@ def gen_summary_files(self): with open(f"{self.output_dir}/gh_executive_summary.txt", "w") as gh_fh: gh_fh.write(self.executive_summary) - print(self.executive_summary) with open(f"{self.output_dir}/gh_warnings.txt", "w") as gh_fh: warnings_body = self.generate_warnings() gh_fh.write(warnings_body) - print(warnings_body) str_io = io.StringIO() for suite in self.suites: str_io.write(self.prepare_message(suite)) str_io.write("\n") - print(str_io.getvalue()) with open(f"{self.output_dir}/gh_{self.mode}.txt", "w") as gh_fh: gh_fh.write(str_io.getvalue()) @@ -820,10 +833,86 @@ def get_date(log_info): return datetime.strptime(f"{log_info.day}", "%j").strftime("%m-%d") -class AccuracyRegressionTracker: +def find_last_2_with_filenames(lookup_file, dashboard_archive_path, dtype, filenames): + df = pd.read_csv(lookup_file, names=("day", "mode", "prec", "path")) + df = df[df["mode"] == "performance"] + df = df[df["prec"] == dtype] + df = df[::-1] + last2 = [] + for path in df["path"]: + output_dir = os.path.join(dashboard_archive_path, path) + fullpaths = [ + os.path.join(dashboard_archive_path, path, name) for name in filenames + ] + if all([os.path.exists(fullpath) for fullpath in fullpaths]): + last2.append(output_dir) + if len(last2) >= 2: + return last2 + return None + + +class SummaryStatDiffer: + def __init__(self, args): + self.args = args + self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") + assert os.path.exists(self.lookup_file) + + def generate_diff(self, last2, filename, caption): + df_cur, df_prev = [pd.read_csv(os.path.join(path, filename)) for path in last2] + df_merge = df_cur.merge(df_prev, on="Compiler", suffixes=("_cur", "_prev")) + data = {col: [] for col in ("compiler", "suite", "prev_value", "cur_value")} + for _, row in df_merge.iterrows(): + if row["Compiler"] in self.args.flag_compilers: + for suite in self.args.suites: + data["compiler"].append(row["Compiler"]) + data["suite"].append(suite) + data["prev_value"].append(row[suite + "_prev"]) + data["cur_value"].append(row[suite + "_cur"]) + + df = pd.DataFrame(data) + tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") + str_io = io.StringIO() + str_io.write("\n") + str_io.write(f"{caption}\n") + str_io.write("~~~\n") + str_io.write(f"{tabform}\n") + str_io.write("~~~\n") + return str_io.getvalue() + + def generate_comment(self): + title = "## Summary Statistics Diff ##\n" + body = ( + "For each relevant compiler, we compare the summary statistics " + "for the most 2 recent reports that actually run the compiler.\n\n" + ) + dtype = self.args.dtypes[0] + last2 = find_last_2_with_filenames( + self.lookup_file, + self.args.dashboard_archive_path, + dtype, + ["geomean.csv", "passrate.csv"], + ) + + if last2 is None: + body += "Could not find most 2 recent reports.\n\n" + else: + for state, path in zip(("Current", "Previous"), last2): + body += f"{state} report name: {path}\n\n" + body += self.generate_diff(last2, "passrate.csv", "Passrate diff") + body += self.generate_diff( + last2, "geomean.csv", "Geometric mean speedup diff" + ) + + comment = generate_dropdown_comment(title, body) + + with open(f"{self.args.output_dir}/gh_summary_diff.txt", "w") as gh_fh: + gh_fh.write(comment) + + +class RegressionDetector: """ - Compares the most recent 2 accuracy benchmarks to find previously - passing models that now fail. + Compares the most recent 2 benchmarks to find previously unflagged models + that are now flagged. """ def __init__(self, args): @@ -831,97 +920,113 @@ def __init__(self, args): self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") assert os.path.exists(self.lookup_file) - def find_last_2(self, suite, device, dtype, compiler): - df = pd.read_csv(self.lookup_file, names=("day", "mode", "prec", "path")) - df = df[df["mode"] == "performance"] - df = df[df["prec"] == dtype] - df = df[::-1] - parsers = [] - for path in df["path"]: - output_dir = os.path.join(self.args.dashboard_archive_path, path) - if os.path.exists( - os.path.join( - output_dir, - generate_csv_name( - self.args, dtype, suite, device, compiler, "accuracy" - ), - ) - ): - parsers.append( - ParsePerformanceLogs( - [suite], - [device], - [dtype], - [compiler], - [compiler], - get_mode(self.args), - output_dir, - ) - ) - if len(parsers) >= 2: - return parsers - return None - def generate_comment(self): - title = "## Accuracy Regressions ##\n" + title = "## Recent Regressions ##\n" body = ( "For each relevant compiler, we compare the most recent 2 reports " - "(that actually run the compiler) to find models where previously " - "successful accuracy tests now fail.\n\n" + "(that actually run the compiler) to find previously unflagged " + "models that are now flagged as problematic (according to the " + "'Warnings' section).\n\n" ) dtype = self.args.dtypes[0] device = self.args.devices[0] - regressions_present = False for suite in self.args.suites: - dfs = [] - for compiler in self.args.flag_compilers: - last2 = self.find_last_2(suite, device, dtype, compiler) - if last2 is None: - continue + body += f"### Regressions for {suite} ###\n" + last2 = {} - df_cur, df_prev = [ - last2[i].untouched_parsed_frames[suite]["accuracy"] for i in (0, 1) + for compiler in self.args.flag_compilers: + filenames = [ + generate_csv_name( + self.args, dtype, suite, device, compiler, testing + ) + for testing in ["performance", "accuracy"] ] - df_merge = df_cur.merge(df_prev, on="name", suffixes=("_cur", "_prev")) - flag = np.logical_and( - df_merge[compiler + "_prev"].apply(lambda x: "pass" in x), - df_merge[compiler + "_cur"].apply(lambda x: "pass" not in x), + compiler_last2 = find_last_2_with_filenames( + self.lookup_file, self.args.dashboard_archive_path, dtype, filenames ) - df_bad = df_merge[flag] - dfs.append( - pd.DataFrame( - data={ - "compiler": compiler, - "name": df_bad["name"], - "prev_status": df_bad[compiler + "_prev"], - "cur_status": df_bad[compiler + "_cur"], - } + if compiler_last2 is not None: + last2[compiler] = [ + ParsePerformanceLogs( + [suite], + [device], + [dtype], + [compiler], + [compiler], + get_mode(self.args), + output_dir, + ) + for output_dir in compiler_last2 + ] + for state, path in zip(("Current", "Previous"), compiler_last2): + body += ( + f"{state} report name (compiler: {compiler}, " + f"suite: {suite}): {path}\n\n" + ) + + for metric in [ + "accuracy", + "speedup", + "compilation_latency", + "compression_ratio", + ]: + regressions_present = False + dfs = [] + for compiler in self.args.flag_compilers: + if last2[compiler] is None: + continue + + df_cur, df_prev = [ + last2[compiler][i].untouched_parsed_frames[suite][metric] + for i in (0, 1) + ] + df_merge = df_cur.merge( + df_prev, on="name", suffixes=("_cur", "_prev") + ) + flag_fn = FLAG_FNS[metric] + flag = np.logical_and( + df_merge[compiler + "_prev"].apply( + lambda x: not pd.isna(x) and not flag_fn(x) + ), + df_merge[compiler + "_cur"].apply( + lambda x: not pd.isna(x) and flag_fn(x) + ), + ) + df_bad = df_merge[flag] + dfs.append( + pd.DataFrame( + data={ + "compiler": compiler, + "name": df_bad["name"], + "prev_status": df_bad[compiler + "_prev"], + "cur_status": df_bad[compiler + "_cur"], + } + ) ) - ) - if not dfs: - continue - df = pd.concat(dfs, axis=0) - if df.empty: - continue - regressions_present = True - tabform = tabulate(df, headers="keys", tablefmt="pretty", showindex="never") - str_io = io.StringIO() - str_io.write("\n") - str_io.write(f"Accuracy regressions for {suite}\n") - str_io.write("~~~\n") - str_io.write(f"{tabform}\n") - str_io.write("~~~\n") - body += str_io.getvalue() + if not dfs: + continue + df = pd.concat(dfs, axis=0) + if df.empty: + continue + regressions_present = True + tabform = tabulate( + df, headers="keys", tablefmt="pretty", showindex="never" + ) + str_io = io.StringIO() + str_io.write("\n") + str_io.write(f"{get_metric_title(metric)} regressions\n") + str_io.write("~~~\n") + str_io.write(f"{tabform}\n") + str_io.write("~~~\n") + body += str_io.getvalue() - if not regressions_present: - body += "No accuracy regressions found.\n" + if not regressions_present: + body += "No regressions found.\n" comment = generate_dropdown_comment(title, body) - with open(f"{self.args.output_dir}/gh_accuracy_regression.txt", "w") as gh_fh: + with open(f"{self.args.output_dir}/gh_metric_regression.txt", "w") as gh_fh: gh_fh.write(comment) - print(comment) class RegressionTracker: @@ -955,13 +1060,14 @@ def find_last_k(self): def generate_comment(self): title = "## Metrics over time ##\n" str_io = io.StringIO() - for name in glob.glob(self.args.output_dir + "/*over_time.png"): - output = ( - subprocess.check_output([self.args.dashboard_image_uploader, name]) - .decode("ascii") - .rstrip() - ) - str_io.write(f"\n{name} : ![]({output})\n") + if not self.args.update_dashboard_test: + for name in glob.glob(self.args.output_dir + "/*over_time.png"): + output = ( + subprocess.check_output([self.args.dashboard_image_uploader, name]) + .decode("ascii") + .rstrip() + ) + str_io.write(f"\n{name} : ![]({output})\n") comment = generate_dropdown_comment(title, str_io.getvalue()) with open(f"{self.args.output_dir}/gh_regression.txt", "w") as gh_fh: @@ -1032,9 +1138,10 @@ def __init__(self, args): self.lookup_file = os.path.join(self.args.dashboard_archive_path, "lookup.csv") assert os.path.exists(self.lookup_file) try: - self.update_lookup_file() + if not self.args.update_dashboard_test: + self.update_lookup_file() except subprocess.CalledProcessError: - print("failed to update lookup file") + sys.stderr.write("failed to update lookup file\n") def update_lookup_file(self): dtype = self.args.dtypes[0] @@ -1063,14 +1170,17 @@ def archive(self): def upload_graphs(self): title = "## Performance graphs ##\n" str_io = io.StringIO() - for name in glob.glob(self.output_dir + "/*png"): - if "over_time" not in name: - output = ( - subprocess.check_output([self.args.dashboard_image_uploader, name]) - .decode("ascii") - .rstrip() - ) - str_io.write(f"\n{name} : ![]({output})\n") + if not self.args.update_dashboard_test: + for name in glob.glob(self.output_dir + "/*png"): + if "over_time" not in name: + output = ( + subprocess.check_output( + [self.args.dashboard_image_uploader, name] + ) + .decode("ascii") + .rstrip() + ) + str_io.write(f"\n{name} : ![]({output})\n") comment = generate_dropdown_comment(title, str_io.getvalue()) with open(f"{self.output_dir}/gh_graphs.txt", "w") as gh_fh: @@ -1080,9 +1190,10 @@ def gen_comment(self): files = [ "gh_title.txt", "gh_executive_summary.txt", + "gh_summary_diff.txt", "gh_warnings.txt", - "gh_regression.txt", - "gh_accuracy_regression.txt", + # "gh_regression.txt", + "gh_metric_regression.txt", "gh_training.txt", "gh_graphs.txt", ] @@ -1120,7 +1231,8 @@ def comment_on_gh(self, comment): def update(self): self.upload_graphs() - AccuracyRegressionTracker(self.args).generate_comment() + SummaryStatDiffer(self.args).generate_comment() + RegressionDetector(self.args).generate_comment() try: RegressionTracker(self.args).diff() except Exception as e: @@ -1129,9 +1241,11 @@ def update(self): gh_fh.write("") comment = self.gen_comment() - self.comment_on_gh(comment) + print(comment) - self.archive() + if not self.args.update_dashboard_test: + self.comment_on_gh(comment) + self.archive() if __name__ == "__main__": From e70f446a16f25b7f344d256c8fa0b78769920d00 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 16 Nov 2022 21:59:31 +0000 Subject: [PATCH 253/453] [Dynamo] Fix bug in NamedTupleVariable (#89110) Fixes https://github.com/pytorch/torchdynamo/issues/1866 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89110 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 14 ++++++++++++++ torch/_dynamo/variables/lists.py | 3 +++ 2 files changed, 17 insertions(+) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index aef364d76994..b3cddcbf1dff 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -495,6 +495,20 @@ def fn(packed): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 3) + def test_namedtuple3(self): + def fn(x, packed): + if isinstance(packed, mytuple): + return x + 1 + else: + return x - 1 + + x = torch.rand([2, 3]) + packed = mytuple(1, 2, 3) + ref = fn(x, packed) + opt_fn = torch._dynamo.optimize("eager")(fn) + res = opt_fn(x, packed) + self.assertTrue(same(ref, res)) + def test_range_input(self): def fn(a, rng): x = a diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 151619d0e4ab..70c6da07adb5 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -378,6 +378,9 @@ def __init__(self, items, tuple_cls, **kwargs): def python_type(self): return self.tuple_cls + def as_python_constant(self): + return self.python_type()(*[x.as_python_constant() for x in self.items]) + def reconstruct(self, codegen): create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls) codegen.append_output(codegen._create_load_const(create_fn)) From ee1d375bf98f6e4c69b2d6f3aa1c702cb652d2f2 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 16 Nov 2022 18:36:24 +0000 Subject: [PATCH 254/453] [FSDP] Add fast path for `NO_SHARD` `clip_grad_norm_()` (#89137) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89137 Approved by: https://github.com/rohan-varma --- .../fsdp/test_fsdp_clip_grad_norm.py | 29 +++++++++++++++++++ .../fsdp/fully_sharded_data_parallel.py | 14 +++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index 1a742da889ac..97b37ff2f185 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -209,6 +209,35 @@ def _test_ddp_parity( self.assertEqual(n1, n2) self.assertEqual(p1, p2) + if offload_params: + # TODO: Gradient computation on CPU and GPU differ slightly causing + # drift unrelated to `clip_grad_norm_()`. + # https://github.com/pytorch/pytorch/issues/89133 + return + + # Run a few more iterations + # TODO: We cannot run too many iterations, or else there is drift: + # https://github.com/pytorch/pytorch/issues/89136 + for i in range(3): + set_to_none = i % 2 == 0 # exercise both + ddp_optim.zero_grad(set_to_none=set_to_none) + fsdp_optim.zero_grad(set_to_none=set_to_none) + inp = ddp_model.module.get_input(device) + for model in (ddp_model, fsdp_model): + out = model(*inp) + out.sum().backward() + ddp_total_norm = torch.nn.utils.clip_grad_norm_( + ddp_model.parameters(), + max_norm=max_norm, + norm_type=norm_type, + ) + fsdp_total_norm = fsdp_model.clip_grad_norm_( + max_norm=max_norm, norm_type=norm_type + ) + self.assertEqual(ddp_total_norm, fsdp_total_norm) + ddp_optim.step() + fsdp_optim.step() + instantiate_parametrized_tests(TestClipGradNorm) diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 3e84315a4e11..d2d4fbf229b6 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -1161,10 +1161,20 @@ def clip_grad_norm_( self._streams["unshard"], self._streams["pre_unshard"], ) + # If every FSDP instance uses `NO_SHARD`, then we can directly use + # the normal `nn.utils` one targeting local gradients + all_no_shard = all( + not handle.uses_sharded_strategy + for handle in FullyShardedDataParallel._fsdp_handles(self) + ) + if all_no_shard: + return torch.nn.utils.clip_grad_norm_( + self.parameters(), max_norm, norm_type + ) + # Otherwise, there exists some FSDP instance using a sharded strategy, + # where sharded and non-sharded parameters must be handled separately max_norm = float(max_norm) norm_type = float(norm_type) - # Perform local gradient norm computation, where sharded and - # non-sharded parameters must be handled separately sharded_params = set() nonsharded_params = set() # `NO_SHARD` or not FSDP-managed for handle in FullyShardedDataParallel._fsdp_handles(self): From 5848704ef8feba9fff3ec4f8ce7d1d3189ec5af8 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 16 Nov 2022 19:00:49 +0000 Subject: [PATCH 255/453] Removed unecessary check in `select_nested` (#89150) Implementation in #88585 should work for all dimensions. Removed unnecessary check that constrained select to dims 0 and 1 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89150 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/nested/NestedTensorMath.cpp | 14 ++++++-------- docs/source/nested.rst | 2 +- test/test_nestedtensor.py | 5 ++++- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 9a47322644ca..5842c3b8b217 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -502,11 +502,6 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { int64_t ntensors = self_ptr->size(0); TORCH_CHECK_INDEX(ntensors > 0, "You can only select when the NT is not empty."); int64_t ndims = static_cast(sizes[0].size()); - TORCH_CHECK( - positive_dim == 0 || positive_dim == 1, - "NestedTensor can only be selected along dimension 0 or 1", - "got dimension ", dim, " instead." - ); if (positive_dim == 0) { TORCH_CHECK_INDEX( index >= -ntensors && index < ntensors, @@ -534,13 +529,16 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { size_ptr[dim_idx] = sizes[i][j]; stride_ptr[dim_idx] = strides[i][j]; ++dim_idx; - } - else { + } else { TORCH_CHECK_INDEX( index >= 0 && index < sizes[i][j], "index ", index, - " is out of bounds for irregular dimension 1 with size ", + " is out of bounds for dimension ", + j, + " of the ", + i, + "th constituent tensor with size ", sizes[i][j]); new_offsets[i] = offsets[i] + index * strides[i][j]; } diff --git a/docs/source/nested.rst b/docs/source/nested.rst index 07712e0376f1..ac07f8acb5a2 100644 --- a/docs/source/nested.rst +++ b/docs/source/nested.rst @@ -201,7 +201,7 @@ NestedTensor and any constraints they have. Supports addition of a scalar to a nested tensor." :func:`torch.mul`; "Supports elementwise multiplication of two nested tensors. Supports multiplication of a nested tensor by a scalar." - :func:`torch.select`; "Supports selecting along ``dim=0`` only (analogously ``nt[i]``)." + :func:`torch.select`; "Supports selecting along all dimensions." :func:`torch.clone`; "Behavior is the same as on regular tensors." :func:`torch.detach`; "Behavior is the same as on regular tensors." :func:`torch.unbind`; "Supports unbinding along ``dim=0`` only." diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index f1f211cdafca..710753886315 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -800,10 +800,13 @@ def test_nested_tensor_indexing(self, device, dtype): self.assertEqual(nt[1, ...], x1) self.assertRaises(IndexError, lambda: nt[1, 4, 2]) self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) - # test select on the irregular dimension only + # test select on non-batch dimensions self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) self.assertRaises(IndexError, lambda: nt.select(1, 3)) + self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) + self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) + self.assertRaises(IndexError, lambda: nt.select(2, 5)) # make sure indexing returns a view nt[0].fill_(100.0) answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) From ec61951f0771e70de12e6e46bd131ace98486238 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Wed, 16 Nov 2022 19:17:08 +0000 Subject: [PATCH 256/453] Fix inaccuracy in nt constructor documentation + broken rendering (#89152) Rendering was broken and docstring seemed to be inaccurate ![Screen Shot 2022-11-16 at 2 16 28 PM](https://user-images.githubusercontent.com/35276741/202273588-a2da5b7b-1a6d-46bb-a74e-c0de9a0fd064.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89152 Approved by: https://github.com/cpuhrsch --- torch/nested/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nested/__init__.py b/torch/nested/__init__.py index 71498187298d..151d44ab66e1 100644 --- a/torch/nested/__init__.py +++ b/torch/nested/__init__.py @@ -125,8 +125,8 @@ def as_nested_tensor( :ref:`Autograd mechanics `) from :attr:`tensor_list` a list of tensors. Args: - tensor_list (List[array_like]): a list of tensors (or anything that can be passed to torch.tensor) - where their first dimension can be of irregular size, but all other dimensions have to be equal. + tensor_list (List[array_like]): a list of tensors, or anything that can be passed to torch.tensor, + where each element of the list has the same dimensionality. Keyword arguments: dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor. From 8ba62bdff5441b65938ad27e944aa91e4f7eb61a Mon Sep 17 00:00:00 2001 From: Fuzzkatt Date: Wed, 16 Nov 2022 22:50:11 +0000 Subject: [PATCH 257/453] add test_c10d_spawn_ucc.py (#86508) Initial PR to create UCC equivalent of https://github.com/pytorch/pytorch/blob/master/test/distributed/test_c10d_spawn_gloo.py and https://github.com/pytorch/pytorch/blob/master/test/distributed/test_c10d_spawn_nccl.py. Currently only added common ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86508 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_spawn_ucc.py | 110 ++++++++++++++++++ test/run_test.py | 1 + torch/testing/_internal/common_distributed.py | 5 + 3 files changed, 116 insertions(+) create mode 100644 test/distributed/test_c10d_spawn_ucc.py diff --git a/test/distributed/test_c10d_spawn_ucc.py b/test/distributed/test_c10d_spawn_ucc.py new file mode 100644 index 000000000000..eabd7e1cf45b --- /dev/null +++ b/test/distributed/test_c10d_spawn_ucc.py @@ -0,0 +1,110 @@ +# Owner(s): ["oncall: distributed"] + +import sys +import test_c10d_spawn +import torch +import torch.distributed as c10d +from test_c10d_spawn import _torch_dist_nn_available, TestDistributedNNFunctions +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_distributed import ( + requires_ucc, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, + sandcastle_skip, + sandcastle_skip_if, + TEST_WITH_DEV_DBG_ASAN, +) + +NO_UCC = not hasattr(c10d, "ProcessGroupUCC") + +# Fails on Python-3.9, see https://github.com/pytorch/pytorch/issues/51619 +if sys.version_info < (3, 9): + + class ProcessGroupShareTensorTest( + test_c10d_spawn.AbstractProcessGroupShareTensorTest, TestCase + ): + @classmethod + def _init_pg_ucc(cls, rank, filename, world_size): + store = c10d.FileStore(filename, world_size) + return c10d.ProcessGroupUCC(store, rank, world_size) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_broadcast_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_broadcast_process, + [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + 1, + ) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_allreduce_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_allreduce_process, + [torch.ones(2, 2).to(i) for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + 1, + ) + + @sandcastle_skip_if(not TEST_MULTIGPU, "At least 2 CUDA GPUS needed") + @sandcastle_skip_if(NO_UCC, "UCC needed") + def test_shared_allgather_ucc(self): + self._test_multiprocess( + ProcessGroupShareTensorTest._test_allgather_process, + [torch.ones(2, 2).to(i) * i for i in range(self.world_size)], + ProcessGroupShareTensorTest._init_pg_ucc, + self.world_size, + ) + + +# Skip dev-asan as torch + multiprocessing spawn have known issues +if not TEST_WITH_DEV_DBG_ASAN: + + class TestDistributedNNFunctionsUcc(TestDistributedNNFunctions): + # Test Common Ops First. + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if( + not _torch_dist_nn_available, "torch.distributed.nn is not available" + ) + def test_broadcast(self): + self._test_broadcast("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_reduce(self): + self._test_reduce("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_allreduce(self): + self._test_allreduce("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + @sandcastle_skip("runs into illegal memory access on first assertEqual check when run locally") + def test_all_gather(self): + self._test_all_gather("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all(self): + self._test_all_to_all("ucc") + + @requires_ucc() + @skip_if_lt_x_gpu(2) + @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available") + def test_all_to_all_single(self): + self._test_all_to_all_single("ucc") + +if __name__ == "__main__": + run_tests() diff --git a/test/run_test.py b/test/run_test.py index 8a25a2e70785..6bf98a01a44d 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -785,6 +785,7 @@ def run_test_ops(test_module, test_directory, options): "distributed/test_c10d_common": get_run_test_with_subprocess_fn(), "distributed/test_c10d_spawn_gloo": get_run_test_with_subprocess_fn(), "distributed/test_c10d_spawn_nccl": get_run_test_with_subprocess_fn(), + "distributed/test_c10d_spawn_ucc": get_run_test_with_subprocess_fn(), "distributed/test_store": get_run_test_with_subprocess_fn(), "distributed/test_pg_wrapper": get_run_test_with_subprocess_fn(), "distributed/rpc/test_faulty_agent": get_run_test_with_subprocess_fn(), diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 883a48a5a5fe..9dcb71ae0907 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -304,6 +304,11 @@ def requires_nccl(): "c10d was not compiled with the NCCL backend", ) +def requires_ucc(): + return sandcastle_skip_if( + not c10d.is_ucc_available(), + "c10d was not compiled with the UCC backend", + ) def requires_mpi(): return sandcastle_skip_if( From f920bfaf2a6bfb4bc7966f8417309d94164ff86f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 16 Nov 2022 18:40:41 +0000 Subject: [PATCH 258/453] Use torchrun for dynamo/distributed.py (#89149) Mainly wanted to confirm torchrun works fine with dynamo/ddp, but it is also a better system than manually launching processes. Partially addresses issue #1779 New run commands ------------ single process: python benchmarks/dynamo/distributed.py [args] multi-gpu (e.g. 2 gpu on one host): torchrun --nproc_per_node 2 benchmarks/dynamo/distributed.py [args] Pull Request resolved: https://github.com/pytorch/pytorch/pull/89149 Approved by: https://github.com/aazzolini --- benchmarks/dynamo/dist_util.py | 11 +++++-- benchmarks/dynamo/distributed.py | 51 ++++++++------------------------ 2 files changed, 21 insertions(+), 41 deletions(-) diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index d0267cbca307..9957ef6139df 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -25,9 +25,14 @@ def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - dist.init_process_group("nccl", rank=rank, world_size=world_size) + # set defaults in case torchrun isn't used; no idea why the if is needed, but it hangs torchrun otherwise + if not os.getenv("MASTER_ADDR"): + os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") + if not os.getenv("MASTER_PORT"): + os.environ["MASTER_PORT"] = os.getenv("MASETER_PORT", "12355") + os.environ["RANK"] = os.getenv("RANK", "0") + os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") + dist.init_process_group("nccl") def cleanup(): diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py index 32e3b544d87d..360fd846dbe8 100644 --- a/benchmarks/dynamo/distributed.py +++ b/benchmarks/dynamo/distributed.py @@ -1,12 +1,10 @@ import argparse +import logging +import os from functools import partial -import numpy as np -import tabulate import torch - import torch._dynamo as dynamo -import torch.multiprocessing as mp import torch.utils._pytree as pytree from torch._dynamo.testing import reduce_to_scalar_loss from torch.nn.parallel import DistributedDataParallel as DDP @@ -32,7 +30,11 @@ def profile_model(args, model, inputs, rank): prof.export_chrome_trace(args.trace_file) -def run_model(args, model, inputs, rank, world_size, key, result_q): +def run_model(args, model, inputs, key): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + # result_q = [] + setup(rank, world_size) if args.device == "cuda": # needed for FSDP @@ -62,8 +64,10 @@ def move_tensor(maybe_tensor): print(model) if args.dynamo: + dynamo.reset() if args.verbose: dynamo.config.verbose = True + dynamo.config.log_level = logging.DEBUG if args.dynamo_optimize_ddp: dynamo.config.optimize_ddp = True @@ -80,40 +84,15 @@ def print_compile(gm, ex): # warmup _ = timed(model, model_iter_fn, inputs, times=3, return_result=False) - times = [] t_total = timed( model, model_iter_fn, inputs, times=args.repeat, return_result=False ) - times.append(t_total / args.repeat) - - if rank == 0: - result_q.put(times) if args.profile: profile_model(args, model, inputs, rank) cleanup() - - -def experiment(fn, key, world_size, results): - key = f"{key}_{world_size}" - dynamo.reset() - ctx = mp.get_context("spawn") - result_q = ctx.SimpleQueue() - f_args = (world_size, key, result_q) - if world_size > 1: - mp.spawn( - fn, - args=f_args, - nprocs=world_size, - join=True, - ) - else: - # rank 0 - fn(0, *f_args) - times = result_q.get() - - results.append((key, np.median(times))) + return t_total if __name__ == "__main__": @@ -129,9 +108,6 @@ def experiment(fn, key, world_size, results): parser.add_argument("--profile", action="store_true", help="Run the profiler") parser.add_argument("--trace_file", default="profile.json", help="Run the profiler") parser.add_argument("--repeat", default=10, help="Repeats for timing run") - parser.add_argument( - "--world_size", type=int, default=2, help="Number of ranks/gpus for experiments" - ) parser.add_argument( "--dynamo_optimize_ddp", action="store_true", @@ -168,7 +144,6 @@ def experiment(fn, key, world_size, results): fn = partial(run_model, args, model, inputs) - times = [] - experiment(fn, model_name, args.world_size, times) - print("\nExperiment Results:") - print(tabulate.tabulate(times, headers=("key", "time"))) + world_size = os.getenv("WORLD_SIZE", 1) + t_total = fn(f"{model_name}_{world_size}") + print(f"mean latency {t_total / args.repeat} across {args.repeat} runs") From 98379a3949ed4b4f4a76bd9fed2806f82b6c0aa0 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 16 Nov 2022 19:50:02 +0000 Subject: [PATCH 259/453] [ONNX] Add onnx-script test cases (#86907) The test cases for #86906 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86907 Approved by: https://github.com/BowenBao --- .jenkins/caffe2/test.sh | 2 + test/onnx/test_onnxscript_no_runtime.py | 164 ++++++++++++++++++++++++ test/onnx/test_onnxscript_runtime.py | 132 +++++++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 test/onnx/test_onnxscript_no_runtime.py create mode 100644 test/onnx/test_onnxscript_runtime.py diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh index 42111ea22bdd..d245dabda4da 100755 --- a/.jenkins/caffe2/test.sh +++ b/.jenkins/caffe2/test.sh @@ -177,6 +177,8 @@ fi if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then pip install -q --user --no-use-pep517 "git+https://github.com/pytorch/vision.git@$(cat .github/ci_commit_pins/vision.txt)" pip install -q --user ninja flatbuffers==2.0 numpy==1.21.5 onnxruntime==1.12.1 beartype==0.10.4 onnx==1.12.0 + # TODO: change this when onnx-script is on testPypi + pip install 'onnx-script @ git+https://github.com/microsoft/onnx-script' # numba requires numpy <= 1.20, onnxruntime requires numpy >= 1.21. # We don't actually need it for our tests, but it's imported if it's present, so uninstall. pip uninstall -q --yes numba diff --git a/test/onnx/test_onnxscript_no_runtime.py b/test/onnx/test_onnxscript_no_runtime.py new file mode 100644 index 000000000000..125e899af944 --- /dev/null +++ b/test/onnx/test_onnxscript_no_runtime.py @@ -0,0 +1,164 @@ +# Owner(s): ["module: onnx"] + +"""Test the support on onnxscript in PyTorch-ONNX converter.""" +import io +from typing import List + +import onnx +import onnxscript +import torch +from onnxscript.onnx_types import FLOAT +from torch.onnx._internal import jit_utils +from torch.testing._internal import common_utils + + +class TestONNXScriptExport(common_utils.TestCase): + + # opset version is + # 1. local function is supported after opset 15 + # 2. onnx-script requires users to determine opset in local function + opset_version = 15 + + def test_onnxscript_registration_with_multiple_models(self): + + from onnxscript.onnx_opset import opset15 as op + + # 1. Register Selu onnxscript function as custom Op + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + # TODO: onnx/ort doesn't support default values for now + # move this when they do + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=self.opset_version, + ) + + # 2. Register layer_norm onnxscript function as custom Op + @onnxscript.script(custom_opset) + def layer_norm( + X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float + ): + mean = op.ReduceMean(X, axes=axes) + D = X - mean # op.Sub(X, mean) + DD = D * D # op.Mul(D, D) + var = op.ReduceMean(DD, axes=axes) + vareps = var + eps # op.Add(var, eps) + stddev = op.Sqrt(vareps) + invstddev = op.Reciprocal(stddev) + normalized = D * invstddev # op.Mul(D, invstddev) + normalizedw = op.CastLike( + normalized, weight + ) # Type issue if missing this Op + normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) + return normalizedscaled + bias + + @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") + def custom_layer_norm( + g, input, normalized_shape, weight, bias, eps, cudnn_enable + ): + # TODO: move the comprehension into local function once + # it's supported by onnxscript + axes = [-i for i in range(len(normalized_shape), 0, -1)] + return g.onnxscript_op( + layer_norm, input, weight, bias, axes_i=axes, eps_f=eps + ).setType(input.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::layer_norm", + symbolic_fn=custom_layer_norm, + opset_version=self.opset_version, + ) + + # 3. export two models + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model_selu = torch.nn.SELU() + selu_onnx = io.BytesIO() + torch.onnx.export(model_selu, x, selu_onnx, opset_version=self.opset_version) + + N, C = 3, 4 + y = torch.randn(N, C) + model_layer_norm = torch.nn.LayerNorm(C) + layer_norm_onnx = io.BytesIO() + torch.onnx.export( + model_layer_norm, y, layer_norm_onnx, opset_version=self.opset_version + ) + + # 4. test on models + selu_proto = onnx.load(io.BytesIO(selu_onnx.getvalue())) + layer_norm_proto = onnx.load(io.BytesIO(layer_norm_onnx.getvalue())) + + self.assertEqual(len(selu_proto.functions), 1) + self.assertEqual(len(layer_norm_proto.functions), 1) + self.assertEqual(selu_proto.functions[0].name, "Selu") + self.assertEqual(layer_norm_proto.functions[0].name, "layer_norm") + + def test_loop_registration(self): + # Control flow is tested for _find_onnxscript_op function in torch/onnx/utils.py, + # which has recursive logic to go through every nodes with subgraph in model proto + class NestedLoopsModel(torch.jit.ScriptModule): + def __init__(self): + super().__init__() + self.selu = torch.nn.SELU() + + @torch.jit.script_method + def forward(self, x): + y = x + for i in range(x.size(3)): + if i == 0: + y = self.selu(x) + else: + y += i + return y + + model = NestedLoopsModel() + inputs = torch.zeros(1, 2, 3, 4) + + from onnxscript.onnx_opset import opset15 as op + + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=2) + + @onnxscript.script(custom_opset) + def Selu(X): + alpha = 1.6732632423543772848170429916717 + gamma = 1.0507009873554804934193349852946 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g, X): + # domain of the Op should be aligned with onnx-script + # setType API is required for custom Op to support + # torchscript shape type inference + print("custom_selu is used!") + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=15, + ) + + saved_model = io.BytesIO() + torch.onnx.export( + torch.jit.script(model), inputs, f=saved_model, opset_version=15 + ) + loop_selu_proto = onnx.load(io.BytesIO(saved_model.getvalue())) + self.assertEqual(len(loop_selu_proto.functions), 1) diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py new file mode 100644 index 000000000000..2d0d1e3a5357 --- /dev/null +++ b/test/onnx/test_onnxscript_runtime.py @@ -0,0 +1,132 @@ +# Owner(s): ["module: onnx"] + +"""Test the support on onnxscript in PyTorch-ONNX converter with onnxruntime.""" +from typing import List + +import onnx_test_common +import onnxscript +import torch +from onnxscript.onnx_types import FLOAT +from torch.onnx._internal import jit_utils +from torch.testing._internal import common_utils + + +class TestONNXScriptRuntime(onnx_test_common._TestONNXRuntime): + + # opset version is + # 1. local function is supported after opset 15 + # 2. onnx-script requires users to determine opset in local function + opset_version = 15 + + def test_selu_from_onnxscript_example(self): + + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model = torch.nn.SELU() + + from onnxscript.onnx_opset import opset15 as op + + # custom domain is needed for custom Op domain name should be + # aligned to the one in symbolic_fn + # TODO(titaiwang): make an official domain for onnxscript usage + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + # TODO: onnx/ort doesn't support default values for now + # move this when they do + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=self.opset_version, + ) + self.run_test(model, x) + + def test_layer_norm(self): + + x = torch.randn(2, 3) + y = torch.randn(2, 3) + z = torch.randn(2, 3) + + class N(torch.nn.Module): + def __init__(self, prob): + super().__init__() + self.dropout = torch.nn.Dropout(prob) + + def forward(self, x): + return self.dropout(x) + + class M(torch.nn.Module): + def __init__(self, num_layers): + super().__init__() + self.num_layers = num_layers + self.lns = torch.nn.ModuleList( + [torch.nn.LayerNorm(3, eps=i) for i in range(num_layers)] + ) + self.celu1 = torch.nn.CELU(1.0) + self.celu2 = torch.nn.CELU(2.0) + self.dropout = N(0.5) + + def forward(self, x, y, z): + res1 = self.celu1(x) + res2 = self.celu2(y) + for ln in self.lns: + z = ln(z) + return res1 + res2, self.dropout(z) + + model = M(3) + + from onnxscript.onnx_opset import opset15 as op + + custom_opset = onnxscript.values.Opset(domain="onnxscript", version=1) + + @onnxscript.script(custom_opset) + def layer_norm( + X, axes: List[int], weight: FLOAT[...], bias: FLOAT[...], eps: float + ): + mean = op.ReduceMean(X, axes=axes) + D = X - mean # op.Sub(X, mean) + DD = D * D # op.Mul(D, D) + var = op.ReduceMean(DD, axes=axes) + vareps = var + eps # op.Add(var, eps) + stddev = op.Sqrt(vareps) + invstddev = op.Reciprocal(stddev) + normalized = D * invstddev # op.Mul(D, invstddev) + normalizedw = op.CastLike( + normalized, weight + ) # Type issue if missing this Op + normalizedscaled = normalizedw * weight # op.Mul(normalized, weight) + return normalizedscaled + bias + + @torch.onnx.symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") + def custom_layer_norm( + g, input, normalized_shape, weight, bias, eps, cudnn_enable + ): + # TODO: move the comprehension into local function once it's supported by onnxscript + axes = [-i for i in range(len(normalized_shape), 0, -1)] + return g.onnxscript_op( + layer_norm, input, weight, bias, axes_i=axes, eps_f=eps + ).setType(input.type()) + + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::layer_norm", + symbolic_fn=custom_layer_norm, + opset_version=self.opset_version, + ) + + self.run_test(model, (x, y, z)) + + +if __name__ == "__main__": + common_utils.run_tests() From 0c835e25bbde7869101023ebfaab9b7ec01ece25 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 17 Nov 2022 00:30:12 +0000 Subject: [PATCH 260/453] Fix nightly build binary errors (#89153) This is pretty much self explanatory issues Two typo's in generate generate binary script caused workflows to be generated with invalid parameters: 1 .generated-linux-binary-libtorch-pre-cxx11-master.yml 2 .generated-macos-arm64-binary-wheel-nightly.yml Pull Request resolved: https://github.com/pytorch/pytorch/pull/89153 Approved by: https://github.com/malfet --- .github/scripts/generate_ci_workflows.py | 4 +- ...linux-binary-libtorch-pre-cxx11-master.yml | 18 +-- ...rated-macos-arm64-binary-wheel-nightly.yml | 110 ------------------ 3 files changed, 11 insertions(+), 121 deletions(-) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 1ef3142286bf..35680e30ee6a 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -154,7 +154,7 @@ class OperatingSystem: package_type="libtorch", abi_version=generate_binary_build_matrix.PRE_CXX11_ABI, build_configs=generate_binary_build_matrix.generate_libtorch_matrix( - OperatingSystem.LINUX, generate_binary_build_matrix.CXX11_ABI, + OperatingSystem.LINUX, generate_binary_build_matrix.PRE_CXX11_ABI, arches=["cpu"], libtorch_variants=["shared-with-deps"], ), @@ -277,7 +277,7 @@ class OperatingSystem: BinaryBuildWorkflow( os=OperatingSystem.MACOS_ARM64, package_type="wheel", - build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS), + build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS_ARM64), cross_compile_arm64=True, ciflow_config=CIFlowConfig( labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml index edacb2e949b0..39e41e67853a 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-master.yml @@ -31,7 +31,7 @@ concurrency: cancel-in-progress: true jobs: - libtorch-cpu-shared-with-deps-cxx11-abi-build: + libtorch-cpu-shared-with-deps-pre-cxx11-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml with: @@ -42,17 +42,17 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu LIBTORCH_VARIANT: shared-with-deps - DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-cpu-shared-with-deps-cxx11-abi + DESIRED_DEVTOOLSET: pre-cxx11 + build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - libtorch-cpu-shared-with-deps-cxx11-abi-test: # Testing + libtorch-cpu-shared-with-deps-pre-cxx11-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} - needs: libtorch-cpu-shared-with-deps-cxx11-abi-build + needs: libtorch-cpu-shared-with-deps-pre-cxx11-build uses: ./.github/workflows/_binary-test-linux.yml with: PYTORCH_ROOT: /pytorch @@ -62,10 +62,10 @@ jobs: # favor of GPU_ARCH_VERSION DESIRED_CUDA: cpu GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/libtorch-cxx11-builder:cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu LIBTORCH_VARIANT: shared-with-deps - DESIRED_DEVTOOLSET: cxx11-abi - build_name: libtorch-cpu-shared-with-deps-cxx11-abi + DESIRED_DEVTOOLSET: pre-cxx11 + build_name: libtorch-cpu-shared-with-deps-pre-cxx11 build_environment: linux-binary-libtorch-pre-cxx11 runs_on: linux.4xlarge secrets: diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 6bc3894a00be..7a7df02efe89 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -34,116 +34,6 @@ concurrency: cancel-in-progress: true jobs: - wheel-py3_7-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-12-xl - timeout-minutes: 240 - env: - PYTORCH_ROOT: ${{ github.workspace }}/pytorch - BUILDER_ROOT: ${{ github.workspace }}/builder - PACKAGE_TYPE: wheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - SKIP_ALL_TESTS: 1 - DESIRED_PYTHON: "3.7" - # For sccache access (only on non-forked PRs) - AWS_ACCESS_KEY_ID: ${{ secrets.MACOS_SCCACHE_S3_ACCESS_KEY_ID }} - AWS_SECRET_ACCESS_KEY: ${{ secrets.MACOS_SCCACHE_S3_SECRET_ACCESS_KEY }} - steps: - # NOTE: These environment variables are put here so that they can be applied on every job equally - # They are also here because setting them at a workflow level doesn't give us access to the - # runner.temp variable, which we need. - - name: Populate binary env - shell: bash - run: | - # shellcheck disable=SC2129 - echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" - # shellcheck disable=SC2129 - echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" - # shellcheck disable=SC2129 - echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}" - - name: Install conda and dependencies - run: | - # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh - chmod +x "${RUNNER_TEMP}/conda.sh" - /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" - echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" - echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" - - name: Checkout PyTorch - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - submodules: recursive - path: pytorch - - name: Clean PyTorch checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: pytorch - - name: Checkout pytorch/builder - uses: zhouzhuojie/checkout@05b13c9a0d21f08f6d5e64a1d5042246d13619d9 - with: - ref: main - submodules: recursive - repository: pytorch/builder - path: builder - - name: Clean pytorch/builder checkout - run: | - # Remove any artifacts from the previous checkouts - git clean -fxd - working-directory: builder - - name: Install sccache (only for non-forked PRs, and pushes to trunk) - uses: nick-fields/retry@v2.8.2 - if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} - with: - timeout_minutes: 5 - max_attempts: 3 - retry_wait_seconds: 90 - command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache - sudo chmod +x /usr/local/bin/sccache - echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - - name: Populate binary env - run: | - # shellcheck disable=SC1091 - source "${RUNNER_TEMP}/anaconda/bin/activate" - "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" - - name: Build PyTorch binary - run: | - # shellcheck disable=SC1091 - source "${RUNNER_TEMP}/anaconda/bin/activate" - "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" - - uses: actions/upload-artifact@v3 - if: always() - with: - name: wheel-py3_7-cpu - retention-days: 14 - if-no-files-found: error - path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - wheel-py3_7-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - needs: wheel-py3_7-cpu-build - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: wheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu - DESIRED_PYTHON: "3.7" - build_name: wheel-py3_7-cpu - use_s3: False - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - aws-access-key-id: ${{ secrets.AWS_PYTORCH_UPLOADER_ACCESS_KEY_ID }} - aws-pytorch-uploader-secret-access-key: ${{ secrets.AWS_PYTORCH_UPLOADER_SECRET_ACCESS_KEY }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - uses: ./.github/workflows/_binary-upload.yml wheel-py3_8-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} runs-on: macos-12-xl From 8506b305df531f7567a430854cbe7fcfa539416a Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Thu, 17 Nov 2022 00:38:44 +0000 Subject: [PATCH 261/453] handle scatter(Scalar) overload in inductor (#88894) Relanding https://github.com/pytorch/pytorch/pull/88210 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88894 Approved by: https://github.com/desertfire --- torch/_inductor/lowering.py | 44 +++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 9924396075f6..75d4e471e5bb 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2029,14 +2029,37 @@ def scatter(x, dim: int, index, src, **kwargs): return scatter_(clone(x), dim, index, src, **kwargs) +def scatter_fallback( + fn, self, dim: int, index, src, *, reduce: str = None, include_self: bool = True +): + + if reduce not in {None, "sum"} or ( + reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64} + ): + self.realize() + return fallback_handler(fn)( + self, dim, index, src, reduce=reduce, include_self=include_self + ) + + return None + + @register_lowering(aten.scatter_, type_promotion_kind=None) def scatter_(self, dim: int, index, src, *, reduce: str = None): + if reduce == "add": reduce = "sum" elif reduce == "multiply": reduce = "prod" else: assert reduce is None + + fallback_result = scatter_fallback( + aten.scatter_, self, dim, index, src, reduce=reduce + ) + + if fallback_result: + return fallback_result return scatter_reduce_(self, dim, index, src, reduce) @@ -2062,15 +2085,18 @@ def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs): def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True): assert reduce in {None, "sum", "prod", "mean", "amax", "amin"} - # TODO: Need to support more reduction type - # For reduction of "sum", tl.atomic_add doesn't support bool or int64 - if reduce not in {None, "sum"} or ( - reduce == "sum" and self.get_dtype() in {torch.bool, torch.int64} - ): - self.realize() - return fallback_scatter_reduce_( - self, dim, index, src, reduce, include_self=include_self - ) + fallback_result = scatter_fallback( + aten.scatter_reduce_, + self, + dim, + index, + src, + reduce=reduce, + include_self=include_self, + ) + + if fallback_result: + return fallback_result assert isinstance(self, TensorBox) assert "int" in str(index.get_dtype()) From cfd552547f106f4a7841976dad8b795b82d161c8 Mon Sep 17 00:00:00 2001 From: Charlie West-Taylor Date: Thu, 17 Nov 2022 00:59:12 +0000 Subject: [PATCH 262/453] Use the Python frame safely in _pythonCallstack (#88993) Currently, the result of `PyEval_GetFrame()` is piped straight to `Py_INCREF`. However, `PyEval_GetFrame` [may return null](https://docs.python.org/3/c-api/reflection.html#c.PyEval_GetFrame), which seems to be the case sometimes, when calling `_pythonCallstack` from another thread. This is handled in the subsequent `while (nullptr != frame)` block, but `Py_INCREF`, called before it, [doesn't handle this case](https://docs.python.org/3/c-api/refcounting.html#c.Py_INCREF), so the program segfaults. The safe form of `Py_INCREF` is `Py_XINCREF`, so use that instead ([docs](https://docs.python.org/3/c-api/refcounting.html#c.Py_XINCREF)). Pull Request resolved: https://github.com/pytorch/pytorch/pull/88993 Approved by: https://github.com/albanD --- torch/csrc/jit/python/python_tracer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 83570c85e9b4..c89d54872a07 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -27,7 +27,7 @@ namespace tracer { std::vector _pythonCallstack() { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); - Py_INCREF(frame); + Py_XINCREF(frame); std::vector entries; while (nullptr != frame) { From 3af5cf4de16e4e9256be6439a3539e3e52e3a879 Mon Sep 17 00:00:00 2001 From: R Max Espinoza Date: Thu, 17 Nov 2022 01:03:31 +0000 Subject: [PATCH 263/453] doc(typo): memroy -> memory (#89126) Minor typo in comments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89126 Approved by: https://github.com/kit1980 --- torch/csrc/jit/codegen/cuda/executor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index 25e87c91cd25..23be5f4232aa 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -238,7 +238,7 @@ void FusionExecutor::compileFusion( #ifndef USE_ROCM device_smem_limit_ = properties->sharedMemPerBlockOptin; #else - // don't know if rocm supports opt-in shared memroy reconfiguration + // don't know if rocm supports opt-in shared memory reconfiguration device_smem_limit_ = properties->sharedMemPerBlock; #endif warp_size_ = properties->warpSize; From 80b6761863407a8cf1ca780fcf97d135743f7812 Mon Sep 17 00:00:00 2001 From: John Detloff Date: Thu, 17 Nov 2022 01:06:12 +0000 Subject: [PATCH 264/453] Update README.md (#85534) Our jenkins builds are gone, so this badge is broken and should be removed Pull Request resolved: https://github.com/pytorch/pytorch/pull/85534 Approved by: https://github.com/ngimel, https://github.com/kit1980 --- caffe2/README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/caffe2/README.md b/caffe2/README.md index 0b69eec8191b..13171fca23bb 100644 --- a/caffe2/README.md +++ b/caffe2/README.md @@ -1,7 +1,5 @@ # Caffe2 -[![Jenkins Build Status](https://ci.pytorch.org/jenkins/job/caffe2-master/lastCompletedBuild/badge/icon)](https://ci.pytorch.org/jenkins/job/caffe2-master) - Caffe2 is a lightweight, modular, and scalable deep learning framework. Building on the original [Caffe](http://caffe.berkeleyvision.org), Caffe2 is designed with expression, speed, and modularity in mind. ## Questions and Feedback From 0d87a4fec89fc78e568224935897ec585a6368a6 Mon Sep 17 00:00:00 2001 From: keineahnung2345 Date: Thu, 17 Nov 2022 01:09:55 +0000 Subject: [PATCH 265/453] Fix typo in Dispatcher.h (#89045) Fix typo in Dispatcher.h: hamespace -> namespace Pull Request resolved: https://github.com/pytorch/pytorch/pull/89045 Approved by: https://github.com/bdhirsh, https://github.com/kit1980 --- aten/src/ATen/core/dispatch/Dispatcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 6e1c7d754d72..5af8ef1e52de 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -735,4 +735,4 @@ struct hash { } }; -} // hamespace std +} // namespace std From 251fdda77b8f60667e016c89f65f798ea5f3eaea Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 17 Nov 2022 01:45:48 +0000 Subject: [PATCH 266/453] Add pytest-flakefinder as a test dependency (#89103) This is used to re-run tests multiple times to determine their flakiness status. The way re-run is handled in https://github.com/pytorch/pytorch/pull/88646 only applies to unittest Per their documentation, `pytest-repeat` doesn't work with `unittest.Testcase` it seems, so trying https://github.com/dropbox/pytest-flakefinder instead Pull Request resolved: https://github.com/pytorch/pytorch/pull/89103 Approved by: https://github.com/clee2000 --- .circleci/docker/requirements-ci.txt | 7 ++++++- .github/requirements/pip-requirements-macOS.txt | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.circleci/docker/requirements-ci.txt b/.circleci/docker/requirements-ci.txt index 018a7f6544fd..e527d29d4989 100644 --- a/.circleci/docker/requirements-ci.txt +++ b/.circleci/docker/requirements-ci.txt @@ -159,8 +159,13 @@ pytest-shard #Pinned versions: #test that import: +pytest-flakefinder==1.1.0 +#Description: plugin for rerunning tests a fixed number of times in pytest +#Pinned versions: 1.1.0 +#test that import: + pytest-rerunfailures -#Description: plugin for rerunning tests in pytest +#Description: plugin for rerunning failure tests in pytest #Pinned versions: #test that import: diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 7aa2306b1309..dfbaea260116 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -14,6 +14,7 @@ pygments==2.12.0 pytest==7.2.0 pytest-xdist==3.0.2 pytest-rerunfailures==10.2 +pytest-flakefinder==1.1.0 pytest-shard==0.1.2 scipy==1.9.0 sympy==1.11.1 From 716f70f19a4b63268da2a753afdbe9b385a831ab Mon Sep 17 00:00:00 2001 From: Horace He Date: Wed, 16 Nov 2022 19:58:30 +0000 Subject: [PATCH 267/453] Added conv constraint that infers layouts (#89031) The core problem that we often have with contiguous/channels-last layouts and convolutions is that Inductor often doesn't do a great job of "preserving" the eager-mode layouts. So, for example, we'll often have something like ``` a: channels-last b = foo(a) c = convolution(a) ``` In eager-mode, `a` would stay channels-last, and we would avoid two transpose copies (one into NHWC and one back into NCHW) within the convolution kernel. However, Inductor currently sometimes loses the "correct" layout of `b` (not in this simple example, but others). Then, not only will we do a transpose within `foo`, but we'll then immediately transpose it back to do the convolution (and then again once the convolution is done). This is particularly egregious in `convnext_base`, where there's a lot of mixing of non-channels last tensors and channels-last tensors. The solution in this PR is to constrain the inputs to `aten.convolution`/`aten.convolution_backward` to match the layouts from eager-mode. This ensures that we'll never do extra transposes *within* `aten.convolution`, which are particularly bad (since Inductor can't fuse them). Pull Request resolved: https://github.com/pytorch/pytorch/pull/89031 Approved by: https://github.com/ngimel, https://github.com/jansel --- test/inductor/test_torchinductor.py | 4 +- torch/_inductor/graph.py | 29 ++++++- torch/_inductor/ir.py | 3 + torch/_inductor/lowering.py | 118 +++++++++----------------- torch/fx/experimental/proxy_tensor.py | 5 ++ 5 files changed, 78 insertions(+), 81 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 1265ca3e7872..651ef9ec016f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -65,7 +65,6 @@ from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten - requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow @@ -5088,6 +5087,8 @@ def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: return kernels def test_divisibile_by_16_covers_numel_args(self): + torch._dynamo.reset() + def fn(a: torch.Tensor) -> torch.Tensor: return torch.sum(a) @@ -5107,6 +5108,7 @@ def fn(a: torch.Tensor) -> torch.Tensor: kernels[1].meta["configs"][0].divisible_by_16 ) self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) + torch._dynamo.reset() if __name__ == "__main__": diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index e0e41fd8afa5..5114ffa76111 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -20,7 +20,12 @@ MissingOperatorWithoutDecomp, ) from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox -from .lowering import lowerings, make_fallback, needs_realized_inputs +from .lowering import ( + layout_constraints, + lowerings, + make_fallback, + needs_realized_inputs, +) from .sizevars import SizeVarAllocator from .utils import dynamo_utils, gather_origins from .virtualized import V @@ -301,7 +306,12 @@ def finalize(self): def run_node(self, n: torch.fx.Node): with ir.IRNode.current_origins({n}): - result = super().run_node(n) + if n.op == "call_function" and n.target in layout_constraints: + args, kwargs = self.fetch_args_kwargs_from_env(n) + args, kwargs = layout_constraints[n.target](n, *args, **kwargs) + result = self.call_function(n.target, args, kwargs) + else: + result = super().run_node(n) # Realize if (1) any user need inputs realized, or (2) there is # already too many reads and rematerializing can be bad. @@ -310,7 +320,20 @@ def run_node(self, n: torch.fx.Node): for user in n.users: if user.target in needs_realized_inputs: result.realize_hint() - elif user.op == "output": + # This inclusion is somewhat controversial (from + # discussion between Horace, Natalia, and Elias). + # Currently, it's not very clear why this is helpful. + # The general idea here is that even though a node may + # have FlexibleLayout, we still often *treat* it as if + # it was contiguous. This appears to sometime result in + # suboptimal behavior. + # + # When we do a better job selecting layout, we should + # revisit this. + result = ir.ExternKernel.require_stride_order( + result, ir.get_stride_order(n.meta["val"].stride()) + ) + if user.op == "output": if isinstance(result.data.data, (Pointwise, Reduction)): result.realize() diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 8327fe0d7b52..d54724671768 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2478,6 +2478,9 @@ def require_stride1(cls, x): @classmethod def require_stride_order(cls, x, order): + if x.get_numel() == 0: # Layout doesn't matter + return x + # require x to have the layout as strided_ordered as order if is_storage_and_layout(x): if isinstance( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 75d4e471e5bb..5168f37cd392 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -23,7 +23,6 @@ from .decomposition import decompositions, get_decompositions from .ir import ( ExpandView, - get_stride_order, IndexingConstant, IndexingDiv, PermuteView, @@ -38,6 +37,7 @@ log = logging.getLogger(__name__) lowerings = {} +layout_constraints = {} fallbacks = set() aten = torch.ops.aten prims = torch.ops.prims @@ -53,6 +53,14 @@ def add_needs_realized_inputs(fn): needs_realized_inputs.add(getattr(fn, overload)) +def add_layout_constraint(fn, constraint): + if isinstance(fn, torch._ops.OpOverloadPacket): + for overload in fn.overloads(): + layout_constraints[getattr(fn, overload)] = constraint + else: + layout_constraints[fn] = constraint + + add_needs_realized_inputs( [ aten.as_strided, @@ -1013,12 +1021,10 @@ def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): register_onednn_fusion_ops() -def fallback_handler(kernel, inps_hook=None): +def fallback_handler(kernel): fallbacks.add(kernel) def handler(*args, **kwargs): - if inps_hook is not None: - args, kwargs = inps_hook(*args, **kwargs) return pytree.tree_map( TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs) ) @@ -1026,7 +1032,7 @@ def handler(*args, **kwargs): return handler -def make_fallback(kernel, inps_hook=None): +def make_fallback(kernel, layout_constraint=None): assert ( kernel not in decompositions ), f"both a fallback and a decomp for same kernel: {kernel}" @@ -1036,9 +1042,9 @@ def make_fallback(kernel, inps_hook=None): ) add_needs_realized_inputs(kernel) - return register_lowering(kernel, type_promotion_kind=None)( - fallback_handler(kernel, inps_hook) - ) + if layout_constraint is not None: + add_layout_constraint(kernel, layout_constraint) + return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel)) @register_lowering(aten.native_dropout, type_promotion_kind=None) @@ -1189,72 +1195,14 @@ def inner_fn(index): ) -def conv_backward(*args, **kwargs): - # output striding complex and has a lot of build dependent options, - # take the output strides to determine what to set the inputs - with torch._subclasses.FakeTensorMode(): - args_fake, kwargs_fake = pytree.tree_map_only( - ir.IRNode, - lambda t: ir.ir_node_to_tensor(t, guard_shape=False), - (args, kwargs), - ) - output = aten.convolution_backward(*args_fake, **kwargs_fake) - - def constraints( - grad_output, - input, - weight, - bias_sizes, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_mask, - ): - out = ( - output[0] - if output[0] is not None - else output[1] - if output[1] is not None - else output[2] - ) - if out is not None: - stride_order = get_stride_order(out.stride()) - grad_output = ir.ExternKernel.require_stride_order( - grad_output, stride_order - ) - weight = ir.ExternKernel.require_stride_order(weight, stride_order) - # Only make input contiguous when it is necessary for the backwards computation - if output_mask[1]: - input = ir.ExternKernel.require_stride_order(input, stride_order) - - return ( - grad_output, - input, - weight, - bias_sizes, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_mask, - ), {} - - return constraints(*args, **kwargs) - - -def require_dense(*args, **kwargs): +def require_dense(_, *args, **kwargs): args, kwargs = pytree.tree_map_only( ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs) ) return args, kwargs -def require_contiguous(*args, **kwargs): +def require_contiguous(_, *args, **kwargs): args, kwargs = pytree.tree_map_only( ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs) ) @@ -1264,26 +1212,42 @@ def require_contiguous(*args, **kwargs): if has_torchvision_roi_align(): make_fallback(torch.ops.torchvision.roi_align) + +def constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, ir.IRNode): + stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) + return ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + + args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)] + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + # TODO(jansel): we should implement decomps or lowerings for these # https://github.com/pytorch/torchdynamo/issues/327 make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) -make_fallback(aten.convolution_backward, inps_hook=conv_backward) +make_fallback(aten.convolution_backward, constrain_to_fx_strides) make_fallback(aten._cudnn_rnn, require_dense) -make_fallback(aten._cudnn_rnn_backward, inps_hook=require_contiguous) -make_fallback(aten.cumsum, inps_hook=require_dense) -make_fallback(aten._embedding_bag, inps_hook=require_contiguous) -make_fallback(aten._embedding_bag_forward_only, inps_hook=require_contiguous) +make_fallback(aten._cudnn_rnn_backward, require_contiguous) +make_fallback(aten.cumsum, require_dense) +make_fallback(aten._embedding_bag, require_contiguous) +make_fallback(aten._embedding_bag_forward_only, require_contiguous) make_fallback(aten._fused_moving_avg_obs_fq_helper) make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) -make_fallback(aten.grid_sampler_2d_backward, inps_hook=require_dense) +make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) -make_fallback(aten._thnn_fused_lstm_cell, inps_hook=require_dense) +make_fallback(aten._thnn_fused_lstm_cell, require_dense) make_fallback(aten.topk) -make_fallback(aten.upsample_bicubic2d_backward, inps_hook=require_contiguous) -make_fallback(aten.upsample_bilinear2d_backward, inps_hook=require_dense) +make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) +make_fallback(aten.upsample_bilinear2d_backward, require_dense) + + +add_layout_constraint(aten.convolution, constrain_to_fx_strides) @register_lowering(aten.convolution) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index c83560754890..8a51294c5a8f 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -118,6 +118,11 @@ def set_meta(proxy, val): elif isinstance(val, torch.Tensor): if not val.is_sparse: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype) return proxy def thunkify(f, *args, **kwargs): From 088f2fa567fcf74aa746886e3e90fd3e6c58fa61 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 Nov 2022 01:55:03 +0000 Subject: [PATCH 268/453] Fix typos in messages under test (#89121) This PR fixes typos of messages in `.cpp` and `.py` files under test directory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89121 Approved by: https://github.com/mruberry, https://github.com/kit1980 --- test/cpp/c10d/ProcessGroupNCCLTest.cpp | 2 +- test/cpp/jit/test_custom_class_registrations.cpp | 2 +- test/inductor/test_torchinductor.py | 2 +- test/jit/test_hooks.py | 2 +- test/lazy/test_extract_compiled_graph.py | 2 +- test/mobile/test_lite_script_module.py | 4 ++-- test/onnx/test_pytorch_onnx_onnxruntime.py | 2 +- test/quantization/fx/test_quantize_fx.py | 2 +- test/scripts/run_cuda_memcheck.py | 2 +- test/test_sparse.py | 2 +- test/test_type_promotion.py | 2 +- 11 files changed, 12 insertions(+), 12 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index 0d566344f2ce..083c4770e0ae 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -355,7 +355,7 @@ void testAllreduce(const std::string& path, int rank, int size) { const auto* const data = tensor.data_ptr(); for (const auto k : c10::irange(tensor.numel())) { EXPECT_EQ(data[k], expected) - << "Allreduce ouputs do not match expected outputs"; + << "Allreduce outputs do not match expected outputs"; } } } diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index 63c6b7013306..16e690d99d8a 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -222,7 +222,7 @@ struct ElementwiseInterpreter : torch::CustomClassHolder { } if (!output_name_) { - throw std::runtime_error("Output name not specififed!"); + throw std::runtime_error("Output name not specified!"); } return environment.at(*output_name_); diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 651ef9ec016f..fb7ca1fc92b7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3298,7 +3298,7 @@ def fn(in_ptr0, in_ptr1, in_ptr2): ), ) - @unittest.skipIf(not has_torchvision_roi_align(), "requirs torchvision") + @unittest.skipIf(not has_torchvision_roi_align(), "requires torchvision") def test_roi_align(self): def fn(a, b): return torch.ops.torchvision.roi_align(a, b, 0.25, 7, 7, 2, False) diff --git a/test/jit/test_hooks.py b/test/jit/test_hooks.py index 109a5e3f1b71..2963837a638a 100644 --- a/test/jit/test_hooks.py +++ b/test/jit/test_hooks.py @@ -229,7 +229,7 @@ def pre_hook(self, input: Tuple[str]) -> Tuple[str]: with self.assertRaisesRegex( RuntimeError, - "This error occured while scripting the forward pre-hook 'pre_hook'", + "This error occurred while scripting the forward pre-hook 'pre_hook'", ): torch.jit.script(m) diff --git a/test/lazy/test_extract_compiled_graph.py b/test/lazy/test_extract_compiled_graph.py index f4152d0af68b..b27a11bf49b6 100644 --- a/test/lazy/test_extract_compiled_graph.py +++ b/test/lazy/test_extract_compiled_graph.py @@ -141,7 +141,7 @@ def verify_reusing_compiled_graph(mod, exception_msg_pattern, ncase=10): raise e # reraise the exception exception_message = str(e) if not re.search(exception_msg_pattern, exception_message): - raise RuntimeError(f"Expection message does not match the required pattern: {exception_message}") + raise RuntimeError(f"Exception message does not match the required pattern: {exception_message}") else: # We are done for the test case that expects an exception return diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 638ac37eb88b..9089977b77f1 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -241,7 +241,7 @@ def forward(self): script_module = torch.jit.script(MyTestModuleForListWithModuleClass()) with self.assertRaisesRegex(RuntimeError, - r"^Returining a list or dictionary with pytorch class type " + r"^Returning a list or dictionary with pytorch class type " r"is not supported in mobile module " r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " r"Workaround\: instead of using pytorch class as their element type\, " @@ -264,7 +264,7 @@ def forward(self): script_module = torch.jit.script(MyTestModuleForDictWithModuleClass()) with self.assertRaisesRegex(RuntimeError, - r"^Returining a list or dictionary with pytorch class type " + r"^Returning a list or dictionary with pytorch class type " r"is not supported in mobile module " r"\(List\[Foo\] or Dict\[int\, Foo\] for class Foo\(torch\.nn\.Module\)\)\. " r"Workaround\: instead of using pytorch class as their element type\, " diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 7ae9d8edaccc..16839dded0c4 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -9168,7 +9168,7 @@ def forward(self, x, y, cond): ) @skipScriptTest( - skip_before_opset_version=11, reason="dynamic split support addded in 11" + skip_before_opset_version=11, reason="dynamic split support added in 11" ) def test_split_tensor_scalar(self): class SplitModel(torch.nn.Module): diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 6cee5e95f21c..b03b7fb0cf0e 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -491,7 +491,7 @@ def forward(self, x): self.checkGraphModuleNodes(m, expected_node=ns.call_module(torch.nn.intrinsic.modules.fused.LinearReLU)) - @unittest.skip("Temprorarily skipping the test case, will enable after the simple" + @unittest.skip("Temporarily skipping the test case, will enable after the simple" "pattern format is supported") def test_fuse_addtional_fuser_method(self): class MyConvReLU(torch.nn.Module): diff --git a/test/scripts/run_cuda_memcheck.py b/test/scripts/run_cuda_memcheck.py index 10202e416d00..7d882b8c1fff 100755 --- a/test/scripts/run_cuda_memcheck.py +++ b/test/scripts/run_cuda_memcheck.py @@ -119,7 +119,7 @@ async def run1(coroutine_id): gpuid = coroutine_id % GPUS else: gpu_assignments = args.gpus.split(':') - assert args.nproc == len(gpu_assignments), 'Please specify GPU assignmnent for each process, separated by :' + assert args.nproc == len(gpu_assignments), 'Please specify GPU assignment for each process, separated by :' gpuid = gpu_assignments[coroutine_id] while progress < len(ALL_TESTS): diff --git a/test/test_sparse.py b/test/test_sparse.py index a2b623e2508e..4bfccaff0e2c 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -3708,7 +3708,7 @@ def check_empty(sparse_shape, nnz, dense_shape, coalesce): check(self, s, d) check_empty(shape, nnz, sub_shape, coalesced) - @unittest.skipIf(not TEST_NUMPY, "NumPy is not availible") + @unittest.skipIf(not TEST_NUMPY, "NumPy is not available") @onlyCPU @dtypes(*all_types_and_complex_and(torch.bool)) def test_sparse_spdiags(self, device, dtype): diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py index b351f2d6d494..1d80556a7d48 100644 --- a/test/test_type_promotion.py +++ b/test/test_type_promotion.py @@ -473,7 +473,7 @@ def _get_dtype(x): elif isinstance(x, complex): return torch.complex64 else: - raise AssertionError(f"Unkonwn type {x}") + raise AssertionError(f"Unknown type {x}") # tensor against tensor a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0]) From f5e2cb52496ab51edaa25ac35908b6832e23dadb Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 17 Nov 2022 02:02:26 +0000 Subject: [PATCH 269/453] Add comprehensive minifier tests (#88022) Adds tests for https://github.com/pytorch/torchdynamo/issues/1241. To run: `pytest test/dynamo/test_minifier.py`. Actually runs minifier launcher script and repro scripts, rather than just checking for existence of the minifier launcher script. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88022 Approved by: https://github.com/mlazos, https://github.com/anijain2305 --- test/dynamo/test_minifier.py | 388 ++++++++++++++++++++------ test/inductor/test_minifier.py | 211 ++++++++++++++ torch/_dynamo/debug_utils.py | 78 +++++- torch/_dynamo/test_minifier_common.py | 131 +++++++++ 4 files changed, 704 insertions(+), 104 deletions(-) create mode 100644 test/inductor/test_minifier.py create mode 100644 torch/_dynamo/test_minifier_common.py diff --git a/test/dynamo/test_minifier.py b/test/dynamo/test_minifier.py index 0cec7d202a9d..c1a56f070be5 100644 --- a/test/dynamo/test_minifier.py +++ b/test/dynamo/test_minifier.py @@ -1,111 +1,315 @@ # Owner(s): ["module: dynamo"] -import os -import shutil +import functools +import re +import textwrap import unittest -from unittest.mock import patch import torch import torch._dynamo -import torch._dynamo.test_case -import torch._dynamo.testing -from torch._dynamo.optimizations.backends import create_backend +from torch._dynamo.test_minifier_common import MinifierTestBase +requires_cuda = functools.partial( + unittest.skipIf, not torch.cuda.is_available(), "requires cuda" +) -class MockModule(torch.nn.Module): - def __init__(self): - super().__init__() +RELU_COMPILE_ERROR_BACKEND = """\ +from torch._dynamo.optimizations.backends import register_backend - def forward(self, x): - for _ in range(10): - x = torch.sin(x) - x = torch._foobar(x) - for _ in range(10): - x = torch.cos(x) - return x +class DynamoCompileError(Exception): + pass +@register_backend +def test_relu_compile_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise DynamoCompileError("relu found") + return gm +""" -class MinfierTests(torch._dynamo.test_case.TestCase): +RELU_RUNTIME_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_runtime_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch._assert + node.args = (False, "DynamoRuntimeError") + gm.recompile() + return gm +""" + +RELU_ACCURACY_ERROR_BACKEND = """\ +import copy +from torch._dynamo.optimizations.backends import register_backend + +@register_backend +def test_relu_accuracy_error(gm: torch.fx.GraphModule, example_inputs): + gm = copy.deepcopy(gm) + for node in gm.graph.nodes: + if node.target == torch.relu: + node.target = torch.add + node.args = (node.args[0], 1) + gm.recompile() + + return gm +""" + +RELU_CUSTOM_ERROR_BACKEND = """\ +class CustomError(Exception): + pass + +def test_relu_custom_error(gm: torch.fx.GraphModule, example_inputs): + for node in gm.graph.nodes: + if node.target == torch.relu: + raise CustomError("relu found") + return gm +""" + + +class MinifierTests(MinifierTestBase): @classmethod def setUpClass(cls): super().setUpClass() - cls._exit_stack.enter_context( - unittest.mock.patch.object( - torch._dynamo.config, - "debug_dir_root", - "/tmp/_torchdynamo_debug_/", - ) - ) @classmethod def tearDownClass(cls): - shutil.rmtree(torch._dynamo.config.debug_dir_root, ignore_errors=True) - cls._exit_stack.close() - - def setUp(self): - super().setUp() - - def tearDown(self): - super().tearDown() - - def test_after_dynamo(self): - @create_backend - def bad_dynamo_backend(subgraph): - import sys - - def f(*args): - # Shifted the forced exception to runtime as this is more common - # in JIT compilers. - for node in subgraph.model.graph.nodes: - if node.op == "call_function" and node.target is torch._foobar: - sys.stdout.write("Dynamo compiled failed\n") - raise NotImplementedError("foobar is not implemented") - return subgraph.model(*args) - - return f - - mod = MockModule() - opt_mod = torch._dynamo.optimize("bad_dynamo_backend")(mod) - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - - @patch.object(torch._dynamo.config, "repro_after", "dynamo") - def inner(): - x = torch.randn(4) - try: - opt_mod(x) - except Exception: - pass - - inner() - self.assertTrue(os.path.exists(repro_file)) - - # If error_at_aot is True, an error will be produced when AOTAutograd - # attempts to generate the backward graph. - # If error_after_aot is False, an error will be produced in inductor. - def _test_around_aot(self, error_at_aot): - mod = MockModule() - opt_mod = torch._dynamo.optimize("inductor")(mod) - - repro_file = torch._dynamo.debug_utils.get_minifier_repro_path() - repro_after = "dynamo" if error_at_aot else "aot" - - @patch.object(torch._dynamo.config, "repro_after", repro_after) - def inner(): - x = torch.randn(4) - x.requires_grad = error_at_aot - try: - opt_mod(x) - except Exception: - pass - - inner() - - self.assertTrue(os.path.exists(repro_file)) - - def test_at_aot(self): - self._test_around_aot(True) - - def test_after_aot(self): - self._test_around_aot(False) + super().tearDownClass() + + # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) + def _test_after_dynamo(self, device, repro_level, backend_code, error_name): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{self._get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "dynamo", repro_level, backend_code + ) + + self.assertIn(error_name, test_proc.stderr.decode("utf-8")) + self.assertIn(error_name, repro_proc.stderr.decode("utf-8")) + + def test_after_dynamo_cpu_compile_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + def test_after_dynamo_cpu_runtime_error(self): + self._test_after_dynamo( + "cpu", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + def test_after_dynamo_cpu_accuracy_error(self): + self._test_after_dynamo("cpu", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + @requires_cuda() + def test_after_dynamo_cuda_compile_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_COMPILE_ERROR_BACKEND, "DynamoCompileError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_error(self): + self._test_after_dynamo( + "cuda", 2, RELU_RUNTIME_ERROR_BACKEND, "DynamoRuntimeError" + ) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_error(self): + self._test_after_dynamo("cuda", 4, RELU_ACCURACY_ERROR_BACKEND, "AccuracyError") + + # Ensure that the testing backends pass when relu is not present. + def _test_after_dynamo_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{self._get_fn_name(backend_code)}") + def inner(x): + for _ in range(10): + x = torch.sin(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + + test_code = self._gen_test_code(run_code, "dynamo", repro_level, backend_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_dynamo_cpu_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_COMPILE_ERROR_BACKEND) + + def test_after_dynamo_cpu_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 2, RELU_RUNTIME_ERROR_BACKEND) + + def test_after_dynamo_cpu_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cpu", 4, RELU_ACCURACY_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_compile_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_COMPILE_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_runtime_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 2, RELU_RUNTIME_ERROR_BACKEND) + + @requires_cuda() + def test_after_dynamo_cuda_accuracy_backend_passes(self): + self._test_after_dynamo_backend_passes("cuda", 4, RELU_ACCURACY_ERROR_BACKEND) + + # Ensure that generated code with a custom backends generates a runnable minifier + # launcher script that results in a RuntimeError + def test_after_dynamo_custom_backend(self): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize({self._get_fn_name(RELU_CUSTOM_ERROR_BACKEND)}) + def inner(x): + for _ in range(10): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(10): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + test_code = self._gen_test_code( + run_code, "dynamo", 2, RELU_CUSTOM_ERROR_BACKEND + ) + _, repro_dir = self._run_test_code(test_code) + launch_proc, _ = self._run_minifier_launcher("", repro_dir) + self.assertIn("RuntimeError", launch_proc.stderr.decode("utf-8")) + + # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd + @requires_cuda() + def test_cpu_cuda_module_after_dynamo(self): + backend_name = self._get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + class CpuCudaModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.m_x = torch.nn.Linear(20, 20).cuda() + self.m_y = torch.nn.Linear(20, 20) + self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) + self.p_y = torch.nn.Parameter(torch.randn(20, 20)) + self.register_buffer("b_x", torch.ones(20, 20).cuda()) + self.register_buffer("b_y", torch.ones(20, 20)) + + def forward(self, x, y): + return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y + + mod = CpuCudaModule() + + @torch._dynamo.optimize("{backend_name}") + def inner(x1, y1): + x2 = torch.randn(20, 20).cuda() + y2 = torch.randn(20, 20) + x3, y3 = mod(x1 + x2, y1 + y2) + return torch.relu(x3.cpu() + y3) + + inner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, _) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + # Check if generated minifier code covers all cpu/cuda cases + self.assertIsNotNone(re.search(r"args.*cuda", launch_code)) + self.assertIsNotNone(re.search(r"args.*cpu", launch_code)) + # search for Linear(...).cuda() + self.assertIsNotNone(re.search(r"Linear.*cuda", launch_code)) + # search for Linear(...) + self.assertIsNotNone( + re.search(r"Linear(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"register_buffer.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"register_buffer(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + self.assertIsNotNone(re.search(r"Parameter.*cuda", launch_code)) + self.assertIsNotNone( + re.search(r"Parameter(?!.*cuda.*$)", launch_code, re.MULTILINE) + ) + # search for + # = torch.randn(...) + # ... = .cuda() + self.assertIsNotNone( + re.search(r"(\w+) = torch.randn.*\1\.cuda", launch_code, re.DOTALL) + ) + # search for + # = torch.randn(...) + # no followup call to .cuda() + self.assertIsNotNone( + re.search( + r"(\w+) = torch.randn(?!.*\1\.cuda\(\).*$)", launch_code, re.DOTALL + ) + ) + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # Test if we can actually get a minified graph + def test_if_graph_minified(self): + backend_name = self._get_fn_name(RELU_COMPILE_ERROR_BACKEND) + + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("{backend_name}") + def inner(x): + for _ in range(20): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(20): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20)) + """ + ) + + (test_proc, _, repro_proc), (launch_code, repro_code) = self._run_full_test( + run_code, "dynamo", 2, RELU_COMPILE_ERROR_BACKEND + ) + + tb1 = test_proc.stderr.decode("utf-8") + tb2 = repro_proc.stderr.decode("utf-8") + + self.assertIn(backend_name, tb1) + self.assertIn(backend_name, tb2) + + # compare the length of the forward functions + match = re.search(r"def forward.*return", launch_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertGreater(match.group(0).count("\n"), 40) + + match = re.search(r"def forward.*return", repro_code, re.DOTALL) + self.assertIsNotNone(match) + self.assertLess(match.group(0).count("\n"), 5) if __name__ == "__main__": diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py new file mode 100644 index 000000000000..55c0a1b6bb05 --- /dev/null +++ b/test/inductor/test_minifier.py @@ -0,0 +1,211 @@ +# Owner(s): ["module: inductor"] +import functools +import textwrap +import unittest + +import torch +import torch._dynamo +import torch._inductor.utils +from torch._dynamo.test_minifier_common import MinifierTestBase +from torch.testing._internal.common_utils import IS_MACOS + +_HAS_TRITON = torch._inductor.utils.has_triton() +requires_cuda = functools.partial(unittest.skipIf, not _HAS_TRITON, "requires cuda") + +CPP_COMPILE_ERROR = """\ +def cpp_compile_error(x): + return "compile error!" +""" + +CPP_RUNTIME_ERROR = """\ +def cpp_runtime_error(x): + return f"{x}; throw 1" +""" + +CPP_ACCURACY_ERROR = """\ +def cpp_accuracy_error(x): + return f"{x} + 1" +""" + +TRITON_COMPILE_ERROR = """\ +def triton_compile_error(x): + return "compile error!" +""" + +# NOTE: there is currently not an easy way to cause a triton runtime error. +TRITON_RUNTIME_ERROR = """\ +def triton_runtime_error(x): + return f"{x}; assert?" +""" + +TRITON_ACCURACY_ERROR = """\ +def triton_accuracy_error(x): + return f"{x} + 1" +""" + + +class MinifierTests(MinifierTestBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Generates code that patches CppOverrides/TritonOverrides. + def _gen_codegen_fn_patch_code(self, old_fn_name, new_fn_code, device): + new_fn_name = self._get_fn_name(new_fn_code) + if new_fn_name is not None: + patch_code = f"""\ +import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen +overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} +{new_fn_code} +overrides.{old_fn_name} = staticmethod({new_fn_name}) +""" + return f"""\ +{patch_code} +isolate_fails_code_str = \"\"\"\\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +\"\"\" +""" + + # Test that compile and accuracy errors after aot can be repro'd (both CPU and CUDA) + def _test_after_aot(self, device, backend_code, repro_level): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", repro_level, patch_code + ) + return ( + (test_proc.stderr.decode("utf-8"), repro_proc.stderr.decode("utf-8")), + (test_proc.returncode, repro_proc.returncode), + ) + + def test_after_aot_cpu_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_COMPILE_ERROR, 2) + self.assertIn("CppCompileError", tb1) + self.assertIn("CppCompileError", tb2) + + def test_after_aot_cpu_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cpu", CPP_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + @requires_cuda() + def test_after_aot_cuda_compile_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_COMPILE_ERROR, 2) + self.assertIn("SyntaxError", tb1) + self.assertIn("SyntaxError", tb2) + + @requires_cuda() + def test_after_aot_cuda_accuracy_error(self): + (tb1, tb2), _ = self._test_after_aot("cuda", TRITON_ACCURACY_ERROR, 4) + self.assertIn("AccuracyError", tb1) + self.assertIn("AccuracyError", tb2) + + # Test that runtime errors after aot can be repro'd (CPU only for now) + def _test_after_aot_runtime_error(self, device, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + x = torch.relu(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + (test_proc, _, repro_proc), _ = self._run_full_test( + run_code, "aot", 3, patch_code + ) + + self.assertNotIn("CompilerError", test_proc.stderr.decode("utf-8")) + + self.assertEqual(test_proc.returncode, repro_proc.returncode) + self.assertNotEqual(test_proc.returncode, 0) + + def test_after_aot_cpu_runtime_error(self): + self._test_after_aot_runtime_error("cpu", CPP_RUNTIME_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_error(self): + self._test_after_aot_runtime_error("cuda", TRITON_RUNTIME_ERROR) + + # Ensure that inductor codegen patches pass when relu is not present. + def _test_after_aot_backend_passes(self, device, repro_level, backend_code): + run_code = textwrap.dedent( + f"""\ + @torch._dynamo.optimize("inductor") + def inner(x): + for _ in range(3): + x = torch.sin(x) + for _ in range(3): + x = torch.cos(x) + return x + + inner(torch.randn(20, 20).to("{device}")) + """ + ) + patch_code = self._gen_codegen_fn_patch_code("relu", backend_code, device) + self.assertIsNotNone(patch_code) + + test_code = self._gen_test_code(run_code, "aot", repro_level, patch_code) + proc, repro_dir = self._run_test_code(test_code) + self.assertEqual(proc.returncode, 0) + self.assertIsNone(repro_dir) + + def test_after_aot_cpu_compile_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_COMPILE_ERROR) + + def test_after_aot_cpu_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 2, CPP_RUNTIME_ERROR) + + def test_after_aot_cpu_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cpu", 4, CPP_ACCURACY_ERROR) + + @requires_cuda() + def test_after_aot_cuda_compile_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_COMPILE_ERROR) + + # NOTE: there is currently not an easy way to cause a triton runtime error. + @unittest.skip + @requires_cuda() + def test_after_aot_cuda_runtime_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 2, TRITON_RUNTIME_ERROR) + + @requires_cuda() + def test_after_aot_cuda_accuracy_backend_passes(self): + self._test_after_aot_backend_passes("cuda", 4, TRITON_ACCURACY_ERROR) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + # skip CI tests on mac since CPU inductor does not seem to work due to C++ compile errors + if not IS_MACOS: + run_tests() diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index f09991f9bf34..98a269fe8c9e 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -84,6 +84,11 @@ def __init__(self): for module_name, module in gm.named_children(): module_str = f"{module.__repr__()}" + # module should be a core torch.nn.Module, so all parameters + # should be on the same device. + example_param = next(module.parameters(), None) + if example_param is not None and example_param.is_cuda: + module_str = f"{module_str}.cuda()" model_str += f"{tab*2}self.{module_name} = {module_str}\n" for buffer_name, buffer in gm._buffers.items(): @@ -95,12 +100,16 @@ def __init__(self): tensor_str = ( f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" ) + if buffer.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" for param_name, param in gm._parameters.items(): if param is None: continue tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" + if param.is_cuda: + tensor_str = f"{tensor_str}.cuda()" model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" # TODO - Keep this code for now. But, I don't think we will need this. @@ -145,6 +154,9 @@ def _cuda_system_info_comment(): return model_str +TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" + + def generate_compiler_repro_string(gm, args): model_str = textwrap.dedent( f""" @@ -155,6 +167,8 @@ def generate_compiler_repro_string(gm, args): from math import inf from torch.fx.experimental.proxy_tensor import make_fx + {TEST_REPLACEABLE_COMMENT} + """ ) model_str += f"# torch version: {torch.version.__version__}\n" @@ -170,7 +184,7 @@ def generate_compiler_repro_string(gm, args): model_str += ( "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" ) - model_str += 'mod = make_fx(Repro().to(device="cuda"))(*args)\n' + model_str += "mod = make_fx(Repro())(*args)\n" return model_str @@ -197,7 +211,8 @@ def dump_compiler_graph_state(gm, args, compiler_name): log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") with open(file_name, "w") as fd: save_graph_repro(fd, gm, args, compiler_name) - repro_path = os.path.join(config.base_dir, "repro.py") + curdir = os.getcwd() + repro_path = os.path.join(curdir, "repro.py") try: shutil.copyfile(file_name, repro_path) log.warning(f"Copying repro file for convenience to {repro_path}") @@ -216,7 +231,10 @@ def save_graph_repro(fd, gm, args, compiler_name): textwrap.dedent( f""" compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) - assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed" + class AccuracyError(Exception): + pass + if not same_two_models(mod, compiled, args, only_fwd=True): + raise AccuracyError("Bad accuracy detected") """ ) ) @@ -231,7 +249,7 @@ def save_graph_repro(fd, gm, args, compiler_name): ) -def isolate_fails(fx_g, args, compiler_name: str, env=None): +def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): if env is None: env = {} subdir = os.path.join(os.getcwd(), "isolate") @@ -239,7 +257,10 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): os.makedirs(subdir, exist_ok=True) file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") with open(file_name, "w") as fd: - fd.write(generate_compiler_repro_string(fx_g, args)) + repro_code = generate_compiler_repro_string(fx_g, args) + if patch_code is not None: + repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + fd.write(repro_code) fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] fd.write( textwrap.dedent( @@ -263,6 +284,7 @@ def isolate_fails(fx_g, args, compiler_name: str, env=None): stdout, stderr = TemporaryFile(), TemporaryFile() p = subprocess.Popen( ["python", file_name], + cwd=subdir, stdout=stdout, stderr=stderr, env=new_env, @@ -329,6 +351,8 @@ def dump_to_minify(gm, args, compiler_name: str): contents = textwrap.dedent( f""" +isolate_fails_code_str = None + {generate_compiler_repro_string(gm, args)} from functools import partial @@ -343,7 +367,7 @@ def dump_to_minify(gm, args, compiler_name: str): minifier( mod, args, - module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}"), + module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), ) """ @@ -351,6 +375,10 @@ def dump_to_minify(gm, args, compiler_name: str): return helper_for_dump_minify(contents) +class AccuracyError(Exception): + pass + + def wrap_compiler_debug(compiler_fn, compiler_name: str): """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both @@ -410,7 +438,7 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, f"{compiler_name}_accuracy", ) - raise ValueError("Bad accuracy detected") + raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs return inner_compiled_fn(real_inputs) @@ -435,7 +463,8 @@ def deferred_for_real_inputs(real_inputs): copy_tensor_attrs, compiler_name, ) - raise e + log.error("CompilerError") + raise if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs @@ -544,9 +573,14 @@ def generate_dynamo_fx_repro_string( f""" mod.eval() opt_mod.eval() + +class AccuracyError(Exception): + pass + with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): assert same_two_models(mod, mod, args), "Eager itself failed" - assert same_two_models(mod, opt_mod, args), "Dynamo failed" + if not same_two_models(mod, opt_mod, args): + raise AccuracyError("Dynamo failed") """ ) @@ -561,12 +595,14 @@ def generate_dynamo_fx_repro_string( from {config.dynamo_import}.debug_utils import run_fwd_maybe_bwd from {config.dynamo_import}.debug_utils import same_two_models +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() opt_mod = {config.dynamo_import}.optimize("{compiler_name}")(mod) {run_code} @@ -705,6 +741,21 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): if config.repro_level == 4: minifier_backend = "dynamo_accuracy_minifier_backend" + custom_compiler_error = ( + textwrap.dedent( + """\ + raise RuntimeError( + 'Compiler name is None - this likely means that a custom compiler ' + 'was called by torchdynamo. Please remove this error, import your ' + 'custom compiler function, and replace the compiler_name="None" ' + 'line below to compiler_name=' + ) + """ + ) + if compiler_name is None + else "" + ) + contents = textwrap.dedent( f""" import os @@ -718,14 +769,17 @@ def dump_to_minify_after_dynamo(gm, args, compiler_name): from {config.dynamo_import}.optimizations.backends import BACKENDS from {config.dynamo_import}.testing import rand_strided +{TEST_REPLACEABLE_COMMENT} + args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] {model_str} -mod = Repro().cuda() +mod = Repro() # Setup debug minifier compiler compiler_fn = BACKENDS["{minifier_backend}"] +{custom_compiler_error} dynamo_minifier_backend = functools.partial( compiler_fn, compiler_name="{compiler_name}", @@ -769,7 +823,7 @@ def debug_wrapper(gm, example_inputs, **kwargs): example_inputs, compiler_name, ) - exc = ValueError("Bad accuracy detected.") + exc = AccuracyError("Bad accuracy detected.") exc.minifier_path = os.path.join( minifier_dir(), "minifier_launcher.py" ) diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py new file mode 100644 index 000000000000..8fb0688f2c3e --- /dev/null +++ b/torch/_dynamo/test_minifier_common.py @@ -0,0 +1,131 @@ +import os +import re +import subprocess +import tempfile +import unittest + +import torch +import torch._dynamo +import torch._dynamo.test_case +from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT + + +class MinifierTestBase(torch._dynamo.test_case.TestCase): + _debug_dir_obj = tempfile.TemporaryDirectory() + DEBUG_DIR = _debug_dir_obj.name + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._exit_stack.enter_context( + unittest.mock.patch.object( + torch._dynamo.config, + "debug_dir_root", + cls.DEBUG_DIR, + ) + ) + os.makedirs(cls.DEBUG_DIR, exist_ok=True) + + @classmethod + def tearDownClass(cls): + cls._debug_dir_obj.cleanup() + cls._exit_stack.close() + + def setUp(self): + super().setUp() + + def tearDown(self): + super().tearDown() + + # Search for the name of the first function defined in a code string. + def _get_fn_name(self, code): + fn_name_match = re.search(r"def (\w+)\(", code) + if fn_name_match is not None: + return fn_name_match.group(1) + return None + + # Run `code` in a separate python process. + # Returns the completed process state and the directory containing the + # minifier launcher script, if `code` outputted it. + def _run_test_code(self, code): + proc = subprocess.run( + ["python3", "-c", code], capture_output=True, cwd=self.DEBUG_DIR + ) + + repro_dir_match = re.search( + r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8") + ) + if repro_dir_match is not None: + # Print repro directory for debugging generated code. + # Make sure to comment out `shutil.rmtree...` above as well. + print("repro dir:", repro_dir_match.group(1)) + return proc, repro_dir_match.group(1) + return proc, None + + # Patch generated files with testing patches + def _inject_code(self, patch_code, filename): + patch_code = f"""\ +{patch_code} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +""" + with open(filename, "r") as f: + code = f.read() + code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code) + with open(filename, "w") as f: + f.write(code) + return code + + # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`. + def _run_minifier_launcher(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + launch_file = os.path.join(repro_dir, "minifier_launcher.py") + self.assertTrue(os.path.exists(launch_file)) + launch_code = self._inject_code(patch_code, launch_file) + + launch_proc = subprocess.run( + ["python3", launch_file], + capture_output=True, + cwd=repro_dir, + ) + + return launch_proc, launch_code + + # Runs the repro script in `repro_dir`, patched with `patch_code` + def _run_repro(self, patch_code, repro_dir): + self.assertIsNotNone(repro_dir) + repro_file = os.path.join(repro_dir, "repro.py") + self.assertTrue(os.path.exists(repro_file)) + repro_code = self._inject_code(patch_code, repro_file) + + repro_proc = subprocess.run( + ["python3", repro_file], capture_output=True, cwd=repro_dir + ) + + return repro_proc, repro_code + + # Template for testing code. + # `run_code` is the code to run for the test case. + # `patch_code` is the code to be patched in every generated file. + def _gen_test_code(self, run_code, repro_after, repro_level, patch_code): + return f"""\ +import torch +import torch._dynamo +{patch_code} +torch._dynamo.config.repro_after = "{repro_after}" +torch._dynamo.config.repro_level = {repro_level} +torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}" +{run_code} +""" + + # Runs a full minifier test. + # Minifier tests generally consist of 3 stages: + # 1. Run the problematic code (in a separate process since it could segfault) + # 2. Run the generated minifier launcher script + # 3. Run the generated repro script + def _run_full_test(self, run_code, repro_after, repro_level, patch_code): + test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code) + test_proc, repro_dir = self._run_test_code(test_code) + self.assertIsNotNone(repro_dir) + launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir) + repro_proc, repro_code = self._run_repro(patch_code, repro_dir) + return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code)) From 30d9fb9157b59db27cd2c0c6e6b0b6221efda571 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 17 Nov 2022 02:03:45 +0000 Subject: [PATCH 270/453] [dynamo][reland] API Support for nn.Module (#89113) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/89113 Approved by: https://github.com/ezyang --- test/dynamo/test_modules.py | 135 +++++++++++++++++++++++++++++++++++ torch/_dynamo/__init__.py | 2 + torch/_dynamo/debug_utils.py | 8 +++ torch/_dynamo/eval_frame.py | 79 ++++++++++++++------ torch/_dynamo/testing.py | 14 ++++ 5 files changed, 218 insertions(+), 20 deletions(-) diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 2fb83b3add6c..ed3b715f72f9 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -904,6 +904,141 @@ def forward(self, x): self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) +class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.relu(self.linear(x) + self.buf0) + + +class OptimizedModuleTest(torch._dynamo.test_case.TestCase): + def test_nn_module(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_to(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + x = torch.randn(10, 10) + self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 1) + + opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) + self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) + x = torch.randn(10, 10).to(dtype=torch.float64) + opt_mod(x) + # Ensure that there is a recompilation + self.assertEqual(cnt.frame_count, 2) + + # Ensure that there is no recompilation + opt_mod(x) + self.assertEqual(cnt.frame_count, 2) + + torch._dynamo.reset() + opt_mod(x) + self.assertEqual(cnt.frame_count, 3) + + def test_attr(self): + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.register_buffer("buf0", torch.randn(10, 10)) + + def forward(self, x): + return self.r(torch.sin(x)) + self.buf0 + + mod = MockModule() + opt_mod = torch._dynamo.optimize("eager")(mod) + + # Check parameteres and buffers + for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): + self.assertTrue(id(p1) == id(p2)) + + def test_recursion(self): + mod = MockModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt)(mod) + + for _ in range(5): + opt_mod = torch._dynamo.optimize(cnt)(opt_mod) + opt_mod(torch.randn(10, 10)) + self.assertEqual(cnt.frame_count, 1) + + def test_composition(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + opt_inner_mod = InnerModule() + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + self.assertEqual(cnt.frame_count, 1) + + def test_composition_with_opt_mod(self): + class InnerModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(torch.sin(x)) + + inner_mod = InnerModule() + cnt = torch._dynamo.testing.CompileCounter() + opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mod = opt_inner_mod + + def forward(self, x): + return self.mod(torch.cos(x)) + + outer_mod = OuterModule() + opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) + + x = torch.randn(4) + self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) + self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) + # There will be a graph break for the inner mod being OptimizedModule + self.assertEqual(cnt.frame_count, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 80f927aeef2f..5eee609b0852 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -7,6 +7,7 @@ export, optimize, optimize_assert, + OptimizedModule, reset_code, run, skip, @@ -25,6 +26,7 @@ "reset", "list_backends", "skip", + "OptimizedModule", ] diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 98a269fe8c9e..29d830167b10 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -515,8 +515,16 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): """ Check two models have same accuracy. """ + from .eval_frame import OptimizedModule + from .testing import named_parameters_for_optimized_module from .utils import same + if isinstance(gm, OptimizedModule): + gm.named_parameters = named_parameters_for_optimized_module(gm) + + if isinstance(opt_gm, OptimizedModule): + opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) + ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) try: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index cb3cffaa73d1..1188bfd74fc2 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -5,6 +5,7 @@ import logging import os import sys +import textwrap import threading import traceback import types @@ -44,6 +45,27 @@ most_recent_backend = None +class OptimizedModule(torch.nn.Module): + """ + Wraps the original nn.Module object and later patches its + forward method to optimized self.forward method. + """ + + def __init__(self, mod, dynamo_ctx): + super().__init__() + # Installs the params/buffer + self._orig_mod = mod + self.dynamo_ctx = dynamo_ctx + + def __getattr__(self, name): + if name == "_orig_mod": + return self._modules["_orig_mod"] + return getattr(self._orig_mod, name) + + def forward(self, *args, **kwargs): + return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs) + + def remove_from_cache(f): """ Make sure f.__code__ is not cached to force a recompile @@ -118,31 +140,14 @@ def __call__(self, fn): # Optimize the forward method of torch.nn.Module object if isinstance(fn, torch.nn.Module): mod = fn - optimized_forward = self(mod.forward) - - class TorchDynamoNNModuleWrapper: - """ - A wrapper that redirects the forward call to the optimized - forward, while for rest it redirects the calls to the original - module. - """ - - def __getattr__(self, name): - return getattr(mod, name) - - def forward(self, *args, **kwargs): - return optimized_forward(*args, **kwargs) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - new_mod = TorchDynamoNNModuleWrapper() + new_mod = OptimizedModule(mod, self) # Save the function pointer to find the original callable while nesting # of decorators. - new_mod._torchdynamo_orig_callable = mod + new_mod._torchdynamo_orig_callable = mod.forward return new_mod assert callable(fn) + callback = self.callback on_enter = self.on_enter backend_ctx_ctor = self.extra_ctx_ctor @@ -184,6 +189,40 @@ def _fn(*args, **kwargs): # If the function is called using torch._dynamo.optimize decorator, we # should prevent any type of skipping. if callback not in (None, False): + if not hasattr(fn, "__code__"): + raise RuntimeError( + textwrap.dedent( + """ + + torch._dynamo.optimize is called on a non function object. + If this is a callable class, please wrap the relevant code into a function and optimize the + wrapper function. + + >> class CallableClass: + >> def __init__(self): + >> super().__init__() + >> self.relu = torch.nn.ReLU() + >> + >> def __call__(self, x): + >> return self.relu(torch.sin(x)) + >> + >> def print_hello(self): + >> print("Hello world") + >> + >> mod = CallableClass() + + If you want to optimize the __call__ function and other code, wrap that up in a function + + >> def wrapper_fn(x): + >> y = mod(x) + >> return y.sum() + + and then optimize the wrapper_fn + + >> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) + """ + ) + ) always_optimize_code_objects[fn.__code__] = True return _fn diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d6082ce48acf..6e0d32d21f97 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -32,6 +32,18 @@ def clone_me(x): return x.detach().clone().requires_grad_(x.requires_grad) +def named_parameters_for_optimized_module(mod): + assert isinstance(mod, eval_frame.OptimizedModule) + return mod._orig_mod.named_parameters + + +def remove_optimized_module_prefix(name): + prefix = "_orig_mod." + assert name.startswith(prefix) + name = name[len(prefix) :] + return torch.distributed.fsdp._common_utils.clean_tensor_name(name) + + def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) @@ -44,6 +56,8 @@ def collect_results(model, prediction, loss, example_inputs): grads = dict() params = dict() for name, param in model.named_parameters(): + if isinstance(model, eval_frame.OptimizedModule): + name = remove_optimized_module_prefix(name) param_copy = param grad = param.grad # Treat None and zero grad as same From ac0a6f381de06b58aa583daf7771c410c69709fd Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Wed, 16 Nov 2022 22:28:36 +0000 Subject: [PATCH 271/453] [dtensor] disable op db tests for now (#89162) context: https://github.com/pytorch/pytorch/issues/89160 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89162 Approved by: https://github.com/fduwjj --- test/run_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/run_test.py b/test/run_test.py index 6bf98a01a44d..94bee60cc24e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -113,6 +113,7 @@ def skip_test_p(name: str) -> bool: "distributed/launcher/bin/test_script_is_torchelastic_launched", "distributed/launcher/bin/test_script_local_rank", "distributed/test_c10d_spawn", + "distributed/_tensor/test_dtensor_ops", 'distributions/test_transforms', 'distributions/test_utils', ], From f73d9a79fe8d52be27c3c28cd93ce690bdc4f9b7 Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Thu, 17 Nov 2022 02:43:33 +0000 Subject: [PATCH 272/453] [torch][fx] Fix PassManager to not use a class variable mutable list (#89108) Summary: I found a confusing bug in the PassManager that only happens when you instantiate one multiple times: it will use old passes and constraints! This occurs because the class-level declarations initialize it to an empty list, but the problem is that class initializers only run once, and are creating class variables. This means the same empty list was being reused every time, except after the first time it isn't empty. The empty list has to be created in `__init__` newly each time or else it'll be shared. Note that this is the same type of bug as using an empty list as a default parameter, where it'll reuse the same list pointer and not make it empty each time. The better way to do this is with either: * An immutable default parameter like an empty tuple, that you create a new list from: `self.passes = list(passes)` * Use None and then create the empty list inside `__init__` I chose the latter as it's less likely to cause a behavior change due to the changed default. Note that for immutable values like `False` and `1` this doesn't apply as you can't mutate that value for everyone. Test Plan: Added a test to ensure that the pass state is not saved. Without my change, this test would fail as it would run all of the `2 * x` passes first, then all of the `3 * x` passes. Differential Revision: D41327056 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89108 Approved by: https://github.com/angelayi --- torch/fx/passes/infra/pass_manager.py | 10 ++++------ torch/fx/passes/pass_manager.py | 10 ++++------ torch/fx/passes/tests/test_pass_manager.py | 22 ++++++++++++++++++++++ 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 265c6263da54..e649acfb28f5 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -165,8 +165,8 @@ class PassManager: checks """ - passes: List[Callable[[nn.Module], PassResult]] = [] - constraints: List[Callable[[Callable, Callable], bool]] = [] + passes: List[Callable[[nn.Module], PassResult]] + constraints: List[Callable[[Callable, Callable], bool]] _validated: bool = False steps: int = 1 @@ -178,10 +178,8 @@ def __init__( run_checks_after_each_pass: bool = False, suppress_check_failures: bool = False, ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints + self.passes = passes or [] + self.constraints = constraints or [] if steps: self.steps = steps diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 5a34c5bca362..cf002b3611bf 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -184,8 +184,8 @@ class PassManager: `this_before_that_pass_constraint` for example. """ - passes: List[Callable] = [] - constraints: List[Callable] = [] + passes: List[Callable] + constraints: List[Callable] _validated: bool = False def __init__( @@ -193,10 +193,8 @@ def __init__( passes=None, constraints=None, ): - if passes: - self.passes = passes - if constraints: - self.constraints = constraints + self.passes = passes or [] + self.constraints = constraints or [] @classmethod def build_from_passlist(cls, passes): diff --git a/torch/fx/passes/tests/test_pass_manager.py b/torch/fx/passes/tests/test_pass_manager.py index 4ed0cfce89de..60ed6671179b 100644 --- a/torch/fx/passes/tests/test_pass_manager.py +++ b/torch/fx/passes/tests/test_pass_manager.py @@ -34,3 +34,25 @@ def test_these_before_those_pass_constraint(self) -> None: pm.add_constraint(constraint) self.assertRaises(RuntimeError, pm.validate) + + def test_two_pass_managers(self) -> None: + """Make sure we can construct the PassManager twice and not share any + state between them""" + + passes = [lambda x: 2 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm1 = PassManager() + for p in passes: + pm1.add_pass(p) + pm1.add_constraint(constraint) + output1 = pm1(1) + self.assertEqual(output1, 2 ** 3) + + passes = [lambda x: 3 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm2 = PassManager() + for p in passes: + pm2.add_pass(p) + pm2.add_constraint(constraint) + output2 = pm2(1) + self.assertEqual(output2, 3 ** 3) From f3af5ba48effeb7785df2049348d83467c5fb986 Mon Sep 17 00:00:00 2001 From: Charlie Yan Date: Tue, 15 Nov 2022 23:33:05 +0000 Subject: [PATCH 273/453] [WIP] Composable API: `replicate` and `DistributedState` (#87649) This PR adds the first version of the `replicate()` composable API. For this prototype version, I try to reuse as much code from existing `DistributedDataParallel` as possible, and iterate on it in later changes. The basic idea of this prototype is: - create a `ReplicateState` object. It internally uses a `ParameterList` module to hold all parameters of modules marked by `replicate()` API. - create an internal `_ddp` object, which reuses existing `DistributedDataParallel` implementation, and wraps the `ParameterList` object - install pre-forward and after-forward hooks on the root module, which calls methods of `_ddp` to run initialization and forward Pull Request resolved: https://github.com/pytorch/pytorch/pull/87649 Approved by: https://github.com/zhaojuanmao --- .../distributed/_composable/test_replicate.py | 100 ++++++++++++++++ torch/distributed/_composable/__init__.py | 1 + torch/distributed/_composable/_ddp.py | 20 +++- torch/distributed/_composable/contract.py | 11 +- torch/distributed/_composable/replicate.py | 107 ++++++++++++++++++ 5 files changed, 232 insertions(+), 7 deletions(-) create mode 100644 test/distributed/_composable/test_replicate.py create mode 100644 torch/distributed/_composable/replicate.py diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py new file mode 100644 index 000000000000..831ccc3376af --- /dev/null +++ b/test/distributed/_composable/test_replicate.py @@ -0,0 +1,100 @@ +# Owner(s): ["oncall: distributed"] + +import os +from copy import deepcopy + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn +from torch.distributed._composable.replicate import mark_root_module, replicate +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_utils import run_tests + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.fc2 = nn.Linear(10, 50, bias=False) + self.fc3 = nn.Linear(50, 4, bias=False) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.fc3(x) + return F.softmax(x, dim=1) + + +class ReplicateTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + def tearDown(self): + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + def _prepare_module(self, global_batch_size): + model = Net() + input = torch.randn(global_batch_size, 2) + target = torch.randn(global_batch_size, 4) + return model, input, target + + def test_replicate(self): + dist.init_process_group( + backend="gloo", + rank=self.rank, + world_size=self.world_size, + store=dist.FileStore(self.file_name, self.world_size), + ) + + local_batch_size = 1 + global_batch_size = self.world_size * local_batch_size + model, input, target = self._prepare_module(global_batch_size) + replicate_model = mark_root_module(replicate(deepcopy(model))) + + def step_model(model, input, target): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + loss.backward() + for param in model.parameters(): + with torch.no_grad(): + param -= param.grad + param.grad = None + + for iteration in range(2): + step_model(model, input, target) + step_model( + replicate_model, + input[ + self.rank + * local_batch_size : (self.rank + 1) + * local_batch_size + ], + target[ + self.rank + * local_batch_size : (self.rank + 1) + * local_batch_size + ], + ) + + self.assertEqual( + len(list(model.parameters())), + len(list(replicate_model.parameters())), + ) + for i, j in zip(model.parameters(), replicate_model.parameters()): + self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) + + # Shuffle the input so that DDP input is different + torch.manual_seed(iteration) + input = input[torch.randperm(global_batch_size)] + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_composable/__init__.py b/torch/distributed/_composable/__init__.py index 5b0d8e77e5cc..2952426a09fd 100644 --- a/torch/distributed/_composable/__init__.py +++ b/torch/distributed/_composable/__init__.py @@ -1,3 +1,4 @@ from .checkpoint_activation import checkpoint from .contract import contract from .fully_shard import fully_shard +from .replicate import replicate diff --git a/torch/distributed/_composable/_ddp.py b/torch/distributed/_composable/_ddp.py index 76a4aa70c422..9e94ec3d53cd 100644 --- a/torch/distributed/_composable/_ddp.py +++ b/torch/distributed/_composable/_ddp.py @@ -1058,9 +1058,9 @@ def _run_ddp_forward(self, *inputs, **kwargs): with self._inside_ddp_forward(): return module_to_run(*inputs, **kwargs) - def forward(self, *inputs, **kwargs): + def pre_forward(self): with torch.autograd.profiler.record_function( - "DistributedDataParallel.forward" + "DistributedDataParallel.pre_forward" ): if torch.is_grad_enabled() and self.require_backward_grad_sync: assert self.logger is not None @@ -1090,7 +1090,6 @@ def forward(self, *inputs, **kwargs): # sync params according to location (before/after forward) user # specified as part of hook, if hook was specified. - buffer_hook_registered = hasattr(self, "buffer_hook") if self._check_sync_bufs_pre_fwd(): self._sync_buffers() @@ -1100,8 +1099,10 @@ def forward(self, *inputs, **kwargs): is_joined_rank=False ) - output = self._run_ddp_forward(*inputs, **kwargs) - + def post_forward(self, output): + with torch.autograd.profiler.record_function( + "DistributedDataParallel.post_forward" + ): # sync params according to location (before/after forward) user # specified as part of hook, if hook was specified. if self._check_sync_bufs_post_fwd(): @@ -1166,6 +1167,15 @@ def forward(self, *inputs, **kwargs): ) return output + def forward(self, *inputs, **kwargs): + self.pre_forward(*inputs, **kwargs) + with torch.autograd.profiler.record_function( + "DistributedDataParallel.forward" + ): + output = self._run_ddp_forward(*inputs, **kwargs) + output = self.post_forward(output) + return output + def scatter(self, inputs, kwargs, device_ids): return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index b75604872a59..fca817bcfc1d 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -121,7 +121,9 @@ def check_fqn(orig_fqns: List[str], new_fqns: List[str]): f"New FQNs: {new_only}" ) - check_fqn(list(orig_named_params.keys()), list(new_named_params.keys())) + check_fqn( + list(orig_named_params.keys()), list(new_named_params.keys()) + ) check_fqn( list(orig_named_buffers.keys()), list(new_named_buffers.keys()) ) @@ -138,7 +140,12 @@ def check_fqn(orig_fqns: List[str], new_fqns: List[str]): return updated def get_state(module: nn.Module) -> Optional[_State]: - return module.__dict__.get(STATE_KEY).get(func) # type: ignore[call-overload] + return module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, + {}, # TODO(@yhcharles): this is a temporary fix, need a better way + ).get( + func + ) # type: ignore[call-overload] wrapper.state = get_state # type: ignore[attr-defined] diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py new file mode 100644 index 000000000000..0e94427afee8 --- /dev/null +++ b/torch/distributed/_composable/replicate.py @@ -0,0 +1,107 @@ +from typing import List, Tuple + +import torch +import torch.nn as nn + +from . import _ddp +from .contract import contract + + +class DistributedState: + ... + + +class ReplicateState(DistributedState): + def __init__(self) -> None: + self.modules: List[nn.Module] = [] + self.has_initialized: bool = False + self._param_list: nn.ParameterList = nn.ParameterList() + + def mark_modules(self, *modules: nn.Module) -> None: + for module in modules: + self.modules.append(module) + replicate.state(module)._distributed_state = self + replicate.state(module)._params_collected = False + + def _recursive_collect_params(self, module: nn.Module) -> None: + # TODO: skip if managed by other APIs + + if hasattr(replicate.state(module), "_params_collected"): + if replicate.state(module)._params_collected: + return + replicate.state(module)._params_collected = True + + self._param_list.extend( + param + for param in module.parameters(recurse=False) + # for param in module.parameters() + if param.requires_grad + ) + for child in module.children(): + self._recursive_collect_params(child) + + def init_helper(self): + if self.has_initialized: + return + + self.has_initialized = True + for module in self.modules: + self._recursive_collect_params(module) + + self._ddp = _ddp.DistributedDataParallel(self._param_list) + + def root_module_forward_pre_hook( + self, module: nn.Module, input: Tuple[torch.Tensor] + ) -> None: + self.init_helper() + self._ddp.pre_forward() + + def root_module_forward_post_hook( + self, + module: nn.Module, + input: Tuple[torch.Tensor], + output: torch.Tensor, + ) -> torch.Tensor: + return self._ddp.post_forward(output) + + +# TODO(@yhcharles): use a per-model instance instead of a global one +_default_state = ReplicateState() + + +@contract +def replicate( + module: nn.Module, # NOTE: contract now supports single module only + dist_state: ReplicateState = _default_state, +) -> nn.Module: + r"""Replicates module(s) + + Args: + modules (torch.nn.Module): modules to replicate + + Example:: + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + dist_state.mark_modules(module) + return module + + +def mark_root_module( + module: nn.Module, dist_state: ReplicateState = _default_state +) -> nn.Module: + r"""Mark the root module. Its sub-modules can be replicated. + + Args: + modules (torch.nn.Module): root module + + Example:: + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + module.register_forward_pre_hook(dist_state.root_module_forward_pre_hook) + # TODO(@yhcharles): fix type error + module.register_forward_hook( + dist_state.root_module_forward_post_hook # type: ignore[arg-type] + ) + return module From c3acb9c8859fb5cfa1959ee49849f07942c40ccc Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Wed, 16 Nov 2022 19:50:02 +0000 Subject: [PATCH 274/453] [ONNX] Add Internal Utils: onnx_proto_utils.py for onnx/onnx-script/onnx_proto (#88376) Added `onnx_proto_utils.py` for onnx/onnx-script related process. The idea is like jit_utils.py, and to simplify what we have in `torch/onnx/utils.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88376 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- torch/onnx/_internal/onnx_proto_utils.py | 143 ++++++++++++++++++++++ torch/onnx/utils.py | 148 ++--------------------- 2 files changed, 152 insertions(+), 139 deletions(-) create mode 100644 torch/onnx/_internal/onnx_proto_utils.py diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py new file mode 100644 index 000000000000..f557089707b8 --- /dev/null +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -0,0 +1,143 @@ +"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto.""" + +import io +import os +import zipfile +from typing import List, Mapping, Set, Union + +import torch +import torch.jit._trace +import torch.serialization +from torch.onnx import _constants, _exporter_states, errors +from torch.onnx._internal import _beartype, jit_utils, registration + + +@_beartype.beartype +def _export_file( + model_bytes: bytes, + f: Union[io.BytesIO, str], + export_type: str, + export_map: Mapping[str, bytes], +) -> None: + """export/write model bytes into directory/protobuf/zip""" + # TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f, + # but beartype raises beartype.roar.BeartypeDecorHintNonpepException, + # as os.PathLike[str] uncheckable at runtime + if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) + elif export_type in { + _exporter_states.ExportTypes.ZIP_ARCHIVE, + _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, + }: + compression = ( + zipfile.ZIP_DEFLATED + if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE + else zipfile.ZIP_STORED + ) + with zipfile.ZipFile(f, "w", compression=compression) as z: + z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) + for k, v in export_map.items(): + z.writestr(k, v) + elif export_type == _exporter_states.ExportTypes.DIRECTORY: + if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] + raise ValueError( + f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" + ) + if not os.path.exists(f): # type: ignore[arg-type] + os.makedirs(f) # type: ignore[arg-type] + + model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] + with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: + opened_file.write(model_bytes) + + for k, v in export_map.items(): + weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] + with torch.serialization._open_file_like( + weight_proto_file, "wb" + ) as opened_file: + opened_file.write(v) + else: + raise ValueError("Unknown export type") + + +@_beartype.beartype +def _add_onnxscript_fn( + model_bytes: bytes, + custom_opsets: Mapping[str, int], +) -> bytes: + """Insert model-included custom onnx-script function into ModelProto""" + # TODO(titaiwang): remove this when onnx becomes dependency + try: + import onnx + except ImportError: + raise errors.OnnxExporterError("Module onnx is not installed!") + + # For > 2GB model, onnx.load_fromstring would fail. However, because + # in _export_onnx, the tensors should be saved separately if the proto + # size > 2GB, and if it for some reason did not, the model would fail on + # serialization anyway in terms of the protobuf limitation. So we don't + # need to worry about > 2GB model getting here. + model_proto = onnx.load_from_string(model_bytes) + + # Iterate graph nodes to insert only the included custom + # function_proto into model_proto + # TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction + # calling other ONNXFunction scenario, neither does it here + onnx_function_list = list() # type: ignore[var-annotated] + included_node_func = set() # type: Set[str] + # onnx_function_list and included_node_func are expanded in-place + _find_onnxscript_op( + model_proto.graph, included_node_func, custom_opsets, onnx_function_list + ) + + if onnx_function_list: + model_proto.functions.extend(onnx_function_list) + model_bytes = model_proto.SerializeToString() + return model_bytes + + +@_beartype.beartype +def _find_onnxscript_op( + graph_proto, + included_node_func: Set[str], + custom_opsets: Mapping[str, int], + onnx_function_list: List, +): + """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" + for node in graph_proto.node: + node_kind = node.domain + "::" + node.op_type + # Recursive needed for control flow nodes: IF/Loop which has inner graph_proto + for attr in node.attribute: + if attr.g is not None: + _find_onnxscript_op( + attr.g, included_node_func, custom_opsets, onnx_function_list + ) + # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry + onnx_function_group = registration.registry.get_function_group(node_kind) + # Ruled out corner cases: onnx/prim in registry + if ( + node.domain + and not jit_utils.is_aten(node.domain) + and not jit_utils.is_prim(node.domain) + and not jit_utils.is_onnx(node.domain) + and onnx_function_group is not None + and node_kind not in included_node_func + ): + specified_version = custom_opsets.get(node.domain, 1) + onnx_fn = onnx_function_group.get(specified_version) + if onnx_fn is not None: + # TODO(titaiwang): to_function_proto is onnx-script API and can be annotated + # after onnx-script is dependency + onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined] + included_node_func.add(node_kind) + continue + raise errors.UnsupportedOperatorError( + node_kind, + specified_version, + onnx_function_group.get_min_supported() + if onnx_function_group + else None, + ) + return onnx_function_list, included_node_func diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 9d6ec0b32523..67dd719bae9f 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -9,12 +9,10 @@ import copy import inspect import io -import os import re import textwrap import typing import warnings -import zipfile from typing import ( Any, Callable, @@ -38,15 +36,19 @@ from torch import _C from torch.onnx import ( # noqa: F401 _constants, - _deprecation, _exporter_states, - _patch_torch, errors, symbolic_caffe2, symbolic_helper, ) from torch.onnx._globals import GLOBALS -from torch.onnx._internal import _beartype, diagnostics, jit_utils, registration +from torch.onnx._internal import ( + _beartype, + diagnostics, + jit_utils, + onnx_proto_utils, + registration, +) __all__ = [ "is_in_onnx_export", @@ -1598,13 +1600,13 @@ def _export( node_attr_to_name, ) # insert function_proto into model_proto. - proto = _add_onnxscript_fn( + proto = onnx_proto_utils._add_onnxscript_fn( proto, custom_opsets, ) if verbose: torch.onnx.log("Exported graph: ", graph) - _export_file(proto, f, export_type, export_map) + onnx_proto_utils._export_file(proto, f, export_type, export_map) # The ONNX checker only works for ONNX graph. So if the operator_export_type is not ONNX, # we can skip this check. # If large model format export is enabled, proto will only contain data location instead of @@ -1625,138 +1627,6 @@ def _export( return torch_out -@_beartype.beartype -def _export_file( - model_bytes: bytes, - f: Union[io.BytesIO, str], - export_type: str, - export_map: Mapping[str, bytes], -) -> None: - """export/write model bytes into directory/protobuf/zip""" - # TODO(titaiwang) MYPY asks for os.PathLike[str] type for parameter: f, - # but beartype raises beartype.roar.BeartypeDecorHintNonpepException, - # as os.PathLike[str] uncheckable at runtime - if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: - assert len(export_map) == 0 - with torch.serialization._open_file_like(f, "wb") as opened_file: - opened_file.write(model_bytes) - elif export_type in [ - _exporter_states.ExportTypes.ZIP_ARCHIVE, - _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, - ]: - compression = ( - zipfile.ZIP_DEFLATED - if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE - else zipfile.ZIP_STORED - ) - with zipfile.ZipFile(f, "w", compression=compression) as z: - z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) - for k, v in export_map.items(): - z.writestr(k, v) - elif export_type == _exporter_states.ExportTypes.DIRECTORY: - if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] - raise ValueError( - f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" - ) - if not os.path.exists(f): # type: ignore[arg-type] - os.makedirs(f) # type: ignore[arg-type] - - model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] - with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: - opened_file.write(model_bytes) - - for k, v in export_map.items(): - weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] - with torch.serialization._open_file_like( - weight_proto_file, "wb" - ) as opened_file: - opened_file.write(v) - else: - raise RuntimeError("Unknown export type") - - -@_beartype.beartype -def _add_onnxscript_fn( - model_bytes: bytes, - custom_opsets: Mapping[str, int], -) -> bytes: - """Insert model-included custom onnx-script function into ModelProto""" - - # TODO(titaiwang): remove this when onnx becomes dependency - try: - import onnx - except ImportError: - raise errors.OnnxExporterError("Module onnx is not installed!") - - # For > 2GB model, onnx.load_fromstring would fail. However, because - # in _export_onnx, the tensors should be saved separately if the proto - # size > 2GB, and if it for some reason did not, the model would fail on - # serialization anyway in terms of the protobuf limitation. So we don't - # need to worry about > 2GB model getting here. - model_proto = onnx.load_from_string(model_bytes) - - # Iterate graph nodes to insert only the included custom - # function_proto into model_proto - # TODO(titaiwang): Currently, onnxscript doesn't support ONNXFunction - # calling other ONNXFunction scenario, neither does it here - onnx_function_list = list() # type: ignore[var-annotated] - included_node_func = set() # type: Set[str] - # onnx_function_list and included_node_func are expanded in-place - _find_onnxscript_op( - model_proto.graph, included_node_func, custom_opsets, onnx_function_list - ) - - if onnx_function_list: - model_proto.functions.extend(onnx_function_list) - model_bytes = model_proto.SerializeToString() - return model_bytes - - -@_beartype.beartype -def _find_onnxscript_op( - graph_proto, - included_node_func: Set[str], - custom_opsets: Mapping[str, int], - onnx_function_list: List, -): - """Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op.""" - for node in graph_proto.node: - node_kind = node.domain + "::" + node.op_type - # Recursive is needed for control flow nodes: IF/Loop which has inner graph_proto - for attr in node.attribute: - if attr.g is not None: - _find_onnxscript_op( - attr.g, included_node_func, custom_opsets, onnx_function_list - ) - # Only custom Op with ONNX function and aten with symbolic_fn should be found in registry - onnx_function_group = registration.registry.get_function_group(node_kind) - # Ruled out corner cases: onnx/prim in registry - if ( - node.domain - and not jit_utils.is_aten(node.domain) - and not jit_utils.is_prim(node.domain) - and not jit_utils.is_onnx(node.domain) - and onnx_function_group is not None - and node_kind not in included_node_func - ): - specified_version = custom_opsets.get(node.domain, 1) - onnx_fn = onnx_function_group.get(specified_version) - if onnx_fn is not None: - # TODO(titaiwang): to_function_proto is onnx-script API and can be annotated - # after onnx-script is dependency - onnx_function_list.append(onnx_fn.to_function_proto()) # type: ignore[attr-defined] - included_node_func.add(node_kind) - continue - raise errors.UnsupportedOperatorError( - node_kind, - specified_version, - onnx_function_group.get_min_supported() - if onnx_function_group - else None, - ) - return onnx_function_list, included_node_func - - @_beartype.beartype def _apply_friendly_debug_names(graph, params): for n in graph.nodes(): From fce6d6b3dcc879720bc45143426b86232106818a Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Wed, 16 Nov 2022 23:58:11 +0000 Subject: [PATCH 275/453] Redefine the simdlen semantic: (#88482) This PR is targeting to automatically enable vectorization optimization for TorchInductor. It refined the semantics of `config.cpp.simdlen`. Originally, `None` means to disable vectorization while a specific value means the number of elements to be vectorized once time. But it depends on the data. Regarding 256bit SVE/SIMD ISA for ARM and X86, the `simdlen` should be 16 for Float while 32 for BFloat. Hence, this PR defined the `simdlen` as the bit width. The detailed semantics are as follows. - **_simdlen = None_**: Automatically determine the SIMD bit width. Detect HW information and pick the proper vectorization ISA. Specific for X86, the priority of AVX512 is higher than AVX2. - **_simdlen <=1_**: Explicitly disable SIMD - **_simdlen > 1_**: Explicitly specify the SIMD bit width. It equals the disabled semantic if the bit width does not match the ISA width. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88482 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 94 +++++++++++- torch/_inductor/codecache.py | 215 +++++++++++++++++++++------- torch/_inductor/codegen/common.py | 6 + torch/_inductor/codegen/cpp.py | 92 +++++++++--- 4 files changed, 327 insertions(+), 80 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fb7ca1fc92b7..f9aa93f4a7e6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4529,7 +4529,11 @@ def fn(x): v = torch.randn(10) result = fn(v) - assert same(result, mod(v)) + # TODO: OMP parallel reduction order is not deterministic. + # Hence, the accurarcy might vary up and down. For short term, + # we increase the tolerance and will fix it later by using + # aten parallel. + assert same(result, mod(v), tol=5e-1) def test_inplace_add_alpha(self): def fn(x, y): @@ -4599,7 +4603,79 @@ def test_complex_memory_overlap(self): self.assertFalse(complex_memory_overlap(gathered.t())) @unittest.skipIf( - not codecache.get_cpu_proc_info(), "Does not support vectorization" + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch.object(config, "dynamic_shapes", True) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + @patch.object(functorch_config, "use_dynamic_shapes", True) + def test_vec_dynamic_shapes(self): + def fn(x): + return torch.softmax(x, -1) + + value = torch.randn((2, 10)) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value) + + real_out = fn(value) + compiled_out = opt_fn(value) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count < 1 + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_simd(self): + vec_avx512 = codecache.supported_vec_isa_list[0] + vec_avx2 = codecache.supported_vec_isa_list[1] + self.assertTrue(vec_avx512.bit_width() == 512) + self.assertTrue(vec_avx2.bit_width() == 256) + self.assertTrue(vec_avx512.nelements() == 16) + self.assertTrue(vec_avx2.nelements() == 8) + self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) + self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) + + with patch.object(config.cpp, "simdlen", None): + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with patch.object(config.cpp, "simdlen", 0): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 1): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 257): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 513): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 512): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx512) + + with patch.object(config.cpp, "simdlen", 256): + isa_list = codecache.valid_vec_isa_list() + if vec_avx2 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx2) + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): @@ -4610,7 +4686,7 @@ def fn(x): x[0, 0] = torch.nan x[1, -1] = torch.nan - with patch.object(config.cpp, "simdlen", 8): + with patch.object(config.cpp, "simdlen", None): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x) @@ -4623,7 +4699,7 @@ def fn(x): # other platforms support, we just need to add the ISA info to the supported_vector_isa # and include proper aten vectorization head file. @unittest.skipIf( - not codecache.get_cpu_proc_info(), "Does not support vectorization" + not codecache.valid_vec_isa_list(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_vec_kernel_cpu_only(self): @@ -4662,7 +4738,15 @@ def fn(x1, x2): x1 = torch.randn((10, 20)) x2 = torch.randn((10, 20)) - with patch.object(config.cpp, "simdlen", 8): + with patch.object(config.cpp, "simdlen", 1): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, [x1, x2]) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 0 + + with patch.object(config.cpp, "simdlen", None): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x1, x2) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2826f3599912..232a611b06c6 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,5 +1,5 @@ import base64 -import enum +import dataclasses import functools import getpass import hashlib @@ -18,7 +18,7 @@ from ctypes import cdll from threading import Thread from time import sleep, time -from typing import Any, Dict +from typing import Any, Callable, Dict, List import torch from torch.utils import cpp_extension @@ -147,79 +147,181 @@ def is_gcc(): return re.search(r"(gcc|g\+\+)", cpp_compiler()) -class _SupportedVecIsa(enum.Enum): - AVX512 = 1 - AVX2 = 2 - INVALID = -1 +class VecISA(object): + _bit_width: int + _macro: str + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +#include +#include +#endif + +__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self): + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float): + return self._dtype_nelements[dtype] + def build_macro(self): + return self._macro + + def build_arch_flags(self): + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + @functools.lru_cache(None) def __bool__(self): - return self != _SupportedVecIsa.INVALID + key, input_path = write(VecISA._avx_code, "cpp", extra="") + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[:-3] + "so" + build_cmd = cpp_compile_command( + input_path, output_path, warning_all=False, vec_isa=self + ).split(" ") + try: + # Check build result + subprocess.check_output(build_cmd, stderr=subprocess.STDOUT) + subprocess.check_call( + [ + "python", + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + stderr=subprocess.DEVNULL, + ) + except Exception as e: + return False - @staticmethod - def isa_str(supported_isa: enum.Enum): - if supported_isa == _SupportedVecIsa.AVX512: - return "avx512" - elif supported_isa == _SupportedVecIsa.AVX2: - return "avx2" - else: - return "" + return True + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = "CPU_CAPABILITY_AVX512" + _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ - @staticmethod - def vec_macro(supported_isa: enum.Enum): - if supported_isa == _SupportedVecIsa.AVX512: - return "CPU_CAPABILITY_AVX512" - elif supported_isa == _SupportedVecIsa.AVX2: - return "CPU_CAPABILITY_AVX2" - else: - return "" + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = "CPU_CAPABILITY_AVX2" + _arch_flags = "-mavx2 -mfma" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = "" + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self): + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAVX512(), VecAVX2()] # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. -@functools.lru_cache(1) -def get_cpu_proc_info(): +@functools.lru_cache(None) +def valid_vec_isa_list(): if sys.platform != "linux": return [] - isa_info = [] + isa_list = [] with open("/proc/cpuinfo") as _cpu_info: _cpu_info_content = _cpu_info.read() - if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content: - isa_info.append(_SupportedVecIsa.AVX512) - - if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content: - isa_info.append(_SupportedVecIsa.AVX2) + for isa in supported_vec_isa_list: + if str(isa) in _cpu_info_content and isa: + isa_list.append(isa) + return isa_list - return isa_info +def pick_vec_isa(): + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa -def supported_vector_isa(): - # TODO: Add ARM Vec here. - # Dict(k: isa, v: number of float element) - vec_isa_info = { - _SupportedVecIsa.AVX512: 16, - _SupportedVecIsa.AVX2: 8, - } + # If the simdlen is None, it indicates determin the vectroization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] - if config.cpp.simdlen is None or config.cpp.simdlen <= 1: - return _SupportedVecIsa.INVALID - - cpu_info_content = get_cpu_proc_info() - for isa in vec_isa_info.keys(): - if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]: + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): return isa - return _SupportedVecIsa.INVALID + return invalid_vec_isa -def cpp_compile_command(input, output, include_pytorch=False): - valid_isa = supported_vector_isa() - if include_pytorch or valid_isa: +def cpp_compile_command( + input, + output, + warning_all=True, + shared=True, + include_pytorch=False, + vec_isa: VecISA = invalid_vec_isa, +): + if include_pytorch or vec_isa != invalid_vec_isa: ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")] lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")] libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"] - macros = _SupportedVecIsa.vec_macro(valid_isa) + macros = vec_isa.build_macro() if macros: macros = f"-D{macros}" else: @@ -235,11 +337,13 @@ def cpp_compile_command(input, output, include_pytorch=False): lpaths = " ".join(["-L" + p for p in lpaths]) libs = " ".join(["-l" + p for p in libs]) + shared_lib = "-shared -fPIC" if shared else "" + warning_all_flag = "-Wall" if warning_all else "" return re.sub( r"[ \n]+", " ", f""" - {cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable + {cpp_compiler()} {input} {shared_lib} {warning_all_flag} -std=c++14 -Wno-unused-variable {ipaths} {lpaths} {libs} {macros} -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS @@ -266,7 +370,12 @@ def _load_library(path): @classmethod def load(cls, source_code): - key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o")) + picked_vec_isa = pick_vec_isa() + key, input_path = write( + source_code, + "cpp", + extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa), + ) if key not in cls.cache: from filelock import FileLock @@ -276,7 +385,7 @@ def load(cls, source_code): output_path = input_path[:-3] + "so" if not os.path.exists(output_path): cmd = cpp_compile_command( - input=input_path, output=output_path + input=input_path, output=output_path, vec_isa=picked_vec_isa ).split(" ") try: subprocess.check_output(cmd, stderr=subprocess.STDOUT) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 2803970295cc..cf98833964ca 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -417,6 +417,12 @@ def __init__(self, name): def __str__(self): return self.name + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + def update_on_args(self, args, kwargs): pass diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 65a9335d6cbf..38ef2179d5b7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -616,7 +616,7 @@ def codegen_loops(self, code, worksharing): ) reductions.mark_reduction(self.reduction_vars) - if config.cpp.simdlen: + if codecache.pick_vec_isa(): # TODO(jansel): detect stride-1 dimension and vectorize that if reductions: reductions.loops[-1].simd = True @@ -707,7 +707,8 @@ class CppVecKernel(CppKernel): def __init__(self, args, num_threads): super(CppVecKernel, self).__init__(args, num_threads) - self.simd_len = config.cpp.simdlen + assert codecache.pick_vec_isa() + self.simd_nelements = codecache.pick_vec_isa().nelements() self.reduction_omp_dec: Dict[str, str] = {} metrics.generated_cpp_vec_kernel_count += 1 @@ -723,10 +724,10 @@ def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr): def transform_index(self, index: sympy.Expr): expanded_index = sympy.expand(index) - assert self.simd_len - assert self.simd_len > 0 + assert self.simd_nelements + assert self.simd_nelements >= 1 most_inner_var = self.itervars[-1] - replacement = {most_inner_var: most_inner_var * self.simd_len} + replacement = {most_inner_var: most_inner_var * self.simd_nelements} new_index = sympy_subs(expanded_index, replacement) return new_index @@ -947,21 +948,24 @@ def __init__(self, args=None, num_threads=None): super(CppKernelProxy, self).__init__(args, num_threads) self.simd_vec_kernel: CppVecKernel = None self.simd_omp_kernel: CppKernel = None + self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() - def vectorize_most_inner_loop(self, loop_nest): - loop_nest.split_most_inner_loop(config.cpp.simdlen) + def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): + assert self.picked_vec_isa + nelements = self.picked_vec_isa.nelements(dtype) + loop_nest.split_most_inner_loop(nelements) loop_with_tail = loop_nest.loops[-1] assert isinstance(loop_with_tail, LoopLevelWithTail) loop_with_tail.main_loop.simd_vec = True loop_with_tail.tail_loop.simd_omp = True - # We chope the loop into two cubes by the config.cpp.simdlen - main loop and tail loop. + # We chope the loop into two cubes by the nelements - main loop and tail loop. # Regarding the main loop, it is straightforward that it could be vectorized with - # config.cpp.simdlen. But for the tail loop, it still could be vectorized. For example, - # if the config.cpp.simdlen is 8(256bits), then the tail loop still could be vectorized + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized # as 4(128bits). - loop_with_tail.tail_loop.simd_len = int(config.cpp.simdlen / 2) + loop_with_tail.tail_loop.simd_nelements = int(nelements / 2) loop_with_tail.tail_loop.simd_vec = False loop_with_tail.main_loop_body = self.simd_vec_kernel @@ -971,7 +975,7 @@ def vectorize_most_inner_loop(self, loop_nest): def codegen_loops(self, code, worksharing): threads = parallel_num_threads() - if self.simd_vec_kernel is None: + if self.simd_vec_kernel is None or not self.picked_vec_isa: assert self.simd_omp_kernel return self.simd_omp_kernel.codegen_loops(code, worksharing) @@ -993,12 +997,52 @@ def codegen_loops(self, code, worksharing): ), LoopNest(loops[reduction_depth:]) loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars) - if config.cpp.simdlen: - # TODO(jansel): detect stride-1 dimension and vectorize that - if loops_nest_reduce: - loops_nest_reduce.loops[-1].simd = True - elif loops_nest_non_reduce: - loops_nest_non_reduce.loops[-1].simd = True + assert self.picked_vec_isa + # Do not apply vectorization since the range of most inner is too small. Meanwhile, + # If the range of the most inner is less then the codecache.pick_vec_isa().nelements(), + # the generated code for some reduction will be as follows that leads to incrrect result. + # + # LINE01: float tmp1 = 0; + # LINE02: auto tmp1_vec = at::vec::Vectorized(tmp1); + # LINE03: for(long i1=0; i1<2; i1+=1) + # LINE04: { + # LINE05: for(long i2=0; i2<0; i2+=1) + # LINE06: { + # LINE07: auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1)); + # LINE08: tmp1_vec += tmp0; + # LINE09: } + # LINE10: tmp1 = vec_reduce_all([](Vectorized& x, Vectorized&y) {return x + y;}, tmp1_vec); + # LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1) + # LINE12: for(long i2=0; i2<8; i2+=1) + # LINE13: { + # LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)]; + # LINE15: tmp1 += tmp0; + # LINE16: } + # LINE17: } + # LINE18: out_ptr3[i0] = tmp1; + # + # tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05) + # is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result + # will be incorrect. We skip thie case. + most_inner_loop = ( + loops_nest_reduce.loops[-1] + if loops_nest_reduce + else loops_nest_non_reduce.loops[-1] + ) + main_loop_range = ir.IndexingDiv( + most_inner_loop.size, self.picked_vec_isa.nelements() + ) + loop_interval = sympy.simplify(main_loop_range) + # TODO(Eikan): To support dynamic shape. + if not loop_interval.is_integer or loop_interval <= 0: + metrics.generated_cpp_vec_kernel_count -= 1 + return self.simd_omp_kernel.codegen_loops(code, worksharing) + + # TODO(jansel): detect stride-1 dimension and vectorize that + if loops_nest_reduce: + loops_nest_reduce.loops[-1].simd = True + elif loops_nest_non_reduce: + loops_nest_non_reduce.loops[-1].simd = True par_depth = 0 reduction_par_depth = 0 @@ -1138,8 +1182,7 @@ def can_fuse_vertical(cls, node1, node2): return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_vec(self, nodes): - # TODO: Query cpu arch and vec length from aten - if not codecache.supported_vector_isa(): + if not codecache.pick_vec_isa(): return False _, (group, reduction_group) = max( @@ -1349,7 +1392,8 @@ class LoopLevel: steps: sympy.Expr = sympy.Integer(1) parallel: int = 0 simd_omp: bool = False - simd_len: int = config.cpp.simdlen + picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 simd_vec: bool = False collapsed: bool = False reduction_vars: Dict[str, str] = None @@ -1363,7 +1407,11 @@ def lines(self): ) else: reduction = "" - simd = f"simd simdlen({self.simd_len}) " if self.simd_omp else "" + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) if self.parallel: # TODO(jansel): look into chunk size and other schedules line1 = f"#pragma omp for{reduction} " From 573eaf12258df8e87434ffa19a42b04fb873c6dc Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 17 Nov 2022 03:36:56 +0000 Subject: [PATCH 276/453] Analyze and upload disabled tests rerun to S3 (#89083) Analyze and upload disabled tests rerun to S3. Note that this only picks up `test-reports` from `rerun_disable_tests` workflows. ### Testing Running the script manually `python -m tools.stats.check_disabled_tests --workflow-run-id 3473068035 --workflow-run-attempt 1 --repo pytorch/pytorch` and see the files successfully uploaded to s3://ossci-raw-job-status/rerun_disabled_tests/3473068035/1 Rockset collection created https://console.rockset.com/collections/details/commons.rerun_disabled_tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/89083 Approved by: https://github.com/clee2000 --- .github/workflows/upload-test-stats.yml | 11 + tools/stats/check_disabled_tests.py | 290 ++++++++++++++++++++++++ 2 files changed, 301 insertions(+) create mode 100644 tools/stats/check_disabled_tests.py diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 27289983e270..3f3db80670d8 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -72,6 +72,17 @@ jobs: # anything on GitHub to upload. The command should return right away python3 -m tools.stats.upload_artifacts --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" + - name: Analyze disabled tests rerun + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + WORKFLOW_ARTIFACTS_URL: ${{ github.event.workflow_run.artifacts_url }} + WORKFLOW_RUN_ID: ${{ github.event.workflow_run.id }} + WORKFLOW_RUN_ATTEMPT: ${{ github.event.workflow_run.run_attempt }} + REPO_FULLNAME: ${{ github.event.workflow_run.repository.full_name }} + run: | + # Analyze the results from disable tests rerun and upload them to S3 + python3 -m tools.stats.check_disabled_tests --workflow-run-id "${WORKFLOW_RUN_ID}" --workflow-run-attempt "${WORKFLOW_RUN_ATTEMPT}" --repo "${REPO_FULLNAME}" + check-api-rate: if: ${{ always() }} runs-on: [self-hosted, linux.2xlarge] diff --git a/tools/stats/check_disabled_tests.py b/tools/stats/check_disabled_tests.py new file mode 100644 index 000000000000..75c4f236ef21 --- /dev/null +++ b/tools/stats/check_disabled_tests.py @@ -0,0 +1,290 @@ +import argparse +import json +import os +import xml.etree.ElementTree as ET +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Any, Dict, Generator, Tuple + +from tools.stats.upload_stats_lib import ( + download_gha_artifacts, + download_s3_artifacts, + unzip, + upload_to_s3, +) +from tools.stats.upload_test_stats import process_xml_element + +TESTCASE_TAG = "testcase" +TARGET_WORKFLOW = "--rerun-disabled-tests" +SEPARATOR = ";" + + +def is_rerun_disabled_tests(root: ET.ElementTree) -> bool: + """ + Check if the test report is coming from rerun_disabled_tests workflow + """ + skipped = root.find(".//*skipped") + # Need to check against None here, if not skipped doesn't work as expected + if skipped is None: + return False + + message = skipped.attrib.get("message", "") + return TARGET_WORKFLOW in message or "num_red" in message + + +def process_report( + report: Path, +) -> Dict[str, Dict[str, int]]: + """ + Return a list of disabled tests that should be re-enabled and those that are still + flaky (failed or skipped) + """ + root = ET.parse(report) + + # All rerun tests from a report are grouped here: + # + # * Success test should be re-enable if it's green after rerunning in all platforms + # where it is currently disabled + # * Failures from pytest because pytest-flakefinder is used to run the same test + # multiple times, some could fails + # * Skipped tests from unittest + # + # We want to keep track of how many times the test fails (num_red) or passes (num_green) + all_tests: Dict[str, Dict[str, int]] = {} + + if not is_rerun_disabled_tests(root): + return all_tests + + for test_case in root.iter(TESTCASE_TAG): + parsed_test_case = process_xml_element(test_case) + + # Under --rerun-disabled-tests mode, a test is skipped when: + # * it's skipped explicitly inside PyToch code + # * it's skipped because it's a normal enabled test + # * or it's falky (num_red > 0 and num_green > 0) + # * or it's failing (num_red > 0 and num_green == 0) + # + # We care only about the latter two here + skipped = parsed_test_case.get("skipped", None) + if skipped and "num_red" not in skipped.get("message", ""): + continue + + name = parsed_test_case.get("name", "") + classname = parsed_test_case.get("classname", "") + filename = parsed_test_case.get("file", "") + + if not name or not classname or not filename: + continue + + # Check if the test is a failure + failure = parsed_test_case.get("failure", None) + + disabled_test_id = SEPARATOR.join([name, classname, filename]) + if disabled_test_id not in all_tests: + all_tests[disabled_test_id] = { + "num_green": 0, + "num_red": 0, + } + + # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's + # counted as a success. Otherwise, it's still flaky or failing + if skipped: + try: + stats = json.loads(skipped.get("message", "")) + except json.JSONDecodeError: + stats = {} + + all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0) + all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0) + elif failure: + # As a failure, increase the failure count + all_tests[disabled_test_id]["num_red"] += 1 + else: + all_tests[disabled_test_id]["num_green"] += 1 + + return all_tests + + +def get_test_reports( + repo: str, workflow_run_id: int, workflow_run_attempt: int +) -> Generator[Path, None, None]: + """ + Gather all the test reports from S3 and GHA. It is currently not possible to guess which + test reports are from rerun_disabled_tests workflow because the name doesn't include the + test config. So, all reports will need to be downloaded and examined + """ + with TemporaryDirectory() as temp_dir: + print("Using temporary directory:", temp_dir) + os.chdir(temp_dir) + + artifact_paths = download_s3_artifacts( + "test-reports", workflow_run_id, workflow_run_attempt + ) + for path in artifact_paths: + unzip(path) + + artifact_paths = download_gha_artifacts( + "test-report", workflow_run_id, workflow_run_attempt + ) + for path in artifact_paths: + unzip(path) + + for report in Path(".").glob("**/*.xml"): + yield report + + +def get_disabled_test_name(test_id: str) -> Tuple[str, str, str, str]: + """ + Follow flaky bot convention here, if that changes, this will also need to be updated + """ + name, classname, filename = test_id.split(SEPARATOR) + return f"{name} (__main__.{classname})", name, classname, filename + + +def prepare_record( + workflow_id: int, + workflow_run_attempt: int, + name: str, + classname: str, + filename: str, + flaky: bool, + num_red: int = 0, + num_green: int = 0, +) -> Tuple[Any, Dict[str, Any]]: + """ + Prepare the record to save onto S3 + """ + key = ( + workflow_id, + workflow_run_attempt, + name, + classname, + filename, + ) + + record = { + "workflow_id": workflow_id, + "workflow_run_attempt": workflow_run_attempt, + "name": name, + "classname": classname, + "filename": filename, + "flaky": flaky, + "num_green": num_green, + "num_red": num_red, + } + + return key, record + + +def save_results( + workflow_id: int, + workflow_run_attempt: int, + all_tests: Dict[str, Dict[str, int]], +) -> None: + """ + Save the result to S3, so it can go to Rockset + """ + should_be_enabled_tests = { + name: stats + for name, stats in all_tests.items() + if "num_green" in stats + and stats["num_green"] + and "num_red" in stats + and stats["num_red"] == 0 + } + still_flaky_tests = { + name: stats + for name, stats in all_tests.items() + if name not in should_be_enabled_tests + } + + records = {} + for test_id, stats in all_tests.items(): + num_green = stats.get("num_green", 0) + num_red = stats.get("num_red", 0) + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + + key, record = prepare_record( + workflow_id=workflow_id, + workflow_run_attempt=workflow_run_attempt, + name=name, + classname=classname, + filename=filename, + flaky=test_id in still_flaky_tests, + num_green=num_green, + num_red=num_red, + ) + records[key] = record + + # Log the results + print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:") + for test_id, stats in should_be_enabled_tests.items(): + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + print(f" {disabled_test_name} from {filename}") + + print(f"The following {len(still_flaky_tests)} are still flaky:") + for test_id, stats in still_flaky_tests.items(): + num_green = stats.get("num_green", 0) + num_red = stats.get("num_red", 0) + + disabled_test_name, name, classname, filename = get_disabled_test_name(test_id) + print( + f" {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}" + ) + + upload_to_s3( + workflow_id, + workflow_run_attempt, + "rerun_disabled_tests", + list(records.values()), + ) + + +def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None: + """ + Find the list of all disabled tests that should be re-enabled + """ + # Aggregated across all jobs + all_tests: Dict[str, Dict[str, int]] = {} + + for report in get_test_reports( + args.repo, args.workflow_run_id, args.workflow_run_attempt + ): + tests = process_report(report) + for name, stats in tests.items(): + if name not in all_tests: + all_tests[name] = stats.copy() + else: + all_tests[name]["num_green"] += stats.get("num_green", 0) + all_tests[name]["num_red"] += stats.get("num_red", 0) + + save_results( + workflow_run_id, + workflow_run_attempt, + all_tests, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3") + parser.add_argument( + "--workflow-run-id", + type=int, + required=True, + help="id of the workflow to get artifacts from", + ) + parser.add_argument( + "--workflow-run-attempt", + type=int, + required=True, + help="which retry of the workflow this is", + ) + parser.add_argument( + "--repo", + type=str, + required=True, + help="which GitHub repo this workflow run belongs to", + ) + + args = parser.parse_args() + main(args.repo, args.workflow_run_id, args.workflow_run_attempt) From a5f04e9a915104692ae67ccd79768e8147cc0d2d Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 Nov 2022 03:36:59 +0000 Subject: [PATCH 277/453] Fix typos in .md and .rst files (#88962) This PR fixes typos `Github` in `.md` and `.rst` files. `Github` -> `GitHub` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88962 Approved by: https://github.com/kit1980 --- .github/scripts/README.md | 2 +- RELEASE.md | 2 +- caffe2/contrib/tensorrt/README.md | 2 +- docs/source/community/contribution_guide.rst | 2 +- docs/source/masked.rst | 2 +- docs/source/onnx.rst | 2 +- docs/source/sparse.rst | 4 ++-- torch/csrc/jit/operator_upgraders/README.md | 2 +- torch/csrc/lazy/tutorial.md | 2 +- 9 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/scripts/README.md b/.github/scripts/README.md index 73bec509c2c4..cc9e1617b11a 100644 --- a/.github/scripts/README.md +++ b/.github/scripts/README.md @@ -3,7 +3,7 @@ > NOTE: This README contains information for the `.github` directory but cannot be located there because it will overwrite the repo README. -This directory contains workflows and scripts to support our CI infrastructure that runs on Github Actions. +This directory contains workflows and scripts to support our CI infrastructure that runs on GitHub Actions. ## Workflows diff --git a/RELEASE.md b/RELEASE.md index e2b69b5bf82e..d13ca5d11e10 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -281,7 +281,7 @@ need to support these particular versions of software. In the event a submodule cannot be fast forwarded and a patch must be applied we can take two different approaches: -* (preferred) Fork the said repository under the pytorch Github organization, apply the patches we need there, and then switch our submodule to accept our fork. +* (preferred) Fork the said repository under the pytorch GitHub organization, apply the patches we need there, and then switch our submodule to accept our fork. * Get the dependencies maintainers to support a release branch for us Editing submodule remotes can be easily done with: (running from the root of the git repository) diff --git a/caffe2/contrib/tensorrt/README.md b/caffe2/contrib/tensorrt/README.md index f1e449e727e9..6ffe1dfb53bc 100644 --- a/caffe2/contrib/tensorrt/README.md +++ b/caffe2/contrib/tensorrt/README.md @@ -15,4 +15,4 @@ For further information please explore `caffe2/python/trt/test_trt.py` test show ## Questions and Feedback -Please use Github issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. +Please use GitHub issues (https://github.com/pytorch/pytorch/issues) to ask questions, report bugs, and request new features. diff --git a/docs/source/community/contribution_guide.rst b/docs/source/community/contribution_guide.rst index a2a89721b64e..30bd9c6cf975 100644 --- a/docs/source/community/contribution_guide.rst +++ b/docs/source/community/contribution_guide.rst @@ -138,7 +138,7 @@ A great deal of the tutorials on `pytorch.org `__ come from the community itself and we welcome additional contributions. To learn more about how to contribute a new tutorial you can learn more here: `PyTorch.org Tutorial Contribution Guide on -Github `__ +GitHub `__ Improving Documentation & Tutorials ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/masked.rst b/docs/source/masked.rst index 10ae8420425f..3655a6d79fd9 100644 --- a/docs/source/masked.rst +++ b/docs/source/masked.rst @@ -157,7 +157,7 @@ Binary Operators As you may have seen in the tutorial, :class:`MaskedTensor` also has binary operations implemented with the caveat that the masks in the two MaskedTensors must match or else an error will be raised. As noted in the error, if you need support for a particular operator or have proposed semantics for how they should behave instead, please open -an issue on Github. For now, we have decided to go with the most conservative implementation to ensure that users +an issue on GitHub. For now, we have decided to go with the most conservative implementation to ensure that users know exactly what is going on and are being intentional about their decisions with masked semantics. The available binary operators are: diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index 78ef3cd93663..fea0b3bc94d2 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -594,7 +594,7 @@ all of the unconvertible ops in one go you can:: The set is approximated because some ops may be removed during the conversion process and don't need to be converted. Some other ops may have partial support that will fail conversion with particular inputs, but this should give you a -general idea of what ops are not supported. Please feel free to open Github Issues +general idea of what ops are not supported. Please feel free to open GitHub Issues for op support requests. Frequently Asked Questions diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst index 29790312cb3b..77e8dabec274 100644 --- a/docs/source/sparse.rst +++ b/docs/source/sparse.rst @@ -10,7 +10,7 @@ torch.sparse .. warning:: The PyTorch API of sparse tensors is in beta and may change in the near future. - We highly welcome feature requests, bug reports and general suggestions as Github issues. + We highly welcome feature requests, bug reports and general suggestions as GitHub issues. Why and when to use sparsity ++++++++++++++++++++++++++++ @@ -40,7 +40,7 @@ Like many other performance optimization sparse storage formats are not always advantageous. When trying sparse formats for your use case you might find your execution time to decrease rather than increase. -Please feel encouraged to open a Github issue if you analytically +Please feel encouraged to open a GitHub issue if you analytically expected to see a stark increase in performance but measured a degradation instead. This helps us prioritize the implementation of efficient kernels and wider performance optimizations. diff --git a/torch/csrc/jit/operator_upgraders/README.md b/torch/csrc/jit/operator_upgraders/README.md index bf1350aa21f3..75639006e503 100644 --- a/torch/csrc/jit/operator_upgraders/README.md +++ b/torch/csrc/jit/operator_upgraders/README.md @@ -226,7 +226,7 @@ def foo(x, y, z=100): return x, y, z ``` -2. To help understanding the BC/FC breakage changes, here are some FC breaking changes examples. The solution to resolve it is not there yet. If it's desired, please report it in either [PyTorch Forum](https://discuss.pytorch.org/) or [PyTorch Github](https://github.com/pytorch/pytorch). We will prioritize it accordingly. +2. To help understanding the BC/FC breakage changes, here are some FC breaking changes examples. The solution to resolve it is not there yet. If it's desired, please report it in either [PyTorch Forum](https://discuss.pytorch.org/) or [PyTorch GitHub](https://github.com/pytorch/pytorch). We will prioritize it accordingly. - Adding new default argument: - Adding a new default argument not RIGHT BEFORE the out arguments which can be 0 or more. diff --git a/torch/csrc/lazy/tutorial.md b/torch/csrc/lazy/tutorial.md index 6d4e75affc38..e26c55d2c520 100644 --- a/torch/csrc/lazy/tutorial.md +++ b/torch/csrc/lazy/tutorial.md @@ -283,4 +283,4 @@ This concludes our brief introduction to LT. Hopefully, you'll remember the main * It's really tricky to produce such graphs without overburdening a user too much. Think, torch.jit.script, torch.jit.trace! Also, think ifs, fors, "Lions, and Tigers, and Bears, Oh My" We digressed. -Please give LT a try and tell us what you think on Github! We are **eager, not lazy** (haha!) to hear from you! +Please give LT a try and tell us what you think on GitHub! We are **eager, not lazy** (haha!) to hear from you! From 1adb7b9b845603a834f452da0e99790779740d83 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 15 Nov 2022 16:01:29 -0800 Subject: [PATCH 278/453] [nn][utils] Preserve requires_grad from original weight and bias in fuse conv/linear bn weights (#89100) Summary: att, previously we just call nn.Parameter which will have requires_grad=True by default, after this PR we will preserve the requires_grad Test Plan: python test/test_nn.py TestFusionUtils Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D41343694](https://our.internmc.facebook.com/intern/diff/D41343694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89100 Approved by: https://github.com/ngimel --- test/test_nn.py | 28 ++++++++++++++++++++++++++++ torch/nn/utils/fusion.py | 8 ++++---- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 2b96838e3601..4231c19ed0da 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -36,6 +36,8 @@ import torch.nn.utils.parametrize as parametrize import torch.nn.utils.prune as prune from torch.nn.utils import parameters_to_vector, vector_to_parameters +from torch.nn.utils.fusion import fuse_conv_bn_weights +from torch.nn.utils.fusion import fuse_linear_bn_weights from torch.nn import Parameter from torch.nn.parallel._functions import Broadcast from torch.testing._internal.common_dtype import integral_types, get_all_math_dtypes @@ -16267,6 +16269,32 @@ def my_post_load_hook(mod, _): m.load_state_dict(sd) self.assertTrue(called) +class TestFusionUtils(TestCase): + def test_fuse_conv_bn_requires_grad(self): + conv = torch.nn.Conv2d(3, 3, 3) + bn = torch.nn.BatchNorm2d(3) + cases = itertools.product([True, False], [True, False]) + for w_rg, b_rg in cases: + conv.weight.requires_grad = w_rg + conv.bias.requires_grad = b_rg + weight, bias = \ + fuse_conv_bn_weights(conv.weight, conv.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + self.assertEqual(weight.requires_grad, w_rg) + self.assertEqual(bias.requires_grad, b_rg) + + def test_fuse_linear_bn_requires_grad(self): + linear = torch.nn.Linear(3, 3) + bn = torch.nn.BatchNorm1d(3) + cases = itertools.product([True, False], [True, False]) + for w_rg, b_rg in cases: + linear.weight.requires_grad = w_rg + linear.bias.requires_grad = b_rg + weight, bias = \ + fuse_linear_bn_weights(linear.weight, linear.bias, + bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) + self.assertEqual(weight.requires_grad, w_rg) + self.assertEqual(bias.requires_grad, b_rg) instantiate_device_type_tests(TestNNDeviceType, globals()) instantiate_parametrized_tests(TestNN) diff --git a/torch/nn/utils/fusion.py b/torch/nn/utils/fusion.py index e96c4f7d4426..81b1431c53c9 100644 --- a/torch/nn/utils/fusion.py +++ b/torch/nn/utils/fusion.py @@ -27,10 +27,10 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, trans else: shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) - conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) - conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b + fused_conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape) + fused_conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b - return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) + return torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad) def fuse_linear_bn_eval(linear, bn): assert(not (linear.training or bn.training)), "Fusion only for eval!" @@ -50,4 +50,4 @@ def fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b) fused_w = linear_w * bn_scale.unsqueeze(-1) fused_b = (linear_b - bn_rm) * bn_scale + bn_b - return torch.nn.Parameter(fused_w), torch.nn.Parameter(fused_b) + return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(fused_b, linear_b.requires_grad) From 366f1b2c2f6b273fcba5f071bf2297a963051894 Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 16 Nov 2022 10:46:27 -0800 Subject: [PATCH 279/453] [xnnpack][lite-int] Freeze/Inline module to remove reference to self (#88863) We need to inline graph before converting from torchscript to xnnpack flatubuffer. Remove graph dependence on self. This will later help us work with constant data. Differential Revision: [D41049858](https://our.internmc.facebook.com/intern/diff/D41049858/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88863 Approved by: https://github.com/digantdesai --- .../jit/backends/xnnpack/xnnpack_backend_preprocess.cpp | 9 ++++----- .../csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp | 1 - 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp index f2734a5e529a..b4b7c912554a 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp @@ -32,8 +32,9 @@ c10::IValue preprocess( const Module& mod, const c10::Dict& method_compile_spec, const BackendDebugHandleGenerator& generate_debug_handles) { - auto output_min = -std::numeric_limits::infinity(); - auto output_max = std::numeric_limits::infinity(); + auto eval_mod = mod.clone(); + eval_mod.eval(); + eval_mod = torch::jit::freeze(eval_mod); c10::Dict compiled(StringType::get(), TensorType::get()); @@ -62,7 +63,7 @@ c10::IValue preprocess( "method_compile_spec does not contain either a Tensor or TensorList, under it's \"outputs\" key."); // Graph preprocessing - const auto& forward_method = mod.get_method("forward"); + const auto& forward_method = eval_mod.get_method("forward"); auto graph = toGraphFunction(forward_method.function()).graph()->copy(); graph = tensorexpr::removeUnusedSelfArgument(graph); @@ -75,7 +76,6 @@ c10::IValue preprocess( example_inputs.reserve(inp_list.size()); for (const auto i : c10::irange(inp_list.size())) { - graph->inputs()[i]->setType(TensorType::create(inp_list[i])); example_inputs.emplace_back(inp_list[i]); } } else { @@ -83,7 +83,6 @@ c10::IValue preprocess( graph->inputs().size() == 1, "method_compile_spec inputs do not match expected number of forward inputs"); - graph->inputs()[0]->setType(TensorType::create(inp.toTensor())); example_inputs.emplace_back(inp.toTensor()); } diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp index ec740bd66c50..4eaefea56960 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp @@ -21,7 +21,6 @@ namespace delegate { std::shared_ptr XNNGraph::optimizeAndTraceGraph( std::shared_ptr graph, std::vector& example_inputs) { - graph = tensorexpr::removeUnusedSelfArgument(graph); OptimizeFrozenGraph(graph, true); RemoveListMutation(graph); RemoveTensorMutation(graph); From d1f48f05cef9e2b3b01c64a21a6e2abc3ddab323 Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 16 Nov 2022 10:46:28 -0800 Subject: [PATCH 280/453] [xnnpack][Bug Fix] Pass serialized model by reference (#89089) Two changes - Remove XNNCompiler Dependence on std::string by passing void* - Grab ser_model by reference: This bug was causing data pointers given to xnn_runtime to be freed because ser_model was on the stack. Differential Revision: [D41208380](https://our.internmc.facebook.com/intern/diff/D41208380/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89089 Approved by: https://github.com/digantdesai --- torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp | 6 +++--- torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h | 2 +- torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index 4147edf90e85..af9c68df31e8 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -11,9 +11,9 @@ namespace jit { namespace xnnpack { namespace delegate { -XNNExecutor XNNCompiler::compileModel(std::string ser_model) { - const char* buffer_pointer = ser_model.data(); - +XNNExecutor XNNCompiler::compileModel( + const void* buffer_pointer, + size_t num_bytes) { auto output_min = -std::numeric_limits::infinity(); auto output_max = std::numeric_limits::infinity(); diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h index 99eecfdcaa45..625b41e43c14 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -16,7 +16,7 @@ class XNNCompiler { // Takes Flatbuffer Serialized XNNPack Model and rebuilds the xnn-subgraph // returns an executor object that holds the xnn runtime object which we // can then use to set inputs and run inference using the xnn graph. - static XNNExecutor compileModel(std::string ser_model); + static XNNExecutor compileModel(const void* buffer_pointer, size_t num_bytes); }; } // namespace delegate diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index a5718820fc19..553e8350ddbd 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -40,8 +40,9 @@ class XNNPackBackend : public PyTorchBackendInterface { auto dict = processed.toGenericDict(); // Compiling and wrapping exeuction object - std::string ser_model = dict.at("ser_model").toStringRef(); - XNNExecutor executor = XNNCompiler::compileModel(ser_model); + const std::string& ser_model = dict.at("ser_model").toStringRef(); + XNNExecutor executor = + XNNCompiler::compileModel(ser_model.data(), ser_model.length()); auto model_ptr = c10::make_intrusive(std::move(executor)); auto runtime_handle = IValue::make_capsule(model_ptr); From 1cd6ebe0958ab8eff2b7ba715d9544f067dfe59e Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Thu, 17 Nov 2022 04:18:10 +0000 Subject: [PATCH 281/453] Fix typos in messages under torch (#89049) This PR fixes typos of messages in `.py` files under torch directory. Only in `torch/onnx/symbolic_opset16.py`, fix a typo in comment to make the operator name correct. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89049 Approved by: https://github.com/lezcano --- torch/_refs/nn/functional/__init__.py | 2 +- torch/ao/nn/intrinsic/qat/modules/linear_fused.py | 2 +- torch/ao/quantization/fx/_model_report/model_report.py | 2 +- torch/ao/quantization/observer.py | 2 +- torch/backends/xeon/run_cpu.py | 2 +- torch/cuda/memory.py | 2 +- torch/distributed/benchmarks/benchmark_ddp_rpc.py | 2 +- torch/distributed/elastic/multiprocessing/api.py | 4 ++-- torch/distributed/elastic/rendezvous/etcd_rendezvous.py | 2 +- torch/distributions/mixture_same_family.py | 2 +- torch/fx/experimental/accelerator_partitioner.py | 2 +- torch/fx/experimental/graph_gradual_typechecker.py | 4 ++-- torch/fx/passes/split_module.py | 4 ++-- torch/jit/frontend.py | 2 +- torch/nn/utils/parametrize.py | 2 +- torch/onnx/symbolic_helper.py | 2 +- torch/onnx/symbolic_opset16.py | 2 +- torch/profiler/_pattern_matcher.py | 2 +- torch/serialization.py | 2 +- torch/testing/_internal/common_distributed.py | 2 +- torch/testing/_internal/composite_compliance.py | 2 +- torch/utils/benchmark/examples/fuzzer.py | 2 +- torch/utils/benchmark/examples/sparse/fuzzer.py | 2 +- torch/utils/data/datapipes/dataframe/dataframes.py | 2 +- torch/utils/data/datapipes/iter/callable.py | 2 +- 25 files changed, 28 insertions(+), 28 deletions(-) diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 3848a738d534..12f44c4092a4 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -595,7 +595,7 @@ def _nll_loss_nd( ) -> TensorLikeType: utils.check( input.ndim > 0 and input.ndim <= 3, - lambda: f"Expected input dimension to be either [1, 2, 3] but recieved {input.ndim}.", + lambda: f"Expected input dimension to be either [1, 2, 3] but received {input.ndim}.", ) utils.check( diff --git a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py index f19dbd9a0f58..7c92c470ba5b 100644 --- a/torch/ao/nn/intrinsic/qat/modules/linear_fused.py +++ b/torch/ao/nn/intrinsic/qat/modules/linear_fused.py @@ -35,7 +35,7 @@ def __init__(self, freeze_bn=False, qconfig=None): nn.modules.linear.Linear.__init__(self, in_features, out_features, bias) - assert qconfig, 'qconfig must be provded for QAT module' + assert qconfig, 'qconfig must be provided for QAT module' self.qconfig = qconfig self.freeze_bn = freeze_bn if self.training else True self.bn = nn.BatchNorm1d(out_features, eps, momentum, True, True) diff --git a/torch/ao/quantization/fx/_model_report/model_report.py b/torch/ao/quantization/fx/_model_report/model_report.py index dfe777a54058..ee96dd4bf5a9 100644 --- a/torch/ao/quantization/fx/_model_report/model_report.py +++ b/torch/ao/quantization/fx/_model_report/model_report.py @@ -385,7 +385,7 @@ def _reformat_reports_for_visualizer(self) -> OrderedDict: module_fqns_to_features[module_fqn] = {**new_info, **present_info} else: error_str = "You have the same key with different values across detectors. " - error_str += "Someone incorrectly implemented a detector with conflicting keys to exisiting detectors." + error_str += "Someone incorrectly implemented a detector with conflicting keys to existing detectors." raise ValueError(error_str) else: # we just set it diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 26a39c8c2e02..3156b4245a12 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1019,7 +1019,7 @@ def _non_linear_param_search(self) -> Tuple[torch.Tensor, torch.Tensor]: This follows the implementation of NormMinimization::NonlinearQuantizationParamsSearch in caffe2/quantization/server/norm_minimization.cc """ - assert self.histogram.size()[0] == self.bins, "bins mistmatch" + assert self.histogram.size()[0] == self.bins, "bins mismatch" bin_width = (self.max_val - self.min_val) / self.bins # cumulative sum diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 69632cb20862..da55a9e605e1 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -598,7 +598,7 @@ def create_args(parser=None): _add_multi_instance_params(parser) # positional parser.add_argument("program", type=str, - help="The full path to the proram/script to be launched. " + help="The full path to the program/script to be launched. " "followed by all the arguments for the script") # rest from the training program diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 46bdda80bf87..9f9ae724a15d 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -61,7 +61,7 @@ def caching_allocator_alloc(size, device: Union[Device, int] = None, stream=None if not isinstance(stream, int): raise TypeError('Invalid type for stream argument, must be ' '`torch.cuda.Stream` or `int` representing a pointer ' - 'to a exisiting stream') + 'to a existing stream') with torch.cuda.device(device): return torch._C._cuda_cudaCachingAllocator_raw_alloc(size, stream) diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index e12556f396fb..6614d3969bfc 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -335,7 +335,7 @@ def run_worker(rank, world_size): "--embedding-dim", type=int, default=EMBEDDING_DIM, - help="Number of embedding dimentions.", + help="Number of embedding dimensions.", ) parser.add_argument( "--warmup-cycles", diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 208c9a070e9d..727566fc6039 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -537,7 +537,7 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: for proc in self._pc.processes: if proc.is_alive(): log.warning( - f"Unable to shutdown process {proc.pid} via {death_sig}, forcefully exitting via {_get_kill_signal()}" + f"Unable to shutdown process {proc.pid} via {death_sig}, forcefully exiting via {_get_kill_signal()}" ) try: os.kill(proc.pid, _get_kill_signal()) @@ -714,7 +714,7 @@ def _close(self, death_sig: signal.Signals, timeout: int = 30) -> None: for handler in self.subprocess_handlers.values(): if handler.proc.poll() is None: log.warning( - f"Unable to shutdown process {handler.proc.pid} via {death_sig}, forcefully exitting via {_get_kill_signal()}" + f"Unable to shutdown process {handler.proc.pid} via {death_sig}, forcefully exiting via {_get_kill_signal()}" ) handler.close(death_sig=_get_kill_signal()) handler.proc.wait() diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 5e11ad1e6d33..a7b682ccc89f 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -293,7 +293,7 @@ def rendezvous_barrier(self): time.sleep(1) except RendezvousTimeoutError: - log.info("Rendezvous timeout occured in EtcdRendezvousHandler") + log.info("Rendezvous timeout occurred in EtcdRendezvousHandler") raise except RendezvousClosedError: diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 8e0fdce3ada2..dd0beace1917 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -60,7 +60,7 @@ def __init__(self, if not isinstance(self._mixture_distribution, Categorical): raise ValueError(" The Mixture distribution needs to be an " - " instance of torch.distribtutions.Categorical") + " instance of torch.distributions.Categorical") if not isinstance(self._component_distribution, Distribution): raise ValueError("The Component distribution need to be an " diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 2b17ef2f86c3..5a007314d628 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -696,7 +696,7 @@ def find_partition_to_combine_based_on_size( return find_combination, partitions def reset_partition_in_sparse_nn(partition, new_partition=True): - """If crossing the boudary between non-embedding nodes and + """If crossing the boundary between non-embedding nodes and embedding nodes, create a new partition """ if in_embedding_region: diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index 6094952f1695..7ffabc9c6996 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -184,7 +184,7 @@ def get_attr_inference_rule(n: Node, traced): if attr_name == "shape": n.type = Dyn else: - raise TypeError("Not yet implelemted") + raise TypeError("Not yet implemented") # TODO. We leave it like this till we add a type to represent tensor sizes return n.type @@ -507,7 +507,7 @@ def flatten_check(tensor_type, start_dim, end_dim): new_type_list = lhs + mid + rhs return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Incompatable dimentions {start_dim}, {end_dim - 1} in type {tensor_type}') + raise TypeError(f'Incompatable dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') @register_inference_rule(torch.flatten) def flatten_inference_rule(n: Node): diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 251fdadea7e2..c6954c2cc717 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -28,8 +28,8 @@ def __repr__(self) -> str: f" nodes: {self.node_names},\n" f" inputs: {self.inputs},\n" f" outputs: {self.outputs},\n" - f" partitions depenent on: {self.partitions_dependent_on},\n" - f" parition dependents: {self.partition_dependents}" + f" partitions dependent on: {self.partitions_dependent_on},\n" + f" partition dependents: {self.partition_dependents}" ) diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 4b5e3d68f75c..44a8628f77d5 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -614,7 +614,7 @@ def build_AugAssign(ctx, stmt): else: raise NotSupportedError( find_before(ctx, rhs.range().start, '=', offsets=(-1, 0)), - "unsupported kind of augumented assignment: " + op.__name__) + "unsupported kind of augmented assignment: " + op.__name__) return AugAssign(lhs, op_token, rhs) @staticmethod diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 17de23a97a4a..801a1e80c1aa 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -242,7 +242,7 @@ def right_inverse(self, value: Tensor) -> None: if len(value) != self.ntensors: raise ValueError( "'right_inverse' must return a sequence of tensors of length " - f"{self.ntensors}. Got a sequence of lenght {len(value)}." + f"{self.ntensors}. Got a sequence of length {len(value)}." ) for i, tensor in enumerate(value): original_i = getattr(self, f"original{i}") diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 84224e88d86e..a27db1e2a327 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1308,7 +1308,7 @@ def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index): from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef] if self.type().dim() is None: - return _unimplemented("index_fill", "input rank not accesible") + return _unimplemented("index_fill", "input rank not accessible") self_dim = self.type().dim() dim_value = _parse_arg(dim, "i") unsqueezed_index = _unsqueeze_helper( diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index a2d3505072ba..75cb96890a12 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -15,7 +15,7 @@ PRelu RoiAlign Scan - ScatterElemenets + ScatterElements ScatterND Where GreaterOrEqual diff --git a/torch/profiler/_pattern_matcher.py b/torch/profiler/_pattern_matcher.py index 3cec84df219b..6c06bf2b2861 100644 --- a/torch/profiler/_pattern_matcher.py +++ b/torch/profiler/_pattern_matcher.py @@ -161,7 +161,7 @@ class ExtraCUDACopyPattern(Pattern): def __init__(self, prof: profile, should_benchmark: bool = False): super().__init__(prof, should_benchmark) self.name = "Extra CUDA Copy Pattern" - self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU." + self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initialize it on GPU." self.url = "https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#create-tensors-directly-on-the-target-device" self.init_ops = { "aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_" diff --git a/torch/serialization.py b/torch/serialization.py index 5f9eda67648b..b9fc92b5110c 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -773,7 +773,7 @@ def load( if weights_only: if pickle_module is not None: - raise RuntimeError("Can not safely load weights when expiclit picke_module is specified") + raise RuntimeError("Can not safely load weights when explicit picke_module is specified") else: if pickle_module is None: pickle_module = pickle diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 9dcb71ae0907..272dd7479ce5 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -333,7 +333,7 @@ def wrapper(*args, **kwargs): def skip_if_win32(): return sandcastle_skip_if( sys.platform == "win32", - "This unit test case is not supportted on Windows platform", + "This unit test case is not supported on Windows platform", ) diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index 0eaab2e1796d..5d7de4e2328a 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -311,7 +311,7 @@ def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): def raise_composite_compliance_error(err, additional_info=''): raise RuntimeError( - "Composite compilance check failed with " + "Composite compliance check failed with " "the above error.\n" f"{additional_info}" "If you are adding an OpInfo of an " diff --git a/torch/utils/benchmark/examples/fuzzer.py b/torch/utils/benchmark/examples/fuzzer.py index 4446e2d85c0a..9728bf3d26c9 100644 --- a/torch/utils/benchmark/examples/fuzzer.py +++ b/torch/utils/benchmark/examples/fuzzer.py @@ -65,7 +65,7 @@ def main(): print() # More string munging to make pretty output. - print(f"Average attemts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") + print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") def time_fn(m): return m.median / m.metadata["numel"] diff --git a/torch/utils/benchmark/examples/sparse/fuzzer.py b/torch/utils/benchmark/examples/sparse/fuzzer.py index 8e2bf554c42a..38421474ccf8 100644 --- a/torch/utils/benchmark/examples/sparse/fuzzer.py +++ b/torch/utils/benchmark/examples/sparse/fuzzer.py @@ -80,7 +80,7 @@ def main(): print() # More string munging to make pretty output. - print(f"Average attemts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") + print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}") def time_fn(m): return m.mean / m.metadata["nnz"] diff --git a/torch/utils/data/datapipes/dataframe/dataframes.py b/torch/utils/data/datapipes/dataframe/dataframes.py index fcbf15328e43..3a7cbb44feaf 100644 --- a/torch/utils/data/datapipes/dataframe/dataframes.py +++ b/torch/utils/data/datapipes/dataframe/dataframes.py @@ -408,7 +408,7 @@ def collate(self, *args, **kwargs): def __getattr__(self, attrname): # ? if attrname in UNIMPLEMENTED_ATTR: - raise AttributeError('Attemping to get ', attrname) + raise AttributeError('Attempting to get ', attrname) if attrname in DATAPIPES_OPS: return (self.as_datapipe()).__getattr__(attrname) return super().__getattr__(attrname) diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 30b04885787a..f0f91dee34b4 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -155,7 +155,7 @@ def _collate_helper(conversion, item): import torcharrow.pytorch as tap # type: ignore[import] collation_fn = tap.rec.Default() except Exception: - raise Exception("unable to import default collation function from the TorchArrrow") + raise Exception("unable to import default collation function from the TorchArrow") tuple_names.append(str(name)) value = collation_fn(df[name]) From 24b9890f0343a156a5785be859610316ecf8274e Mon Sep 17 00:00:00 2001 From: Colin Taylor Date: Thu, 17 Nov 2022 04:26:10 +0000 Subject: [PATCH 282/453] [torchrec] [composable] update ShardedEmbeddingBagCollection to be use registered EBCs with shardedTensors as registered modules (#758) (#88026) Summary: X-link: https://github.com/pytorch/torchrec/pull/758 This PR fixes a bug in FSDP/DDP, where ShardedTensors are not supported even if passed in as params to ignore. this is important for composability because TorchRec named_parameters() will return FQN of shardedTensors (as defined in goals) It defines device of ShardedTensor to be None when local_tensor() does not exist on rank update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit Differential Revision: D40458625 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88026 Approved by: https://github.com/wanchaol, https://github.com/rohan-varma --- test/distributed/test_c10d_gloo.py | 62 +++++++++++++++++-- .../_shard/sharded_tensor/_ops/tensor_ops.py | 11 +++- .../distributed/_shard/sharded_tensor/api.py | 8 ++- torch/nn/parallel/distributed.py | 24 ++++--- 4 files changed, 82 insertions(+), 23 deletions(-) diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index ba214a02696f..c0a25fff9d82 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -23,28 +23,35 @@ import torch.nn.functional as F import torch.testing._internal.common_utils as common from test_c10d_common import ( - LOOPBACK, gpus_for_rank, - Task, + LOOPBACK, ModuleForDdpCommHook, SparseGradientModule, + Task, ) from torch import nn +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard, + ShardedTensor, + ShardMetadata, +) from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor from torch.testing._internal.common_distributed import ( + create_device, MultiProcessTestCase, requires_gloo, - skip_if_lt_x_gpu, simple_sparse_reduce_tests, + skip_if_lt_x_gpu, skip_if_win32, - create_device, verify_ddp_error_logged, ) from torch.testing._internal.common_utils import ( - TestCase, - run_tests, retry_on_connect_failures, + run_tests, sandcastle_skip, + TestCase, ) @@ -1754,6 +1761,49 @@ def forward(self, x): loss = criterion(output, target) loss.backward() + @requires_gloo() + @skip_if_lt_x_gpu(2) + def test_ignored_sharded_tensor(self): + class MyModule(nn.Module): + def __init__(self, shard_tensor: ShardedTensor) -> None: + super().__init__() + self.fc1 = nn.Linear(2, 10, bias=False) + self.st = nn.Parameter(shard_tensor) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + return F.softmax(x, dim=1) + pg = dist.init_process_group( + "gloo", + init_method=f"file://{self.file_name}", + world_size=self.world_size, + rank=self.rank, + ) + device = torch.device(f"cuda:{self.rank}") + local_shard_metadata = ShardMetadata( + shard_offsets=[(self.rank % 2) * 5, 0], + shard_sizes=[5, 10], + placement=f"rank:{self.rank}/cuda:{self.rank}" + ) + local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)] + st = init_from_local_shards(local_shards, [10, 10]) + m = MyModule(st) + with _ddp_replicated_tensor(False): + DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( + module=m, + params_and_buffers_to_ignore={'st'} + ) + # test to make DDP constructor will not fail when module includes a ShardedTensor when ignored + DistributedDataParallel( + m, + device_ids=[device] if device.type == "gpu" else None, + process_group=pg, + gradient_as_bucket_view=True, + broadcast_buffers=False, + static_graph=True, + ) + def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model): mult = 2 batch_size = mult * self.world_size diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index e52c29238a62..fbdeb553cc28 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -42,9 +42,14 @@ def tensor_device(types, args=(), kwargs=None, pg=None): # Validate types if not isinstance(self_st, ShardedTensor): raise TypeError("input needs to be a ShardedTensor") - - return self_st.local_shards()[0].tensor.device - + dev: torch.device + if self_st._local_shards: + dev = self_st._local_shards[0].tensor.device + elif pg and pg._get_backend_name() == "gloo": + dev = torch.device("cpu") + else: + dev = torch.device(torch.cuda.current_device()) + return dev @_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] def st_is_meta(types, args=(), kwargs=None, pg=None): diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 156423c65c11..36ab5d6969a3 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -630,7 +630,13 @@ def cuda( return st_cuda def to(self, *args, **kwargs) -> ShardedTensor: - current_device = self._local_shards[0].tensor.device + current_device: torch.device + if self._local_shards: + current_device = self._local_shards[0].tensor.device + elif self._process_group._get_backend_name() == "gloo": + current_device = torch.device("cpu") + else: + current_device = torch.device(torch.cuda.current_device()) current_dtype = self.dtype device_to = current_device dtype_to = current_dtype diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index c29a0a7ef46b..47eb6bb2ebf1 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -553,11 +553,15 @@ def __init__( gradient_as_bucket_view=False, static_graph=False, ): - super(DistributedDataParallel, self).__init__() Joinable.__init__(self) self.logger = None - if not any((p.requires_grad for p in module.parameters())): + if hasattr(module, "_ddp_params_and_buffers_to_ignore"): + self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore) + else: + self.parameters_to_ignore = set() + self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore] + if not any((p.requires_grad for p in self._module_parameters)): self._log_and_throw( RuntimeError, "DistributedDataParallel is not needed when a module " @@ -570,10 +574,8 @@ def __init__( "device_ids can only be None or contain a single element.", ) - self.is_multi_device_module = ( - len({p.device for p in module.parameters()}) > 1 - ) - distinct_device_types = {p.device.type for p in module.parameters()} + self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1 + distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None} if len(distinct_device_types) != 1: self._log_and_throw( ValueError, @@ -599,7 +601,7 @@ def __init__( "but got device_ids {}, output_device {}, and module parameters {}.".format( device_ids, output_device, - {p.device for p in module.parameters()}, + {p.device for p in self._module_parameters}, ), ) @@ -621,16 +623,12 @@ def __init__( self.static_graph = False self.dim = dim self.module = module - self.device = list(self.module.parameters())[0].device + self.device = list(self._module_parameters)[0].device self.broadcast_buffers = broadcast_buffers self.find_unused_parameters = find_unused_parameters self.require_backward_grad_sync = True self.require_forward_param_sync = True self.gradient_as_bucket_view = gradient_as_bucket_view - if hasattr(module, "_ddp_params_and_buffers_to_ignore"): - self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore - else: - self.parameters_to_ignore = [] self._use_replicated_tensor_module = ( _ddp_with_replicated_tensor_enabled() @@ -647,7 +645,7 @@ def __init__( ) # Check that a module does not have Uninitialized parameters - for param in module.parameters(): + for param in self._module_parameters: if isinstance(param, torch.nn.parameter.UninitializedParameter): self._log_and_throw( RuntimeError, From 637e764ec5d879a5cce0f63f747db3967b708517 Mon Sep 17 00:00:00 2001 From: maxren Date: Wed, 16 Nov 2022 10:46:30 -0800 Subject: [PATCH 283/453] [xnnpack][executorch] Pass xnnexecutor pointer to compileModel() (#89090) Here we pass XNNExecutor* to compile model so that XNNExecutor can be allocated by runtime. This signature change is for executorch: ``` XNNExecutor compileModel(void* buffer) --> void compileModel(void* buffer, XNNExecutor* executor) ``` The intended usecase for allocating Executor and Compiling the serialized flatbuffer: ``` XNNExecutor* executor = runtime_allocator->allocateList(1); XNNCompiler::compileModel(processed.buffer, executor); ``` Differential Revision: [D41208387](https://our.internmc.facebook.com/intern/diff/D41208387/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89090 Approved by: https://github.com/digantdesai --- .../backends/xnnpack/compiler/xnn_compiler.cpp | 15 ++++++++------- .../jit/backends/xnnpack/compiler/xnn_compiler.h | 5 ++++- .../jit/backends/xnnpack/executor/xnn_executor.h | 13 +++++++------ .../jit/backends/xnnpack/xnnpack_backend_lib.cpp | 4 ++-- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index af9c68df31e8..49e2804c99a9 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -11,9 +11,10 @@ namespace jit { namespace xnnpack { namespace delegate { -XNNExecutor XNNCompiler::compileModel( +void XNNCompiler::compileModel( const void* buffer_pointer, - size_t num_bytes) { + size_t num_bytes, + XNNExecutor* executor) { auto output_min = -std::numeric_limits::infinity(); auto output_max = std::numeric_limits::infinity(); @@ -109,17 +110,17 @@ XNNExecutor XNNCompiler::compileModel( status = xnn_create_runtime_v2(subgraph_ptr, nullptr, 0, &runtime_ptr); TORCH_CHECK(xnn_status_success == status); - XNNExecutor executor(runtime_ptr); + executor->runtime_ = + std::unique_ptr( + runtime_ptr, xnn_delete_runtime); for (auto old_id : *flatbuffer_graph->input_ids()) { - executor.input_ids_.push_back(remapped_ids.at(old_id)); + executor->input_ids_.emplace_back(remapped_ids.at(old_id)); } for (auto old_id : *flatbuffer_graph->output_ids()) { - executor.output_ids_.push_back(remapped_ids.at(old_id)); + executor->output_ids_.emplace_back(remapped_ids.at(old_id)); } - - return executor; }; } // namespace delegate diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h index 625b41e43c14..e87fcbcd063d 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -16,7 +16,10 @@ class XNNCompiler { // Takes Flatbuffer Serialized XNNPack Model and rebuilds the xnn-subgraph // returns an executor object that holds the xnn runtime object which we // can then use to set inputs and run inference using the xnn graph. - static XNNExecutor compileModel(const void* buffer_pointer, size_t num_bytes); + static void compileModel( + const void* buffer_pointer, + size_t num_bytes, + XNNExecutor* executor); }; } // namespace delegate diff --git a/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h index f82bde231c90..2521c0c7749d 100644 --- a/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h +++ b/torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h @@ -1,5 +1,5 @@ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. - +#pragma once #include #include #include @@ -11,14 +11,15 @@ namespace delegate { class XNNExecutor { private: - std::unique_ptr runtime_; + std::unique_ptr runtime_{ + nullptr, + &xnn_delete_runtime}; std::vector input_ids_; std::vector output_ids_; std::vector externals_; public: - XNNExecutor(xnn_runtime_t runtime_ptr) - : runtime_(runtime_ptr, xnn_delete_runtime){}; + XNNExecutor() = default; template bool set_inputs(std::vector& inputs, std::vector& outputs) { @@ -41,7 +42,7 @@ class XNNExecutor { } return true; - }; + } bool forward() { xnn_status status = @@ -58,7 +59,7 @@ class XNNExecutor { } return true; - }; + } friend class XNNCompiler; }; diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp index 553e8350ddbd..46c7458039d4 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp @@ -41,8 +41,8 @@ class XNNPackBackend : public PyTorchBackendInterface { // Compiling and wrapping exeuction object const std::string& ser_model = dict.at("ser_model").toStringRef(); - XNNExecutor executor = - XNNCompiler::compileModel(ser_model.data(), ser_model.length()); + XNNExecutor executor; + XNNCompiler::compileModel(ser_model.data(), ser_model.length(), &executor); auto model_ptr = c10::make_intrusive(std::move(executor)); auto runtime_handle = IValue::make_capsule(model_ptr); From 44c9185f91699b74c7953eb912f37fb24991958d Mon Sep 17 00:00:00 2001 From: ecao Date: Thu, 17 Nov 2022 04:47:45 +0000 Subject: [PATCH 284/453] Fix empty input issue of convolution for channels last memory format (#86521) Fixes empty input convolution issue : when input is empty e.g. shape of (0, 3, 3, 4) and weight is channels last format, at::_unsafe_view will raise "view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead." Pull Request resolved: https://github.com/pytorch/pytorch/pull/86521 Approved by: https://github.com/jgong5, https://github.com/malfet --- aten/src/ATen/native/Convolution.cpp | 22 +++++++++++++++++--- test/test_nn.py | 30 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index bf7017f20a4f..8584bae445ad 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -11,11 +11,15 @@ #include #include #include - #include - #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + #if AT_NNPACK_ENABLED() #include #endif @@ -1508,7 +1512,19 @@ at::Tensor _convolution( break; case ConvBackend::Empty: { - auto weight_view = at::_unsafe_view(weight, -1); + Tensor weight_view; + // Use permute and clone to avoid at::_unsafe_view(weight, -1) failure for non-contiguous cases where + // view size is not compatible with input tensor's size and stride. + if(weight.is_contiguous()) { + weight_view = at::_unsafe_view(weight, -1); + } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast)) { + weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 1}), -1); + } else if (weight.is_contiguous(at::MemoryFormat::ChannelsLast3d)) { + weight_view = at::_unsafe_view(at::permute(weight, {0, 2, 3, 4, 1}), -1); + } else { + weight_view = at::_unsafe_view(weight.clone(at::MemoryFormat::Contiguous), -1); + } + output = (input.size(1) == 0) ? (input.view(-1) * weight_view) : (input * weight_view[0]); if (bias.defined()) { output.add_(bias[0]); diff --git a/test/test_nn.py b/test/test_nn.py index 4231c19ed0da..7d6a016a6f51 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11736,6 +11736,36 @@ def test_batchnorm_large_batch(self, device, dtype): data = torch.rand(880801, 1, 1, 1, device=device, dtype=dtype) out = bn(data).sum().backward() + @dtypesIfCUDA(torch.float, torch.double, torch.half, torch.complex128) + @dtypes(torch.float, torch.double, torch.bfloat16, torch.complex128) + def test_conv_empty_input(self, device, dtype): + def help(input, conv, memory_format): + ref_out = conv(input) + conv_cl = conv.to(memory_format=memory_format) + out_cl = conv_cl(input) + self.assertEqual(ref_out, out_cl) + input_cl = input.to(memory_format=memory_format) + out_cl2 = conv(input_cl) + self.assertEqual(out_cl, out_cl2) + out_cl3 = conv_cl(input_cl) + self.assertEqual(out_cl, out_cl3) + + # channels_last case + input2d = torch.randn((0, 4, 20, 20)).to(device=device, dtype=dtype) + conv2d = torch.nn.Conv2d(4, 4, 3, 1).to(device=device, dtype=dtype) + help(input2d, conv2d, torch.channels_last) + # channels_last_3d case + input3d = torch.randn((0, 4, 20, 20, 20)).to(device=device, dtype=dtype) + conv3d = torch.nn.Conv3d(4, 4, 3, 1).to(device=device, dtype=dtype) + help(input3d, conv3d, torch.channels_last_3d) + # non-contiguous case + weight = torch.rand(4, 8, 3, 3)[:, ::2, :, :].to(device=device, dtype=dtype) + bias = torch.rand(4).to(device=device, dtype=dtype) + out = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) + weight = weight.contiguous() + out_ref = F.conv2d(input2d, weight, bias, (1, 1), 0, (1, 1), 1) + self.assertEqual(out_ref, out) + def test_InstanceNorm1d_general(self, device): b = random.randint(3, 5) c = random.randint(3, 5) From 81a8fdc40d7c504f99d5796a5b187551493685e4 Mon Sep 17 00:00:00 2001 From: Lukas Hoenig Date: Thu, 17 Nov 2022 04:54:23 +0000 Subject: [PATCH 285/453] [MPS] Add binary operations dtype precedence test case (#87545) See https://github.com/pytorch/pytorch/pull/84742 and https://github.com/pytorch/pytorch/pull/78319. The test case tests that - for the binary operations (add, sub, mul, div), - for all data types (dtypes), - for a range of representative values and their combinations, - for various shapes and ways of creating the test tensors, the contents and dtype of the result tensor is identical for the MPS and CPU backends. It adds about 15-18s runtime to `test_mps.py`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87545 Approved by: https://github.com/kit1980 --- test/test_mps.py | 62 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/test/test_mps.py b/test/test_mps.py index 30546f50fd65..31e2e367e7de 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1688,6 +1688,68 @@ def test_copy_non_contiguous(self): y.permute(3, 2, 1, 0)[1::, ::2] = z self.assertEqual(x, y.to('cpu')) + # See https://github.com/pytorch/pytorch/pull/84742 + # and https://github.com/pytorch/pytorch/pull/78319 + def test_binops_dtype_precedence(self): + # Test dtype precedence (casting order) in binary operations by comparing to CPU result + # Example values for all dtypes supported on the MPS backend + sample_vals = { + torch.bool: [False, True], + torch.int16: [-15, 0, 1, 10], + torch.int32: [-376, 0, 1, 13], + torch.int64: [-8, 0, 1, 77], + torch.float16: [-234.5, 0.0, 1.0, 2.0], + torch.float32: [-1.0, 0.0, 0.1, 111.99], + } + # Test all combinations of dtypes, operations, dimensionality + for dtype1, dtype2, binop in itertools.product( + sample_vals.keys(), sample_vals.keys(), ['add', 'sub', 'mul', 'div']): + # bool minus bool is generally unsupported, so skip + if binop == 'sub' and (dtype1 == torch.bool or dtype2 == torch.bool): + continue + full_shape = (10,) + for val1, val2 in itertools.product(sample_vals[dtype1], sample_vals[dtype2]): + # print(f'{dtype1},{dtype2}: ({val1}).{binop}({val2})') + # print(getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) + # (torch.tensor(val2, dtype=dtype2, device='mps'))) + # print(getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) + # (torch.tensor(val2, dtype=dtype2, device='cpu'))) + self.assertEqual( + getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) + (torch.tensor(val2, dtype=dtype2, device='mps')), + getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) + (torch.tensor(val2, dtype=dtype2, device='cpu'))) + self.assertEqual( + getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) + (torch.tensor([val2], dtype=dtype2, device='mps')), + getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) + (torch.tensor([val2], dtype=dtype2, device='cpu'))) + self.assertEqual( + getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) + (torch.tensor([val2], dtype=dtype2, device='mps')), + getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) + (torch.tensor([val2], dtype=dtype2, device='cpu'))) + self.assertEqual( + getattr(torch.tensor([val1], dtype=dtype1, device='mps'), binop) + (torch.tensor(val2, dtype=dtype2, device='mps')), + getattr(torch.tensor([val1], dtype=dtype1, device='cpu'), binop) + (torch.tensor(val2, dtype=dtype2, device='cpu'))) + # Test tensors created with torch.full + x1 = torch.full(full_shape, val1, dtype=dtype1, device='mps') + y1 = torch.tensor(val2, dtype=dtype2, device='mps') + x2 = torch.full(full_shape, val1, dtype=dtype1, device='cpu') + y2 = torch.tensor(val2, dtype=dtype2, device='cpu') + self.assertEqual(getattr(x1, binop)(y1), getattr(x2, binop)(y2)) + x3 = torch.tensor(val1, dtype=dtype1, device='mps') + y3 = torch.full(full_shape, val2, dtype=dtype2, device='mps') + x4 = torch.tensor(val1, dtype=dtype1, device='cpu') + y4 = torch.full(full_shape, val2, dtype=dtype2, device='cpu') + self.assertEqual(getattr(x3, binop)(y3), getattr(x4, binop)(y4)) + self.assertEqual( + getattr(torch.tensor(val1, dtype=dtype1, device='mps'), binop) + (torch.full(full_shape, val2, dtype=dtype2, device='mps')), + getattr(torch.tensor(val1, dtype=dtype1, device='cpu'), binop) + (torch.full(full_shape, val2, dtype=dtype2, device='cpu'))) class TestLogical(TestCase): From 4e1d19c5a577b947a3dc84d9eec4a186ad3cd52f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Nov 2022 04:58:53 +0000 Subject: [PATCH 286/453] Revert "Redefine the simdlen semantic: (#88482)" This reverts commit fce6d6b3dcc879720bc45143426b86232106818a. Reverted https://github.com/pytorch/pytorch/pull/88482 on behalf of https://github.com/kit1980 due to Broke multiple tests in several trunk workflows, for example https://github.com/pytorch/pytorch/actions/runs/3485086792/jobs/5830429554 --- test/inductor/test_torchinductor.py | 94 +----------- torch/_inductor/codecache.py | 215 +++++++--------------------- torch/_inductor/codegen/common.py | 6 - torch/_inductor/codegen/cpp.py | 92 +++--------- 4 files changed, 80 insertions(+), 327 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f9aa93f4a7e6..fb7ca1fc92b7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4529,11 +4529,7 @@ def fn(x): v = torch.randn(10) result = fn(v) - # TODO: OMP parallel reduction order is not deterministic. - # Hence, the accurarcy might vary up and down. For short term, - # we increase the tolerance and will fix it later by using - # aten parallel. - assert same(result, mod(v), tol=5e-1) + assert same(result, mod(v)) def test_inplace_add_alpha(self): def fn(x, y): @@ -4603,79 +4599,7 @@ def test_complex_memory_overlap(self): self.assertFalse(complex_memory_overlap(gathered.t())) @unittest.skipIf( - not codecache.valid_vec_isa_list(), "Does not support vectorization" - ) - @patch.object(config, "dynamic_shapes", True) - @patch.object(torch._dynamo.config, "dynamic_shapes", True) - @patch.object(functorch_config, "use_dynamic_shapes", True) - def test_vec_dynamic_shapes(self): - def fn(x): - return torch.softmax(x, -1) - - value = torch.randn((2, 10)) - with patch.object(config.cpp, "simdlen", None): - torch._dynamo.reset() - metrics.reset() - opt_fn = torch._dynamo.optimize("inductor")(fn) - opt_fn(value) - - real_out = fn(value) - compiled_out = opt_fn(value) - assert same(real_out, compiled_out, equal_nan=True) - assert metrics.generated_cpp_vec_kernel_count < 1 - - @unittest.skipIf( - not codecache.valid_vec_isa_list(), "Does not support vectorization" - ) - @patch("torch.cuda.is_available", lambda: False) - def test_auto_simd(self): - vec_avx512 = codecache.supported_vec_isa_list[0] - vec_avx2 = codecache.supported_vec_isa_list[1] - self.assertTrue(vec_avx512.bit_width() == 512) - self.assertTrue(vec_avx2.bit_width() == 256) - self.assertTrue(vec_avx512.nelements() == 16) - self.assertTrue(vec_avx2.nelements() == 8) - self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) - self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) - - with patch.object(config.cpp, "simdlen", None): - isa = codecache.pick_vec_isa() - if vec_avx512 in codecache.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - - with patch.object(config.cpp, "simdlen", 0): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with patch.object(config.cpp, "simdlen", 1): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with patch.object(config.cpp, "simdlen", 257): - isa = codecache.pick_vec_isa() - self.assertFalse(isa) - - with patch.object(config.cpp, "simdlen", 513): - isa_list = codecache.valid_vec_isa_list() - if vec_avx512 in isa_list: - self.assertFalse(isa) - - with patch.object(config.cpp, "simdlen", 512): - isa_list = codecache.valid_vec_isa_list() - if vec_avx512 in isa_list: - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_avx512) - - with patch.object(config.cpp, "simdlen", 256): - isa_list = codecache.valid_vec_isa_list() - if vec_avx2 in isa_list: - isa = codecache.pick_vec_isa() - self.assertTrue(isa == vec_avx2) - - @unittest.skipIf( - not codecache.valid_vec_isa_list(), "Does not support vectorization" + not codecache.get_cpu_proc_info(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): @@ -4686,7 +4610,7 @@ def fn(x): x[0, 0] = torch.nan x[1, -1] = torch.nan - with patch.object(config.cpp, "simdlen", None): + with patch.object(config.cpp, "simdlen", 8): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x) @@ -4699,7 +4623,7 @@ def fn(x): # other platforms support, we just need to add the ISA info to the supported_vector_isa # and include proper aten vectorization head file. @unittest.skipIf( - not codecache.valid_vec_isa_list(), "Does not support vectorization" + not codecache.get_cpu_proc_info(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_vec_kernel_cpu_only(self): @@ -4738,15 +4662,7 @@ def fn(x1, x2): x1 = torch.randn((10, 20)) x2 = torch.randn((10, 20)) - with patch.object(config.cpp, "simdlen", 1): - torch._dynamo.reset() - metrics.reset() - traced = make_fx(fn)(x1, x2) - compiled = compile_fx_inner(traced, [x1, x2]) - assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) - assert metrics.generated_cpp_vec_kernel_count == 0 - - with patch.object(config.cpp, "simdlen", None): + with patch.object(config.cpp, "simdlen", 8): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x1, x2) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 232a611b06c6..2826f3599912 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,5 +1,5 @@ import base64 -import dataclasses +import enum import functools import getpass import hashlib @@ -18,7 +18,7 @@ from ctypes import cdll from threading import Thread from time import sleep, time -from typing import Any, Callable, Dict, List +from typing import Any, Dict import torch from torch.utils import cpp_extension @@ -147,181 +147,79 @@ def is_gcc(): return re.search(r"(gcc|g\+\+)", cpp_compiler()) -class VecISA(object): - _bit_width: int - _macro: str - _arch_flags: str - _dtype_nelements: Dict[torch.dtype, int] - - # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions - # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions - # like exp, pow, sin, cos and etc. - # But PyTorch and TorchInductor might use different compilers to build code. If - # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so - # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass - # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest - # gcc/g++ compiler by default while it could support the AVX512 compilation. - # Therefore, there would be a conflict sleef version between PyTorch and - # TorchInductor. Hence, we dry-compile the following code to check whether current - # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM - # also needs the logic - _avx_code = """ -#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) -#include -#include -#endif - -__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; - -extern "C" void __avx_chk_kernel() { - auto tmp0 = at::vec::Vectorized(1); - auto tmp1 = tmp0.exp(); - tmp1.store(in_out_ptr0); -} -""" - - _avx_py_load = """ -import torch -from ctypes import cdll -cdll.LoadLibrary("__lib_path__") -""" - - def bit_width(self): - return self._bit_width - - def nelements(self, dtype: torch.dtype = torch.float): - return self._dtype_nelements[dtype] - - def build_macro(self): - return self._macro - - def build_arch_flags(self): - return self._arch_flags - - def __hash__(self) -> int: - return hash(str(self)) - - @functools.lru_cache(None) - def __bool__(self): - key, input_path = write(VecISA._avx_code, "cpp", extra="") - from filelock import FileLock - - lock_dir = get_lock_dir() - lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) - with lock: - output_path = input_path[:-3] + "so" - build_cmd = cpp_compile_command( - input_path, output_path, warning_all=False, vec_isa=self - ).split(" ") - try: - # Check build result - subprocess.check_output(build_cmd, stderr=subprocess.STDOUT) - subprocess.check_call( - [ - "python", - "-c", - VecISA._avx_py_load.replace("__lib_path__", output_path), - ], - stderr=subprocess.DEVNULL, - ) - except Exception as e: - return False - - return True - - -@dataclasses.dataclass -class VecAVX512(VecISA): - _bit_width = 512 - _macro = "CPU_CAPABILITY_AVX512" - _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" - _dtype_nelements = {torch.float: 16, torch.bfloat16: 32} - - def __str__(self) -> str: - return "avx512" - - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ - - -@dataclasses.dataclass -class VecAVX2(VecISA): - _bit_width = 256 - _macro = "CPU_CAPABILITY_AVX2" - _arch_flags = "-mavx2 -mfma" - _dtype_nelements = {torch.float: 8, torch.bfloat16: 16} - - def __str__(self) -> str: - return "avx2" - - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ - - -class InvalidVecISA(VecISA): - _bit_width = 0 - _macro = "" - _arch_flags = "" - _dtype_nelements = {} - - def __str__(self) -> str: - return "INVALID_VEC_ISA" +class _SupportedVecIsa(enum.Enum): + AVX512 = 1 + AVX2 = 2 + INVALID = -1 def __bool__(self): - return False - - __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + return self != _SupportedVecIsa.INVALID + @staticmethod + def isa_str(supported_isa: enum.Enum): + if supported_isa == _SupportedVecIsa.AVX512: + return "avx512" + elif supported_isa == _SupportedVecIsa.AVX2: + return "avx2" + else: + return "" -invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [VecAVX512(), VecAVX2()] + @staticmethod + def vec_macro(supported_isa: enum.Enum): + if supported_isa == _SupportedVecIsa.AVX512: + return "CPU_CAPABILITY_AVX512" + elif supported_isa == _SupportedVecIsa.AVX2: + return "CPU_CAPABILITY_AVX2" + else: + return "" # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. -@functools.lru_cache(None) -def valid_vec_isa_list(): +@functools.lru_cache(1) +def get_cpu_proc_info(): if sys.platform != "linux": return [] - isa_list = [] + isa_info = [] with open("/proc/cpuinfo") as _cpu_info: _cpu_info_content = _cpu_info.read() - for isa in supported_vec_isa_list: - if str(isa) in _cpu_info_content and isa: - isa_list.append(isa) - return isa_list + if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content: + isa_info.append(_SupportedVecIsa.AVX512) + + if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content: + isa_info.append(_SupportedVecIsa.AVX2) + return isa_info -def pick_vec_isa(): - _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() - if not _valid_vec_isa_list: - return invalid_vec_isa - # If the simdlen is None, it indicates determin the vectroization length automatically - if config.cpp.simdlen is None: - assert _valid_vec_isa_list - return _valid_vec_isa_list[0] +def supported_vector_isa(): + # TODO: Add ARM Vec here. + # Dict(k: isa, v: number of float element) + vec_isa_info = { + _SupportedVecIsa.AVX512: 16, + _SupportedVecIsa.AVX2: 8, + } - for isa in _valid_vec_isa_list: - if config.cpp.simdlen == isa.bit_width(): + if config.cpp.simdlen is None or config.cpp.simdlen <= 1: + return _SupportedVecIsa.INVALID + + cpu_info_content = get_cpu_proc_info() + for isa in vec_isa_info.keys(): + if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]: return isa - return invalid_vec_isa + return _SupportedVecIsa.INVALID -def cpp_compile_command( - input, - output, - warning_all=True, - shared=True, - include_pytorch=False, - vec_isa: VecISA = invalid_vec_isa, -): - if include_pytorch or vec_isa != invalid_vec_isa: +def cpp_compile_command(input, output, include_pytorch=False): + valid_isa = supported_vector_isa() + if include_pytorch or valid_isa: ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")] lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")] libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"] - macros = vec_isa.build_macro() + macros = _SupportedVecIsa.vec_macro(valid_isa) if macros: macros = f"-D{macros}" else: @@ -337,13 +235,11 @@ def cpp_compile_command( lpaths = " ".join(["-L" + p for p in lpaths]) libs = " ".join(["-l" + p for p in libs]) - shared_lib = "-shared -fPIC" if shared else "" - warning_all_flag = "-Wall" if warning_all else "" return re.sub( r"[ \n]+", " ", f""" - {cpp_compiler()} {input} {shared_lib} {warning_all_flag} -std=c++14 -Wno-unused-variable + {cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable {ipaths} {lpaths} {libs} {macros} -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS @@ -370,12 +266,7 @@ def _load_library(path): @classmethod def load(cls, source_code): - picked_vec_isa = pick_vec_isa() - key, input_path = write( - source_code, - "cpp", - extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa), - ) + key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o")) if key not in cls.cache: from filelock import FileLock @@ -385,7 +276,7 @@ def load(cls, source_code): output_path = input_path[:-3] + "so" if not os.path.exists(output_path): cmd = cpp_compile_command( - input=input_path, output=output_path, vec_isa=picked_vec_isa + input=input_path, output=output_path ).split(" ") try: subprocess.check_output(cmd, stderr=subprocess.STDOUT) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index cf98833964ca..2803970295cc 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -417,12 +417,6 @@ def __init__(self, name): def __str__(self): return self.name - def __hash__(self) -> int: - return hash(self.name) - - def __eq__(self, other) -> bool: - return type(other) == type(self) and other.name == self.name - def update_on_args(self, args, kwargs): pass diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 38ef2179d5b7..65a9335d6cbf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -616,7 +616,7 @@ def codegen_loops(self, code, worksharing): ) reductions.mark_reduction(self.reduction_vars) - if codecache.pick_vec_isa(): + if config.cpp.simdlen: # TODO(jansel): detect stride-1 dimension and vectorize that if reductions: reductions.loops[-1].simd = True @@ -707,8 +707,7 @@ class CppVecKernel(CppKernel): def __init__(self, args, num_threads): super(CppVecKernel, self).__init__(args, num_threads) - assert codecache.pick_vec_isa() - self.simd_nelements = codecache.pick_vec_isa().nelements() + self.simd_len = config.cpp.simdlen self.reduction_omp_dec: Dict[str, str] = {} metrics.generated_cpp_vec_kernel_count += 1 @@ -724,10 +723,10 @@ def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr): def transform_index(self, index: sympy.Expr): expanded_index = sympy.expand(index) - assert self.simd_nelements - assert self.simd_nelements >= 1 + assert self.simd_len + assert self.simd_len > 0 most_inner_var = self.itervars[-1] - replacement = {most_inner_var: most_inner_var * self.simd_nelements} + replacement = {most_inner_var: most_inner_var * self.simd_len} new_index = sympy_subs(expanded_index, replacement) return new_index @@ -948,24 +947,21 @@ def __init__(self, args=None, num_threads=None): super(CppKernelProxy, self).__init__(args, num_threads) self.simd_vec_kernel: CppVecKernel = None self.simd_omp_kernel: CppKernel = None - self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() - def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): - assert self.picked_vec_isa - nelements = self.picked_vec_isa.nelements(dtype) - loop_nest.split_most_inner_loop(nelements) + def vectorize_most_inner_loop(self, loop_nest): + loop_nest.split_most_inner_loop(config.cpp.simdlen) loop_with_tail = loop_nest.loops[-1] assert isinstance(loop_with_tail, LoopLevelWithTail) loop_with_tail.main_loop.simd_vec = True loop_with_tail.tail_loop.simd_omp = True - # We chope the loop into two cubes by the nelements - main loop and tail loop. + # We chope the loop into two cubes by the config.cpp.simdlen - main loop and tail loop. # Regarding the main loop, it is straightforward that it could be vectorized with - # nelements. But for the tail loop, it still could be vectorized. For example, - # if the nelements is 8(256bits), then the tail loop still could be vectorized + # config.cpp.simdlen. But for the tail loop, it still could be vectorized. For example, + # if the config.cpp.simdlen is 8(256bits), then the tail loop still could be vectorized # as 4(128bits). - loop_with_tail.tail_loop.simd_nelements = int(nelements / 2) + loop_with_tail.tail_loop.simd_len = int(config.cpp.simdlen / 2) loop_with_tail.tail_loop.simd_vec = False loop_with_tail.main_loop_body = self.simd_vec_kernel @@ -975,7 +971,7 @@ def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): def codegen_loops(self, code, worksharing): threads = parallel_num_threads() - if self.simd_vec_kernel is None or not self.picked_vec_isa: + if self.simd_vec_kernel is None: assert self.simd_omp_kernel return self.simd_omp_kernel.codegen_loops(code, worksharing) @@ -997,52 +993,12 @@ def codegen_loops(self, code, worksharing): ), LoopNest(loops[reduction_depth:]) loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars) - assert self.picked_vec_isa - # Do not apply vectorization since the range of most inner is too small. Meanwhile, - # If the range of the most inner is less then the codecache.pick_vec_isa().nelements(), - # the generated code for some reduction will be as follows that leads to incrrect result. - # - # LINE01: float tmp1 = 0; - # LINE02: auto tmp1_vec = at::vec::Vectorized(tmp1); - # LINE03: for(long i1=0; i1<2; i1+=1) - # LINE04: { - # LINE05: for(long i2=0; i2<0; i2+=1) - # LINE06: { - # LINE07: auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1)); - # LINE08: tmp1_vec += tmp0; - # LINE09: } - # LINE10: tmp1 = vec_reduce_all([](Vectorized& x, Vectorized&y) {return x + y;}, tmp1_vec); - # LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1) - # LINE12: for(long i2=0; i2<8; i2+=1) - # LINE13: { - # LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)]; - # LINE15: tmp1 += tmp0; - # LINE16: } - # LINE17: } - # LINE18: out_ptr3[i0] = tmp1; - # - # tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05) - # is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result - # will be incorrect. We skip thie case. - most_inner_loop = ( - loops_nest_reduce.loops[-1] - if loops_nest_reduce - else loops_nest_non_reduce.loops[-1] - ) - main_loop_range = ir.IndexingDiv( - most_inner_loop.size, self.picked_vec_isa.nelements() - ) - loop_interval = sympy.simplify(main_loop_range) - # TODO(Eikan): To support dynamic shape. - if not loop_interval.is_integer or loop_interval <= 0: - metrics.generated_cpp_vec_kernel_count -= 1 - return self.simd_omp_kernel.codegen_loops(code, worksharing) - - # TODO(jansel): detect stride-1 dimension and vectorize that - if loops_nest_reduce: - loops_nest_reduce.loops[-1].simd = True - elif loops_nest_non_reduce: - loops_nest_non_reduce.loops[-1].simd = True + if config.cpp.simdlen: + # TODO(jansel): detect stride-1 dimension and vectorize that + if loops_nest_reduce: + loops_nest_reduce.loops[-1].simd = True + elif loops_nest_non_reduce: + loops_nest_non_reduce.loops[-1].simd = True par_depth = 0 reduction_par_depth = 0 @@ -1182,7 +1138,8 @@ def can_fuse_vertical(cls, node1, node2): return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_vec(self, nodes): - if not codecache.pick_vec_isa(): + # TODO: Query cpu arch and vec length from aten + if not codecache.supported_vector_isa(): return False _, (group, reduction_group) = max( @@ -1392,8 +1349,7 @@ class LoopLevel: steps: sympy.Expr = sympy.Integer(1) parallel: int = 0 simd_omp: bool = False - picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() - simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 + simd_len: int = config.cpp.simdlen simd_vec: bool = False collapsed: bool = False reduction_vars: Dict[str, str] = None @@ -1407,11 +1363,7 @@ def lines(self): ) else: reduction = "" - simd = ( - f"simd simdlen({self.simd_nelements}) " - if self.simd_omp and self.simd_nelements > 1 - else "" - ) + simd = f"simd simdlen({self.simd_len}) " if self.simd_omp else "" if self.parallel: # TODO(jansel): look into chunk size and other schedules line1 = f"#pragma omp for{reduction} " From 54fca6a9da77b56b1a82373c814e61378b5d04c2 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Thu, 17 Nov 2022 05:01:08 +0000 Subject: [PATCH 287/453] Fix: prefer .is_none() over .is(py::none()) for pybind11 in caffe2 (#88199) Follow up to #88051 . I noticed that I missed a few spots in the caffe2 folder. Prefer `.is_none()` over `.is(py::none())` as `.is_none()` is more efficient since it avoid reference counting increments and decrements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88199 Approved by: https://github.com/albanD, https://github.com/kit1980 --- caffe2/python/pybind_state.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/caffe2/python/pybind_state.cc b/caffe2/python/pybind_state.cc index 3103006774df..5b2c2f71a827 100644 --- a/caffe2/python/pybind_state.cc +++ b/caffe2/python/pybind_state.cc @@ -209,7 +209,7 @@ bool feedBlob( const py::object& arg, const py::object device_option) { DeviceOption option; - if (!device_option.is(py::none())) { + if (!device_option.is_none()) { // If we have a device option passed in, read it. CAFFE_ENFORCE(ParseProtoFromLargeString( py::bytes(device_option).cast(), &option)); @@ -752,7 +752,7 @@ void addObjectMethods(py::module& m) { .def( "reset", [](caffe2::onnx::DummyName& instance, const py::object& args) { - if (args.is(py::none())) { + if (args.is_none()) { instance.Reset(std::unordered_set()); } else { instance.Reset(args.cast>()); @@ -1130,7 +1130,7 @@ void addGlobalMethods(py::module& m) { m.def( "switch_workspace", [](const std::string& name, const py::object create_if_missing) { - if (create_if_missing.is(py::none())) { + if (create_if_missing.is_none()) { return caffe2::python::SwitchWorkspaceInternal(name, false); } return caffe2::python::SwitchWorkspaceInternal( @@ -1143,7 +1143,7 @@ void addGlobalMethods(py::module& m) { "reset_workspace", [](const py::object& root_folder) { VLOG(1) << "Resetting workspace."; - if (root_folder.is(py::none())) { + if (root_folder.is_none()) { caffe2::python::ResetWorkspace(new Workspace()); } else { caffe2::python::ResetWorkspace( @@ -1634,7 +1634,7 @@ void addGlobalMethods(py::module& m) { "register_python_op", [](py::object func, bool pass_workspace, std::string name) { using namespace python_detail; - CAFFE_ENFORCE(!func.is(py::none())); + CAFFE_ENFORCE(!func.is_none()); if (!name.empty()) { name += ":"; } @@ -1650,7 +1650,7 @@ void addGlobalMethods(py::module& m) { "register_python_gradient_op", [](const std::string& token, py::object func) { using namespace python_detail; - CAFFE_ENFORCE(!func.is(py::none())); + CAFFE_ENFORCE(!func.is_none()); CAFFE_ENFORCE(gRegistry().find(token) != gRegistry().end()); // For global sanity gradient ops shouldn't access workspace gRegistry()[token + "_gradient"] = Func{func, false}; From 70fb673e51decdd8bf4e55244d910a8e5680d12f Mon Sep 17 00:00:00 2001 From: Rachel030219 <13704467+Rachel030219@users.noreply.github.com> Date: Thu, 17 Nov 2022 05:55:25 +0000 Subject: [PATCH 288/453] Use software approach to catch overflow ( `c10/utils/safe_numerics.h` ) on ARM devices (#89042) Fixes #89040 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89042 Approved by: https://github.com/malfet --- c10/util/safe_numerics.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/c10/util/safe_numerics.h b/c10/util/safe_numerics.h index 7eb9ed39395d..e5c249dd1d2b 100644 --- a/c10/util/safe_numerics.h +++ b/c10/util/safe_numerics.h @@ -22,7 +22,13 @@ C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { return __builtin_add_overflow(a, b, out); #else unsigned long long tmp; +#if defined(_M_IX86) || defined(_M_X64) auto carry = _addcarry_u64(0, a, b, &tmp); +#else + tmp = a + b; + unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); + auto carry = vector >> 63; +#endif *out = tmp; return carry; #endif From a41f70603aededc414da58523361773dbf13bde2 Mon Sep 17 00:00:00 2001 From: "Andrew M. James" Date: Thu, 17 Nov 2022 02:01:13 +0000 Subject: [PATCH 289/453] Round out rad2deg sparse support (#88442) - Add sparse coo dispatch - Modify backward to work with sparse compressed layouts - Enable sparse_compressed autograd testing - Correct layout support attributes on OpInfo Pull Request resolved: https://github.com/pytorch/pytorch/pull/88442 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 3 +++ aten/src/ATen/native/sparse/SparseUnaryOps.cpp | 3 +++ test/test_sparse_csr.py | 3 ++- torch/csrc/autograd/FunctionsManual.cpp | 2 +- torch/testing/_internal/common_methods_invocations.py | 7 ++++++- 5 files changed, 15 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 726a54b5e225..8046b4f6ac4b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4020,17 +4020,20 @@ variants: function, method dispatch: CompositeExplicitAutograd: rad2deg + SparseCPU, SparseCUDA: rad2deg_sparse SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr - func: rad2deg_(Tensor(a!) self) -> Tensor(a!) variants: function, method dispatch: CompositeExplicitAutograd: rad2deg_ + SparseCPU, SparseCUDA: rad2deg_sparse_ SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_ - func: rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: CompositeExplicitAutograd: rad2deg_out + SparseCPU, SparseCUDA: rad2deg_sparse_out SparseCsrCPU, SparseCsrCUDA: rad2deg_sparse_csr_out - func: deg2rad(Tensor self) -> Tensor diff --git a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp index 084daed4df4e..9e0503337b5d 100644 --- a/aten/src/ATen/native/sparse/SparseUnaryOps.cpp +++ b/aten/src/ATen/native/sparse/SparseUnaryOps.cpp @@ -43,6 +43,8 @@ #include #include #include +#include +#include #include #include #include @@ -177,6 +179,7 @@ COALESCED_UNARY_UFUNC(floor); COALESCED_UNARY_UFUNC(frac); COALESCED_UNARY_UFUNC(log1p); COALESCED_UNARY_UFUNC(round); +COALESCED_UNARY_UFUNC(rad2deg); COALESCED_UNARY_UFUNC(sign); COALESCED_UNARY_UFUNC(sgn); COALESCED_UNARY_UFUNC(sin); diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index e83616489fc2..7ec2d4a79bf9 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -66,7 +66,8 @@ def _check_cusparse_sddmm_available(): 'positive', 'frac', 'nn.functional.relu', - 'log1p' + 'log1p', + 'rad2deg' ] # This should be just an import from test_linalg instead of code duplication diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index c0fbf5f6c0aa..05fcdea3e6b7 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -571,7 +571,7 @@ Tensor permute_backwards(const Tensor& grad, IntArrayRef fwd_dims) { Tensor rad2deg_backward(const Tensor& grad) { constexpr double M_180_PI = 57.295779513082320876798154814105170332405472466564; - return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_180_PI))); + return at::mul(grad, Scalar(M_180_PI)); } Tensor deg2rad_backward(const Tensor& grad) { diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index af4539ee5fec..5db917424a2f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12804,7 +12804,12 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypes=[torch.bfloat16]), ), supports_forward_ad=True, - supports_fwgrad_bwgrad=True), + supports_fwgrad_bwgrad=True, + supports_sparse=True, + supports_sparse_csr=True, + supports_sparse_csc=True, + supports_sparse_bsr=True, + supports_sparse_bsc=True), UnaryUfuncInfo('real', ref=np.real, dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half, torch.chalf), From 74610a1cedbab64e813f3b49535cd8691a3ec5c7 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 17 Nov 2022 06:14:21 +0000 Subject: [PATCH 290/453] [dynamo][benchmarks] HF - Fix seq len and batch sizes (#89165) Fixes many models in https://github.com/pytorch/torchdynamo/issues/1842 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89165 Approved by: https://github.com/ngimel --- benchmarks/dynamo/common.py | 4 +- benchmarks/dynamo/huggingface.py | 91 ++++++++++++++----- benchmarks/dynamo/huggingface_models_list.txt | 66 +++++++------- torch/_dynamo/testing.py | 8 +- 4 files changed, 105 insertions(+), 64 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index a6e66c4281b6..789ebc3683d3 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -144,6 +144,8 @@ "MT5ForConditionalGeneration", # OOM "PegasusForConditionalGeneration", # OOM "XGLMForCausalLM", # fp64_OOM + "DebertaV2ForMaskedLM", # OOM + "DebertaV2ForQuestionAnswering", # OOM # OOM "BigBird", "TrOCRForCausalLM", @@ -1038,7 +1040,7 @@ def decay_batch_exp(self, batch_size, factor=0.5, divisor=2): out_batch_size = batch_size - 1 return max(0, int(out_batch_size)) - def batch_size_finder(self, device, model_name, initial_batch_size=128): + def batch_size_finder(self, device, model_name, initial_batch_size=1024): batch_size = initial_batch_size while batch_size >= 1: torch.cuda.empty_cache() diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index c7ecd5f222ec..489fcd69df94 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -89,18 +89,13 @@ def pip_install(package): SKIP = { - # Difficult to run and compare - "Reformer", # Fails deepcopy - "BlenderbotForCausalLM", "BlenderbotForConditionalGeneration", - "GPTJForCausalLM", - "GPTJForQuestionAnswering", "GPTNeoForCausalLM", "GPTNeoForSequenceClassification", # Fails with even batch size = 1 - "DebertaV2ForMaskedLM", - "DebertaV2ForQuestionAnswering", + "GPTJForCausalLM", + "GPTJForQuestionAnswering", } # TODO - Fails even after fake tensors @@ -108,23 +103,54 @@ def pip_install(package): "AlbertForMaskedLM": 2, "AlbertForQuestionAnswering": 2, "AllenaiLongformerBase": 2, + "BartForCausalLM": 2, "BartForConditionalGeneration": 2, "BertForMaskedLM": 2, - "BlenderbotSmallForCausalLM": 2, + "BertForQuestionAnswering": 2, + "BlenderbotForCausalLM": 8, + # "BlenderbotForConditionalGeneration" : 16, + "BlenderbotSmallForCausalLM": 4, "BlenderbotSmallForConditionalGeneration": 2, + "CamemBert": 2, + "DebertaForMaskedLM": 8, + "DebertaForQuestionAnswering": 4, + "DebertaV2ForMaskedLM": 8, + "DebertaV2ForQuestionAnswering": 4, + "DistilBertForMaskedLM": 2, + "DistilBertForQuestionAnswering": 2, + "DistillGPT2": 2, "ElectraForCausalLM": 2, "ElectraForQuestionAnswering": 2, "GPT2ForSequenceClassification": 2, + # "GPTJForCausalLM" : 2, + # "GPTJForQuestionAnswering" : 2, + # "GPTNeoForCausalLM" : 2, + # "GPTNeoForSequenceClassification" : 2, + "GoogleFnet": 2, "LayoutLMForMaskedLM": 2, "LayoutLMForSequenceClassification": 2, + "M2M100ForConditionalGeneration": 4, + "MBartForCausalLM": 2, + "MBartForConditionalGeneration": 2, + "MT5ForConditionalGeneration": 2, + "MegatronBertForCausalLM": 4, + "MegatronBertForQuestionAnswering": 2, + "MobileBertForMaskedLM": 4, + "MobileBertForQuestionAnswering": 2, + "OPTForCausalLM": 2, + "PLBartForCausalLM": 2, + "PLBartForConditionalGeneration": 2, + "PegasusForCausalLM": 4, + "PegasusForConditionalGeneration": 2, "RobertaForCausalLM": 2, + "RobertaForQuestionAnswering": 2, + "Speech2Text2ForCausalLM": 4, "T5ForConditionalGeneration": 2, - # Large footprint - "BartForCausalLM": 4, - "DebertaForQuestionAnswering": 4, - "XLNetLMHeadModel": 4, - # Very large footprint - "DebertaForMaskedLM": 8, + "T5Small": 2, + "TrOCRForCausalLM": 2, + "XGLMForCausalLM": 4, + "XLNetLMHeadModel": 2, + "YituTechConvBert": 2, } @@ -139,18 +165,33 @@ def get_module_cls_by_model_name(model_cls_name): def get_sequence_length(model_cls, model_name): - if model_name.startswith(("Bert", "Roberta", "Blenderbot")): + if model_name.startswith(("Blenderbot",)): seq_length = 128 - elif model_name.startswith(("GPT2", "Bart", "T5")): + elif model_name.startswith(("GPT2", "Bart", "T5", "PLBart", "MBart")): seq_length = 1024 elif model_name in ("AllenaiLongformerBase", "BigBird"): seq_length = 1024 + elif model_name.startswith("OPT"): + seq_length = 2048 elif "Reformer" in model_name: seq_length = 4096 elif model_name.startswith( - ("Albert", "Deberta", "Layout", "Electra", "XLNet") + ( + "Albert", + "Deberta", + "Layout", + "Electra", + "XLNet", + "MegatronBert", + "Bert", + "Roberta", + ) ) or model_name in ("DistillGPT2", "GoogleFnet", "YituTechConvBert", "CamemBert"): seq_length = 512 + elif model_name in ("TrOCRForCausalLM"): + seq_length = 256 + elif model_name.startswith("MobileBert"): + seq_length = 128 else: log.warning( f"Sequence Length not defined for {model_name}. Choosing 128 arbitrarily" @@ -287,10 +328,10 @@ def rand_int_tensor(device, low, high, shape): AutoConfig.from_pretrained("t5-small"), AutoModelForSeq2SeqLM, ), - "BigBird": ( - BigBirdConfig(attention_type="block_sparse"), - AutoModelForMaskedLM, - ), + # "BigBird": ( + # BigBirdConfig(attention_type="block_sparse"), + # AutoModelForMaskedLM, + # ), "DistillGPT2": ( AutoConfig.from_pretrained("distilgpt2"), AutoModelForCausalLM, @@ -461,10 +502,10 @@ def refresh_model_names_and_batch_sizes(): if model_cls in [ CLIPModel, CLIPVisionModel, - SwinForImageClassification, - SwinForImageClassification, - SwinForMaskedImageModeling, - SwinModel, + # SwinForImageClassification, + # SwinForImageClassification, + # SwinForMaskedImageModeling, + # SwinModel, ViTForImageClassification, ViTForMaskedImageModeling, ViTModel, diff --git a/benchmarks/dynamo/huggingface_models_list.txt b/benchmarks/dynamo/huggingface_models_list.txt index 8272c79b12bd..6e3cf19a783d 100644 --- a/benchmarks/dynamo/huggingface_models_list.txt +++ b/benchmarks/dynamo/huggingface_models_list.txt @@ -1,53 +1,51 @@ AlbertForMaskedLM,8 AlbertForQuestionAnswering,8 -AllenaiLongformerBase,1 -BartForCausalLM,16 +AllenaiLongformerBase,8 +BartForCausalLM,8 BartForConditionalGeneration,4 -BertForMaskedLM,128 -BertForQuestionAnswering,128 -BigBird,1 +BertForMaskedLM,32 +BertForQuestionAnswering,32 BlenderbotForCausalLM,32 -BlenderbotForConditionalGeneration,32 -BlenderbotSmallForCausalLM,128 +BlenderbotForConditionalGeneration,16 +BlenderbotSmallForCausalLM,256 BlenderbotSmallForConditionalGeneration,128 -CamemBert,1 +CamemBert,32 DebertaForMaskedLM,32 DebertaForQuestionAnswering,32 DebertaV2ForMaskedLM,8 DebertaV2ForQuestionAnswering,8 -DistilBertForMaskedLM,64 -DistilBertForQuestionAnswering,64 -DistillGPT2,1 +DistilBertForMaskedLM,256 +DistilBertForQuestionAnswering,512 +DistillGPT2,32 ElectraForCausalLM,64 ElectraForQuestionAnswering,128 GPT2ForSequenceClassification,8 GPTJForCausalLM,1 GPTJForQuestionAnswering,1 -GPTNeoForCausalLM,8 -GPTNeoForSequenceClassification,8 -GoogleFnet,1 +GPTNeoForCausalLM,32 +GPTNeoForSequenceClassification,32 +GoogleFnet,32 LayoutLMForMaskedLM,32 LayoutLMForSequenceClassification,32 -M2M100ForConditionalGeneration,8 -MBartForCausalLM,32 -MBartForConditionalGeneration,16 -MT5ForConditionalGeneration,8 +M2M100ForConditionalGeneration,64 +MBartForCausalLM,8 +MBartForConditionalGeneration,4 +MT5ForConditionalGeneration,32 MegatronBertForCausalLM,16 MegatronBertForQuestionAnswering,16 -MobileBertForMaskedLM,32 -MobileBertForQuestionAnswering,64 -OPTForCausalLM,32 -PLBartForCausalLM,32 -PLBartForConditionalGeneration,16 -PegasusForCausalLM,32 -PegasusForConditionalGeneration,16 -Reformer,1 -RobertaForCausalLM,128 -RobertaForQuestionAnswering,128 -Speech2Text2ForCausalLM,128 +MobileBertForMaskedLM,256 +MobileBertForQuestionAnswering,256 +OPTForCausalLM,4 +PLBartForCausalLM,16 +PLBartForConditionalGeneration,8 +PegasusForCausalLM,128 +PegasusForConditionalGeneration,64 +RobertaForCausalLM,32 +RobertaForQuestionAnswering,32 +Speech2Text2ForCausalLM,1024 T5ForConditionalGeneration,8 -T5Small,1 -TrOCRForCausalLM,32 -XGLMForCausalLM,8 -XLNetLMHeadModel,128 -YituTechConvBert,1 +T5Small,8 +TrOCRForCausalLM,64 +XGLMForCausalLM,32 +XLNetLMHeadModel,16 +YituTechConvBert,32 diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 6e0d32d21f97..eea4c26a171c 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -48,10 +48,10 @@ def collect_results(model, prediction, loss, example_inputs): results = [] results.append(prediction) results.append(loss) - if isinstance(loss, torch.Tensor) and loss.item() > 1: - log.warning( - f"High loss value alert - {loss:.2f}. Can result in unstable gradients." - ) + # if isinstance(loss, torch.Tensor) and loss.item() > 1: + # log.warning( + # f"High loss value alert - {loss:.2f}. Can result in unstable gradients." + # ) grads = dict() params = dict() From 126e44173d0dd4d942d8e20c73442048a46cfc24 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 17 Nov 2022 03:27:18 +0000 Subject: [PATCH 291/453] [ONNX] Add onnx-script into ONNX docs (#89078) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89078 Approved by: https://github.com/BowenBao --- docs/source/onnx.rst | 70 +++++++++++++++++++++++++++- test/onnx/test_onnxscript_runtime.py | 2 - 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index fea0b3bc94d2..8f52be124e2e 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -499,6 +499,7 @@ ONNX operators that represent the function's behavior in ONNX. For example:: Inline Autograd Function ~~~~~~~~~~~~~~~~~~~~~~~~ + In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided, :func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that @@ -526,6 +527,73 @@ If you need to avoid inlining of :class:`torch.autograd.Function`, you should ex Custom operators ^^^^^^^^^^^^^^^^ +You can export your model with custom operators that includes a combination of many standard ONNX ops, +or are driven by self-defined C++ backend. + +ONNX-script functions +~~~~~~~~~~~~~~~~~~~~~ + +If an operator is not a standard ONNX op, but can be composed of multiple existing ONNX ops, you can utilize +`ONNX-script `_ to create an external ONNX function to support the operator. +You can export it by following this example:: + + import onnxscript + # There are three opset version needed to be aligned + # This is (1) the opset version in ONNX function + from onnxscript.onnx_opset import opset15 as op + opset_version = 15 + + x = torch.randn(1, 2, 3, 4, requires_grad=True) + model = torch.nn.SELU() + + custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) + + @onnxscript.script(custom_opset) + def Selu(X): + alpha = 1.67326 # auto wrapped as Constants + gamma = 1.0507 + alphaX = op.CastLike(alpha, X) + gammaX = op.CastLike(gamma, X) + neg = gammaX * (alphaX * op.Exp(X) - alphaX) + pos = gammaX * X + zero = op.CastLike(0, X) + return op.Where(X <= zero, neg, pos) + + # setType API provides shape/type to ONNX shape/type inference + def custom_selu(g: jit_utils.GraphContext, X): + return g.onnxscript_op(Selu, X).setType(X.type()) + + # Register custom symbolic function + # There are three opset version needed to be aligned + # This is (2) the opset version in registry + torch.onnx.register_custom_op_symbolic( + symbolic_name="aten::selu", + symbolic_fn=custom_selu, + opset_version=opset_version, + ) + + # There are three opset version needed to be aligned + # This is (2) the opset version in exporter + torch.onnx.export( + model, + x, + "model.onnx", + opset_version=opset_version, + # only needed if you want to specify an opset version > 1. + custom_opsets={"onnx-script": 2} + ) + +The example above exports it as a custom operator in the "onnx-script" opset. +When exporting a custom operator, you can specify the custom domain version using the +``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. + +NOTE: Be careful to align the opset version mentioned in the above example, and make sure they are consumed in exporter step. +The example usage of how to write a onnx-script function is a beta version in terms of the active development on onnx-script. +Please follow the latest `ONNX-script `_ + +C++ Operators +~~~~~~~~~~~~~ + If a model uses a custom operator implemented in C++ as described in `Extending TorchScript with Custom C++ Operators `_, you can export it by following this example:: @@ -563,8 +631,6 @@ you can export it by following this example:: custom_opsets={"custom_domain": 2} ) -You can export your model as one or a combination of many standard ONNX ops, or as a custom ONNX operator. - The example above exports it as a custom operator in the "custom_domain" opset. When exporting a custom operator, you can specify the custom domain version using the ``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1. diff --git a/test/onnx/test_onnxscript_runtime.py b/test/onnx/test_onnxscript_runtime.py index 2d0d1e3a5357..e22e76c8315e 100644 --- a/test/onnx/test_onnxscript_runtime.py +++ b/test/onnx/test_onnxscript_runtime.py @@ -25,8 +25,6 @@ def test_selu_from_onnxscript_example(self): from onnxscript.onnx_opset import opset15 as op - # custom domain is needed for custom Op domain name should be - # aligned to the one in symbolic_fn # TODO(titaiwang): make an official domain for onnxscript usage custom_opset = onnxscript.values.Opset(domain="onnx-script", version=1) From b72f5b9ae3f7d1de74d9d2d40236fd09d606be0e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 17 Nov 2022 06:57:42 +0000 Subject: [PATCH 292/453] [Dynamo] Support typing.Mapping & Support function as argument (#88963) These missing features come from https://github.com/pytorch/benchmark/pull/1302, where we'd like to enable E2E hf_bert dynamo train/eval. The dependent [HuggingFace accelerate library](https://huggingface.co/docs/accelerate/index) requires these improvements. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88963 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 36 ++++++++++++++++++++++++++++ torch/_dynamo/utils.py | 8 +++++++ torch/_dynamo/variables/builder.py | 6 +++-- torch/_dynamo/variables/functions.py | 2 ++ torch/_dynamo/variables/misc.py | 6 +++++ 5 files changed, 56 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b3cddcbf1dff..2825b157bc68 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2792,6 +2792,42 @@ def fn(x): res = opt_fn(x) self.assertTrue(torch.allclose(ref, res)) + def test_user_function_variable_supports_function_argument(self): + def add1(x): + return x + 1 + + def add2(x): + return x + 2 + + def gn(x, f=add1): + if f is add1: + return x + 1 + else: + return x + 2 + + def fn(x, f): + return gn(x, f) + + x = torch.randn(2, 3) + ref = fn(x, add2) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, add2) + self.assertTrue(torch.allclose(ref, res)) + + def test_typing_variable_isinstance(self): + def fn(x, m): + if isinstance(m, typing.Mapping): + return x + 1 + else: + return x - 1 + + x = torch.randn(2, 3) + m = {"x": torch.randn(3)} + ref = fn(x, m) + opt_fn = torch._dynamo.optimize("eager")(fn) + res = opt_fn(x, m) + self.assertTrue(torch.allclose(ref, res)) + def test_repro_graph_breaks_in__get_item_by_idx(self): class Mod(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 0b87be7393b5..f426ef691307 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -19,6 +19,7 @@ import sys import time import types +import typing import weakref from contextlib import contextmanager from functools import lru_cache @@ -275,6 +276,13 @@ def istype(obj, allowed_types): return type(obj) is allowed_types +def is_typing(value): + if sys.version_info < (3, 9): + return isinstance(value, typing._GenericAlias) + else: + return isinstance(value, typing._SpecialGenericAlias) + + def is_numpy_int_type(value): return istype( value, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 67e506b5b435..b1b691c41fc6 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -9,7 +9,7 @@ import re import types from abc import ABCMeta -from typing import Any, List, Union +from typing import Any, Union import numpy as np from functorch.experimental.ops import PyOperator @@ -43,6 +43,7 @@ global_key_name, is_namedtuple, is_numpy_int_type, + is_typing, istensor, istype, odict_values, @@ -360,7 +361,8 @@ def index_source(key): value, guards=make_guards(GuardBuilder.FUNCTION_MATCH), ) - elif value is List: + elif is_typing(value): + # typing.List, typing.Mapping, etc. return TypingVariable( value, guards=make_guards(GuardBuilder.ID_MATCH), diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 88be730c3423..a8bb8bd84c79 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -24,6 +24,8 @@ def wrap_bound_arg(val, options): return cls([wrap_bound_arg(x, options) for x in val], **options) elif variables.ConstantVariable.is_literal(val): return variables.ConstantVariable(val, **options) + elif isinstance(val, types.FunctionType): + return variables.UserFunctionVariable(val, **options) elif isinstance(val, enum.Enum): return variables.EnumVariable(val, **options) elif isinstance(val, (type, abc.ABCMeta)): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 298ddf24862b..952cbd2c6424 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -654,6 +654,12 @@ def call_method( ) unimplemented("typing") + def python_type(self): + return type(self.value) + + def as_python_constant(self): + return self.value + class NumpyVariable(VariableTracker): """ From 37c85cf5f2215da13d5836de46f44af72ed079ba Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 17 Nov 2022 07:24:55 +0000 Subject: [PATCH 293/453] Add warning if tensor cores are not used (#88844) Fixes https://github.com/pytorch/torchdynamo/issues/1839 Should I do this for all backends or just inductor? ## Test On a V100 I got from AWS ```python from torch._dynamo import optimize import torch def fn(x, y): a = torch.cos(x) b = torch.sin(y) return a + b new_fn = optimize("inductor")(fn) a = new_fn(torch.Tensor(1),torch.Tensor(1)) print(a) ``` ## New logs ``` (sourcetorch) ubuntu@ip-172-31-31-152:~/test$ python test.py /home/ubuntu/pytorch/torch/_dynamo/eval_frame.py:318: UserWarning: Tensor cores are available but not enabled. Consider setting torch.backends.cuda.matmul.allow_tf32 == True in your python script for speedups warnings.warn( tensor([1.3717]) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88844 Approved by: https://github.com/ngimel, https://github.com/mlazos, https://github.com/anijain2305 --- torch/_dynamo/eval_frame.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 1188bfd74fc2..6b500a87bc32 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -350,6 +350,16 @@ def get_compiler_fn(compiler_fn): def lookup_backend(compiler_fn): """Expand backend strings to functions""" if compiler_fn == "inductor": + if torch.cuda.is_available(): + if ( + torch.backends.cuda.matmul.allow_tf32 is False + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled." + "Consider setting `torch.set_float32_matmul_precision('high')`" + ) + compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx elif isinstance(compiler_fn, str): from .optimizations import BACKENDS From 3beccbc29939f7a34346ed1a3646f6464086eeb4 Mon Sep 17 00:00:00 2001 From: ecao Date: Thu, 17 Nov 2022 08:15:49 +0000 Subject: [PATCH 294/453] Add BFloat16 support and optimization for mish, hardtanh backward, and silu on CPU (#82460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description * add BFloat16 support for mish and hardtanh backward on CPU. * optimize the performance for silu ### Testing - optimize the performance for silu: bfloat16 single socket (28 cores): ``` before: 1x128x1024 forward 0.090 s backward 0.218 s 10x128x1024 forward 0.146 s backward 0.314 s after: 1x128x1024 forward 0.064 s backward 0.100 s 10x128x1024 forward 0.085 s backward 0.133 s ``` single core: ``` before: 1x128x1024 forward 0.300 s backward 0.606 s 10x128x1024 forward 2.825 s backward 5.834 s after: 1x128x1024 forward 0.156 s backward 0.239 s 10x128x1024 forward 1.447 s backward 2.165 s ``` - Add BFloat16 support for mish and backward of hardtanh on CPU. single socket (20 cores): op | shape | fp32 / s | fp32 / s | bf16 / s |  bf16 / s -- | -- | -- | -- | -- | --   |   | forward | backward | forward | backward silu | [10, 128, 10, 10] | 4.41E-05 | 7.67E-05 | 5.32E-05 | 9.38E-05   | [10, 128, 80, 80] | 0.0008 | 0.001788 | 0.00067 | 0.001031 mish | [10, 128, 10, 10] | 0.000356 | 0.000427 | 0.000367 | 0.000436   | [10, 128, 80, 80] | 0.004527 | 0.005807 | 0.004757 | 0.005393 hardtanh | [10, 128, 10, 10] | / | 3.97E-05 | / | 4.45E-05   | [10, 128, 80, 80] | / | 0.001748 | / | 0.000645 single core: op | shape | fp32 / s | fp32 / s | bf16 / s |  bf16 / s -- | -- | -- | -- | -- | --   |   | forward | backward | forward | backward silu | [10, 128, 10, 10] | 1.17E-04 | 1.91E-04 | 1.35E-04 | 2.23E-04   | [10, 128, 80, 80] | 0.007434 | 0.013141 | 0.008464 | 0.013044 mish | [10, 128, 10, 10] | 0.00103 | 0.00122 | 0.00106 | 0.001227   | [10, 128, 80, 80] | 0.065629 | 0.078418 | 0.067779 | 0.077214 hardtanh | [10, 128, 10, 10] | / | 1.18E-04 | / | 9.30E-05   | [10, 128, 80, 80] | / | 0.010773 | / | 0.005834 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82460 Approved by: https://github.com/mingfeima, https://github.com/malfet --- aten/src/ATen/native/cpu/Activation.cpp | 114 ++++++++++++++++-- test/test_nn.py | 3 + .../_internal/common_methods_invocations.py | 6 +- 3 files changed, 113 insertions(+), 10 deletions(-) diff --git a/aten/src/ATen/native/cpu/Activation.cpp b/aten/src/ATen/native/cpu/Activation.cpp index 6f3eac783ccd..728ea62f1898 100644 --- a/aten/src/ATen/native/cpu/Activation.cpp +++ b/aten/src/ATen/native/cpu/Activation.cpp @@ -623,7 +623,25 @@ void shrink_backward_kernel(TensorIteratorBase& iter, const Scalar& lambd) { } void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Scalar& max) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] { + if (iter.dtype() == kBFloat16) { + auto min_val = min.to(); + auto max_val = max.to(); + cpu_kernel_vec( + iter, + [=](BFloat16 grad_val, BFloat16 self_val) -> BFloat16 { + return (float(self_val) <= min_val || float(self_val) >= max_val) ? BFloat16(0) : grad_val; + }, + [=](Vectorized grad_val, Vectorized self_val) -> Vectorized { + Vectorized grad_val0, grad_val1, self_val0, self_val1; + std::tie(grad_val0, grad_val1) = convert_bfloat16_float(grad_val); + std::tie(self_val0, self_val1) = convert_bfloat16_float(self_val); + return convert_float_bfloat16( + ((self_val0 > min_val) & (self_val0 < max_val)) & grad_val0, + ((self_val1 > min_val) & (self_val1 < max_val)) & grad_val1 + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardshrink_backward_cpu", [&] { auto min_val = min.to(); auto max_val = max.to(); cpu_kernel_vec( @@ -635,6 +653,7 @@ void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Sca return ((self_val > min_val) & (self_val < max_val)) & grad_val; }); }); + } } void hardswish_kernel(TensorIterator& iter) { @@ -1035,8 +1054,23 @@ void glu_backward_kernel(TensorIterator& iter) { } void silu_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( - kBFloat16, iter.dtype(), "silu_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + const Vectorized kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 x) -> BFloat16 { + return float(x) / (1.0f + std::exp(-float(x))); + }, + [kOneVec](Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + return convert_float_bfloat16( + x_vec0 / (kOneVec + x_vec0.neg().exp()), + x_vec1 / (kOneVec + x_vec1.neg().exp())); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + iter.dtype(), "silu_cpu", [&]() { const Vectorized kOneVec(scalar_t(1)); cpu_kernel_vec( iter, @@ -1047,11 +1081,34 @@ void silu_kernel(TensorIteratorBase& iter) { return x_vec / (kOneVec + x_vec.neg().exp()); }); }); + } } void silu_backward_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( - kBFloat16, iter.dtype(), "silu_backward_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + const Vectorized kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 dy, BFloat16 x) -> BFloat16 { + const float sigmoid = + 1.0f / (1.0f + std::exp(-float(x))); + return dy * sigmoid * (1.0f + x * (1.0f - sigmoid)); + }, + [kOneVec](Vectorized dy_vec, Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1, dy_vec0, dy_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + std::tie(dy_vec0, dy_vec1) = convert_bfloat16_float(dy_vec); + const Vectorized sigmoid0 = + kOneVec / (kOneVec + x_vec0.neg().exp()); + const Vectorized sigmoid1 = + kOneVec / (kOneVec + x_vec1.neg().exp()); + return convert_float_bfloat16( + dy_vec0 * sigmoid0 * (kOneVec + x_vec0 * (kOneVec - sigmoid0)), + dy_vec1 * sigmoid1 * (kOneVec + x_vec1 * (kOneVec - sigmoid1))); + }); + } else { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + iter.dtype(), "silu_backward_cpu", [&]() { const Vectorized kOneVec(scalar_t(1)); cpu_kernel_vec( iter, @@ -1066,10 +1123,26 @@ void silu_backward_kernel(TensorIteratorBase& iter) { return dy_vec * sigmoid * (kOneVec + x_vec * (kOneVec - sigmoid)); }); }); + } } void mish_kernel(TensorIteratorBase& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + cpu_kernel_vec( + iter, + [](BFloat16 x) -> BFloat16{ + return static_cast(float(x) * std::tanh(std::log1p(std::exp(float(x))))); + }, + [](Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + return convert_float_bfloat16( + x_vec0 * x_vec0.exp().log1p().tanh(), + x_vec1 * x_vec1.exp().log1p().tanh() + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_cpu", [&]() { using Vec = Vectorized; cpu_kernel_vec( iter, @@ -1080,10 +1153,36 @@ void mish_kernel(TensorIteratorBase& iter) { return x_vec * x_vec.exp().log1p().tanh(); }); }); + } } void mish_backward_kernel(TensorIterator& iter) { - AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() { + if (iter.dtype() == kBFloat16) { + using Vec = Vectorized; + const Vec kOneVec(1.0f); + cpu_kernel_vec( + iter, + [](BFloat16 dy, BFloat16 x) -> BFloat16 { + const float sigmoid = + 1.0f / (1.0f + std::exp(-float(x))); + const float tanh_softplus = std::tanh(std::log1p(std::exp(float(x)))); + return dy * (tanh_softplus + x * sigmoid * (1.0f - tanh_softplus * tanh_softplus)); + }, + [kOneVec](Vectorized dy_vec, Vectorized x_vec) -> Vectorized { + Vectorized x_vec0, x_vec1, dy_vec0, dy_vec1; + std::tie(x_vec0, x_vec1) = convert_bfloat16_float(x_vec); + std::tie(dy_vec0, dy_vec1) = convert_bfloat16_float(dy_vec); + const Vec sigmoid0 = kOneVec / (kOneVec + x_vec0.neg().exp()); + const Vec sigmoid1 = kOneVec / (kOneVec + x_vec1.neg().exp()); + const Vec tanh_softplus0 = x_vec0.exp().log1p().tanh(); + const Vec tanh_softplus1 = x_vec1.exp().log1p().tanh(); + return convert_float_bfloat16( + dy_vec0 * (tanh_softplus0 + x_vec0 * sigmoid0 * (kOneVec - tanh_softplus0 * tanh_softplus0)), + dy_vec1 * (tanh_softplus1 + x_vec1 * sigmoid1 * (kOneVec - tanh_softplus1 * tanh_softplus1)) + ); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "mish_backward_cpu", [&]() { using Vec = Vectorized; const Vec kOneVec(scalar_t(1)); cpu_kernel_vec( @@ -1100,6 +1199,7 @@ void mish_backward_kernel(TensorIterator& iter) { return dy_vec * (tanh_softplus + x_vec * sigmoid * (kOneVec - tanh_softplus * tanh_softplus)); }); }); + } } void prelu_cpu_kernel(TensorIterator& iter) { diff --git a/test/test_nn.py b/test/test_nn.py index 7d6a016a6f51..25f85c60037b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -14640,6 +14640,9 @@ def test_bfloat16(fn, device, inp_dims, prec): test_bfloat16(torch.nn.Softshrink(), device, shape, prec=1e-2) test_bfloat16(torch.nn.Hardswish(), device, shape, prec=2e-2) test_bfloat16(torch.nn.Softplus(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.SiLU(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.Hardtanh(), device, shape, prec=1e-2) + test_bfloat16(torch.nn.Mish(), device, shape, prec=1e-2) @onlyCUDA def test_activations_bfloat16(self, device): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5db917424a2f..8fe70e71614d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12141,7 +12141,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'nn.functional.mish', aten_backward_name='mish_backward', ref=lambda x: x * np.tanh(reference_softplus(x)), - dtypes=floating_types(), + dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -12497,7 +12497,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): aten_name="hardtanh", aten_backward_name='hardtanh_backward', dtypes=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16), - backward_dtypes=all_types(), + backward_dtypes=all_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.float16), @@ -12530,7 +12530,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): UnaryUfuncInfo('nn.functional.relu6', aten_name="relu6", dtypes=all_types_and(torch.bfloat16), - backward_dtypes=floating_types(), + backward_dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16), backward_dtypesIfCUDA=floating_types_and(torch.float16), assert_autodiffed=True, From bdc9911575277848ccac56b344dd624aa97fb87d Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 16 Nov 2022 23:31:57 +0000 Subject: [PATCH 295/453] Fix typo in dist_util.py (#89167) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89167 Approved by: https://github.com/davidberard98 --- benchmarks/dynamo/dist_util.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index 9957ef6139df..24625c84e1a1 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -25,11 +25,8 @@ def setup(rank, world_size): - # set defaults in case torchrun isn't used; no idea why the if is needed, but it hangs torchrun otherwise - if not os.getenv("MASTER_ADDR"): - os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") - if not os.getenv("MASTER_PORT"): - os.environ["MASTER_PORT"] = os.getenv("MASETER_PORT", "12355") + os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost") + os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355") os.environ["RANK"] = os.getenv("RANK", "0") os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1") dist.init_process_group("nccl") From e686b8c3ba93cb7caa314c78bf84dbd2d7df9683 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 16 Nov 2022 21:31:02 -0800 Subject: [PATCH 296/453] Reland "Towards unifying symbolic and non symbolic fake tensor (#89038)" (#89143) This reverts commit cf6003f0469ae1440d4a8585860c2c5f4c738707. Differential Revision: [D41363992](https://our.internmc.facebook.com/intern/diff/D41363992) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89143 Approved by: https://github.com/albanD --- aten/src/ATen/native/TensorFactories.cpp | 6 --- test/functorch/test_aotdispatch.py | 1 - test/test_proxy_tensor.py | 21 +++------ torch/_meta_registrations.py | 39 +++++++++++++++- torch/_ops.py | 1 + torch/_prims/__init__.py | 5 +- torch/_prims_common/__init__.py | 3 ++ torch/_subclasses/fake_tensor.py | 58 +++++++++--------------- 8 files changed, 71 insertions(+), 63 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 9d1c6d8a3633..7245cb77b1c5 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -325,12 +325,6 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = self.options() .merge_in(options_) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 1dc5476158f9..ae216f9be4a4 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1011,7 +1011,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8dc42be7fdfb..0a24807af55f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1151,9 +1151,7 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1235,8 +1233,6 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... @@ -1281,7 +1277,6 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... @@ -1298,7 +1293,6 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition @@ -1347,11 +1341,15 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) +outplace_symbolic_tensor_failures = { + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition +} + inplace_symbolic_tensor_failures = { - xfail('abs', ''), # aten.abs_.default - couldn't find symbolic meta function/decomposition xfail('acos', ''), # aten.acos_.default - couldn't find symbolic meta function/decomposition xfail('acosh', ''), # aten.acosh_.default - couldn't find symbolic meta function/decomposition - xfail('addbmm', ''), # aten.addbmm_.default - couldn't find symbolic meta function/decomposition xfail('addcdiv', ''), # aten.addcdiv_.default - couldn't find symbolic meta function/decomposition xfail('addcmul', ''), # aten.addcmul_.default - couldn't find symbolic meta function/decomposition xfail('addmm', ''), # aten.addmm_.default - couldn't find symbolic meta function/decomposition @@ -1365,7 +1363,6 @@ def f(a, b, c, d, e): xfail('clamp', ''), # aten.clamp_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_max', ''), # aten.clamp_max_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_min', ''), # aten.clamp_min_.Tensor - couldn't find symbolic meta function/decomposition - xfail('conj_physical', ''), # aten.conj_physical_.default - couldn't find symbolic meta function/decomposition xfail('copysign', ''), # aten.copysign_.Tensor - couldn't find symbolic meta function/decomposition xfail('cos', ''), # aten.cos_.default - couldn't find symbolic meta function/decomposition xfail('cosh', ''), # aten.cosh_.default - couldn't find symbolic meta function/decomposition @@ -1382,7 +1379,6 @@ def f(a, b, c, d, e): xfail('expm1', ''), # aten.expm1_.default - couldn't find symbolic meta function/decomposition xfail('float_power', ''), # the base given to float_power_ has dtype Float but the operation's result requires dtype Double xfail('floor', ''), # aten.floor_.default - couldn't find symbolic meta function/decomposition - xfail('floor_divide', ''), # aten.floor_divide_.Tensor - couldn't find symbolic meta function/decomposition xfail('fmod', ''), # aten.fmod_.Tensor - couldn't find symbolic meta function/decomposition xfail('frac', ''), # aten.frac_.default - couldn't find symbolic meta function/decomposition xfail('ge', ''), # aten.ge_.Tensor - couldn't find symbolic meta function/decomposition @@ -1398,7 +1394,6 @@ def f(a, b, c, d, e): xfail('log1p', ''), # aten.log1p_.default - couldn't find symbolic meta function/decomposition xfail('log2', ''), # aten.log2_.default - couldn't find symbolic meta function/decomposition xfail('log', ''), # aten.log_.default - couldn't find symbolic meta function/decomposition - xfail('logit', ''), # aten.logit_.default - couldn't find symbolic meta function/decomposition xfail('lt', ''), # aten.lt_.Tensor - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_1'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_3'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition @@ -1408,7 +1403,6 @@ def f(a, b, c, d, e): xfail('neg', ''), # aten.neg_.default - couldn't find symbolic meta function/decomposition xfail('nextafter', ''), # aten.nextafter_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.celu', ''), # aten.celu_.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.dropout3d', ''), # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition xfail('nn.functional.elu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.hardsigmoid', ''), # aten.hardsigmoid_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.mish', ''), # aten.mish_.default - couldn't find symbolic meta function/decomposition @@ -1426,7 +1420,6 @@ def f(a, b, c, d, e): xfail('sinh', ''), # aten.sinh_.default - couldn't find symbolic meta function/decomposition xfail('sqrt', ''), # aten.sqrt_.default - couldn't find symbolic meta function/decomposition xfail('square', ''), # aten.pow_.Scalar - couldn't find symbolic meta function/decomposition - xfail('squeeze', ''), # aten.squeeze_.default - couldn't find symbolic meta function/decomposition xfail('t', ''), # aten.t_.default - couldn't find symbolic meta function/decomposition xfail('tan', ''), # aten.tan_.default - couldn't find symbolic meta function/decomposition xfail('tanh', ''), # aten.tanh_.default - couldn't find symbolic meta function/decomposition @@ -1516,7 +1509,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 4fa3ab09d275..9849df0a58af 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1513,7 +1513,6 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, - aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1521,6 +1520,44 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + if layout == torch.sparse_coo: + check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + return aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): diff --git a/torch/_ops.py b/torch/_ops.py index 9163932144d0..b20398a7f3ab 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -365,6 +365,7 @@ def handler(*args, **kwargs): return handler final_key = resolve_key(self, key) + # print(self, key, final_key) r = self.py_kernels.get(final_key, final_key) self._dispatch_cache[key] = r return r diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index da8d9af723ac..a4bac68f0ff1 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1150,9 +1150,6 @@ def _minimum_aten( # # View operations -# -# TODO: model view relationships -# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1170,7 +1167,7 @@ def _as_strided_meta( a._typed_storage(), size, stride, storage_offset ) - return TensorMeta(a, shape=size, strides=stride) + return torch.as_strided(a, size, stride, storage_offset) def _as_strided_aten( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 128796dfa3d0..041448e8102a 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -291,6 +291,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ + if a.is_sparse: + return False + # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 5d3d3a0e32fe..9a0ac050e6b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,7 +1,6 @@ import contextlib import functools import itertools -import sys import weakref from dataclasses import dataclass from functools import partial @@ -297,8 +296,9 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - # Not in_kernel_invocation_manager as no fake tensor inputs - with no_dispatch(): + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -821,40 +821,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - from torch._decomp import decomposition_table - - with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table - # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise - if has_symbolic_sizes and not self.cpp_meta_supports_symint(func): - from torch._decomp import meta_table as meta_table + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table - if func == aten.size.default: - sys.stderr.write( - "Trying to call aten.size on a tensor with symbolic shapes. " - "It's likely that this is from calling tensor.shape in C++" + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) ) - # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` - return None - - with self: - if func in meta_table: - r = meta_table[func](*args, **kwargs) - return r - if func in decomposition_table: + ): + with self: return decomposition_table[func](*args, **kwargs) - if ( - func in decomposition_table - and torch_decomp_decompositions(func) - and all(not e.is_sparse for e in flat_arg_fake_tensors) - ): with self: - return decomposition_table[func](*args, **kwargs) + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them @@ -865,12 +855,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self: return func.prim_meta_impl(*args, **kwargs) - if has_symbolic_sizes: - if not self.cpp_meta_supports_symint(func): - raise RuntimeError( - f"{func} - couldn't find symbolic meta function/decomposition" - ) - # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors # and then afterwards wrapping them to a FakeTensor From 2b131b1d43b10a2a005f3f042f920a62501e4e2d Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Thu, 17 Nov 2022 03:33:32 +0000 Subject: [PATCH 297/453] Support masked_fill (#88736) Support `masked_fill` to address the GPT2 performance issue. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88736 Approved by: https://github.com/jansel, https://github.com/jgong5 --- test/inductor/test_torchinductor.py | 23 +++++++++++++ torch/_inductor/codegen/cpp.py | 51 ++++++++++++++++++++++++---- torch/_inductor/codegen/cpp_prefix.h | 12 +++++++ 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fb7ca1fc92b7..efedeca381f3 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4602,6 +4602,29 @@ def test_complex_memory_overlap(self): not codecache.get_cpu_proc_info(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) + def test_masked_fill_softmax(self): + def fn(value, mask): + mask = mask.to(torch.bool) + x = torch.masked_fill(value, mask, -33.0) + return torch.softmax(x, -1) + + value = torch.randn((2, 17)) + mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value, mask) + + real_out = fn(value, mask) + compiled_out = opt_fn(value, mask) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count >= 1 + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): def fn(x): return (torch.sign(x),) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 65a9335d6cbf..9f00563a954e 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -311,6 +311,10 @@ def maximum(a, b): def square(a): return f"{a}.pow(2)" + @staticmethod + def where(a, b, c): + return f"decltype({b})::blendv({c}, {b}, {a})" + @staticmethod def sign(x): code = BracesBuffer() @@ -330,6 +334,11 @@ def sign(x): V.kernel.compute.splice(code) return result + @staticmethod + def to_dtype(x, dtype): + assert dtype in [torch.bool], f"{__name__} does not support {dtype}" + return f"({x})" + class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -740,7 +749,16 @@ def load(self, name: str, index: sympy.Expr): if expanded_index == new_index: line = f"at::vec::Vectorized({var}[{cexpr(index)}])" else: - line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" + if V.graph.get_dtype(name) in [torch.bool, torch.uint8]: + g_tmp_buf = f"g_tmp_buffer_{var}" + nelements = codecache.pick_vec_isa().nelements() + self.loads.writeline(f"float {g_tmp_buf}[{nelements}] = {{0}};") + self.loads.writeline( + f"flag_to_float({var} + {cexpr(new_index)}, {g_tmp_buf}, {nelements});" + ) + line = f"at::vec::Vectorized::loadu({g_tmp_buf})" + else: + line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" return self.cse.generate(self.loads, line) @@ -837,9 +855,6 @@ def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr): return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index) def could_vec(self, name: str, index: sympy.Expr): - if V.graph.get_dtype(name) is not torch.float: - return False - assert self.itervars is not None # Not a loop if len(self.itervars) == 0: @@ -849,12 +864,24 @@ def could_vec(self, name: str, index: sympy.Expr): return self.is_legal_data_access(most_inner_var, index) def load(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) + if not V.graph.get_dtype(name) in [ + torch.float, + torch.float32, + torch.bool, + torch.uint8, + ]: + self.simd_vec = False + return self.simd_vec + index = self.rename_indexing(index) self.simd_vec = self.simd_vec and self.could_vec(name, index) return self.simd_vec def store(self, name, index, value, mode=None): + if not V.graph.get_dtype(name) in [torch.float, torch.float32]: + self.simd_vec = False + return self.simd_vec + assert "buf" in name index = self.rename_indexing(index) @@ -927,15 +954,24 @@ def constant(val, dtype): @staticmethod def index_expr(expr, dtype): self.simd_vec = False - return self.cse.newvar() + tmp_var = self.cse.newvar() + return tmp_var @staticmethod def indirect_indexing(index_var): + self.simd_vec = False return sympy.Symbol(str(index_var)) @staticmethod def masked(mask, body, other): - return V.kernel.cse.newvar() + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def to_dtype(x, dtype): + if dtype != torch.bool: + self.simd_vec = False + return x self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) @@ -1040,6 +1076,7 @@ def codegen_loops(self, code, worksharing): if reduction_par_depth > 0 and reduction_par_depth != len( loops_nest_reduce.loops ): + metrics.generated_cpp_vec_kernel_count -= 1 return self.simd_omp_kernel.codegen_loops(code, worksharing) with contextlib.ExitStack() as stack: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 1905aefcda5c..c1c9c3bae112 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -57,3 +57,15 @@ template void atomic_add(volatile T *addr, T offset) { } while (!atomic_addr->compare_exchange_weak(expected, desired, std::memory_order_relaxed)); } + +// This function is used to convert bool or uint8 to float mask for +// vectorization. The caller needs to make sure the src represents TRUE/FALSE +// correctly. +template +void flag_to_float(const T* src, float* dst, int64_t n) { +#pragma unroll + for (int64_t i = 0; i < n; i++) { + uint32_t* dst_u32 = (uint32_t*)dst; + dst_u32[i] = *(src + i) ? 0xFFFFFFFF : 0; + } +} From cd81a700ecfb84a039257896af7b8398435b089e Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Thu, 17 Nov 2022 16:43:16 +0000 Subject: [PATCH 298/453] Fix buffer overflow from AddressSanitizer checks due to inaccurate bfloat16 representation of large integer (#89210) Fixes #88939 The root cause of the issue is that BF16 cannot accurately represent big integer values. In the test case below, `539` as one of the corner pixel index is wrongly represented as `540` (from https://github.com/jgong5/pytorch/blob/fc60a1865eafc985217eccc0251f82014041e6a7/aten/src/ATen/native/UpSample.h#L271) and then the access out of the range with this index. Thanks to @malfet for the investigation and initial fix. I also reported an issue https://github.com/pytorch/pytorch/issues/89212 to track the issue of inaccurate integer representation of bf16 that need to be addressed in other places of PyTorch. ```python import torch def test(): arg_1 = torch.rand([1, 10, 540, 540], dtype=torch.bfloat16).clone() res = torch.nn.functional.interpolate(arg_1,2,mode='bilinear',align_corners=True) test() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89210 Approved by: https://github.com/malfet --- aten/src/ATen/native/cpu/UpSampleKernel.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 7eb7cf5e58bb..8d418c264504 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -474,9 +474,9 @@ void cpu_upsample_linear_channels_last( using opmath_t = at::opmath_type; using Vec = vec::Vectorized; auto loop2d = [&](int64_t begin, int64_t end) { - const scalar_t height_scale = area_pixel_compute_scale( + const auto height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales[0]); - const scalar_t width_scale = area_pixel_compute_scale( + const auto width_scale = area_pixel_compute_scale( input_width, output_width, align_corners, scales[1]); auto input_indexr = [=](int64_t n, int64_t h, int64_t w) { @@ -486,7 +486,7 @@ void cpu_upsample_linear_channels_last( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t ih0, ih1, iw0, iw1; - scalar_t h0lambda, h1lambda, w0lambda, w1lambda; + opmath_t h0lambda, h1lambda, w0lambda, w1lambda; for (const auto n : c10::irange(begin, end)) { for (const auto oh : c10::irange(output_height)) { compute_source_index_and_lambda( @@ -521,11 +521,11 @@ void cpu_upsample_linear_channels_last( }; auto loop3d = [&](int64_t begin, int64_t end) { - const scalar_t depth_scale = area_pixel_compute_scale( + const auto depth_scale = area_pixel_compute_scale( input_depth, output_depth, align_corners, scales[0]); - const scalar_t height_scale = area_pixel_compute_scale( + const auto height_scale = area_pixel_compute_scale( input_height, output_height, align_corners, scales[1]); - const scalar_t width_scale = area_pixel_compute_scale( + const auto width_scale = area_pixel_compute_scale( input_width, output_width, align_corners, scales[2]); auto input_indexr = [=](int64_t n, int64_t d, int64_t h, int64_t w) { @@ -536,7 +536,7 @@ void cpu_upsample_linear_channels_last( // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t id0, id1, ih0, ih1, iw0, iw1; - scalar_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda; + opmath_t d0lambda, d1lambda, h0lambda, h1lambda, w0lambda, w1lambda; for (const auto n : c10::irange(begin, end)) { for (const auto od : c10::irange(output_depth)) { compute_source_index_and_lambda( From 8e4c9828f4c990f439179912159086aaed790493 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Nov 2022 17:02:36 +0000 Subject: [PATCH 299/453] Revert "Reland "Towards unifying symbolic and non symbolic fake tensor (#89038)" (#89143)" This reverts commit e686b8c3ba93cb7caa314c78bf84dbd2d7df9683. Reverted https://github.com/pytorch/pytorch/pull/89143 on behalf of https://github.com/ZainRizvi due to This seems to be causing the test_make_fx_symbolic_exhaustive_rad2deg_cpu_float32 and test_make_fx_symbolic_exhaustive_inplace_rad2deg_cpu_float32 test to fail across multiple jobs --- aten/src/ATen/native/TensorFactories.cpp | 6 +++ test/functorch/test_aotdispatch.py | 1 + test/test_proxy_tensor.py | 21 ++++++--- torch/_meta_registrations.py | 39 +--------------- torch/_ops.py | 1 - torch/_prims/__init__.py | 5 +- torch/_prims_common/__init__.py | 3 -- torch/_subclasses/fake_tensor.py | 58 +++++++++++++++--------- 8 files changed, 63 insertions(+), 71 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 7245cb77b1c5..9d1c6d8a3633 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -325,6 +325,12 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + + TORCH_CHECK( + !(options_.has_memory_format() && optional_memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + TensorOptions options = self.options() .merge_in(options_) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ae216f9be4a4..1dc5476158f9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1011,6 +1011,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides + xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 0a24807af55f..8dc42be7fdfb 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1151,7 +1151,9 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition + xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition + xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1233,6 +1235,8 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... @@ -1277,6 +1281,7 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... @@ -1293,6 +1298,7 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition + xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition @@ -1341,15 +1347,11 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) -outplace_symbolic_tensor_failures = { - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition -} - inplace_symbolic_tensor_failures = { + xfail('abs', ''), # aten.abs_.default - couldn't find symbolic meta function/decomposition xfail('acos', ''), # aten.acos_.default - couldn't find symbolic meta function/decomposition xfail('acosh', ''), # aten.acosh_.default - couldn't find symbolic meta function/decomposition + xfail('addbmm', ''), # aten.addbmm_.default - couldn't find symbolic meta function/decomposition xfail('addcdiv', ''), # aten.addcdiv_.default - couldn't find symbolic meta function/decomposition xfail('addcmul', ''), # aten.addcmul_.default - couldn't find symbolic meta function/decomposition xfail('addmm', ''), # aten.addmm_.default - couldn't find symbolic meta function/decomposition @@ -1363,6 +1365,7 @@ def f(a, b, c, d, e): xfail('clamp', ''), # aten.clamp_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_max', ''), # aten.clamp_max_.Tensor - couldn't find symbolic meta function/decomposition xfail('clamp_min', ''), # aten.clamp_min_.Tensor - couldn't find symbolic meta function/decomposition + xfail('conj_physical', ''), # aten.conj_physical_.default - couldn't find symbolic meta function/decomposition xfail('copysign', ''), # aten.copysign_.Tensor - couldn't find symbolic meta function/decomposition xfail('cos', ''), # aten.cos_.default - couldn't find symbolic meta function/decomposition xfail('cosh', ''), # aten.cosh_.default - couldn't find symbolic meta function/decomposition @@ -1379,6 +1382,7 @@ def f(a, b, c, d, e): xfail('expm1', ''), # aten.expm1_.default - couldn't find symbolic meta function/decomposition xfail('float_power', ''), # the base given to float_power_ has dtype Float but the operation's result requires dtype Double xfail('floor', ''), # aten.floor_.default - couldn't find symbolic meta function/decomposition + xfail('floor_divide', ''), # aten.floor_divide_.Tensor - couldn't find symbolic meta function/decomposition xfail('fmod', ''), # aten.fmod_.Tensor - couldn't find symbolic meta function/decomposition xfail('frac', ''), # aten.frac_.default - couldn't find symbolic meta function/decomposition xfail('ge', ''), # aten.ge_.Tensor - couldn't find symbolic meta function/decomposition @@ -1394,6 +1398,7 @@ def f(a, b, c, d, e): xfail('log1p', ''), # aten.log1p_.default - couldn't find symbolic meta function/decomposition xfail('log2', ''), # aten.log2_.default - couldn't find symbolic meta function/decomposition xfail('log', ''), # aten.log_.default - couldn't find symbolic meta function/decomposition + xfail('logit', ''), # aten.logit_.default - couldn't find symbolic meta function/decomposition xfail('lt', ''), # aten.lt_.Tensor - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_1'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition xfail('mvlgamma', 'mvlgamma_p_3'), # aten.mvlgamma_.default - couldn't find symbolic meta function/decomposition @@ -1403,6 +1408,7 @@ def f(a, b, c, d, e): xfail('neg', ''), # aten.neg_.default - couldn't find symbolic meta function/decomposition xfail('nextafter', ''), # aten.nextafter_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.celu', ''), # aten.celu_.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.dropout3d', ''), # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition xfail('nn.functional.elu', ''), # aten.elu_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.hardsigmoid', ''), # aten.hardsigmoid_.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.mish', ''), # aten.mish_.default - couldn't find symbolic meta function/decomposition @@ -1420,6 +1426,7 @@ def f(a, b, c, d, e): xfail('sinh', ''), # aten.sinh_.default - couldn't find symbolic meta function/decomposition xfail('sqrt', ''), # aten.sqrt_.default - couldn't find symbolic meta function/decomposition xfail('square', ''), # aten.pow_.Scalar - couldn't find symbolic meta function/decomposition + xfail('squeeze', ''), # aten.squeeze_.default - couldn't find symbolic meta function/decomposition xfail('t', ''), # aten.t_.default - couldn't find symbolic meta function/decomposition xfail('tan', ''), # aten.tan_.default - couldn't find symbolic meta function/decomposition xfail('tanh', ''), # aten.tanh_.default - couldn't find symbolic meta function/decomposition @@ -1509,7 +1516,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9849df0a58af..4fa3ab09d275 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1513,6 +1513,7 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, + aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1520,44 +1521,6 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) -# zeros_like is special cased to work for sparse -@register_meta(aten.zeros_like.default) -def zeros_like( - self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None -): - if layout == torch.sparse_coo: - check( - memory_format is None, - lambda: "memory format option is only supported by strided tensors", - ) - - res = torch.empty( - 0, - dtype=self.dtype if dtype is None else dtype, - layout=layout, - device=self.device if device is None else device, - pin_memory=pin_memory, - ) - - if self.is_sparse: - res.sparse_resize_and_clear_( - self.size(), self.sparse_dim(), self.dense_dim() - ) - else: - res.sparse_resize_and_clear_(self.size(), self.dim(), 0) - - res._coalesced_(True) - return res - return aten.empty_like.default( - self, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - memory_format=memory_format, - ) - - # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): diff --git a/torch/_ops.py b/torch/_ops.py index b20398a7f3ab..9163932144d0 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -365,7 +365,6 @@ def handler(*args, **kwargs): return handler final_key = resolve_key(self, key) - # print(self, key, final_key) r = self.py_kernels.get(final_key, final_key) self._dispatch_cache[key] = r return r diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index a4bac68f0ff1..da8d9af723ac 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1150,6 +1150,9 @@ def _minimum_aten( # # View operations +# +# TODO: model view relationships +# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1167,7 +1170,7 @@ def _as_strided_meta( a._typed_storage(), size, stride, storage_offset ) - return torch.as_strided(a, size, stride, storage_offset) + return TensorMeta(a, shape=size, strides=stride) def _as_strided_aten( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 041448e8102a..128796dfa3d0 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -291,9 +291,6 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ - if a.is_sparse: - return False - # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9a0ac050e6b9..5d3d3a0e32fe 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,6 +1,7 @@ import contextlib import functools import itertools +import sys import weakref from dataclasses import dataclass from functools import partial @@ -296,9 +297,8 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - # _like constructors have fake tensor inputs (maybe this causes the non-like - # to fail? hmmm) - with in_kernel_invocation_manager(fake_mode): + # Not in_kernel_invocation_manager as no fake tensor inputs + with no_dispatch(): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -821,30 +821,40 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - # If there's a Python meta, prefer that over the decomposition - from torch._decomp import meta_table as meta_table + from torch._decomp import decomposition_table + + with self: + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r - if func not in meta_table and not self.cpp_meta_supports_symint(func): - from torch._decomp import decomposition_table + # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise + if has_symbolic_sizes and not self.cpp_meta_supports_symint(func): + from torch._decomp import meta_table as meta_table - # Prefer Python decompositions over C++ ones - if func in decomposition_table and ( - has_symbolic_sizes - or ( - # TODO: Remove these exclusions, so that we can remove - # this leg entirely - torch_decomp_decompositions(func) - and all(not e.is_sparse for e in flat_arg_fake_tensors) + if func == aten.size.default: + sys.stderr.write( + "Trying to call aten.size on a tensor with symbolic shapes. " + "It's likely that this is from calling tensor.shape in C++" ) - ): - with self: - return decomposition_table[func](*args, **kwargs) + # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` + return None with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: + if func in meta_table: + r = meta_table[func](*args, **kwargs) return r + if func in decomposition_table: + return decomposition_table[func](*args, **kwargs) + + if ( + func in decomposition_table + and torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) + ): + with self: + return decomposition_table[func](*args, **kwargs) # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them @@ -855,6 +865,12 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self: return func.prim_meta_impl(*args, **kwargs) + if has_symbolic_sizes: + if not self.cpp_meta_supports_symint(func): + raise RuntimeError( + f"{func} - couldn't find symbolic meta function/decomposition" + ) + # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors # and then afterwards wrapping them to a FakeTensor From 706f791a1912af62e5a605bf93e246b457506627 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 17 Nov 2022 18:27:08 +0000 Subject: [PATCH 300/453] Revert "Support masked_fill (#88736)" This reverts commit 2b131b1d43b10a2a005f3f042f920a62501e4e2d. Reverted https://github.com/pytorch/pytorch/pull/88736 on behalf of https://github.com/kit1980 due to Inductor tests are failing with AttributeError: module 'torch._inductor.codecache' has no attribute 'valid_vec_isa_list' --- test/inductor/test_torchinductor.py | 23 ------------- torch/_inductor/codegen/cpp.py | 51 ++++------------------------ torch/_inductor/codegen/cpp_prefix.h | 12 ------- 3 files changed, 7 insertions(+), 79 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index efedeca381f3..fb7ca1fc92b7 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4602,29 +4602,6 @@ def test_complex_memory_overlap(self): not codecache.get_cpu_proc_info(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) - def test_masked_fill_softmax(self): - def fn(value, mask): - mask = mask.to(torch.bool) - x = torch.masked_fill(value, mask, -33.0) - return torch.softmax(x, -1) - - value = torch.randn((2, 17)) - mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) - with patch.object(config.cpp, "simdlen", None): - torch._dynamo.reset() - metrics.reset() - opt_fn = torch._dynamo.optimize("inductor")(fn) - opt_fn(value, mask) - - real_out = fn(value, mask) - compiled_out = opt_fn(value, mask) - assert same(real_out, compiled_out, equal_nan=True) - assert metrics.generated_cpp_vec_kernel_count >= 1 - - @unittest.skipIf( - not codecache.valid_vec_isa_list(), "Does not support vectorization" - ) - @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): def fn(x): return (torch.sign(x),) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 9f00563a954e..65a9335d6cbf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -311,10 +311,6 @@ def maximum(a, b): def square(a): return f"{a}.pow(2)" - @staticmethod - def where(a, b, c): - return f"decltype({b})::blendv({c}, {b}, {a})" - @staticmethod def sign(x): code = BracesBuffer() @@ -334,11 +330,6 @@ def sign(x): V.kernel.compute.splice(code) return result - @staticmethod - def to_dtype(x, dtype): - assert dtype in [torch.bool], f"{__name__} does not support {dtype}" - return f"({x})" - class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -749,16 +740,7 @@ def load(self, name: str, index: sympy.Expr): if expanded_index == new_index: line = f"at::vec::Vectorized({var}[{cexpr(index)}])" else: - if V.graph.get_dtype(name) in [torch.bool, torch.uint8]: - g_tmp_buf = f"g_tmp_buffer_{var}" - nelements = codecache.pick_vec_isa().nelements() - self.loads.writeline(f"float {g_tmp_buf}[{nelements}] = {{0}};") - self.loads.writeline( - f"flag_to_float({var} + {cexpr(new_index)}, {g_tmp_buf}, {nelements});" - ) - line = f"at::vec::Vectorized::loadu({g_tmp_buf})" - else: - line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" + line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" return self.cse.generate(self.loads, line) @@ -855,6 +837,9 @@ def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr): return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index) def could_vec(self, name: str, index: sympy.Expr): + if V.graph.get_dtype(name) is not torch.float: + return False + assert self.itervars is not None # Not a loop if len(self.itervars) == 0: @@ -864,24 +849,12 @@ def could_vec(self, name: str, index: sympy.Expr): return self.is_legal_data_access(most_inner_var, index) def load(self, name: str, index: sympy.Expr): - if not V.graph.get_dtype(name) in [ - torch.float, - torch.float32, - torch.bool, - torch.uint8, - ]: - self.simd_vec = False - return self.simd_vec - index = self.rename_indexing(index) + self.simd_vec = self.simd_vec and self.could_vec(name, index) return self.simd_vec def store(self, name, index, value, mode=None): - if not V.graph.get_dtype(name) in [torch.float, torch.float32]: - self.simd_vec = False - return self.simd_vec - assert "buf" in name index = self.rename_indexing(index) @@ -954,24 +927,15 @@ def constant(val, dtype): @staticmethod def index_expr(expr, dtype): self.simd_vec = False - tmp_var = self.cse.newvar() - return tmp_var + return self.cse.newvar() @staticmethod def indirect_indexing(index_var): - self.simd_vec = False return sympy.Symbol(str(index_var)) @staticmethod def masked(mask, body, other): - tmp_var = self.cse.newvar() - return tmp_var - - @staticmethod - def to_dtype(x, dtype): - if dtype != torch.bool: - self.simd_vec = False - return x + return V.kernel.cse.newvar() self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) @@ -1076,7 +1040,6 @@ def codegen_loops(self, code, worksharing): if reduction_par_depth > 0 and reduction_par_depth != len( loops_nest_reduce.loops ): - metrics.generated_cpp_vec_kernel_count -= 1 return self.simd_omp_kernel.codegen_loops(code, worksharing) with contextlib.ExitStack() as stack: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index c1c9c3bae112..1905aefcda5c 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -57,15 +57,3 @@ template void atomic_add(volatile T *addr, T offset) { } while (!atomic_addr->compare_exchange_weak(expected, desired, std::memory_order_relaxed)); } - -// This function is used to convert bool or uint8 to float mask for -// vectorization. The caller needs to make sure the src represents TRUE/FALSE -// correctly. -template -void flag_to_float(const T* src, float* dst, int64_t n) { -#pragma unroll - for (int64_t i = 0; i < n; i++) { - uint32_t* dst_u32 = (uint32_t*)dst; - dst_u32[i] = *(src + i) ? 0xFFFFFFFF : 0; - } -} From af448e84eb2978062dc6ca4d3d538ed46b58f3d6 Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 17 Nov 2022 19:20:49 +0000 Subject: [PATCH 301/453] Fix bug in dynamo dashboard summary stats diff (#89226) Fixes issue where a suite may not be present in one of the logs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89226 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 8012e82607cf..843dbd12909a 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -864,6 +864,8 @@ def generate_diff(self, last2, filename, caption): for _, row in df_merge.iterrows(): if row["Compiler"] in self.args.flag_compilers: for suite in self.args.suites: + if suite + "_prev" not in row or suite + "_cur" not in row: + continue data["compiler"].append(row["Compiler"]) data["suite"].append(suite) data["prev_value"].append(row[suite + "_prev"]) From 04169c5b6e53e89e339f02b61287154034ee9fca Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Mon, 14 Nov 2022 23:26:15 -0800 Subject: [PATCH 302/453] Rewrite assert statement with torch._assert under config (#88246) This diff rewrites assert statement in python with torch._assert under config. The resulting graph looks something like: ``` SOURCE CODE: def f(x): assert x[0] == 3 return x.cos() CAPTURED GRAPH: graph(): %arg0 : [#users=2] = placeholder[target=arg0] %getitem : [#users=1] = call_function[target=operator.getitem](args = (%arg0, 0), kwargs = {}) %eq : [#users=1] = call_function[target=operator.eq](args = (%getitem, 3), kwargs = {}) %_assert : [#users=0] = call_function[target=torch._assert](args = (%eq, "assertion_error"), kwargs = {}) %cos : [#users=1] = call_method[target=cos](args = (%arg0,), kwargs = {}) return cos ``` Note that this introduces side-effect as it could error out while executing graph, but the assertion can eliminated via DCE if we choose to ignore it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88246 Approved by: https://github.com/jansel --- test/dynamo/test_repros.py | 92 ++++++++++++++++++++++++++++++ torch/_dynamo/config.py | 3 + torch/_dynamo/symbolic_convert.py | 94 +++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 503231b4cb12..e30a1275ed13 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -1938,6 +1938,98 @@ def fn(x): self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 1) + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, "First dim need to be 3" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + cnt = torch._dynamo.testing.CompileCounter() + + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + self.assertEqual(cnt.op_count, 6) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_not_rewrite_assert_for_other_errors(self): + def f(x): + b = x.sin() + if not x.sum() <= 3: + raise ValueError("input sum needs to be 3") + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + opt_fn = torch._dynamo.optimize("eager")(f) + with self.assertRaisesRegex(ValueError, "input sum needs to be 3"): + opt_fn(*args) + + # TODO (tmanlaibaatar) handle data-dependent fstring in assert statement. + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_with_fstring_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3, f"First dim need to be {x[0]}" + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_without_msg(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + with self.assertRaisesRegex(AssertionError, ""): + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", True) + def test_rewrite_assert_noop(self): + def f(x): + b = x.sin() + assert True + assert x.dtype == torch.float32 + return x.cos() + b + + args = (torch.Tensor([3, 4, 5]),) + exported, _ = torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + cnt = torch._dynamo.testing.CompileCounter() + opt_f = torch._dynamo.optimize(cnt, nopython=True)(f) + self.assertTrue(same(f(*args), opt_f(*args))) + # torch._assert shouldn't be in the graph + self.assertEqual(cnt.op_count, 3) + self.assertEqual(cnt.frame_count, 1) + + exported, _ = torch._dynamo.export(f, torch.Tensor([4, 4, 5])) + self.assertTrue(same(exported(*args), f(*args))) + + @patch.object(torch._dynamo.config, "rewrite_assert_with_torch_assert", False) + def test_not_rewrite_assert(self): + def f(x): + b = x.sin() + assert x[0] == 3 + return x.cos() + b + + with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "generic_jump"): + torch._dynamo.export(f, torch.Tensor([3, 4, 5])) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 12088383e741..39a1a6433419 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -87,6 +87,9 @@ # if an exception is encountered replay_record_enabled = False +# Rewrite assert statement in python with torch._assert +rewrite_assert_with_torch_assert = True + # Show a warning on every graph break print_graph_breaks = False diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e64804cb68b2..d2bc5332719c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -53,6 +53,7 @@ fake_tensors_available, graph_break_dup_warning_checker, istype, + proxy_args_kwargs, ) from .variables.base import MutableLocal, typestr, VariableTracker from .variables.builder import VariableBuilder, wrap_fx_proxy @@ -121,10 +122,103 @@ def impl(self: "InstructionTranslatorBase", inst: Instruction): return impl +def _detect_and_normalize_assert_statement( + self: "InstructionTranslatorBase", truth_fn: typing.Callable, push: bool +): + # Detect if this jump instruction is assert and normalize the assert + # by pushing dummy error message when nothing is given. + # + # Python 3.9 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_ASSERTION_ERROR + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS + # + # Python 3.8 assertion is in following format: + # 18 POP_JUMP_IF_TRUE 28 + # 20 LOAD_GLOBAL 0 (Assertion type) + # 22 LOAD_CONST 3 ('Assert message') -> optional instruction + # 24 CALL_FUNCTION 1 -> optional instruction + # 26 RAISE_VARARGS 1 + + if (truth_fn is not operator.truth) or push: + return False + + current_instruction_pointer = self.instruction_pointer + inst = self.instructions[current_instruction_pointer] + # Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 + if sys.version_info < (3, 9): + if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": + return False + else: + if inst.opname != "LOAD_ASSERTION_ERROR": + return False + + current_instruction_pointer += 1 + + if current_instruction_pointer >= len(self.instructions): + return False + + inst = self.instructions[current_instruction_pointer] + has_error_msg = False + # DETECT RAISE_VARARGS or LOAD CONST + if inst.opname == "LOAD_CONST": + if not isinstance(inst.argval, str): + return False + self.LOAD_CONST(inst) + has_error_msg = True + + # if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + if inst.opname != "CALL_FUNCTION": + return False + + # CALL_FUNCTION should be followed by RAISE_VARARGS + current_instruction_pointer += 1 + if current_instruction_pointer >= len(self.instructions): + return False + inst = self.instructions[current_instruction_pointer] + + if inst.opname != "RAISE_VARARGS": + return False + + if not has_error_msg: + # Push dummy value instead of error message + self.push(ConstantVariable("assertion error")) + + return True + + def generic_jump(truth_fn: typing.Callable, push: bool): def inner(self: "InstructionTranslatorBase", inst: Instruction): value: VariableTracker = self.pop() self.output.guards.update(value.guards) + if ( + config.rewrite_assert_with_torch_assert + and _detect_and_normalize_assert_statement(self, truth_fn, push) + ): + error_msg: VariableTracker = self.pop() + self.output.guards.update(error_msg.guards) + # Skip over things like `assert True` + if value.is_python_constant() and bool(value.as_python_constant()): + self.jump(inst) + return + + # Manually insert torch._assert instead of python assert and jump over + # assert related instructions as we don't need them anymore. + self.output.create_proxy( + "call_function", + torch._assert, + *proxy_args_kwargs((value, error_msg), {}), + current_tx=self, + ) + self.jump(inst) + return + if value.is_python_constant(): if truth_fn(value.as_python_constant()): push and self.push(value) From e856a4d66bead8997a83f8714547c09fcbcdc263 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Thu, 17 Nov 2022 20:10:52 +0000 Subject: [PATCH 303/453] Add an env var to skip cudnn version compatibility check (#89184) skip the check by setting `PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK=1` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89184 Approved by: https://github.com/ngimel --- torch/backends/cudnn/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e187d6d26aed..2b63a6379665 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -37,6 +37,8 @@ def _init(): else: cudnn_compatible = runtime_minor >= compile_minor if not cudnn_compatible: + if os.environ.get('PYTORCH_SKIP_CUDNN_COMPATIBILITY_CHECK', '0') == '1': + return True base_error_msg = (f'cuDNN version incompatibility: ' f'PyTorch was compiled against {compile_version} ' f'but found runtime version {runtime_version}. ' From f057a45fafcd5869d8f6f7e687fad1d36749b9d0 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Thu, 17 Nov 2022 06:09:55 -0500 Subject: [PATCH 304/453] reland "support running test_mobile_profiler with buck1/buck2 and OSS (#89001)" (#89091) We modify this to no longer use std::experimental::filesystem::path and use our own custom type instead. This reverts commit c53a5ac6cca7e2e7d7c47b1a816c7eaa2e7a7704. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89091 Approved by: https://github.com/r-barnes, https://github.com/malfet --- test/cpp/lite_interpreter_runtime/resources.h | 41 +++++++++++++++++++ .../test_mobile_profiler.cpp | 34 +++++++-------- 2 files changed, 55 insertions(+), 20 deletions(-) create mode 100644 test/cpp/lite_interpreter_runtime/resources.h diff --git a/test/cpp/lite_interpreter_runtime/resources.h b/test/cpp/lite_interpreter_runtime/resources.h new file mode 100644 index 000000000000..0be5928b299b --- /dev/null +++ b/test/cpp/lite_interpreter_runtime/resources.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace torch { +namespace testing { + +namespace detail { +class Path; +} + +/// Gets the path to the resource identified by name. +/// +/// @param name identifies a resource, relative path starting from the +/// repo root +inline auto getResourcePath(std::string name) -> detail::Path; + +// End interface: implementation details follow. + +namespace detail { + +class Path { + public: + explicit Path(std::string rep) : rep_(std::move(rep)) {} + + auto string() const -> std::string const& { + return rep_; + } + + private: + std::string rep_; +}; + +} // namespace detail + +inline auto getResourcePath(std::string name) -> detail::Path { + return detail::Path(std::move(name)); +} + +} // namespace testing +} // namespace torch diff --git a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp index 08cb81ae7876..df9cb9cea28c 100644 --- a/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp +++ b/test/cpp/lite_interpreter_runtime/test_mobile_profiler.cpp @@ -11,6 +11,8 @@ #include +#include "test/cpp/lite_interpreter_runtime/resources.h" + #ifdef EDGE_PROFILER_USE_KINETO namespace torch { namespace jit { @@ -42,16 +44,15 @@ bool checkMetaData( } // namespace TEST(MobileProfiler, ModuleHierarchy) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("to_be_profiled_module.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/to_be_profiled_module.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -95,16 +96,15 @@ TEST(MobileProfiler, ModuleHierarchy) { } TEST(MobileProfiler, Backend) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { KinetoEdgeCPUProfiler profiler( bc, @@ -130,16 +130,15 @@ TEST(MobileProfiler, Backend) { } TEST(MobileProfiler, BackendMemoryEvents) { - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); inputs.emplace_back(at::rand({64, 64})); std::string trace_file_name("/tmp/test_trace_backend_memory.trace"); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { mobile::KinetoEdgeCPUProfiler profiler( bc, @@ -163,13 +162,8 @@ TEST(MobileProfiler, BackendMemoryEvents) { } TEST(MobileProfiler, ProfilerEvent) { - /* - * TODO: Using __FILE__ is unreliable e.g. it fails to resolve correctly when - * using buck2, works ok with buck1 - */ - std::string filePath(__FILE__); - auto testModelFile = filePath.substr(0, filePath.find_last_of("/\\") + 1); - testModelFile.append("test_backend_for_profiling.ptl"); + auto testModelFile = torch::testing::getResourcePath( + "test/cpp/lite_interpreter_runtime/test_backend_for_profiling.ptl"); std::vector inputs; inputs.emplace_back(at::rand({64, 64})); @@ -180,7 +174,7 @@ TEST(MobileProfiler, ProfilerEvent) { torch::profiler::ProfilerPerfEvents.begin(), torch::profiler::ProfilerPerfEvents.end()); - mobile::Module bc = _load_for_mobile(testModelFile); + mobile::Module bc = _load_for_mobile(testModelFile.string()); { // Bail if something goes wrong here try { From fbbf3687453aed1b732eee6f6e9050258ce29561 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Thu, 17 Nov 2022 21:33:59 +0000 Subject: [PATCH 305/453] Fix distributed test paths when running periodic multigpu job (#89225) Some distributed tests are moved to a new location after https://github.com/pytorch/pytorch/pull/88698. This is currently failing periodic multigpu job: * https://github.com/pytorch/pytorch/actions/runs/3484486207/jobs/5829301159 * https://github.com/pytorch/pytorch/actions/runs/3484486207/jobs/5829301093 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89225 Approved by: https://github.com/clee2000 --- .jenkins/pytorch/multigpu-test.sh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index bbd1c370a638..9d7efc969823 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -8,11 +8,6 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" echo "Testing pytorch" -if [ -n "${CI}" ]; then - # TODO move this to docker - # Pin unittest-xml-reporting to freeze printing test summary logic, related: https://github.com/pytorch/pytorch/issues/69014 - pip_install "unittest-xml-reporting<=3.2.0,>=2.0.0" -fi # Disabling tests to see if they solve timeout issues; see https://github.com/pytorch/pytorch/issues/70015 # python tools/download_mnist.py --quiet -d test/cpp/api/mnist @@ -28,8 +23,8 @@ time python test/run_test.py --verbose -i distributed/rpc/cuda/test_tensorpipe_a # FSDP tests for f in test/distributed/fsdp/*.py ; do time python test/run_test.py --verbose -i "${f#*/}" ; done # ShardedTensor tests -time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_checkpoint -time python test/run_test.py --verbose -i distributed/_shard/checkpoint/test_file_system_checkpoint +time python test/run_test.py --verbose -i distributed/checkpoint/test_checkpoint +time python test/run_test.py --verbose -i distributed/checkpoint/test_file_system_checkpoint time python test/run_test.py --verbose -i distributed/_shard/sharding_spec/test_sharding_spec time python test/run_test.py --verbose -i distributed/_shard/sharding_plan/test_sharding_plan time python test/run_test.py --verbose -i distributed/_shard/sharded_tensor/test_megatron_prototype From 767f6aa49fe20a2766b9843d01e3b7f7793df6a3 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 17 Nov 2022 22:05:27 +0000 Subject: [PATCH 306/453] [JIT][Security] Do not blindly eval input string (#89189) Introduce `_eval_no_call` method, that evaluates statement only if it does not contain any calls(done by examining the bytecode), thus preventing command injection exploit Added simple unit test to check for that `torch.jit.annotations.get_signature` would not result in calling random code. Although, this code path exists for Python-2 compatibility, and perhaps should be simply removed. Fixes https://github.com/pytorch/pytorch/issues/88868 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89189 Approved by: https://github.com/suo --- test/test_jit.py | 8 ++++++++ torch/csrc/jit/frontend/script_type_parser.cpp | 2 +- torch/jit/annotations.py | 14 ++++++++++++-- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index 13c27b0efa55..6cbc091d506b 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -3951,6 +3951,14 @@ def invalid4(a): return a + 2 torch.jit.script(invalid4) + def test_calls_in_type_annotations(self): + with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): + def spooky(a): + # type: print("Hello") -> Tensor # noqa: F723 + return a + 2 + print(torch.__file__) + torch.jit.annotations.get_signature(spooky, None, 1, True) + def test_is_optional(self): ann = Union[List[int], List[float]] torch._jit_internal.is_optional(ann) diff --git a/torch/csrc/jit/frontend/script_type_parser.cpp b/torch/csrc/jit/frontend/script_type_parser.cpp index f5d6f640d413..d05ec95fb9fa 100644 --- a/torch/csrc/jit/frontend/script_type_parser.cpp +++ b/torch/csrc/jit/frontend/script_type_parser.cpp @@ -316,7 +316,7 @@ std::vector ScriptTypeParser::evaluateDefaults( // We then run constant prop on this graph and check the results are // constant. This approach avoids having to have separate handling of // default arguments from standard expressions by piecing together existing - // machinery for graph generation, constant propgation, and constant + // machinery for graph generation, constant propagation, and constant // extraction. auto tuple_type = Subscript::create( r, diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index a4a36ce36a5e..a6ff2d04d207 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -1,4 +1,5 @@ import ast +import dis import enum import inspect import re @@ -144,6 +145,15 @@ def check_fn(fn, loc): raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function") +def _eval_no_call(stmt, glob, loc): + """Evaluate statement as long as it does not contain any method/function calls""" + bytecode = compile(stmt, "", mode="eval") + for insn in dis.get_instructions(bytecode): + if "CALL" in insn.opname: + raise RuntimeError(f"Type annotation should not contain calls, but '{stmt}' does") + return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 + + def parse_type_line(type_line, rcb, loc): """Parses a type annotation specified as a comment. @@ -154,7 +164,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann_str, ret_ann_str = split_type_line(type_line) try: - arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the argument list of a type annotation") from e @@ -162,7 +172,7 @@ def parse_type_line(type_line, rcb, loc): arg_ann = (arg_ann,) try: - ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # type: ignore[arg-type] # noqa: P204 + ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) except (NameError, SyntaxError) as e: raise RuntimeError("Failed to parse the return type of a type annotation") from e From a695fcf20103bb08ae660788d128cd924e6ec05b Mon Sep 17 00:00:00 2001 From: Charlie Yan Date: Thu, 17 Nov 2022 19:05:44 +0000 Subject: [PATCH 307/453] Add tests for replicate multiple modules (#89099) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099 Approved by: https://github.com/zhaojuanmao --- .../distributed/_composable/test_replicate.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py index 831ccc3376af..3e8bf44a1fde 100644 --- a/test/distributed/_composable/test_replicate.py +++ b/test/distributed/_composable/test_replicate.py @@ -39,13 +39,7 @@ def tearDown(self): except OSError: pass - def _prepare_module(self, global_batch_size): - model = Net() - input = torch.randn(global_batch_size, 2) - target = torch.randn(global_batch_size, 4) - return model, input, target - - def test_replicate(self): + def _compare_module(self, mod, replicate_mod): dist.init_process_group( backend="gloo", rank=self.rank, @@ -55,8 +49,8 @@ def test_replicate(self): local_batch_size = 1 global_batch_size = self.world_size * local_batch_size - model, input, target = self._prepare_module(global_batch_size) - replicate_model = mark_root_module(replicate(deepcopy(model))) + input = torch.randn(global_batch_size, 2) + target = torch.randn(global_batch_size, 4) def step_model(model, input, target): model.train() @@ -69,9 +63,9 @@ def step_model(model, input, target): param.grad = None for iteration in range(2): - step_model(model, input, target) + step_model(mod, input, target) step_model( - replicate_model, + replicate_mod, input[ self.rank * local_batch_size : (self.rank + 1) @@ -85,16 +79,29 @@ def step_model(model, input, target): ) self.assertEqual( - len(list(model.parameters())), - len(list(replicate_model.parameters())), + len(list(mod.parameters())), + len(list(replicate_mod.parameters())), ) - for i, j in zip(model.parameters(), replicate_model.parameters()): + for i, j in zip(mod.parameters(), replicate_mod.parameters()): self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5) # Shuffle the input so that DDP input is different torch.manual_seed(iteration) input = input[torch.randperm(global_batch_size)] + def test_replicate_single_module(self): + model = Net() + replicate_model = mark_root_module(replicate(deepcopy(model))) + self._compare_module(model, replicate_model) + + def test_replicate_multi_module(self): + model = Net() + replicate_model = mark_root_module(deepcopy(model)) + replicate(replicate_model.fc1) + replicate(replicate_model.fc2) + replicate(replicate_model.fc3) + self._compare_module(model, replicate_model) + if __name__ == "__main__": run_tests() From e2229a89b0618b58011a69a28e3d23cf7096e547 Mon Sep 17 00:00:00 2001 From: keineahnung2345 Date: Thu, 17 Nov 2022 22:28:20 +0000 Subject: [PATCH 308/453] Fix typo in aten/src/README.md (#89175) remove redundant "have to" Pull Request resolved: https://github.com/pytorch/pytorch/pull/89175 Approved by: https://github.com/kit1980 --- aten/src/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/README.md b/aten/src/README.md index add281692633..3127ed5c8c39 100644 --- a/aten/src/README.md +++ b/aten/src/README.md @@ -69,8 +69,8 @@ will `retain` it itself. ``` Sometimes, you have a tensor in hand which you'd like to use directly, but -under some conditions you have to have to call, e.g., `newContiguous`, to get -it into the correct form: +under some conditions you have to call, e.g., `newContiguous`, to get it into +the correct form: ``` if (!(k_->stride(3) == 1) || !(k_->stride[2] == k_->size(3))) { From 3d8a853a87515a5e29e384396ff8769f4ee2f946 Mon Sep 17 00:00:00 2001 From: erjia Date: Thu, 17 Nov 2022 23:06:41 +0000 Subject: [PATCH 309/453] [DataPipe] Add container template for _Fork and _Demux (#89216) - This would remove the hard-coded check within `_ChildDataPipe`. - Add `get_length_by_instance` to parent class to make sure there is a chance that child DataPipe can have different lengths - Prevent Error when `__del__` executed when the object has already been removed Pull Request resolved: https://github.com/pytorch/pytorch/pull/89216 Approved by: https://github.com/NivekT --- torch/utils/data/datapipes/iter/combining.py | 51 ++++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/torch/utils/data/datapipes/iter/combining.py b/torch/utils/data/datapipes/iter/combining.py index 0c4f34ad51f1..c874cedbde29 100644 --- a/torch/utils/data/datapipes/iter/combining.py +++ b/torch/utils/data/datapipes/iter/combining.py @@ -1,5 +1,6 @@ import warnings +from abc import ABC, abstractmethod from collections import deque from typing import Any, Callable, Iterator, List, Optional, Sized, Tuple, TypeVar, Deque @@ -96,7 +97,31 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, buffer_size: int = return [_ChildDataPipe(container, i) for i in range(num_instances)] -class _ForkerIterDataPipe(IterDataPipe): +class _ContainerTemplate(ABC): + r""" + Abstract class for container ``DataPipes``. The followings are three required + methods. + """ + @abstractmethod + def get_next_element_by_instance(self, instance_id: int): + ... + + @abstractmethod + def is_every_instance_exhausted(self) -> bool: + ... + + @abstractmethod + def reset(self) -> None: + ... + + @abstractmethod + def get_length_by_instance(self, instance_id: int): + r""" + Raise TypeError if it's not supposed to be implemented to support `list(datapipe)` + """ + + +class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate): r""" Container to hold instance-specific information on behalf of ForkerIterDataPipe. It tracks the state of its child DataPipes, maintains the buffer, and yields the next value @@ -159,6 +184,9 @@ def is_every_instance_exhausted(self) -> bool: return self.end_ptr is not None and\ all(self.end_ptr == ptr or self.end_ptr - 1 == ptr for ptr in self.child_pointers) + def get_length_by_instance(self, instance_id: int) -> int: + return len(self.main_datapipe) + def reset(self) -> None: self._datapipe_iterator = iter(self.main_datapipe) self.buffer = deque() @@ -195,7 +223,8 @@ def __setstate__(self, state): self.end_ptr = None def __del__(self): - self.buffer.clear() + if self.buffer: + self.buffer.clear() class _ChildDataPipe(IterDataPipe): @@ -229,10 +258,8 @@ class _ChildDataPipe(IterDataPipe): _is_child_datapipe: bool = True def __init__(self, main_datapipe: IterDataPipe, instance_id: int): - required_attrs = ["get_next_element_by_instance", "is_every_instance_exhausted", "reset"] - required_ops = [getattr(main_datapipe, attr) for attr in required_attrs] - if any(not callable(op) for op in required_ops): - raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.") + assert isinstance(main_datapipe, _ContainerTemplate) + self.main_datapipe: IterDataPipe = main_datapipe self.instance_id = instance_id @@ -242,7 +269,7 @@ def __iter__(self): return self.main_datapipe.get_next_element_by_instance(self.instance_id) def __len__(self): - return len(self.main_datapipe) + return self.main_datapipe.get_length_by_instance(self.instance_id) # This method is called by `hook_iterator` in `_typing.py`. def _set_main_datapipe_valid_iterator_id(self) -> int: @@ -324,7 +351,7 @@ def __new__(cls, datapipe: IterDataPipe, num_instances: int, return [_ChildDataPipe(container, i) for i in range(num_instances)] -class _DemultiplexerIterDataPipe(IterDataPipe): +class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate): r""" Container to hold instance-specific information on behalf of DemultiplexerIterDataPipe. It tracks the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value @@ -393,6 +420,9 @@ def get_next_element_by_instance(self, instance_id: int): def is_every_instance_exhausted(self) -> bool: return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers) + def get_length_by_instance(self, instance_id: int) -> int: + raise TypeError + def reset(self) -> None: self._datapipe_iterator = None self.current_buffer_usage = 0 @@ -429,8 +459,9 @@ def __setstate__(self, state): self.main_datapipe_exhausted = False def __del__(self): - for dq in self.child_buffers: - dq.clear() + if self.child_buffers: + for dq in self.child_buffers: + dq.clear() @functional_datapipe('mux') From 31b10e7d4083acd0eb689ae3873c13b8711770be Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 17 Nov 2022 19:43:37 +0000 Subject: [PATCH 310/453] Enable inductor CI for TorchBench (#87465) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87465 Approved by: https://github.com/malfet --- .github/ci_commit_pins/text.txt | 1 + .github/scripts/filter_test_configs.py | 2 ++ .github/workflows/inductor.yml | 5 ++-- .jenkins/pytorch/common_utils.sh | 19 ++++-------- .jenkins/pytorch/test.sh | 40 ++++++++++++++++---------- benchmarks/dynamo/common.py | 17 ++--------- 6 files changed, 40 insertions(+), 44 deletions(-) create mode 100644 .github/ci_commit_pins/text.txt diff --git a/.github/ci_commit_pins/text.txt b/.github/ci_commit_pins/text.txt new file mode 100644 index 000000000000..c0e01da17fd0 --- /dev/null +++ b/.github/ci_commit_pins/text.txt @@ -0,0 +1 @@ +5b78d074bd303eb230d30567646fcf0358ee2dd4 diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index bb5314434e07..f5c438c29e90 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -24,7 +24,9 @@ "functorch", "inductor", "inductor_distributed", + "inductor_huggingface", "inductor_timm", + "inductor_torchbench", "jit_legacy", "multigpu", "nogpu_AVX512", diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index e8390681e4ab..eb953ff42321 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -23,10 +23,11 @@ jobs: cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, ]} diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index 8af2c93a1e50..7fc1dd6c0f1a 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -101,20 +101,16 @@ function get_pinned_commit() { cat .github/ci_commit_pins/"${1}".txt } -function install_torchvision() { +function install_torchtext() { local commit - commit=$(get_pinned_commit vision) - pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" + commit=$(get_pinned_commit text) + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/text.git@${commit}" } -function checkout_install_torchvision() { +function install_torchvision() { local commit commit=$(get_pinned_commit vision) - git clone https://github.com/pytorch/vision - pushd vision - git checkout "${commit}" - time python setup.py install - popd + pip_install --no-use-pep517 --user "git+https://github.com/pytorch/vision.git@${commit}" } function clone_pytorch_xla() { @@ -194,13 +190,10 @@ function install_timm() { } function checkout_install_torchbench() { - local commit - commit=$(get_pinned_commit torchbench) git clone https://github.com/pytorch/benchmark torchbench pushd torchbench - git checkout "${commit}" + git checkout no_torchaudio python install.py - pip_install gym==0.25.2 # workaround issue in 0.26.0 popd } diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 135fb50762d6..17437a56ae0e 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -256,20 +256,15 @@ test_inductor() { # pytest test/test_ops_gradients.py --verbose -k "not _complex and not test_inplace_grad_acos_cuda_float64" } -test_inductor_huggingface_shard() { - if [[ -z "$NUM_TEST_SHARDS" ]]; then - echo "NUM_TEST_SHARDS must be defined to run a Python test shard" - exit 1 - fi +test_inductor_huggingface() { # Use test-reports directory under test folder will allow the CI to automatically pick up # the test reports and upload them to S3. Need to use full path here otherwise the script # will bark about file not found later on TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" python benchmarks/dynamo/huggingface.py --ci --training --accuracy \ - --device cuda --inductor --float32 --total-partitions 1 --partition-id "$1" \ - --output "$TEST_REPORTS_DIR"/inductor_huggingface_"$1".csv - python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_huggingface_"$1".csv + --device cuda --inductor --float32 --output "$TEST_REPORTS_DIR"/inductor_huggingface.csv + python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_huggingface.csv } test_inductor_timm_shard() { @@ -288,6 +283,14 @@ test_inductor_timm_shard() { python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_timm_"$1".csv } +test_inductor_torchbench() { + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + PYTHONPATH=$(pwd)/torchbench python benchmarks/dynamo/torchbench.py --ci --training --accuracy \ + --device cuda --inductor --float32 --output "$TEST_REPORTS_DIR"/inductor_torchbench.csv + python benchmarks/dynamo/check_csv.py -f "$TEST_REPORTS_DIR"/inductor_torchbench.csv +} + test_python_gloo_with_tls() { source "$(dirname "${BASH_SOURCE[0]}")/run_glootls_test.sh" assert_git_not_dirty @@ -742,25 +745,32 @@ elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHAR install_filelock install_triton test_dynamo_shard 2 -elif [[ "${TEST_CONFIG}" == *inductor_timm* && $SHARD_NUMBER -lt 3 && $NUM_TEST_SHARDS -gt 1 ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_huggingface* ]]; then + install_torchvision + install_filelock + install_triton + install_huggingface + test_inductor_huggingface +elif [[ "${TEST_CONFIG}" == *inductor_timm* && $NUM_TEST_SHARDS -gt 1 ]]; then install_torchvision install_filelock install_triton install_timm id=$((SHARD_NUMBER-1)) test_inductor_timm_shard $id -elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_torchbench* ]]; then + install_torchtext install_torchvision install_filelock install_triton - test_inductor - test_inductor_distributed -elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 2 && $NUM_TEST_SHARDS -gt 1 ]]; then + checkout_install_torchbench + test_inductor_torchbench +elif [[ "${TEST_CONFIG}" == *inductor* && "${SHARD_NUMBER}" == 1 ]]; then install_torchvision install_filelock install_triton - install_huggingface - test_inductor_huggingface_shard 0 + test_inductor + test_inductor_distributed elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then test_without_numpy install_torchvision diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 789ebc3683d3..cad954f825b2 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -110,27 +110,16 @@ # *CI_SKIP_AOT_EAGER_TRAINING, # *CI_SKIP_INDCUTOR_INFERENCE, # TorchBench - "attention_is_all_you_need_pytorch", - "drq", - "hf_Albert", - "hf_Bart", - "hf_GPT2", - "hf_Reformer", + "detectron2", "mobilenet_v3_large", "moco", - "pytorch_struct", - "vgg16", - "speech_transformer", # from functionalization - "vision_maskrcnn", # from functionalization - "timm_efficientnet", # from functionalization (only fails for inductor) - "hf_Bert", - "soft_actor_critic", "tacotron2", - "yolov3", + "vision_maskrcnn", # from functionalization # OOM "Background_Matting", "fastNLP_Bert", "hf_BigBird", + "hf_T5_base", # fp64_OOM "mobilenet_v2", "mobilenet_v2_quantized_qat", "resnet50_quantized_qat", From 2b3ac879a7d68aca8a7608e97a7cfc713dbf5c6c Mon Sep 17 00:00:00 2001 From: Sean Ross-Ross Date: Thu, 17 Nov 2022 23:36:15 +0000 Subject: [PATCH 311/453] feat: adding view_copy_batch_rule and opinfo for view_copy (#88150) to add view_copy to vmap dispatch and adding opinfo part of https://github.com/pytorch/functorch/issues/825 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88150 Approved by: https://github.com/kshitij12345, https://github.com/zou3519 --- aten/src/ATen/functorch/BatchRulesViews.cpp | 14 ++++++++++++++ .../_internal/common_methods_invocations.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index e4513cf69c18..5eb18f71dd11 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -438,6 +438,19 @@ std::tuple> view_batching_rule( return std::make_tuple(self_.view_symint(size_), 0); } +std::tuple> view_copy_batch_rule( + const Tensor& self, + optional self_bdim, + c10::SymIntArrayRef size) { + auto self_ = moveBatchDimToFront(self, self_bdim); + SymDimVector view_size(size.size() + 1); + view_size[0] = self_.size(0); + std::copy(size.cbegin(), size.cend(), view_size.begin() + 1); + + return std::make_tuple(at::view_copy_symint(self_, view_size), 0); +} + + template std::tuple> expand_batch_rule( const Tensor &self, optional self_bdim, SymIntArrayRef size, bool implicit) @@ -544,6 +557,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(select_backward, select_backward_batch_rule); VMAP_SUPPORT(slice_backward, slice_backward_batch_rule); VMAP_SUPPORT(view, view_batching_rule); + VMAP_SUPPORT(view_copy, view_copy_batch_rule); VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule)); VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule)); VMAP_SUPPORT(unfold, unfold_batch_rule); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 8fe70e71614d..24ef757b768d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12636,6 +12636,20 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), )), + OpInfo('view_copy', + dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16), + ref=lambda x, newshape: np.reshape(x, newshape).copy(), + supports_out=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_autograd=True, + sample_inputs_func=sample_inputs_view_reshape, + error_inputs_func=error_inputs_view_reshape, + skips=( + # https://github.com/pytorch/pytorch/issues/89068 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + )), UnaryUfuncInfo('neg', aliases=('negative', ), ref=np.negative, From 57e05e822d0f53db04d2ee2216906f6fc01b4a4f Mon Sep 17 00:00:00 2001 From: Dmitry Tomshin Date: Fri, 18 Nov 2022 00:10:48 +0000 Subject: [PATCH 312/453] Issue 68576 prefetch factor (#88972) Fixes #68576 This PR allows set the `prefetch_factor=None` making it really optional according to the documentation Pull Request resolved: https://github.com/pytorch/pytorch/pull/88972 Approved by: https://github.com/kit1980 --- torch/utils/data/dataloader.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 4c141eccc3be..c836c9fa975f 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -217,7 +217,7 @@ class DataLoader(Generic[T_co]): timeout: float sampler: Union[Sampler, Iterable] pin_memory_device: str - prefetch_factor: int + prefetch_factor: Optional[int] _iterator : Optional['_BaseDataLoaderIter'] __initialized = False @@ -228,7 +228,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, - *, prefetch_factor: int = 2, + *, prefetch_factor: Optional[int] = None, persistent_workers: bool = False, pin_memory_device: str = ""): torch._C._log_api_usage_once("python.data_loader") @@ -240,10 +240,13 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, if timeout < 0: raise ValueError('timeout option should be non-negative') - if num_workers == 0 and prefetch_factor != 2: + if num_workers == 0 and prefetch_factor is not None: raise ValueError('prefetch_factor option could only be specified in multiprocessing.' - 'let num_workers > 0 to enable multiprocessing.') - assert prefetch_factor > 0 + 'let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None.') + elif num_workers > 0 and prefetch_factor is None: + prefetch_factor = 2 + elif prefetch_factor is not None and prefetch_factor < 0: + raise ValueError('prefetch_factor option should be non-negative') if persistent_workers and num_workers == 0: raise ValueError('persistent_workers option needs num_workers > 0') @@ -581,7 +584,6 @@ def __init__(self, loader: DataLoader) -> None: ws, rank = _get_distributed_settings() self._world_size = ws self._rank = rank - self._prefetch_factor = loader.prefetch_factor # for other backends, pin_memory_device need to set. if not set # default behaviour is CUDA device. if pin_memory_device is selected # and pin_memory is not set, the default behaviour false. @@ -991,6 +993,8 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter): def __init__(self, loader): super(_MultiProcessingDataLoaderIter, self).__init__(loader) + self._prefetch_factor = loader.prefetch_factor + assert self._num_workers > 0 assert self._prefetch_factor > 0 From 177621a0b28b931d9be6976c2c38cb57af7949d9 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Nov 2022 00:11:42 +0000 Subject: [PATCH 313/453] Use pytest-flakefinder to rerun tests multiple times (#89106) Per title. The way re-run is handled in https://github.com/pytorch/pytorch/pull/88646 only applies to unittest. ### Testing * https://github.com/pytorch/pytorch/actions/runs/3484930558 * https://github.com/pytorch/pytorch/actions/runs/3484930319 Manually download the test report artifacts and verify that that pytest test_ops is called multiple times. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89106 Approved by: https://github.com/clee2000 --- test/run_test.py | 65 ++++++++++++++++++++----- torch/testing/_internal/common_utils.py | 12 +++-- 2 files changed, 61 insertions(+), 16 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 94bee60cc24e..62ce99ae7937 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -743,33 +743,72 @@ def print_log_file(test: str, file_path: str, failed: bool) -> None: def run_test_ops(test_module, test_directory, options): + if os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1": + # When under rerun-disabled-tests mode, run the same tests multiple times to determine their + # flakiness status. Default to 50 re-runs + rerun_options = ["--flake-finder", "--flake-runs=50"] + else: + # When under the normal mode, retry a failed test 2 more times. -x means stop at the first + # failure + rerun_options = ["-x", "--reruns=2"] + + default_unittest_args = [ + "--use-pytest", + "-vv", + "-rfEX" + ] + default_unittest_args.extend(rerun_options) + if 'slow-gradcheck' in os.getenv("BUILD_ENVIRONMENT", ""): + extra_unittest_args = default_unittest_args.copy() # there are a lot of tests that take up a lot of space in slowgrad check, so don't bother parallelizing # it's also on periodic so we don't care about TTS as much - return run_test(test_module, test_directory, copy.deepcopy(options), - extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX'], - ) + return run_test( + test_module, + test_directory, + copy.deepcopy(options), + extra_unittest_args=extra_unittest_args, + ) + return_codes = [] os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS) pool = get_context("spawn").Pool(NUM_PROCS) for i in range(NUM_PROCS): - return_code = pool.apply_async(run_test, args=(test_module, test_directory, copy.deepcopy(options)), - kwds={"extra_unittest_args": ["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX', - f'--shard-id={i}', f'--num-shards={NUM_PROCS}', - "-k=not _linalg_cholesky_"], - }) + extra_unittest_args = default_unittest_args.copy() + extra_unittest_args.extend([ + f"--shard-id={i}", + f"--num-shards={NUM_PROCS}", + "-k=not _linalg_cholesky_", + ]) + + return_code = pool.apply_async( + run_test, + args=(test_module, test_directory, copy.deepcopy(options)), + kwds={ + "extra_unittest_args": extra_unittest_args, + }, + ) return_codes.append(return_code) + pool.close() pool.join() - del os.environ['NUM_PARALLEL_PROCS'] + del os.environ["NUM_PARALLEL_PROCS"] for return_code in return_codes: if return_code.get() != 0: return return_code.get() - return_code = run_test(test_module, test_directory, copy.deepcopy(options), - extra_unittest_args=["--use-pytest", '-vv', '-x', '--reruns=2', '-rfEX', - "-k=_linalg_cholesky_"], - ) + + extra_unittest_args = default_unittest_args.copy() + extra_unittest_args.extend([ + "-k=_linalg_cholesky_", + ]) + + return_code = run_test( + test_module, + test_directory, + copy.deepcopy(options), + extra_unittest_args=extra_unittest_args, + ) return return_code diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index fa3eda3758e4..35ec53381c1f 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -741,9 +741,15 @@ def run_tests(argv=UNITTEST_ARGS): if TEST_SAVE_XML: sanitize_pytest_xml(test_report_path) print("If in CI, skip info is located in the xml test reports, please either go to s3 or the hud to download them") - # exitcode of 5 means no tests were found, which happens since some test configs don't - # run tests from certain files - exit(0 if exit_code == 5 else exit_code) + + if not RERUN_DISABLED_TESTS: + # exitcode of 5 means no tests were found, which happens since some test configs don't + # run tests from certain files + exit(0 if exit_code == 5 else exit_code) + else: + # Only record the test report and always return a success code when running under rerun + # disabled tests mode + exit(0) elif TEST_SAVE_XML is not None: # import here so that non-CI doesn't need xmlrunner installed import xmlrunner # type: ignore[import] From b652fbc57a331df5aa28b0bcd07f9e72db2fdbae Mon Sep 17 00:00:00 2001 From: David Boetius Date: Fri, 18 Nov 2022 01:57:38 +0000 Subject: [PATCH 314/453] Fix torch.nn.functional.gelu docstring formatting (#89061) The docstring of `torch.nn.functional.gelu` is formatted incorrectly, so that part of the math isn't rendered and there are extra blocks when there shouldn't: https://pytorch.org/docs/stable/generated/torch.nn.functional.gelu.html I didn't build the docs, so I am not 100% sure that I got the formatting right, but I am confident. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89061 Approved by: https://github.com/bdhirsh, https://github.com/kit1980 --- torch/nn/functional.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 961dd83f57b2..e3aea9f0acea 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1713,8 +1713,10 @@ def rrelu( where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. -When the approximate argument is 'tanh', Gelu is estimated with: - :math:: \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) +When the approximate argument is 'tanh', Gelu is estimated with + +.. math:: + \text{GELU}(x) = 0.5 * x * (1 + \text{Tanh}(\sqrt(2 / \pi) * (x + 0.044715 * x^3))) See `Gaussian Error Linear Units (GELUs) `_. """) From 9fd00f194ae4e28948a9a03a6382c20dde04e4fd Mon Sep 17 00:00:00 2001 From: Dmytro Dzhulgakov Date: Fri, 18 Nov 2022 02:42:45 +0000 Subject: [PATCH 315/453] Fix the kineto daemon build condition (#89174) If we're not building the lite interpreter we shouldn't be disabling Kineto. This eliminates a step from https://github.com/facebookincubator/dynolog/blob/main/docs/pytorch_profiler.md Pull Request resolved: https://github.com/pytorch/pytorch/pull/89174 Approved by: https://github.com/kimishpatel, https://github.com/malfet --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6efd3f2df936..5ea01f0c0f53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -795,7 +795,7 @@ if(USE_SOURCE_DEBUG_ON_MOBILE) string(APPEND CMAKE_CXX_FLAGS " -DSYMBOLICATE_MOBILE_DEBUG_HANDLE") endif() -if(USE_LITE_INTERPRETER_PROFILER) +if(BUILD_LITE_INTERPRETER AND USE_LITE_INTERPRETER_PROFILER) string(APPEND CMAKE_CXX_FLAGS " -DEDGE_PROFILER_USE_KINETO") endif() From fd0efb01a7a3a5b487d3d23c2c53a936620ba28a Mon Sep 17 00:00:00 2001 From: Raman kumar Date: Fri, 18 Nov 2022 02:53:39 +0000 Subject: [PATCH 316/453] [MPS] Support for median with dim (#88807) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ⚡ **Aim**: Add support for aten::median for MPS backend (Fixes #87220) This is fresh clean PR from the previous [PR](https://github.com/pytorch/pytorch/pull/88554) - Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm - Adding it to aten/src/ATen/native/native_functions.yaml - Adding it to existing test_median ### **this will works like this** 🪶 median of entire input tensor on MPS `torch.median(mps_inputTensor)` median of along a dim `torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88807 Approved by: https://github.com/kulinseth --- aten/src/ATen/native/mps/MPSGraphVenturaOps.h | 8 + .../ATen/native/mps/operations/ReduceOps.mm | 315 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 41 +++ 4 files changed, 366 insertions(+) diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h index 86153b58ed87..b77db66795cf 100644 --- a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h +++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h @@ -6,4 +6,12 @@ - (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor axis:(NSInteger)axis name:(NSString *)name; + +- (MPSGraphTensor *)sortWithTensor:(MPSGraphTensor *)tensor + axis:(NSInteger)axis + name:(NSString *)name; + +- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor + axis:(NSInteger)axis + name:(NSString *)name; @end diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 91aa245b8991..c99f22d89295 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -9,6 +9,7 @@ #include #include #include +#include namespace at { namespace native { @@ -1638,5 +1639,319 @@ Tensor min_mps(const Tensor& input_t) { return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps"); } +// Median of entire tensor into scalar result +Tensor median_mps(const Tensor& input_t) { + + if(!is_macos_13_or_newer()){ + TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ", + "Falling back on CPU. This may have performace implications."); + return at::median(input_t.to("cpu")); + } + + TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS"); + + namespace native_mps = at::native::mps; + using CachedGraph = native_mps::MPSUnaryCachedGraph; + + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + + // calculate total no. of elements in the input tensor to reduce it to one dimension + NSMutableArray *apparent_input_shape = [NSMutableArray arrayWithCapacity:1]; + int64_t num_in_elements = 1; + for(int i = 0; i < num_input_dims; i++) { + num_in_elements *= input_shape[i]; + } + + apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements]; + + Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt); + + if (output_t.numel() == 0 || num_in_elements == 0) { + return output_t; + } + + @autoreleasepool { + string key = "median_mps:"+ mps::getMPSTypeString(input_t.scalar_type()) + mps::getTensorsStringKey(input_t); + CachedGraph* cachedGraph = cache_->LookUpAs(key); + // Initialize once if configuration not found in cache + if(!cachedGraph) { + native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + + MPSGraphTensor* outputTensor = nil; + + MPSGraphTensor * reshapedTensor = [mpsGraph reshapeTensor:inputTensor + withShape:@[@-1] + name:nil]; + MPSGraphTensor * sortedTensor = [mpsGraph + sortWithTensor:reshapedTensor + axis:((NSUInteger) (int)0) + name:nil]; + + outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:0 + start:((NSUInteger) (int)((num_in_elements+1)/2 ) - 1) + length:1 + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]); + + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); + } + + return output_t; +} + + +void median_out_mps + (const Tensor& input_t, + int64_t dim, + bool keepdim, + const Tensor& output_t, + const Tensor& indices_t, + const std::string& func_name) { + + namespace native_mps = at::native::mps; + + if (output_t.numel() == 0) { + return; + } + if (input_t.numel() == 1 && input_t.dim() == 0) { + output_t.fill_(input_t); + indices_t.fill_(0); + return; + } + + // Derive from MPSCachedGraph + struct CachedGraph : public native_mps::MPSCachedGraph + { + CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; + MPSGraphTensor *indicesTensor_ = nil; + }; + + native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); + + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_out_shape = nil; + + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) + apparent_out_shape[i] = @1; + else + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; + } + int dim_total_elements = input_shape[dim_]; + + auto stream = at::mps::getCurrentMPSStream(); + + @autoreleasepool { + string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + + if(!cachedGraph) { + native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + + CachedGraph *newCachedGraph = nil; + + @autoreleasepool { + MPSGraph* mpsGraph = native_mps::make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + MPSGraphTensor* outputTensor = nil; + MPSGraphTensor * sortedTensor = [mpsGraph + sortWithTensor:inputTensor + axis:((NSUInteger) (int)dim_) + name:nil]; + + outputTensor = [mpsGraph sliceTensor:sortedTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; + MPSGraphTensor* argreduceOutTensor = nil; + argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor + axis:(NSInteger)dim_ + name:@"argmax_out"]; + MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor + dimension:dim_ + start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1) + length:1 + name:nil]; + + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->indicesTensor_ = argOutputTensor; + } + return newCachedGraph; + }); + cachedGraph = static_cast(tmpCachedGraph); + } + + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape); + auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape); + + NSDictionary *feeds = @{ + inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), + }; + + NSDictionary *results = @{ + outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(), + indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData() + }; + + native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + + } + +} + +// in case mps sortWithTensor do not supported on macOS +std::tuple median_from_cpu( + const Tensor& self, + int64_t dim, + bool keepdim, Tensor & valuesI, Tensor & indicesI, IntArrayRef vec_out_shape, IntArrayRef vec_apparent_out_shape) { + // Tensor a = at::median(self.to("cpu")); + Tensor values; + Tensor indices; + if (!keepdim){ + values = at::empty({vec_out_shape}, self.options()); + indices = at::empty({vec_out_shape}, self.options().dtype(kLong)); + + } + else{ + values = at::empty({vec_apparent_out_shape}, self.options()); + indices = at::empty({vec_apparent_out_shape}, self.options().dtype(kLong)); + } + at::median_out(values, indices, self, dim, keepdim); + + valuesI.copy_(values); + indicesI.copy_(indices); + return std::forward_as_tuple(valuesI, indicesI); +} + +TORCH_API ::std::tuple median_out_mps + (const at::Tensor & input_t, + int64_t dim, + bool keepdim, + at::Tensor & values, + at::Tensor & indices){ + + TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS"); + + namespace native_mps = at::native::mps; + int64_t dim_ = maybe_wrap_dim(dim, input_t.dim()); + native::zero_numel_check_dims(input_t, dim_, "max()"); + + // Calculate the output shape according to keepdim=True + // If there is no dim argument, the input shape is flattened + IntArrayRef input_shape = input_t.sizes(); + int64_t num_input_dims = input_shape.size(); + NSMutableArray *apparent_out_shape = nil; + // Use this if keepdim is false + int64_t num_output_dims = num_input_dims - 1; + + std::vector vec_apparent_out_shape(num_input_dims); + std::vector vec_out_shape(num_output_dims); + + apparent_out_shape = [NSMutableArray arrayWithCapacity:num_input_dims]; + // Counter for shape when keepdim is false + int out_i = 0; + for(int i = 0; i < num_input_dims; i++) { + if(dim_ == i) { + apparent_out_shape[i] = @1; + vec_apparent_out_shape[i] = 1; + } + else { + apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]]; + vec_apparent_out_shape[i] = input_shape[i]; + vec_out_shape[out_i] = input_shape[i]; + out_i++; + } + } + + if(!keepdim) { + values = at::native::empty_mps( + IntArrayRef(vec_out_shape), + input_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + indices = at::native::empty_mps( + IntArrayRef(vec_out_shape), + ScalarType::Long, + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } else { + values = at::native::empty_mps( + IntArrayRef(vec_apparent_out_shape), + input_t.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + indices = at::native::empty_mps( + IntArrayRef(vec_apparent_out_shape), + ScalarType::Long, + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + } + + if (values.numel() == 0 || input_t.numel() == 0) { + return std::tuple{values, indices}; + } + + if(!is_macos_13_or_newer()){ + TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.", + "Falling back on CPU. This may have performace implications."); + return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape) ); + } + + median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps"); + + return std::tuple{values, indices}; +} + } // native } // at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8046b4f6ac4b..b1d1094667e1 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3492,6 +3492,7 @@ dispatch: CPU: median_cpu CUDA: median_cuda + MPS: median_mps autogen: median.out - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) @@ -3503,6 +3504,7 @@ dispatch: CPU: median_out_cpu CUDA: median_out_cuda + MPS: median_out_mps - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) variants: function, method diff --git a/test/test_mps.py b/test/test_mps.py index 31e2e367e7de..52d669545b30 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -2624,6 +2624,47 @@ def helper(n, c, h, w, dtype=torch.float32): helper(2, 8, 4, 5, torch.int32) # helper(2, 8, 4, 5, torch.int64) + def test_median(self): + def helper_dtype_int32(n1, n2, n3): + cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32) + mps_x = cpu_x.detach().clone().to('mps') + + result_cpu = torch.median(cpu_x) + result_mps = torch.median(mps_x) + + self.assertEqual(result_cpu, result_mps) + + for dim in [0, 1, 2]: + for keepdim in [True, False]: + y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) + refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) + self.assertEqual(y, refy) + self.assertEqual(idx, refidx) + + def helper_dtype_float32(n1, n2, n3): + cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32) + mps_x = cpu_x.detach().clone().to('mps') + + result_cpu = torch.median(cpu_x) + result_mps = torch.median(mps_x) + + self.assertEqual(result_cpu, result_mps) + + for dim in [0, 1, 2]: + for keepdim in [True, False]: + y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim) + refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim) + self.assertEqual(y, refy) + self.assertEqual(idx, refidx) + + helper_dtype_int32(10, 10, 10) # median at even place + helper_dtype_int32(3, 3, 3) # median at odd place + helper_dtype_int32(1, 1, 1) + helper_dtype_int32(1, 2, 3) + helper_dtype_float32(10, 10, 10) + helper_dtype_float32(3, 3, 3) + helper_dtype_float32(1, 1, 1) + def test_any(self): def helper(shape): input_xs = [] From 92f9214a311a6b94dff9e38836d5b0849a539647 Mon Sep 17 00:00:00 2001 From: mikey dagitses Date: Thu, 17 Nov 2022 16:20:45 -0500 Subject: [PATCH 317/453] add -Wnarrowing as error to cmake builds (#89207) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89207 Approved by: https://github.com/wconstab, https://github.com/malfet --- CMakeLists.txt | 2 +- aten/src/ATen/native/NNPACK.cpp | 4 ++-- aten/src/ATen/native/mps/operations/Distributions.mm | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5ea01f0c0f53..3d70f6ef5816 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -815,7 +815,6 @@ endif() # ---[ Build flags if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -O2 -fPIC") - string(APPEND CMAKE_CXX_FLAGS " -Wno-narrowing") # Eigen fails to build with some versions, so convert this to a warning # Details at http://eigen.tuxfamily.org/bz/show_bug.cgi?id=1459 string(APPEND CMAKE_CXX_FLAGS " -Wall") @@ -824,6 +823,7 @@ if(NOT MSVC) append_cxx_flag_if_supported("-Werror=non-virtual-dtor" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=braced-scalar-init" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=range-loop-construct" CMAKE_CXX_FLAGS) + append_cxx_flag_if_supported("-Wnarrowing" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-missing-field-initializers" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-type-limits" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Wno-array-bounds" CMAKE_CXX_FLAGS) diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 544641f091a3..4fb40a17d026 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -209,8 +209,8 @@ Tensor _nnpack_spatial_convolution( .height = (size_t)output.size(2), }; const nnp_size output_subsample = { - .width = stride[1], - .height = stride[0], + .width = static_cast(stride[1]), + .height = static_cast(stride[0]), }; const auto input_ = input.contiguous(); diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index a1a41d11e5b5..1da2457f3a37 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -438,7 +438,7 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional(n_sample), numCategories}; MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count] shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]] dataType:MPSDataTypeUInt32]; From 65bcd1f88099dfeefccb6c6b7a0918e3a7ded606 Mon Sep 17 00:00:00 2001 From: John Detloff Date: Fri, 18 Nov 2022 03:17:35 +0000 Subject: [PATCH 318/453] Add previously deleted circleci readme back to repo (#85598) This readme was deleted here: https://github.com/pytorch/pytorch/pull/73224 I chatted with the author, who doesn't remember exactly why it was deleted but suspects it was due either to out of date contents or because of the upcoming migration to github actions. With that said, we have references to this readme through our circleci directory, and since we do still have a lot of circleci workflows I feel this readme still adds a lot of value. (I recently did some CI tasks that required me to dig this readme up in order to solve a problem). I recommend we restore this file with a warning that its contents may be out of date, until our CircleCI workflows are entirely migrated to Github Actions Pull Request resolved: https://github.com/pytorch/pytorch/pull/85598 Approved by: https://github.com/clee2000, https://github.com/malfet --- .circleci/README.md | 468 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 468 insertions(+) create mode 100644 .circleci/README.md diff --git a/.circleci/README.md b/.circleci/README.md new file mode 100644 index 000000000000..e2429b4d1f03 --- /dev/null +++ b/.circleci/README.md @@ -0,0 +1,468 @@ +Warning +======= + +Contents may be out of date. Our CircleCI workflows are gradually being migrated to Github actions. + +Structure of CI +=============== + +setup job: +1. Does a git checkout +2. Persists CircleCI scripts (everything in `.circleci`) into a workspace. Why? + We don't always do a Git checkout on all subjobs, but we usually + still want to be able to call scripts one way or another in a subjob. + Persisting files this way lets us have access to them without doing a + checkout. This workspace is conventionally mounted on `~/workspace` + (this is distinguished from `~/project`, which is the conventional + working directory that CircleCI will default to starting your jobs + in.) +3. Write out the commit message to `.circleci/COMMIT_MSG`. This is so + we can determine in subjobs if we should actually run the jobs or + not, even if there isn't a Git checkout. + + +CircleCI configuration generator +================================ + +One may no longer make changes to the `.circleci/config.yml` file directly. +Instead, one must edit these Python scripts or files in the `verbatim-sources/` directory. + + +Usage +---------- + +1. Make changes to these scripts. +2. Run the `regenerate.sh` script in this directory and commit the script changes and the resulting change to `config.yml`. + +You'll see a build failure on GitHub if the scripts don't agree with the checked-in version. + + +Motivation +---------- + +These scripts establish a single, authoritative source of documentation for the CircleCI configuration matrix. +The documentation, in the form of diagrams, is automatically generated and cannot drift out of sync with the YAML content. + +Furthermore, consistency is enforced within the YAML config itself, by using a single source of data to generate +multiple parts of the file. + +* Facilitates one-off culling/enabling of CI configs for testing PRs on special targets + +Also see https://github.com/pytorch/pytorch/issues/17038 + + +Future direction +---------------- + +### Declaring sparse config subsets +See comment [here](https://github.com/pytorch/pytorch/pull/17323#pullrequestreview-206945747): + +In contrast with a full recursive tree traversal of configuration dimensions, +> in the future I think we actually want to decrease our matrix somewhat and have only a few mostly-orthogonal builds that taste as many different features as possible on PRs, plus a more complete suite on every PR and maybe an almost full suite nightly/weekly (we don't have this yet). Specifying PR jobs in the future might be easier to read with an explicit list when we come to this. +---------------- +---------------- + +# How do the binaries / nightlies / releases work? + +### What is a binary? + +A binary or package (used interchangeably) is a pre-built collection of c++ libraries, header files, python bits, and other files. We build these and distribute them so that users do not need to install from source. + +A **binary configuration** is a collection of + +* release or nightly + * releases are stable, nightlies are beta and built every night +* python version + * linux: 3.7m (mu is wide unicode or something like that. It usually doesn't matter but you should know that it exists) + * macos: 3.7, 3.8 + * windows: 3.7, 3.8 +* cpu version + * cpu, cuda 9.0, cuda 10.0 + * The supported cuda versions occasionally change +* operating system + * Linux - these are all built on CentOS. There haven't been any problems in the past building on CentOS and using on Ubuntu + * MacOS + * Windows - these are built on Azure pipelines +* devtoolset version (gcc compiler version) + * This only matters on Linux cause only Linux uses gcc. tldr is gcc made a backwards incompatible change from gcc 4.8 to gcc 5, because it had to change how it implemented std::vector and std::string + +### Where are the binaries? + +The binaries are built in CircleCI. There are nightly binaries built every night at 9pm PST (midnight EST) and release binaries corresponding to Pytorch releases, usually every few months. + +We have 3 types of binary packages + +* pip packages - nightlies are stored on s3 (pip install -f \). releases are stored in a pip repo (pip install torch) (ask Soumith about this) +* conda packages - nightlies and releases are both stored in a conda repo. Nighty packages have a '_nightly' suffix +* libtorch packages - these are zips of all the c++ libraries, header files, and sometimes dependencies. These are c++ only + * shared with dependencies (the only supported option for Windows) + * static with dependencies + * shared without dependencies + * static without dependencies + +All binaries are built in CircleCI workflows except Windows. There are checked-in workflows (committed into the .circleci/config.yml) to build the nightlies every night. Releases are built by manually pushing a PR that builds the suite of release binaries (overwrite the config.yml to build the release) + +# CircleCI structure of the binaries + +Some quick vocab: + +* A \**workflow** is a CircleCI concept; it is a DAG of '**jobs**'. ctrl-f 'workflows' on https://github.com/pytorch/pytorch/blob/master/.circleci/config.yml to see the workflows. +* **jobs** are a sequence of '**steps**' +* **steps** are usually just a bash script or a builtin CircleCI command. *All steps run in new environments, environment variables declared in one script DO NOT persist to following steps* +* CircleCI has a **workspace**, which is essentially a cache between steps of the *same job* in which you can store artifacts between steps. + +## How are the workflows structured? + +The nightly binaries have 3 workflows. We have one job (actually 3 jobs: build, test, and upload) per binary configuration + +1. binary_builds + 1. every day midnight EST + 2. linux: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/linux-binary-build-defaults.yml + 3. macos: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/macos-binary-build-defaults.yml + 4. For each binary configuration, e.g. linux_conda_3.7_cpu there is a + 1. binary_linux_conda_3.7_cpu_build + 1. Builds the build. On linux jobs this uses the 'docker executor'. + 2. Persists the package to the workspace + 2. binary_linux_conda_3.7_cpu_test + 1. Loads the package to the workspace + 2. Spins up a docker image (on Linux), mapping the package and code repos into the docker + 3. Runs some smoke tests in the docker + 4. (Actually, for macos this is a step rather than a separate job) + 3. binary_linux_conda_3.7_cpu_upload + 1. Logs in to aws/conda + 2. Uploads the package +2. update_s3_htmls + 1. every day 5am EST + 2. https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/binary_update_htmls.yml + 3. See below for what these are for and why they're needed + 4. Three jobs that each examine the current contents of aws and the conda repo and update some html files in s3 +3. binarysmoketests + 1. every day + 2. https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-build-smoke-tests-defaults.yml + 3. For each binary configuration, e.g. linux_conda_3.7_cpu there is a + 1. smoke_linux_conda_3.7_cpu + 1. Downloads the package from the cloud, e.g. using the official pip or conda instructions + 2. Runs the smoke tests + +## How are the jobs structured? + +The jobs are in https://github.com/pytorch/pytorch/tree/master/.circleci/verbatim-sources. Jobs are made of multiple steps. There are some shared steps used by all the binaries/smokes. Steps of these jobs are all delegated to scripts in https://github.com/pytorch/pytorch/tree/master/.circleci/scripts . + +* Linux jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/linux-binary-build-defaults.yml + * binary_linux_build.sh + * binary_linux_test.sh + * binary_linux_upload.sh +* MacOS jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/macos-binary-build-defaults.yml + * binary_macos_build.sh + * binary_macos_test.sh + * binary_macos_upload.sh +* Update html jobs: https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/binary_update_htmls.yml + * These delegate from the pytorch/builder repo + * https://github.com/pytorch/builder/blob/master/cron/update_s3_htmls.sh + * https://github.com/pytorch/builder/blob/master/cron/upload_binary_sizes.sh +* Smoke jobs (both linux and macos): https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-build-smoke-tests-defaults.yml + * These delegate from the pytorch/builder repo + * https://github.com/pytorch/builder/blob/master/run_tests.sh + * https://github.com/pytorch/builder/blob/master/smoke_test.sh + * https://github.com/pytorch/builder/blob/master/check_binary.sh +* Common shared code (shared across linux and macos): https://github.com/pytorch/pytorch/blob/master/.circleci/verbatim-sources/nightly-binary-build-defaults.yml + * binary_checkout.sh - checks out pytorch/builder repo. Right now this also checks out pytorch/pytorch, but it shouldn't. pytorch/pytorch should just be shared through the workspace. This can handle being run before binary_populate_env.sh + * binary_populate_env.sh - parses BUILD_ENVIRONMENT into the separate env variables that make up a binary configuration. Also sets lots of default values, the date, the version strings, the location of folders in s3, all sorts of things. This generally has to be run before other steps. + * binary_install_miniconda.sh - Installs miniconda, cross platform. Also hacks this for the update_binary_sizes job that doesn't have the right env variables + * binary_run_in_docker.sh - Takes a bash script file (the actual test code) from a hardcoded location, spins up a docker image, and runs the script inside the docker image + +### **Why do the steps all refer to scripts?** + +CircleCI creates a final yaml file by inlining every <<* segment, so if we were to keep all the code in the config.yml itself then the config size would go over 4 MB and cause infra problems. + +### **What is binary_run_in_docker for?** + +So, CircleCI has several executor types: macos, machine, and docker are the ones we use. The 'machine' executor gives you two cores on some linux vm. The 'docker' executor gives you considerably more cores (nproc was 32 instead of 2 back when I tried in February). Since the dockers are faster, we try to run everything that we can in dockers. Thus + +* linux build jobs use the docker executor. Running them on the docker executor was at least 2x faster than running them on the machine executor +* linux test jobs use the machine executor in order for them to properly interface with GPUs since docker executors cannot execute with attached GPUs +* linux upload jobs use the machine executor. The upload jobs are so short that it doesn't really matter what they use +* linux smoke test jobs use the machine executor for the same reason as the linux test jobs + +binary_run_in_docker.sh is a way to share the docker start-up code between the binary test jobs and the binary smoke test jobs + +### **Why does binary_checkout also checkout pytorch? Why shouldn't it?** + +We want all the nightly binary jobs to run on the exact same git commit, so we wrote our own checkout logic to ensure that the same commit was always picked. Later circleci changed that to use a single pytorch checkout and persist it through the workspace (they did this because our config file was too big, so they wanted to take a lot of the setup code into scripts, but the scripts needed the code repo to exist to be called, so they added a prereq step called 'setup' to checkout the code and persist the needed scripts to the workspace). The changes to the binary jobs were not properly tested, so they all broke from missing pytorch code no longer existing. We hotfixed the problem by adding the pytorch checkout back to binary_checkout, so now there's two checkouts of pytorch on the binary jobs. This problem still needs to be fixed, but it takes careful tracing of which code is being called where. + +# Code structure of the binaries (circleci agnostic) + +## Overview + +The code that runs the binaries lives in two places, in the normal [github.com/pytorch/pytorch](http://github.com/pytorch/pytorch), but also in [github.com/pytorch/builder](http://github.com/pytorch/builder), which is a repo that defines how all the binaries are built. The relevant code is + + +``` +# All code needed to set-up environments for build code to run in, +# but only code that is specific to the current CI system +pytorch/pytorch +- .circleci/ # Folder that holds all circleci related stuff + - config.yml # GENERATED file that actually controls all circleci behavior + - verbatim-sources # Used to generate job/workflow sections in ^ + - scripts/ # Code needed to prepare circleci environments for binary build scripts +- setup.py # Builds pytorch. This is wrapped in pytorch/builder +- cmake files # used in normal building of pytorch +# All code needed to prepare a binary build, given an environment +# with all the right variables/packages/paths. +pytorch/builder +# Given an installed binary and a proper python env, runs some checks +# to make sure the binary was built the proper way. Checks things like +# the library dependencies, symbols present, etc. +- check_binary.sh +# Given an installed binary, runs python tests to make sure everything +# is in order. These should be de-duped. Right now they both run smoke +# tests, but are called from different places. Usually just call some +# import statements, but also has overlap with check_binary.sh above +- run_tests.sh +- smoke_test.sh +# Folders that govern how packages are built. See paragraphs below +- conda/ + - build_pytorch.sh # Entrypoint. Delegates to proper conda build folder + - switch_cuda_version.sh # Switches activate CUDA installation in Docker + - pytorch-nightly/ # Build-folder +- manywheel/ + - build_cpu.sh # Entrypoint for cpu builds + - build.sh # Entrypoint for CUDA builds + - build_common.sh # Actual build script that ^^ call into +- wheel/ + - build_wheel.sh # Entrypoint for wheel builds +- windows/ + - build_pytorch.bat # Entrypoint for wheel builds on Windows +``` + +Every type of package has an entrypoint build script that handles the all the important logic. + +## Conda + +Linux, MacOS and Windows use the same code flow for the conda builds. + +Conda packages are built with conda-build, see https://conda.io/projects/conda-build/en/latest/resources/commands/conda-build.html + +Basically, you pass `conda build` a build folder (pytorch-nightly/ above) that contains a build script and a meta.yaml. The meta.yaml specifies in what python environment to build the package in, and what dependencies the resulting package should have, and the build script gets called in the env to build the thing. +tl;dr on conda-build is + +1. Creates a brand new conda environment, based off of deps in the meta.yaml + 1. Note that environment variables do not get passed into this build env unless they are specified in the meta.yaml + 2. If the build fails this environment will stick around. You can activate it for much easier debugging. The “General Python” section below explains what exactly a python “environment” is. +2. Calls build.sh in the environment +3. Copies the finished package to a new conda env, also specified by the meta.yaml +4. Runs some simple import tests (if specified in the meta.yaml) +5. Saves the finished package as a tarball + +The build.sh we use is essentially a wrapper around `python setup.py build`, but it also manually copies in some of our dependent libraries into the resulting tarball and messes with some rpaths. + +The entrypoint file `builder/conda/build_conda.sh` is complicated because + +* It works for Linux, MacOS and Windows + * The mac builds used to create their own environments, since they all used to be on the same machine. There’s now a lot of extra logic to handle conda envs. This extra machinery could be removed +* It used to handle testing too, which adds more logic messing with python environments too. This extra machinery could be removed. + +## Manywheels (linux pip and libtorch packages) + +Manywheels are pip packages for linux distros. Note that these manywheels are not actually manylinux compliant. + +`builder/manywheel/build_cpu.sh` and `builder/manywheel/build.sh` (for CUDA builds) just set different env vars and then call into `builder/manywheel/build_common.sh` + +The entrypoint file `builder/manywheel/build_common.sh` is really really complicated because + +* This used to handle building for several different python versions at the same time. The loops have been removed, but there's still unnecessary folders and movements here and there. + * The script is never used this way anymore. This extra machinery could be removed. +* This used to handle testing the pip packages too. This is why there’s testing code at the end that messes with python installations and stuff + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * This should really be separate. libtorch packages are c++ only and have no python. They should not share infra with all the python specific stuff in this file. +* There is a lot of messing with rpaths. This is necessary, but could be made much much simpler if the above issues were fixed. + +## Wheels (MacOS pip and libtorch packages) + +The entrypoint file `builder/wheel/build_wheel.sh` is complicated because + +* The mac builds used to all run on one machine (we didn’t have autoscaling mac machines till circleci). So this script handled siloing itself by setting-up and tearing-down its build env and siloing itself into its own build directory. + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * Ditto the comment above. This should definitely be separated out. + +Note that the MacOS Python wheels are still built in conda environments. Some of the dependencies present during build also come from conda. + +## Windows Wheels (Windows pip and libtorch packages) + +The entrypoint file `builder/windows/build_pytorch.bat` is complicated because + +* This used to handle building for several different python versions at the same time. This is why there are loops everywhere + * The script is never used this way anymore. This extra machinery could be removed. +* This used to handle testing the pip packages too. This is why there’s testing code at the end that messes with python installations and stuff + * The script is never used this way anymore. This extra machinery could be removed. +* This also builds libtorch packages + * This should really be separate. libtorch packages are c++ only and have no python. They should not share infra with all the python specific stuff in this file. + +Note that the Windows Python wheels are still built in conda environments. Some of the dependencies present during build also come from conda. + +## General notes + +### Note on run_tests.sh, smoke_test.sh, and check_binary.sh + +* These should all be consolidated +* These must run on all OS types: MacOS, Linux, and Windows +* These all run smoke tests at the moment. They inspect the packages some, maybe run a few import statements. They DO NOT run the python tests nor the cpp tests. The idea is that python tests on master and PR merges will catch all breakages. All these tests have to do is make sure the special binary machinery didn’t mess anything up. +* There are separate run_tests.sh and smoke_test.sh because one used to be called by the smoke jobs and one used to be called by the binary test jobs (see circleci structure section above). This is still true actually, but these could be united into a single script that runs these checks, given an installed pytorch package. + +### Note on libtorch + +Libtorch packages are built in the wheel build scripts: manywheel/build_*.sh for linux and build_wheel.sh for mac. There are several things wrong with this + +* It’s confusing. Most of those scripts deal with python specifics. +* The extra conditionals everywhere severely complicate the wheel build scripts +* The process for building libtorch is different from the official instructions (a plain call to cmake, or a call to a script) + +### Note on docker images / Dockerfiles + +All linux builds occur in docker images. The docker images are + +* pytorch/conda-cuda + * Has ALL CUDA versions installed. The script pytorch/builder/conda/switch_cuda_version.sh sets /usr/local/cuda to a symlink to e.g. /usr/local/cuda-10.0 to enable different CUDA builds + * Also used for cpu builds +* pytorch/manylinux-cuda90 +* pytorch/manylinux-cuda100 + * Also used for cpu builds + +The Dockerfiles are available in pytorch/builder, but there is no circleci job or script to build these docker images, and they cannot be run locally (unless you have the correct local packages/paths). Only Soumith can build them right now. + +### General Python + +* This is still a good explanation of python installations https://caffe2.ai/docs/faq.html#why-do-i-get-import-errors-in-python-when-i-try-to-use-caffe2 + +# How to manually rebuild the binaries + +tl;dr make a PR that looks like https://github.com/pytorch/pytorch/pull/21159 + +Sometimes we want to push a change to master and then rebuild all of today's binaries after that change. As of May 30, 2019 there isn't a way to manually run a workflow in the UI. You can manually re-run a workflow, but it will use the exact same git commits as the first run and will not include any changes. So we have to make a PR and then force circleci to run the binary workflow instead of the normal tests. The above PR is an example of how to do this; essentially you copy-paste the binarybuilds workflow steps into the default workflow steps. If you need to point the builder repo to a different commit then you'd need to change https://github.com/pytorch/pytorch/blob/master/.circleci/scripts/binary_checkout.sh#L42-L45 to checkout what you want. + +## How to test changes to the binaries via .circleci + +Writing PRs that test the binaries is annoying, since the default circleci jobs that run on PRs are not the jobs that you want to run. Likely, changes to the binaries will touch something under .circleci/ and require that .circleci/config.yml be regenerated (.circleci/config.yml controls all .circleci behavior, and is generated using `.circleci/regenerate.sh` in python 3.7). But you also need to manually hardcode the binary jobs that you want to test into the .circleci/config.yml workflow, so you should actually make at least two commits, one for your changes and one to temporarily hardcode jobs. See https://github.com/pytorch/pytorch/pull/22928 as an example of how to do this. + +```sh +# Make your changes +touch .circleci/verbatim-sources/nightly-binary-build-defaults.yml +# Regenerate the yaml, has to be in python 3.7 +.circleci/regenerate.sh +# Make a commit +git add .circleci * +git commit -m "My real changes" +git push origin my_branch +# Now hardcode the jobs that you want in the .circleci/config.yml workflows section +# Also eliminate ensure-consistency and should_run_job checks +# e.g. https://github.com/pytorch/pytorch/commit/2b3344bfed8772fe86e5210cc4ee915dee42b32d +# Make a commit you won't keep +git add .circleci +git commit -m "[DO NOT LAND] testing binaries for above changes" +git push origin my_branch +# Now you need to make some changes to the first commit. +git rebase -i HEAD~2 # mark the first commit as 'edit' +# Make the changes +touch .circleci/verbatim-sources/nightly-binary-build-defaults.yml +.circleci/regenerate.sh +# Ammend the commit and recontinue +git add .circleci +git commit --amend +git rebase --continue +# Update the PR, need to force since the commits are different now +git push origin my_branch --force +``` + +The advantage of this flow is that you can make new changes to the base commit and regenerate the .circleci without having to re-write which binary jobs you want to test on. The downside is that all updates will be force pushes. + +## How to build a binary locally + +### Linux + +You can build Linux binaries locally easily using docker. + +```sh +# Run the docker +# Use the correct docker image, pytorch/conda-cuda used here as an example +# +# -v path/to/foo:path/to/bar makes path/to/foo on your local machine (the +# machine that you're running the command on) accessible to the docker +# container at path/to/bar. So if you then run `touch path/to/bar/baz` +# in the docker container then you will see path/to/foo/baz on your local +# machine. You could also clone the pytorch and builder repos in the docker. +# +# If you know how, add ccache as a volume too and speed up everything +docker run \ + -v your/pytorch/repo:/pytorch \ + -v your/builder/repo:/builder \ + -v where/you/want/packages/to/appear:/final_pkgs \ + -it pytorch/conda-cuda /bin/bash +# Export whatever variables are important to you. All variables that you'd +# possibly need are in .circleci/scripts/binary_populate_env.sh +# You should probably always export at least these 3 variables +export PACKAGE_TYPE=conda +export DESIRED_PYTHON=3.7 +export DESIRED_CUDA=cpu +# Call the entrypoint +# `|& tee foo.log` just copies all stdout and stderr output to foo.log +# The builds generate lots of output so you probably need this when +# building locally. +/builder/conda/build_pytorch.sh |& tee build_output.log +``` + +**Building CUDA binaries on docker** + +You can build CUDA binaries on CPU only machines, but you can only run CUDA binaries on CUDA machines. This means that you can build a CUDA binary on a docker on your laptop if you so choose (though it’s gonna take a long time). + +For Facebook employees, ask about beefy machines that have docker support and use those instead of your laptop; it will be 5x as fast. + +### MacOS + +There’s no easy way to generate reproducible hermetic MacOS environments. If you have a Mac laptop then you can try emulating the .circleci environments as much as possible, but you probably have packages in /usr/local/, possibly installed by brew, that will probably interfere with the build. If you’re trying to repro an error on a Mac build in .circleci and you can’t seem to repro locally, then my best advice is actually to iterate on .circleci :/ + +But if you want to try, then I’d recommend + +```sh +# Create a new terminal +# Clear your LD_LIBRARY_PATH and trim as much out of your PATH as you +# know how to do +# Install a new miniconda +# First remove any other python or conda installation from your PATH +# Always install miniconda 3, even if building for Python <3 +new_conda="~/my_new_conda" +conda_sh="$new_conda/install_miniconda.sh" +curl -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +chmod +x "$conda_sh" +"$conda_sh" -b -p "$MINICONDA_ROOT" +rm -f "$conda_sh" +export PATH="~/my_new_conda/bin:$PATH" +# Create a clean python env +# All MacOS builds use conda to manage the python env and dependencies +# that are built with, even the pip packages +conda create -yn binary python=2.7 +conda activate binary +# Export whatever variables are important to you. All variables that you'd +# possibly need are in .circleci/scripts/binary_populate_env.sh +# You should probably always export at least these 3 variables +export PACKAGE_TYPE=conda +export DESIRED_PYTHON=3.7 +export DESIRED_CUDA=cpu +# Call the entrypoint you want +path/to/builder/wheel/build_wheel.sh +``` + +N.B. installing a brand new miniconda is important. This has to do with how conda installations work. See the “General Python” section above, but tldr; is that + +1. You make the ‘conda’ command accessible by prepending `path/to/conda_root/bin` to your PATH. +2. You make a new env and activate it, which then also gets prepended to your PATH. Now you have `path/to/conda_root/envs/new_env/bin:path/to/conda_root/bin:$PATH` +3. Now say you (or some code that you ran) call python executable `foo` + 1. if you installed `foo` in `new_env`, then `path/to/conda_root/envs/new_env/bin/foo` will get called, as expected. + 2. But if you forgot to installed `foo` in `new_env` but happened to previously install it in your root conda env (called ‘base’), then unix/linux will still find `path/to/conda_root/bin/foo` . This is dangerous, since `foo` can be a different version than you want; `foo` can even be for an incompatible python version! + +Newer conda versions and proper python hygiene can prevent this, but just install a new miniconda to be safe. + +### Windows + +TODO: fill in From 3c2676de3d35fd22f79c46eaa770d03f1418c480 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 18 Nov 2022 03:37:14 +0000 Subject: [PATCH 319/453] [LTC] Restore GetPythonFrames (#89122) Summary: pytorch/pytorch@936e930 delete the registration of GetPythonFramesFunction. Restore that and add a test case to prevent regression. Test Plan: python test/lazy/test_debug_util.py Fixes pytorch/xla#4206. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89122 Approved by: https://github.com/JackCaoG --- test/lazy/test_debug_util.py | 44 +++++++++++++++++++++++++++++++++ torch/csrc/lazy/python/init.cpp | 4 +++ 2 files changed, 48 insertions(+) create mode 100644 test/lazy/test_debug_util.py diff --git a/test/lazy/test_debug_util.py b/test/lazy/test_debug_util.py new file mode 100644 index 000000000000..df201d54737f --- /dev/null +++ b/test/lazy/test_debug_util.py @@ -0,0 +1,44 @@ +# Owner(s): ["oncall: jit"] + +import os +import re +import tempfile +import torch.nn as nn +import unittest + +import torch._lazy +import torch._lazy.ts_backend +from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase + +torch._lazy.ts_backend.init() + + +@unittest.skipIf(IS_WINDOWS, "To be fixed") +class DebugUtilTest(TestCase): + def _run_linear(self): + device = "lazy" + model = nn.Linear(5, 5).to(device) + output = model(torch.randn(1, 5).to(device)) + torch._lazy.mark_step() + + + def test_get_python_frames(self): + # We only care about the first "Python Stacktrace" part of the saved + # graph. However, we cannot save the whole stack for comparison given + # it depends on a lot of things. + partial_graph = (r"Python Stacktrace:.*" + r"mark_step \(.*/_lazy/__init__.py:[0-9]+\).*" + r"_run_linear \(.*lazy/test_debug_util.py:[0-9]+\).*" + r"test_get_python_frames \(.*lazy/test_debug_util.py:[0-9]+\)") + + with tempfile.NamedTemporaryFile(mode="r+", encoding="utf-8") as graph_file: + os.environ["LTC_SAVE_TENSORS_FILE"] = graph_file.name + self._run_linear() + file = graph_file.read() + if re.search(partial_graph, file, re.DOTALL) is None: + print(file) + self.assertTrue(False) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 774df68e26de..0b773788eff9 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -305,6 +305,10 @@ void initLazyBindings(PyObject* module) { #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) return result; }); + + // When libtorch_python is loaded, we register the python frame getter + // otherwise, debug util simply omits python frames + GetPythonFramesFunction() = GetPythonFrames; } } // namespace lazy From 6ed14c7dcfb261e84016407d8025bf3e27999730 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 18 Nov 2022 03:45:53 +0000 Subject: [PATCH 320/453] [vision hash update] update the pinned vision hash (#89102) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89102 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index c9bfe60001af..cc0724ac842d 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -b1f6c9e271368cd84837522af39e68dd4b5768a7 +d710f3d1edc06afa244468cb96603ba6dbd4d9d5 From f4efc5e821259aee1b64ee32f992ea3458dcd546 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 17 Nov 2022 16:45:47 -0800 Subject: [PATCH 321/453] [quant][be] Move some helper functions to the top level to reduce function length (#89246) Summary: att Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89246 Approved by: https://github.com/vkuzo --- torch/ao/quantization/fx/convert.py | 162 ++++++++++++++-------------- 1 file changed, 80 insertions(+), 82 deletions(-) diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 0c1249b4858d..ca6ae61a4c97 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -88,6 +88,83 @@ "run_weight_observers", ] +def _replace_observer_with_quantize_dequantize_node( + model: torch.nn.Module, + graph: Graph, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny], + is_decomposed: bool) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + observer_module = modules[node.target] + maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed) + # Skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all([ + has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())]) + if skip_replacement or maybe_quantize_node_info is None: + # didn't find correponding quantize op and info for the observer_module + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + else: + # otherwise, we can convert the observer moduel call to quantize/dequantize node + node_type, quantize_op, qparams = maybe_quantize_node_info + # replace observer node with quant - dequant node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + if is_decomposed: + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor, + tuple(dq_inputs), + {} + ) + else: + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + +# this is a temporary hack for custom module, we may want to implement +# this properly after the custom module class design is finalized +# TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted +# after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs +# after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. +def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph): + call_custom_module_node = node.args[0] + assert isinstance(call_custom_module_node, Node), \ + f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" + node.replace_all_uses_with(call_custom_module_node) + graph.erase_node(node) + insert_dequantize_node(call_custom_module_node, graph) def restore_state( observed: torch.nn.Module @@ -599,85 +676,6 @@ def convert( if node.op == 'placeholder': graph_inputs.append(node.name) - # TODO: move this outside of this function - def replace_observer_with_quantize_dequantize_node( - model: torch.nn.Module, - graph: Graph, - node: Node, - modules: Dict[str, torch.nn.Module], - node_name_to_scope: Dict[str, Tuple[str, type]], - node_name_to_qconfig: Dict[str, QConfigAny], - is_decomposed: bool) -> None: - """ Replace activation_post_process module call node with quantize and - dequantize node - - Before: - ... -> observer_0(x) -> ... - After: - ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... - """ - assert modules is not None - assert isinstance(node.target, str) - module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) - observer_module = modules[node.target] - maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed) - # Skip replacing observers to quant/dequant nodes if the qconfigs of all - # consumers and producers of this observer are None - skip_replacement = all([ - has_none_qconfig(n, node_name_to_qconfig) for n in - list(node.args) + list(node.users.keys())]) - if skip_replacement or maybe_quantize_node_info is None: - # didn't find correponding quantize op and info for the observer_module - # so we just remove the observer - with graph.inserting_before(node): - node.replace_all_uses_with(node.args[0]) - graph.erase_node(node) - else: - # otherwise, we can convert the observer moduel call to quantize/dequantize node - node_type, quantize_op, qparams = maybe_quantize_node_info - # replace observer node with quant - dequant node - with graph.inserting_before(node): - input_node = node.args[0] - quantize_op_inputs = [input_node] - for key, value in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ['_scale_', '_zero_point_']: - # For scale and zero_point values we register them as buffers in the root module. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) - quantize_op_inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. - quantize_op_inputs.append(value) - - quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) - if is_decomposed: - # use the same qparams from quantize op - dq_inputs = [quantized_node] + quantize_op_inputs[1:] - dequantized_node = graph.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor, - tuple(dq_inputs), - {} - ) - else: - dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) - node.replace_all_uses_with(dequantized_node) - graph.erase_node(node) - - # this is a temporary hack for custom module, we may want to implement - # this properly after the custom module class design is finalized - # TODO: DeQuantStubs are currently inserted only after custom module LSTM, while observers are inserted - # after all other custom modules. In the future, we should simply insert QuantStubs before and DeQuantStubs - # after custom modules in general, and replace these with "quantize" and "dequantize" nodes respectively. - def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Graph): - call_custom_module_node = node.args[0] - assert isinstance(call_custom_module_node, Node), \ - f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" - node.replace_all_uses_with(call_custom_module_node) - graph.erase_node(node) - insert_dequantize_node(call_custom_module_node, graph) - # additional state to override inputs to be quantized, if specified # by the user placeholder_node_seen_cnt = 0 @@ -728,13 +726,13 @@ def replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Gra if _is_activation_post_process(mod): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: - replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) else: - replace_observer_with_quantize_dequantize_node( + _replace_observer_with_quantize_dequantize_node( model, model.graph, node, modules, node_name_to_scope, node_name_to_qconfig, is_decomposed) elif isinstance(mod, DeQuantStub): - replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) + _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) elif is_observed_standalone_module(mod): convert_standalone_module( node, modules, model, is_reference, backend_config) From 6f4f69f54d181b34373e07dcb415f6c2af61868f Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 18 Nov 2022 04:13:03 +0000 Subject: [PATCH 322/453] [Executorch] [Quantization] New pattern for dynamic dequant (#89236) Summary: The op exposed should be qparams, and then we have concerns about prims not being supported so make q and dq ops that take in tensors Test Plan: unit test Differential Revision: D41382580 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89236 Approved by: https://github.com/jerryzh168 --- .../core/test_quantized_tensor.py | 14 ++++---- torch/ao/quantization/fx/_decomposed.py | 34 +++++++++++++++---- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/test/quantization/core/test_quantized_tensor.py b/test/quantization/core/test_quantized_tensor.py index dab53de5b107..a89c98f4e5ab 100644 --- a/test/quantization/core/test_quantized_tensor.py +++ b/test/quantization/core/test_quantized_tensor.py @@ -14,7 +14,6 @@ from torch.testing._internal.common_utils import TestCase, DeterministicGuard import torch.testing._internal.hypothesis_utils as hu from torch.testing._internal.common_quantization import get_supported_device_types -from torch.ao.quantization import MinMaxObserver hu.assert_deadline_disabled() @@ -1499,7 +1498,7 @@ def test_decomposed_dequantize(self): self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) self.assertEqual(dequantized_X, dequantized_decomposed_X) - def test_decomposed_quantize_dynamic(self): + def test_decomposed_dynamic_quant_pattern(self): import torch.ao.quantization.fx._decomposed X = torch.randn(5, 10) dtype = torch.uint8 @@ -1510,14 +1509,13 @@ def test_decomposed_quantize_dynamic(self): quantized_X = torch.quantize_per_tensor(X, scale, zero_point, qdtype) dequantized_X = torch.dequantize(quantized_X) - quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor_dynamic( + # Now try decomposed pattern + (scale_decomposed, zero_point_decomposed) = torch.ops.quantized_decomposed.choose_qparams.tensor( X, quant_min, quant_max, dtype) + quantized_decomposed_X = torch.ops.quantized_decomposed.quantize_per_tensor.tensor( + X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype) - # observer logic is what quantize_per_tensor_dynamic does internally - observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max) - observer(X) - scale_decomposed, zero_point_decomposed = observer.calculate_qparams() - dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor( + dequantized_decomposed_X = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor( quantized_decomposed_X, scale_decomposed, zero_point_decomposed, quant_min, quant_max, dtype ) self.assertEqual(quantized_X.int_repr(), quantized_decomposed_X) diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 3f4d38872e17..fcb4a77a5f49 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -1,5 +1,5 @@ import torch -from torch.library import impl, Library +from torch.library import Library, impl from torch.ao.quantization import MinMaxObserver # Note: decomposed means decomposed quantized tensor, using decomposed so that the @@ -38,6 +38,16 @@ def quantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype): inv_scale = 1.0 / scale return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) +quantized_decomposed_lib.define( + "quantize_per_tensor.tensor(" + "Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd") +def quantize_per_tensor_tensor(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in # the signature as metadata for the input Tensor, this might be useful for pattern # matching in the future @@ -56,13 +66,25 @@ def dequantize_per_tensor(input, scale, zero_point, quant_min, quant_max, dtype) else: raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + quantized_decomposed_lib.define( - "quantize_per_tensor_dynamic(Tensor input, int quant_min, int quant_max, ScalarType dtype) -> Tensor") + "dequantize_per_tensor.tensor(" + "Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype) -> Tensor") + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor_tensor(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) -@impl(quantized_decomposed_lib, "quantize_per_tensor_dynamic", "CompositeExplicitAutograd") -def quantize_per_tensor_dynamic(input, quant_min, quant_max, dtype): + +quantized_decomposed_lib.define( + "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, ScalarType dtype) -> (Tensor, Tensor)") + +@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") +def choose_qparams_tensor(input, quant_min, quant_max, dtype): assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - _quant_min_max_bounds_check(quant_min, quant_max, dtype) + assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}" # Its weird to create an observer manually just to calculate qparams. I tried refactoring this functionality out of observer # into a util and then use that util directly, but I kept running into jit typing errors related to torch.qscheme not @@ -71,4 +93,4 @@ def quantize_per_tensor_dynamic(input, quant_min, quant_max, dtype): observer = MinMaxObserver(quant_min=quant_min, quant_max=quant_max, dtype=tensor_dtype_to_observer_dtype[dtype]) observer(input) scale, zero_point = observer.calculate_qparams() - return torch.ops.quantized_decomposed.quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + return (scale, zero_point) From 6f609dd0e03e11395cc637a34abd68472e5a1e12 Mon Sep 17 00:00:00 2001 From: Yoni Chechik Date: Fri, 18 Nov 2022 04:29:00 +0000 Subject: [PATCH 323/453] docs: conv2d `padding` attribute- add `int` option (#85004) `padding: int` already exists but isn't mentioned in the genereted docs Pull Request resolved: https://github.com/pytorch/pytorch/pull/85004 Approved by: https://github.com/albanD, https://github.com/kit1980 --- torch/nn/modules/conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 93a97f15e7c8..5c081e64ecca 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -340,7 +340,7 @@ class Conv2d(_ConvNd): number or a tuple. * :attr:`padding` controls the amount of padding applied to the input. It - can be either a string {{'valid', 'same'}} or a tuple of ints giving the + can be either a string {{'valid', 'same'}} or an int / a tuple of ints giving the amount of implicit padding applied on both sides. * :attr:`dilation` controls the spacing between the kernel points; also From ba5e39e106caaf4e013fbfc4890d3df13e66d6c9 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 17 Nov 2022 18:10:40 +0000 Subject: [PATCH 324/453] Fix tol for test_nvfuser_correctness__softmax_backward_data_cuda (#89178) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89178 Approved by: https://github.com/kit1980 --- torch/testing/_internal/common_methods_invocations.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24ef757b768d..50732af6f857 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -420,7 +420,9 @@ def sample_inputs_softmax_backward_data(op_info, device, dtype, requires_grad, * input_dtypes += [torch.float16] for (shape, dim), input_dtype in product(cases, input_dtypes): - yield SampleInput(make_arg(shape), make_arg(shape), dim, input_dtype) + input = make_arg(shape) + output = torch.nn.functional.softmax(input, dim=dim, dtype=input_dtype) + yield SampleInput(make_arg(shape), output, dim, input_dtype) def sample_inputs_native_batch_norm(op_info, device, dtype, requires_grad, **kwargs): samples = sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs) @@ -10596,6 +10598,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): skips=( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cpu'), DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + DecorateInfo(toleranceOverride({torch.float16: tol(atol=2e-4, rtol=2e-3), + torch.bfloat16: tol(atol=1e-3, rtol=0.016)}), + 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'), ), ), # `softmin` supports different dtypes based on whether `dtype` argument, From f1fb586bc64b96264f4409421d758e9336f19eef Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 17 Nov 2022 18:50:33 +0000 Subject: [PATCH 325/453] Symintify repeat_interleave.self_int (#89111) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89111 Approved by: https://github.com/ezyang --- .../functorch/BatchRulesDecompositions.cpp | 2 +- aten/src/ATen/native/Repeat.cpp | 17 +++++++--- aten/src/ATen/native/native_functions.yaml | 4 ++- .../cuda/NestedTensorTransformerFunctions.cpp | 4 +-- .../quantized/FakeQuantPerTensorAffine.cpp | 6 ++-- c10/core/SymFloat.cpp | 8 +++++ c10/core/SymFloat.h | 10 ++++++ test/dynamo/test_dynamic_shapes.py | 5 --- test/test_proxy_tensor.py | 3 +- torch/_prims/__init__.py | 10 +++--- torch/csrc/utils/tensor_new.cpp | 32 +++++++++++++++---- 11 files changed, 73 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index e31b36d11241..05ee8d07a410 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -184,7 +184,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(positive); OP_DECOMPOSE(qr); OP_DECOMPOSE(ravel); - OP_DECOMPOSE2(repeat_interleave, self_int); + m.impl("repeat_interleave.self_int", native::repeat_interleave_symint); OP_DECOMPOSE2(repeat_interleave, self_Tensor); m.impl("reshape", native::reshape_symint); OP_DECOMPOSE(resolve_conj); diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp index b671a2232044..c8c4e134929f 100644 --- a/aten/src/ATen/native/Repeat.cpp +++ b/aten/src/ATen/native/Repeat.cpp @@ -75,11 +75,11 @@ Tensor repeat_interleave( } Tensor repeats_ = repeats; - if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.size(0) == 1)) { - repeats_ = repeats.reshape({1}).expand({input.size(dim.value())}); + if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) { + repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())}); } else if (repeats.dim() == 1) { TORCH_CHECK( - repeats.size(0) == input.size(dim.value()), + repeats.sym_size(0) == input.sym_size(dim.value()), "repeats must have the same size as input along dim") } else { AT_ERROR("repeats must be 0-dim or 1-dim tensor"); @@ -102,10 +102,17 @@ Tensor repeat_interleave( int64_t repeats, c10::optional dim, c10::optional output_size) { - at::Tensor repeats_ = - at::empty(1, self.options().dtype(at::kLong)).fill_(repeats); + at::Tensor repeats_ = at::empty(1, self.options().dtype(at::kLong)).fill_(repeats); return at::native::repeat_interleave(self, repeats_, dim, output_size); } +Tensor repeat_interleave_symint( + const Tensor& self, + c10::SymInt repeats, + c10::optional dim, + c10::optional output_size) { + return at::native::repeat_interleave(self, repeats.guard_int(__FILE__, __LINE__), dim, output_size); + } + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b1d1094667e1..5cf0e759db1d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4320,8 +4320,10 @@ - func: repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method -- func: repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> Tensor +- func: repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, int? output_size=None) -> Tensor variants: function, method + dispatch: + CompositeImplicitAutograd: repeat_interleave_symint - func: reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) variants: function, method diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 411ebdb19b5a..c2bf4e08ce04 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -152,8 +152,8 @@ Tensor NestedTensor_to_padded_tensor_cuda( if (t_dim == 3 && nt_input->opt_size(2) && (*nt_input->opt_size(2) > 0) && !(output_size.has_value())) { Tensor nt_sizes = nt_input->get_nested_size_tensor(); - Tensor sizes_dim1 = at::native::narrow(nt_sizes, 1, 0, 1); - Tensor sizes_dim2 = at::native::narrow(nt_sizes, 1, 1, 1); + Tensor sizes_dim1 = at::native::narrow_symint(nt_sizes, 1, 0, 1); + Tensor sizes_dim2 = at::native::narrow_symint(nt_sizes, 1, 1, 1); Tensor result = at::detail::make_tensor( nt_input->get_buffer(), sizes_dim1 * sizes_dim2[0]); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.dim() == 2); diff --git a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp index 700b3b14b180..aac039f0e03e 100644 --- a/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp +++ b/aten/src/ATen/native/quantized/FakeQuantPerTensorAffine.cpp @@ -122,10 +122,10 @@ Tensor fake_quantize_per_tensor_affine_cachemask_backward( const Tensor& dY, const Tensor& mask) { TORCH_CHECK(mask.scalar_type() == ScalarType::Bool); - TORCH_CHECK(mask.numel() == dY.numel(), + TORCH_CHECK(mask.sym_numel() == dY.sym_numel(), "`mask` and `dY` are not the same size: ", - "`mask` is size ", mask.numel(), " and `dY` is size ", dY.numel()); - if (dY.numel() <= 0) { + "`mask` is size ", mask.sym_numel(), " and `dY` is size ", dY.sym_numel()); + if (dY.sym_numel() <= 0) { return dY; } // Note: no additional kernels needed, since mask is pre-computed diff --git a/c10/core/SymFloat.cpp b/c10/core/SymFloat.cpp index 81e8f25d5bb6..511c50e3398e 100644 --- a/c10/core/SymFloat.cpp +++ b/c10/core/SymFloat.cpp @@ -70,4 +70,12 @@ std::ostream& operator<<(std::ostream& os, const SymFloat& s) { return os; } +double SymFloat::guard_float(const char* file, int64_t line) const { + if (!is_symbolic()) { + return data_; + } + SymNode a = toSymNodeImpl(); + return a->guard_float(file, line); +} + } // namespace c10 diff --git a/c10/core/SymFloat.h b/c10/core/SymFloat.h index 7da364ce127a..ff9e101e31af 100644 --- a/c10/core/SymFloat.h +++ b/c10/core/SymFloat.h @@ -40,6 +40,16 @@ class C10_API SymFloat { SymFloat operator*(const SymFloat&) const; SymFloat operator/(const SymFloat&) const; + // Insert a guard for the float to be its concrete value, and then return + // that value. This operation always works, even if the float is symbolic, + // so long as we know what the underlying value is. Don't blindly put this + // everywhere; you can cause overspecialization of PyTorch programs with + // this method. + // + // It should be called as guard_float(__FILE__, __LINE__). The file and line + // number can be used to diagnose overspecialization. + double guard_float(const char* file, int64_t line) const; + // N.B. It's important to keep this definition in the header // as we expect if checks to be folded for mobile builds // where `is_symbolic` is always false diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index f3964a777aa8..2eb16784514d 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -106,11 +106,6 @@ def make_dynamic_cls(cls): DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes ) -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes - # Unable to cast Python instance to C++ type -) - if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8dc42be7fdfb..21682ac76fc6 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1122,6 +1122,8 @@ def f(a, b, c, d, e): xfail('multinomial'), xfail('cholesky'), xfail('cholesky_inverse'), + # cannot do these as they rely on tensor data + xfail('repeat_interleave'), # ASAN failures due to divide by 0 skip('nn.functional.nll_loss'), } @@ -1283,7 +1285,6 @@ def f(a, b, c, d, e): xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition - xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index da8d9af723ac..22917ec048eb 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -2323,10 +2323,12 @@ def _arange_meta( step != 0, lambda: "step must be nonzero", ) - utils.check( - math.isfinite(start) and math.isfinite(end), - lambda: f"unsupported range: {start} -> {end}", - ) + # SymInts can't represent inf + if not isinstance(start, torch.SymInt) and not isinstance(end, torch.SymInt): + utils.check( + math.isfinite(start) and math.isfinite(end), + lambda: f"unsupported range: {start} -> {end}", + ) utils.check( (step > 0 and end >= start) or (step < 0 and end <= start), lambda: "upper bound and lower bound inconsistent with step sign", diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 707ebeb19e84..83506346505e 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -79,10 +79,10 @@ Tensor new_with_sizes( c10::TensorOptions options, at::ScalarType scalar_type, const optional& device, - IntArrayRef sizes) { + c10::SymIntArrayRef sizes) { maybe_initialize_cuda(options.device()); pybind11::gil_scoped_release no_gil; - return torch::empty(sizes, build_options(options, scalar_type, device)); + return at::empty_symint(sizes, build_options(options, scalar_type, device)); } Tensor new_with_storage( @@ -124,6 +124,12 @@ std::vector compute_sizes(PyObject* seq, ScalarType scalar_type) { } ScalarType infer_scalar_type(PyObject* obj) { + if (torch::is_symint(obj)) { + return ScalarType::Long; + } + if (torch::is_symfloat(obj)) { + return ScalarType::Double; + } #ifdef USE_NUMPY if (is_numpy_available()) { if (PyArray_Check(obj)) { @@ -204,7 +210,21 @@ void recursive_store( TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr); int64_t ndim = sizes.size(); + bool is_symfloat = torch::is_symfloat(obj); + bool is_symint = torch::is_symint(obj); if (dim == ndim) { + if (is_symfloat) { + auto new_obj = py::reinterpret_borrow(obj); + auto val = new_obj.cast(); + *(double*)data = val.guard_float(__FILE__, __LINE__); + return; + } + if (is_symint) { + auto new_obj = py::reinterpret_borrow(obj); + auto val = new_obj.cast(); + *(int64_t*)data = val.guard_int(__FILE__, __LINE__); + return; + } torch::utils::store_scalar(data, scalarType, obj); return; } @@ -531,7 +551,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new( "new(*, int64_t cdata)|hidden", "new(Tensor indices, Tensor values, *, Device? device=None)", "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)", - "new(IntArrayRef size, *, Device? device=None)", + "new(SymIntArrayRef size, *, Device? device=None)", }); if (ctor_or_new == CtorOrNew::NEW) check_base_legacy_new(dispatch_key, c10::kSparse); @@ -577,7 +597,7 @@ Tensor legacy_sparse_tensor_generic_ctor_new( } } return new_with_sizes( - options, scalar_type, r.deviceOptional(1), r.intlist(0)); + options, scalar_type, r.deviceOptional(1), r.symintlist(0)); } throw std::runtime_error("new(): invalid arguments"); } @@ -615,7 +635,7 @@ Tensor legacy_tensor_generic_ctor_new( // matching with // IntArrayRef, // PyObject* - "new(IntArrayRef size, *, Device? device=None)", + "new(SymIntArrayRef size, *, Device? device=None)", "new(PyObject* data, *, Device? device=None)", }); @@ -690,7 +710,7 @@ Tensor legacy_tensor_generic_ctor_new( options, scalar_type, deviceOptional, r.pyobject(0)); } return new_with_sizes( - options, scalar_type, r.deviceOptional(1), r.intlist(0)); + options, scalar_type, r.deviceOptional(1), r.symintlist(0)); } else if (r.idx == 6) { auto deviceOptional = r.deviceOptional(1); check_legacy_ctor_device(dispatch_key, deviceOptional); From 3bc78295c265df62983fcbcadb4a87ef7d0fbf2d Mon Sep 17 00:00:00 2001 From: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Date: Fri, 18 Nov 2022 05:08:45 +0000 Subject: [PATCH 326/453] Fix consistentcy of histc on CPU and CUDA (#87832) Fixes #87657 The main reason why `histc` returns slightly different outputs is the difference on how bin position is calculate. The CPU calculates it as: https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cpu/HistogramKernel.cpp#L168-L170 which is basically `(i - a) / (b - a) * N`, while cuda code https://github.com/pytorch/pytorch/blob/449778a939f2adc8867c5035b08be4e2d88339d8/aten/src/ATen/native/cuda/SummaryOps.cu#L41 which is `(i - a) * N / (b - a)`. For some cases like in #87657 the order of arithmetic operations matters due to the floating point round-off. ________________ Not sure where would be the most appropriate place to put the unit test. Hope `test_reductions::test_histc` will do. Pull Request resolved: https://github.com/pytorch/pytorch/pull/87832 Approved by: https://github.com/soumith --- aten/src/ATen/native/cpu/HistogramKernel.cpp | 4 ++-- test/test_reductions.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cpu/HistogramKernel.cpp b/aten/src/ATen/native/cpu/HistogramKernel.cpp index 932bf9beb499..83011aa2e9a7 100644 --- a/aten/src/ATen/native/cpu/HistogramKernel.cpp +++ b/aten/src/ATen/native/cpu/HistogramKernel.cpp @@ -166,8 +166,8 @@ void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges, * the appropriate bin via simple division. */ pos = static_cast((elt - leftmost_edge[dim]) - / (rightmost_edge[dim] - leftmost_edge[dim]) - * (num_bin_edges[dim] - 1)); + * (num_bin_edges[dim] - 1) + / (rightmost_edge[dim] - leftmost_edge[dim])); /* Ensures consistency with bin_edges by checking the bins to the left and right * of the selected position. Necessary for cases in which an element very close diff --git a/test/test_reductions.py b/test/test_reductions.py index a4be31cd6f92..8d91f56545f0 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -2843,6 +2843,9 @@ def test_against_np(tensor, bins=100, min=0, max=0): expanded = torch.randn(1, 5, 1, 2, device=device).expand(3, 5, 7, 2) test_against_np(expanded) + linear = torch.linspace(0, 0.99 - 5.0e-7, 101).to(device) + test_against_np(linear, bins=20, min=0, max=0.99) + @onlyCPU def test_histc_bfloat16(self, device): actual = torch.histc( From ab75982d3a8d76052dbaf1eb37c5b9b729ac0dd8 Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 18 Nov 2022 07:03:22 +0000 Subject: [PATCH 327/453] Always retry curl downloads (#89157) Modify our curl commands so that they always retry downloads. By default, curl only retries what it considers to be "transient" errors, based on the server's response. However, curl's estimate of what's transient is very conservative. By adding the --retry-all-errors parameter we'll always retry curl commands. In particular, I'm hoping this mitigates errors where curl fails with the below error ([logs](https://github.com/pytorch/pytorch/actions/runs/3468758110/jobs/5794939941)) `curl: (35) OpenSSL SSL_connect: SSL_ERROR_SYSCALL in connection to ossci-linux.s3.amazonaws.com:443` Some of the modified downloads didn't even have retries, so I added them in More details: https://everything.curl.dev/usingcurl/downloads/retry Pull Request resolved: https://github.com/pytorch/pytorch/pull/89157 Approved by: https://github.com/kit1980, https://github.com/malfet --- .circleci/docker/common/install_cudnn.sh | 4 ++-- .circleci/docker/common/install_docs_reqs.sh | 4 ++-- .circleci/docker/common/install_protobuf.sh | 2 +- .circleci/scripts/binary_install_miniconda.sh | 4 ++-- .circleci/scripts/binary_ios_build.sh | 2 +- .circleci/scripts/binary_ios_upload.sh | 2 +- .circleci/scripts/driver_update.bat | 2 +- .circleci/scripts/setup_ci_environment.sh | 4 ++-- .../scripts/setup_linux_system_environment.sh | 2 +- .circleci/scripts/vs_install.ps1 | 2 +- .circleci/scripts/vs_install_cmath.ps1 | 2 +- .circleci/scripts/windows_cudnn_install.sh | 2 +- .../templates/macos_binary_build_workflow.yml.j2 | 4 ++-- .github/workflows/_ios-build-test.yml | 2 +- .github/workflows/_mac-build.yml | 2 +- ...enerated-macos-arm64-binary-conda-nightly.yml | 12 ++++++------ ...enerated-macos-arm64-binary-wheel-nightly.yml | 12 ++++++------ .../generated-macos-binary-conda-nightly.yml | 16 ++++++++-------- ...d-macos-binary-libtorch-cxx11-abi-nightly.yml | 16 ++++++++-------- ...d-macos-binary-libtorch-pre-cxx11-nightly.yml | 16 ++++++++-------- .../generated-macos-binary-wheel-nightly.yml | 16 ++++++++-------- .jenkins/pytorch/common_utils.sh | 8 ++++++-- .../installation-helpers/activate_miniconda3.bat | 2 +- .../installation-helpers/install_magma.bat | 2 +- .../installation-helpers/install_mkl.bat | 2 +- .../installation-helpers/install_sccache.bat | 4 ++-- scripts/buck_setup.sh | 6 +++--- third_party/gloo | 2 +- third_party/pybind11 | 2 +- 29 files changed, 80 insertions(+), 76 deletions(-) diff --git a/.circleci/docker/common/install_cudnn.sh b/.circleci/docker/common/install_cudnn.sh index 4a8829b1cba1..f68fc6946c2e 100644 --- a/.circleci/docker/common/install_cudnn.sh +++ b/.circleci/docker/common/install_cudnn.sh @@ -6,9 +6,9 @@ if [[ ${CUDNN_VERSION} == 8 ]]; then CUDNN_NAME="cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive" if [[ ${CUDA_VERSION:0:4} == "11.7" ]]; then CUDNN_NAME="cudnn-linux-x86_64-8.5.0.96_cuda11-archive" - curl -OLs https://ossci-linux.s3.amazonaws.com/${CUDNN_NAME}.tar.xz + curl --retry 3 -OLs https://ossci-linux.s3.amazonaws.com/${CUDNN_NAME}.tar.xz else - curl -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz + curl --retry 3 -OLs https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/${CUDNN_NAME}.tar.xz fi tar xf ${CUDNN_NAME}.tar.xz diff --git a/.circleci/docker/common/install_docs_reqs.sh b/.circleci/docker/common/install_docs_reqs.sh index 1adc9e8009a0..e60171208ae1 100644 --- a/.circleci/docker/common/install_docs_reqs.sh +++ b/.circleci/docker/common/install_docs_reqs.sh @@ -7,10 +7,10 @@ if [ -n "$KATEX" ]; then # Ignore error if gpg-agent doesn't exist (for Ubuntu 16.04) apt-get install -y gpg-agent || : - curl -sL https://deb.nodesource.com/setup_12.x | sudo -E bash - + curl --retry 3 -sL https://deb.nodesource.com/setup_12.x | sudo -E bash - sudo apt-get install -y nodejs - curl -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - + curl --retry 3 -sS https://dl.yarnpkg.com/debian/pubkey.gpg | sudo apt-key add - echo "deb https://dl.yarnpkg.com/debian/ stable main" | sudo tee /etc/apt/sources.list.d/yarn.list apt-get update diff --git a/.circleci/docker/common/install_protobuf.sh b/.circleci/docker/common/install_protobuf.sh index 9d9f6c40ba0c..4b7a7a6ac23f 100755 --- a/.circleci/docker/common/install_protobuf.sh +++ b/.circleci/docker/common/install_protobuf.sh @@ -12,7 +12,7 @@ install_protobuf_317() { # g++: error: ./../lib64/crti.o: No such file or directory ln -s /usr/lib64 "$pb_dir/lib64" - curl -LO "https://github.com/protocolbuffers/protobuf/releases/download/v3.17.3/protobuf-all-3.17.3.tar.gz" + curl -LO "https://github.com/protocolbuffers/protobuf/releases/download/v3.17.3/protobuf-all-3.17.3.tar.gz" --retry 3 tar -xvz -C "$pb_dir" --strip-components 1 -f protobuf-all-3.17.3.tar.gz # -j6 to balance memory usage and speed. # naked `-j` seems to use too much memory. diff --git a/.circleci/scripts/binary_install_miniconda.sh b/.circleci/scripts/binary_install_miniconda.sh index 43eb006742ae..3541a32ac6bf 100755 --- a/.circleci/scripts/binary_install_miniconda.sh +++ b/.circleci/scripts/binary_install_miniconda.sh @@ -31,9 +31,9 @@ fi conda_sh="$workdir/install_miniconda.sh" if [[ "$(uname)" == Darwin ]]; then - curl --retry 3 -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh else - curl --retry 3 -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh + curl --retry 3 --retry-all-errors -o "$conda_sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh fi chmod +x "$conda_sh" "$conda_sh" -b -p "$MINICONDA_ROOT" diff --git a/.circleci/scripts/binary_ios_build.sh b/.circleci/scripts/binary_ios_build.sh index 6c7674ed510e..4bb5ea28af73 100644 --- a/.circleci/scripts/binary_ios_build.sh +++ b/.circleci/scripts/binary_ios_build.sh @@ -8,7 +8,7 @@ PROJ_ROOT=/Users/distiller/project export TCLLIBPATH="/usr/local/lib" # Install conda -curl --retry 3 -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +curl --retry 3 --retry-all-errors -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x ~/conda.sh /bin/bash ~/conda.sh -b -p ~/anaconda export PATH="~/anaconda/bin:${PATH}" diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index da38065847ef..7949dc9170b0 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -47,7 +47,7 @@ echo "${IOS_NIGHTLY_BUILD_VERSION}" > version.txt zip -r ${ZIPFILE} install src version.txt LICENSE # upload to aws # Install conda then 'conda install' awscli -curl --retry 3 -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh +curl --retry 3 --retry-all-errors -o ~/conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x ~/conda.sh /bin/bash ~/conda.sh -b -p ~/anaconda export PATH="~/anaconda/bin:${PATH}" diff --git a/.circleci/scripts/driver_update.bat b/.circleci/scripts/driver_update.bat index 46c05475cdba..fb8774366621 100644 --- a/.circleci/scripts/driver_update.bat +++ b/.circleci/scripts/driver_update.bat @@ -1,5 +1,5 @@ set "DRIVER_DOWNLOAD_LINK=https://s3.amazonaws.com/ossci-windows/452.39-data-center-tesla-desktop-win10-64bit-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 452.39-data-center-tesla-desktop-win10-64bit-international.exe +curl --retry 3 --retry-all-errors -kL %DRIVER_DOWNLOAD_LINK% --output 452.39-data-center-tesla-desktop-win10-64bit-international.exe if errorlevel 1 exit /b 1 start /wait 452.39-data-center-tesla-desktop-win10-64bit-international.exe -s -noreboot diff --git a/.circleci/scripts/setup_ci_environment.sh b/.circleci/scripts/setup_ci_environment.sh index e8dd9ab7195b..42a605cd4445 100755 --- a/.circleci/scripts/setup_ci_environment.sh +++ b/.circleci/scripts/setup_ci_environment.sh @@ -40,8 +40,8 @@ if [ -n "${USE_CUDA_DOCKER_RUNTIME:-}" ]; then # Taken directly from https://github.com/NVIDIA/nvidia-docker # Add the package repositories distribution=$(. /etc/os-release;echo "$ID$VERSION_ID") - curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - - curl -s -L "https://nvidia.github.io/nvidia-docker/${distribution}/nvidia-docker.list" | sudo tee /etc/apt/sources.list.d/nvidia-docker.list + curl -s -L --retry 3 --retry-all-errors https://nvidia.github.io/nvidia-docker/gpgkey | sudo apt-key add - + curl -s -L --retry 3 --retry-all-errors "https://nvidia.github.io/nvidia-docker/${distribution}/nvidia-docker.list" | sudo tee /etc/apt/sources.list.d/nvidia-docker.list retry sudo apt-get update -qq # Necessary to get the `--gpus` flag to function within docker diff --git a/.circleci/scripts/setup_linux_system_environment.sh b/.circleci/scripts/setup_linux_system_environment.sh index ce64076e2d64..780f7c1bd379 100755 --- a/.circleci/scripts/setup_linux_system_environment.sh +++ b/.circleci/scripts/setup_linux_system_environment.sh @@ -2,7 +2,7 @@ set -eux -o pipefail # Set up CircleCI GPG keys for apt, if needed -curl --retry 3 -s -L https://packagecloud.io/circleci/trusty/gpgkey | sudo apt-key add - +curl --retry 3 --retry-all-errors -s -L https://packagecloud.io/circleci/trusty/gpgkey | sudo apt-key add - # Stop background apt updates. Hypothetically, the kill should not # be necessary, because stop is supposed to send a kill signal to diff --git a/.circleci/scripts/vs_install.ps1 b/.circleci/scripts/vs_install.ps1 index a2e373078adb..4bbbc24bb043 100644 --- a/.circleci/scripts/vs_install.ps1 +++ b/.circleci/scripts/vs_install.ps1 @@ -29,7 +29,7 @@ if (Test-Path "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswher } echo "Downloading VS installer from S3." -curl.exe --retry 3 -kL $VS_DOWNLOAD_LINK --output vs_installer.exe +curl.exe --retry 3 --retry-all-errors -kL $VS_DOWNLOAD_LINK --output vs_installer.exe if ($LASTEXITCODE -ne 0) { echo "Download of the VS 2019 Version ${env:VS_VERSION} installer failed" exit 1 diff --git a/.circleci/scripts/vs_install_cmath.ps1 b/.circleci/scripts/vs_install_cmath.ps1 index c2998eba2521..62b637ec21b8 100644 --- a/.circleci/scripts/vs_install_cmath.ps1 +++ b/.circleci/scripts/vs_install_cmath.ps1 @@ -1,5 +1,5 @@ $CMATH_DOWNLOAD_LINK = "https://raw.githubusercontent.com/microsoft/STL/12c684bba78f9b032050526abdebf14f58ca26a3/stl/inc/cmath" $VC14_28_INSTALL_PATH="C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.28.29910\include" -curl.exe --retry 3 -kL $CMATH_DOWNLOAD_LINK --output "$home\cmath" +curl.exe --retry 3 --retry-all-errors -kL $CMATH_DOWNLOAD_LINK --output "$home\cmath" Move-Item -Path "$home\cmath" -Destination "$VC14_28_INSTALL_PATH" -Force diff --git a/.circleci/scripts/windows_cudnn_install.sh b/.circleci/scripts/windows_cudnn_install.sh index c279259e8341..bbf45a3290b3 100644 --- a/.circleci/scripts/windows_cudnn_install.sh +++ b/.circleci/scripts/windows_cudnn_install.sh @@ -36,7 +36,7 @@ else tmp_dir=$(mktemp -d) ( pushd "${tmp_dir}" - curl --retry 3 -o "${cudnn_installer_name}" "$cudnn_installer_link" + curl --retry 3 --retry-all-errors -o "${cudnn_installer_name}" "$cudnn_installer_link" 7z x "${cudnn_installer_name}" -ocudnn # Use '${var:?}/*' to avoid potentially expanding to '/*' # Remove all of the directories before attempting to copy files diff --git a/.github/templates/macos_binary_build_workflow.yml.j2 b/.github/templates/macos_binary_build_workflow.yml.j2 index 95802252a4f9..eb0c2ff4b373 100644 --- a/.github/templates/macos_binary_build_workflow.yml.j2 +++ b/.github/templates/macos_binary_build_workflow.yml.j2 @@ -69,7 +69,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -84,7 +84,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/_ios-build-test.yml b/.github/workflows/_ios-build-test.yml index e9b5461dde7f..269ad3f153ca 100644 --- a/.github/workflows/_ios-build-test.yml +++ b/.github/workflows/_ios-build-test.yml @@ -68,7 +68,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" diff --git a/.github/workflows/_mac-build.yml b/.github/workflows/_mac-build.yml index faf069e7a7c3..5ee909f02c22 100644 --- a/.github/workflows/_mac-build.yml +++ b/.github/workflows/_mac-build.yml @@ -116,7 +116,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" echo "SCCACHE_S3_KEY_PREFIX=${GITHUB_WORKFLOW}" >> "${GITHUB_ENV}" diff --git a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml index ce32755e3209..c88b107a90a9 100644 --- a/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-conda-nightly.yml @@ -67,7 +67,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -103,7 +103,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -177,7 +177,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -213,7 +213,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -287,7 +287,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -323,7 +323,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 7a7df02efe89..c8858fd0501b 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -67,7 +67,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -103,7 +103,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -177,7 +177,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -213,7 +213,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -287,7 +287,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -323,7 +323,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/generated-macos-binary-conda-nightly.yml b/.github/workflows/generated-macos-binary-conda-nightly.yml index ba3697e3fef9..52cfb3d98f76 100644 --- a/.github/workflows/generated-macos-binary-conda-nightly.yml +++ b/.github/workflows/generated-macos-binary-conda-nightly.yml @@ -65,7 +65,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -101,7 +101,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -175,7 +175,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -211,7 +211,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -285,7 +285,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -321,7 +321,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -395,7 +395,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -431,7 +431,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml index 381e0a4c73ad..cd9ad45ba561 100644 --- a/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-macos-binary-libtorch-cxx11-abi-nightly.yml @@ -69,7 +69,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -105,7 +105,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -184,7 +184,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -220,7 +220,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -299,7 +299,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -335,7 +335,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -414,7 +414,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -450,7 +450,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml index 55b28480a754..4ce5c6f32c36 100644 --- a/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-macos-binary-libtorch-pre-cxx11-nightly.yml @@ -69,7 +69,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -105,7 +105,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -184,7 +184,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -220,7 +220,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -299,7 +299,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -335,7 +335,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -414,7 +414,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -450,7 +450,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.github/workflows/generated-macos-binary-wheel-nightly.yml b/.github/workflows/generated-macos-binary-wheel-nightly.yml index f4baf9129b69..a3839d6e8a14 100644 --- a/.github/workflows/generated-macos-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-binary-wheel-nightly.yml @@ -65,7 +65,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -101,7 +101,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -175,7 +175,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -211,7 +211,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -285,7 +285,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -321,7 +321,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env @@ -395,7 +395,7 @@ jobs: - name: Install conda and dependencies run: | # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on - curl --retry 3 -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh chmod +x "${RUNNER_TEMP}/conda.sh" /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" @@ -431,7 +431,7 @@ jobs: max_attempts: 3 retry_wait_seconds: 90 command: | - sudo curl --retry 3 https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache sudo chmod +x /usr/local/bin/sccache echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" - name: Populate binary env diff --git a/.jenkins/pytorch/common_utils.sh b/.jenkins/pytorch/common_utils.sh index 7fc1dd6c0f1a..6d3c96b9278f 100644 --- a/.jenkins/pytorch/common_utils.sh +++ b/.jenkins/pytorch/common_utils.sh @@ -9,6 +9,10 @@ log() { printf '%s\n' "$*"; } error() { log "ERROR: $*" >&2; } fatal() { error "$@"; exit 1; } +retry () { + "$@" || (sleep 10 && "$@") || (sleep 20 && "$@") || (sleep 40 && "$@") +} + # compositional trap taken from https://stackoverflow.com/a/7287873/23845 # appends a command to a trap # @@ -78,12 +82,12 @@ function get_exit_code() { function get_bazel() { if [[ $(uname) == "Darwin" ]]; then # download bazel version - curl https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-darwin-x86_64 -Lo tools/bazel + retry curl https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel-4.2.1-darwin-x86_64 -Lo tools/bazel # verify content echo '74d93848f0c9d592e341e48341c53c87e3cb304a54a2a1ee9cff3df422f0b23c tools/bazel' | shasum -a 256 -c >/dev/null else # download bazel version - curl https://ossci-linux.s3.amazonaws.com/bazel-4.2.1-linux-x86_64 -o tools/bazel + retry curl https://ossci-linux.s3.amazonaws.com/bazel-4.2.1-linux-x86_64 -o tools/bazel # verify content echo '1a4f3a3ce292307bceeb44f459883859c793436d564b95319aacb8af1f20557c tools/bazel' | shasum -a 256 -c >/dev/null fi diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat index e6660a17b389..0552d85a407a 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/activate_miniconda3.bat @@ -13,7 +13,7 @@ if not exist %CONDA_PARENT_DIR%\Miniconda3 ( ) if "%INSTALL_FRESH_CONDA%"=="1" ( - curl --retry 3 -k https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe --output %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe + curl --retry 3 --retry-all-errors -k https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe --output %TMP_DIR_WIN%\Miniconda3-latest-Windows-x86_64.exe if errorlevel 1 exit /b if not errorlevel 0 exit /b diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat index d9f3ab1cf821..d0fbf5b20d88 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_magma.bat @@ -24,7 +24,7 @@ if "%CUDA_SUFFIX%" == "" ( if "%REBUILD%"=="" ( if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --output %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z ) else ( aws s3 cp s3://ossci-windows/magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z %TMP_DIR_WIN%\magma_2.5.4_%CUDA_SUFFIX%_%BUILD_TYPE%.7z --quiet ) diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat index c700a04a1e4a..6c676d1baede 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_mkl.bat @@ -1,6 +1,6 @@ if "%REBUILD%"=="" ( if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/mkl_2020.2.254.7z --output %TMP_DIR_WIN%\mkl.7z + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/mkl_2020.2.254.7z --output %TMP_DIR_WIN%\mkl.7z ) else ( aws s3 cp s3://ossci-windows/mkl_2020.2.254.7z %TMP_DIR_WIN%\mkl.7z --quiet ) diff --git a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat index 0165604400dd..6f8cc15ba868 100644 --- a/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat +++ b/.jenkins/pytorch/win-test-helpers/installation-helpers/install_sccache.bat @@ -7,8 +7,8 @@ if "%REBUILD%"=="" ( del %TMP_DIR_WIN%\bin\sccache.exe || ver > nul del %TMP_DIR_WIN%\bin\sccache-cl.exe || ver > nul if "%BUILD_ENVIRONMENT%"=="" ( - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/sccache.exe --output %TMP_DIR_WIN%\bin\sccache.exe - curl --retry 3 -k https://s3.amazonaws.com/ossci-windows/sccache-cl.exe --output %TMP_DIR_WIN%\bin\sccache-cl.exe + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/sccache.exe --output %TMP_DIR_WIN%\bin\sccache.exe + curl --retry 3 --retry-all-errors -k https://s3.amazonaws.com/ossci-windows/sccache-cl.exe --output %TMP_DIR_WIN%\bin\sccache-cl.exe ) else ( aws s3 cp s3://ossci-windows/sccache.exe %TMP_DIR_WIN%\bin\sccache.exe aws s3 cp s3://ossci-windows/sccache-cl.exe %TMP_DIR_WIN%\bin\sccache-cl.exe diff --git a/scripts/buck_setup.sh b/scripts/buck_setup.sh index 8e60d92a5fd1..331a29956416 100644 --- a/scripts/buck_setup.sh +++ b/scripts/buck_setup.sh @@ -22,16 +22,16 @@ python3 generate-xnnpack-wrappers.py # bazel-skylib printf "\nDownloading bazel-skylib\n" rm -rf bazel-skylib; mkdir bazel-skylib -curl -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib +curl --retry 3 --retry-all-errors -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib # glog printf "\nDownloading glog\n" rm -rf glog; mkdir glog -curl -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 +curl --retry 3 --retry-all-errors -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 # ruy printf "\nDownloading ruy\n" -curl -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip +curl --retry 3 --retry-all-errors -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip unzip -q /tmp/ruy.zip -d /tmp/ rm -rf ruy/ mv /tmp/ruy-a09683b8da7164b9c5704f88aef2dc65aa583e5d ruy/ diff --git a/third_party/gloo b/third_party/gloo index 4a5e339b7642..5b1435132631 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 4a5e339b764261d20fc409071dc7a8b8989aa195 +Subproject commit 5b143513263133af2b95547e97c07cebeb72bf72 diff --git a/third_party/pybind11 b/third_party/pybind11 index 80dc998efced..aa304c9c7d72 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit 80dc998efced8ceb2be59756668a7e90e8bef917 +Subproject commit aa304c9c7d725ffb9d10af08a3b34cb372307020 From 7beb1518896482596a0d35ec404338d430933250 Mon Sep 17 00:00:00 2001 From: maxren Date: Thu, 17 Nov 2022 14:31:43 -0800 Subject: [PATCH 328/453] [xnnpack][executorch] remove unordered_set from xnn_compiler (#89231) Removing unrodered_set from xnncompiler for executorch. While some STL libraries are unavoidable, and I think it should be ok for delegate to pull these libraries, unordered_set wasn't really needed, and we should be serializing the number of external ids anyways After this, the backend classes should be good to hg copy into executorch Differential Revision: [D41227391](https://our.internmc.facebook.com/intern/diff/D41227391/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89231 Approved by: https://github.com/salilsdesai, https://github.com/cccclai --- .../jit/backends/xnnpack/compiler/xnn_compiler.cpp | 12 +----------- .../jit/backends/xnnpack/compiler/xnn_compiler.h | 1 - .../jit/backends/xnnpack/serialization/schema.fbs | 3 +++ .../backends/xnnpack/serialization/serializer.cpp | 4 +++- .../jit/backends/xnnpack/serialization/serializer.h | 3 ++- .../jit/backends/xnnpack/xnnpack_graph_builder.cpp | 6 +++++- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index 49e2804c99a9..0f654dff0ac0 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -4,7 +4,6 @@ #include #include -#include namespace torch { namespace jit { @@ -25,17 +24,8 @@ void XNNCompiler::compileModel( // create xnnpack subgraph xnn_subgraph_t subgraph_ptr = nullptr; - - // TODO: @maxren serialize extern_ids in flatbuffer schema - std::unordered_set extern_ids; - for (auto input_id : *flatbuffer_graph->input_ids()) { - extern_ids.insert(input_id); - } - for (auto output_id : *flatbuffer_graph->output_ids()) { - extern_ids.insert(output_id); - } status = xnn_create_subgraph( - /*external_value_ids=*/extern_ids.size(), + /*external_value_ids=*/flatbuffer_graph->num_externs(), /*flags=*/0, &subgraph_ptr); TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph"); diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h index e87fcbcd063d..f74e784111d4 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace torch { diff --git a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs index 87ebe20a825a..cc1290b718fa 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs +++ b/torch/csrc/jit/backends/xnnpack/serialization/schema.fbs @@ -75,6 +75,9 @@ table XNNGraph { xnodes:[XNode]; xvalues:[XValue]; + // Number of external inputs/outputs + num_externs:uint; + // Ids of external inputs input_ids:[uint]; diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp index df1ccc791781..63cb62c5698e 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp @@ -67,12 +67,14 @@ void XNNSerializer::serializeTensorValue( std::string XNNSerializer::finishAndSerialize( std::vector input_ids, - std::vector output_ids) { + std::vector output_ids, + size_t num_extern_ids) { auto xnnGraph = CreateXNNGraphDirect( _builder, _version_sha1, &_nodes, &_values, + num_extern_ids, &input_ids, &output_ids, &_constantBuffer, diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h index 6d01571d424d..08a3875d3267 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -51,7 +51,8 @@ class XNNSerializer { // finish and serialize xnngraph returning serialized data std::string finishAndSerialize( std::vector input_ids, - std::vector output_ids); + std::vector output_ids, + size_t num_extern_ids); private: // xnnpack version we are serializing diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp index 4eaefea56960..45a4bd2fa795 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp @@ -129,16 +129,20 @@ void XNNGraph::checkOpsToDelegate(std::shared_ptr& graph) { std::string XNNGraph::serializedXNNGraph() { std::vector input_ids; std::vector output_ids; + std::unordered_set num_externs; for (auto val : _inputs) { input_ids.push_back(_val_to_ids[val]); + num_externs.emplace(_val_to_ids[val]); } for (auto val : _outputs) { output_ids.push_back(_val_to_ids[val]); + num_externs.emplace(_val_to_ids[val]); } - return _serializer.finishAndSerialize(input_ids, output_ids); + return _serializer.finishAndSerialize( + input_ids, output_ids, num_externs.size()); } std::vector> XNNGraph::getGraphOutputShapes() { From fc1c0cd3ef5af94e2b6cb262252cf97b61e5d3cb Mon Sep 17 00:00:00 2001 From: PumeTu Date: Fri, 18 Nov 2022 07:24:33 +0000 Subject: [PATCH 329/453] Add support trace on MPS backend (#87910) Fixes [#87221](https://github.com/pytorch/pytorch/issues/87221) `trace` now supported on MPS Pull Request resolved: https://github.com/pytorch/pytorch/pull/87910 Approved by: https://github.com/kulinseth, https://github.com/malfet --- .../ATen/native/mps/operations/ReduceOps.mm | 31 ++++++++++++++++++- aten/src/ATen/native/native_functions.yaml | 1 + test/test_mps.py | 22 ++++++++++--- 3 files changed, 49 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index c99f22d89295..39680240f7f2 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -27,7 +27,8 @@ SUM, PROD, MEAN, - COUNT_NONZERO + COUNT_NONZERO, + TRACE }; @@ -239,6 +240,14 @@ void set_axes_and_shapes(const Tensor& input_t, castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor axes:axes name:nil]; + } else if(reduction_type == MPSReductionType::TRACE) { + MPSGraphTensor *bandPartWithTensor = [mpsGraph bandPartWithTensor:inputTensor + numLower:0 + numUpper:0 + name:nil]; + castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor + axes:@[@0, @1] + name:nil]; } MPSGraphTensor* outputTensor = nil; @@ -287,6 +296,26 @@ void set_axes_and_shapes(const Tensor& input_t, reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); } +Tensor trace_mps_out(const Tensor& self) { + + Tensor output_t = at::native::empty_mps( + {}, + self.scalar_type(), + c10::nullopt, + kMPS, + c10::nullopt, + c10::nullopt); + + std::vector dims(self.dim()); + std::iota(dims.begin(), dims.end(), 0); + + reduction_out_mps(self, IntArrayRef(dims), false, c10::nullopt, const_cast(output_t), MPSReductionType::TRACE, "trace_mps_out"); + + return output_t; + + +} + TORCH_IMPL_FUNC(prod_out_mps) (const Tensor& input_t, int64_t dim, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5cf0e759db1d..f625c9faff41 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7790,6 +7790,7 @@ dispatch: CPU: trace_cpu CUDA: trace_cuda + MPS: trace_mps_out autogen: trace.out - func: trace_backward(Tensor grad, SymInt[] sizes) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index 52d669545b30..8e40a5cce293 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -315,6 +315,16 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + def test_trace(self): + M_cpu = torch.randn(3, 3) + M_mps = M_cpu.detach().clone().to("mps") + + output_cpu = torch.trace(M_cpu) + output_mps = torch.trace(M_mps) + + self.assertEqual(output_cpu, output_mps) + self.assertEqual(output_cpu.size(), output_mps.size()) + def test_addbmm(self): M_cpu = torch.randn(3, 5) batch1_cpu = torch.randn(10, 3, 4) @@ -5141,10 +5151,14 @@ def test_conv_expand(self): # The test should not crash def test_permute(self): - X = torch.randn(5, 5).to('mps') - torch.log(X) - X = X.permute(1, 0) - torch.log(X) + M_cpu = torch.randn(5, 5) + M_mps = M_cpu.to('mps') + + output_cpu = M_cpu.permute(1, 0) + output_mps = M_mps.permute(1, 0) + + self.assertEqual(output_cpu, output_mps) + self.assertEqual(output_cpu.size(), output_mps.size()) # Printing of non_contiguous should not crash def test_print_non_contiguous(self): From 6a964c16e5125f485372418d129c3eabdec7e881 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 18 Nov 2022 07:31:10 +0000 Subject: [PATCH 330/453] [flaky] relax tolerance conv1d_vs_scipy (#89193) Fixes https://github.com/pytorch/pytorch/issues/89087 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89193 Approved by: https://github.com/kit1980 --- test/nn/test_convolution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index c94eb5447d5a..a30a27643975 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -1412,7 +1412,7 @@ def _test(t, weight, mode): if mode == 'same': actual = actual[:feat_dim] - self.assertEqual(actual, expected) + self.assertEqual(actual, expected, atol=2e-5, rtol=2e-5) # Global dtype for this test suite is torch.double # This leads to change in type-promotion From afdc48f843afab531a4315a1ca1a43f5f303c5b7 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Nov 2022 07:39:16 +0000 Subject: [PATCH 331/453] Gate CUDA-only inductor tests by HAS_CUDA (#89251) This is to prevent these tests from running on platform where CUDA doesn't exist such as macos. And they are quite flaky https://hud.pytorch.org/failure/test_linear_permute_fusion_cpu there failing the CI from time to time Pull Request resolved: https://github.com/pytorch/pytorch/pull/89251 Approved by: https://github.com/soumith, https://github.com/desertfire --- test/inductor/test_torchinductor.py | 149 ++++++++++++++-------------- 1 file changed, 73 insertions(+), 76 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fb7ca1fc92b7..f2b1caeb32ea 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1618,82 +1618,6 @@ def fn(a, b): y = torch.tensor(0) self.assertEqual(fn(x, y), x + x) - @unittest.skipIf(HAS_CPU, "Support GPU so far") - def test_linear_permute_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - a0 = torch.nn.functional.linear(input, self.weight, self.bias) - b0 = a0.permute(0, 2, 1) - return b0 - - m, k, n = 16, 8, 4 - trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, m, k) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_linear_transpose = count_call_function(traced, linear_transpose) - self.assertEqual(num_linear, 0) - self.assertEqual(num_linear_transpose, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - @unittest.skipIf(HAS_CPU, "Support GPU so far") - def test_permute_linear_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, k: int, n: int): - super().__init__() - self.weight = torch.nn.Parameter(torch.randn(n, k)) - self.bias = torch.nn.Parameter(torch.randn(n)) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.nn.functional.linear(input1, self.weight, self.bias) - return output - - m, k, n = 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) - module = TestModule(k, n).eval() - input = torch.randn(6, k, m) - traced = trace_func(module, [input]) - num_linear = count_call_function(traced, torch.nn.functional.linear) - num_transpose_linear = count_call_function(traced, transpose_linear) - self.assertEqual(num_linear, 0) - self.assertEqual(num_transpose_linear, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - - @unittest.skipIf(HAS_CPU, "Support GPU so far") - def test_permute_bmm_fusion(self): - class TestModule(torch.nn.Module): - def __init__(self, batch: int, k: int, n: int): - super().__init__() - self.other = torch.randn(batch, k, n) - - def forward(self, input: torch.Tensor): - input1 = input.permute(0, 2, 1) - output = torch.bmm(input1, self.other) - return output - - batch, m, k, n = 6, 16, 8, 4 - - trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) - module = TestModule(batch, k, n).eval() - input = torch.randn(batch, k, m) - traced = trace_func(module, [input]) - num_bmm = count_call_function(traced, torch.bmm) - num_transpose_matmul = count_call_function(traced, transpose_matmul) - self.assertEqual(num_bmm, 0) - self.assertEqual(num_transpose_matmul, 1) - - self.assertTrue(torch.allclose(module(input), traced(input))) - def test_slice1(self): def fn(a): return ( @@ -4710,6 +4634,79 @@ def fn(a): fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],) ) + def test_linear_permute_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + a0 = torch.nn.functional.linear(input, self.weight, self.bias) + b0 = a0.permute(0, 2, 1) + return b0 + + m, k, n = 16, 8, 4 + trace_func = chain_passes(torch.fx.symbolic_trace, linear_permute_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, m, k) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_linear_transpose = count_call_function(traced, linear_transpose) + self.assertEqual(num_linear, 0) + self.assertEqual(num_linear_transpose, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_linear_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, k: int, n: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(n, k)) + self.bias = torch.nn.Parameter(torch.randn(n)) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.nn.functional.linear(input1, self.weight, self.bias) + return output + + m, k, n = 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_linear_fusion) + module = TestModule(k, n).eval() + input = torch.randn(6, k, m) + traced = trace_func(module, [input]) + num_linear = count_call_function(traced, torch.nn.functional.linear) + num_transpose_linear = count_call_function(traced, transpose_linear) + self.assertEqual(num_linear, 0) + self.assertEqual(num_transpose_linear, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + + def test_permute_bmm_fusion(self): + class TestModule(torch.nn.Module): + def __init__(self, batch: int, k: int, n: int): + super().__init__() + self.other = torch.randn(batch, k, n) + + def forward(self, input: torch.Tensor): + input1 = input.permute(0, 2, 1) + output = torch.bmm(input1, self.other) + return output + + batch, m, k, n = 6, 16, 8, 4 + + trace_func = chain_passes(torch.fx.symbolic_trace, permute_matmul_fusion) + module = TestModule(batch, k, n).eval() + input = torch.randn(batch, k, m) + traced = trace_func(module, [input]) + num_bmm = count_call_function(traced, torch.bmm) + num_transpose_matmul = count_call_function(traced, transpose_matmul) + self.assertEqual(num_bmm, 0) + self.assertEqual(num_transpose_matmul, 1) + + self.assertTrue(torch.allclose(module(input), traced(input))) + CommonTemplate.install(CudaTests, "cuda") class CudaReproTests(TestCase): From 30c3e5afb0c0ad22c1084a2064ebdc09f7808ecc Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Fri, 18 Nov 2022 07:46:35 +0000 Subject: [PATCH 332/453] Disable tracing `zero_grad()` (#88731) Tracing through zero grad is slow, and doesn't provide any benefits. Helps https://github.com/pytorch/torchdynamo/issues/1803 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88731 Approved by: https://github.com/anijain2305 --- test/dynamo/test_optimizers.py | 2 +- torch/_dynamo/eval_frame.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 2f204a7a1199..90b8cfaaad7b 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -125,7 +125,7 @@ def training_iter_fn(batch, model, optimizer): batch = {"x": input1, "y": input2} for _ in range(2): opt_training_iter_fn(batch, net, optimizer) - self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.frame_count, 1) if __name__ == "__main__": diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 6b500a87bc32..538f6131d62b 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -730,6 +730,7 @@ def patch(): opt._cuda_graph_capture_health_check = disable( opt._cuda_graph_capture_health_check ) + opt.zero_grad = disable(opt.zero_grad) # disable any currently set hooks # Note: we only want to disable the profiling hook # which is the *last* hook applied, we want to keep the no_grad hook From c5fafb4e1694f141d8a1a31142cce4049d9057ed Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 17 Nov 2022 19:20:22 -0800 Subject: [PATCH 333/453] [ao] maintain BC for is_activation_post_process (#89260) Summary: tests are failing due to code packaged with trained models calling now defunct function names (is_activation_post_process). this diff maintains BC temporarily until the cached code can be refreshed Test Plan: no functional change Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89260 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/quantize.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index b9ef24e35fdb..51eb2c1c1ec6 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -27,7 +27,12 @@ float_qparams_weight_only_qconfig_4bit, _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations -from torch.ao.quantization.observer import _is_activation_post_process + +from torch.ao.quantization.observer import ( # noqa: F401 + _is_activation_post_process, + _is_activation_post_process as is_activation_post_process, + # TODO remove this once problems from name change are resolved +) __all__ = [ "get_default_custom_config_dict", From 2dcacc6b999a44e13a0dbb679ac17d767b05d898 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Fri, 18 Nov 2022 09:28:46 +0000 Subject: [PATCH 334/453] [LTC] Upstream short_metrics (#89186) Summary: This pull request upstreams pytorch/xla#4148. Test Plan: xla CI. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89186 Approved by: https://github.com/JackCaoG --- test/lazy/test_ts_opinfo.py | 13 +++++++++++-- torch/csrc/lazy/core/metrics.cpp | 32 +++++++++++++++++++++++++++++++- torch/csrc/lazy/core/metrics.h | 5 +++++ 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 2e6703558147..092ba3d0388d 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -71,20 +71,28 @@ def init_lists(): 'linalg_pinv.atol_rtol_tensor', 'logsumexp', ]) + # For some ops, we don't support all variants. Here we use formatted_name + # to uniquely identify the variant. + SKIP_VARIANT_LIST = set([ + 'norm_nuc', + 'min_reduction_with_dim' + ]) return (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST, - HAS_SYMINT_SUFFIX) + HAS_SYMINT_SUFFIX, + SKIP_VARIANT_LIST) (LAZY_OPS_LIST, FALLBACK_LIST, SKIP_RUNTIME_ERROR_LIST, SKIP_INCORRECT_RESULTS_LIST, FUNCTIONAL_DECOMPOSE_LIST, - HAS_SYMINT_SUFFIX) = init_lists() + HAS_SYMINT_SUFFIX, + SKIP_VARIANT_LIST) = init_lists() torch.manual_seed(42) @@ -166,6 +174,7 @@ class TestLazyOpInfo(TestCase): if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST and op.name not in FUNCTIONAL_DECOMPOSE_LIST + and op.formatted_name not in SKIP_VARIANT_LIST ], allowed_dtypes=(torch.float,)) def test_dispatched_to_lazy(self, device, dtype, op): def get_name(op): diff --git a/torch/csrc/lazy/core/metrics.cpp b/torch/csrc/lazy/core/metrics.cpp index cb8120c1d45c..86758edc4dfc 100644 --- a/torch/csrc/lazy/core/metrics.cpp +++ b/torch/csrc/lazy/core/metrics.cpp @@ -172,7 +172,9 @@ std::vector MetricsArena::GetCounterNames() { std::vector names; std::lock_guard lock(lock_); for (auto& name_data : counters_) { - names.push_back(name_data.first); + if (name_data.second->Value() > 0) { + names.push_back(name_data.first); + } } return names; } @@ -353,6 +355,34 @@ std::string CreateMetricReport() { return ss.str(); } +std::string CreateMetricReport( + const std::vector& counter_names, + const std::vector& metric_names) { + MetricsArena* arena = MetricsArena::Get(); + std::stringstream ss; + for (const std::string& metric_name : metric_names) { + MetricData* data = arena->GetMetric(metric_name); + if (data && data->TotalSamples() > 0) { + EmitMetricInfo(metric_name, data, &ss); + } + } + for (const std::string& counter_name : counter_names) { + CounterData* data = arena->GetCounter(counter_name); + if (data && data->Value() > 0) { + EmitCounterInfo(counter_name, data, &ss); + } + } + static std::string fall_back_counter_prefix = "aten::"; + arena->ForEachCounter([&ss](const std::string& name, CounterData* data) { + if (name.rfind(fall_back_counter_prefix, 0) == 0 && data->Value() > 0) { + // it might emit duplicated counter if user also specified exact aten + // counter in the `counter_names` but it should be very rare. + EmitCounterInfo(name, data, &ss); + } + }); + return ss.str(); +} + std::vector GetMetricNames() { return MetricsArena::Get()->GetMetricNames(); } diff --git a/torch/csrc/lazy/core/metrics.h b/torch/csrc/lazy/core/metrics.h index 43fb617c1ba1..1d629c4973db 100644 --- a/torch/csrc/lazy/core/metrics.h +++ b/torch/csrc/lazy/core/metrics.h @@ -216,6 +216,11 @@ class TORCH_API Counter { // Creates a report with the current metrics statistics. TORCH_API std::string CreateMetricReport(); +// Creates a report with the selected metrics statistics. +TORCH_API std::string CreateMetricReport( + const std::vector& counter_names, + const std::vector& metric_names); + // Returns the currently registered metric names. Note that the list can grow // since metrics are usually function intialized (they are static function // variables). From 4c6724985d8b85c5719078a25255dbd7369c25e5 Mon Sep 17 00:00:00 2001 From: Iris Date: Fri, 18 Nov 2022 09:49:36 +0000 Subject: [PATCH 335/453] [PT-D][Checkpoint] Update import and update docstring for distributed checkpoint (#89256) Update test import and docstring as we have moved distributed checkpointing from torch.distributed._shard.checkpoint to torch.distributed.checkpoint (https://github.com/pytorch/pytorch/pull/88698). Test: CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/89256 Approved by: https://github.com/fduwjj --- .../distributed/checkpoint/test_checkpoint.py | 125 +++++++++--------- .../fsdp/test_distributed_checkpoint.py | 2 +- .../checkpoint/state_dict_loader.py | 4 +- .../checkpoint/state_dict_saver.py | 4 +- 4 files changed, 68 insertions(+), 67 deletions(-) diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py index 167fdc5e7154..96c98116328c 100644 --- a/test/distributed/checkpoint/test_checkpoint.py +++ b/test/distributed/checkpoint/test_checkpoint.py @@ -2,9 +2,9 @@ import sys from typing import Optional, List, cast -from torch.distributed._shard.checkpoint.storage import WriteResult +from torch.distributed.checkpoint.storage import WriteResult -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( StorageReader, StorageWriter, CheckpointException, @@ -63,6 +63,7 @@ ) sys.exit(0) + class TestModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -121,34 +122,44 @@ def test_default_metadata(self) -> None: ) state_dict = { - 'sharded': sharded_tensor.rand(spec, (10, 10, )), - 'replicated': torch.rand(4, device=device), - 'bytes': [1, 2, 3, 4], + "sharded": sharded_tensor.rand( + spec, + ( + 10, + 10, + ), + ), + "replicated": torch.rand(4, device=device), + "bytes": [1, 2, 3, 4], } metadata = _create_default_local_metadata(state_dict) - self.assertTrue('bytes' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['bytes'], BytesStorageMetadata) + self.assertTrue("bytes" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["bytes"], BytesStorageMetadata + ) - self.assertTrue('replicated' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['replicated'], TensorStorageMetadata) - md = metadata.state_dict_metadata['replicated'] - self.assertEqual(md.size, state_dict['replicated'].size()) + self.assertTrue("replicated" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["replicated"], TensorStorageMetadata + ) + md = metadata.state_dict_metadata["replicated"] + self.assertEqual(md.size, state_dict["replicated"].size()) self.assertEqual(md.properties.dtype, torch.float32) self.assertEqual(1, len(md.chunks)) - self.assertTrue('sharded' in metadata.state_dict_metadata) - self.assertIsInstance(metadata.state_dict_metadata['sharded'], TensorStorageMetadata) - md = metadata.state_dict_metadata['sharded'] + self.assertTrue("sharded" in metadata.state_dict_metadata) + self.assertIsInstance( + metadata.state_dict_metadata["sharded"], TensorStorageMetadata + ) + md = metadata.state_dict_metadata["sharded"] self.assertEqual(md.properties.dtype, torch.float32) - self.assertEqual(md.size, state_dict['sharded'].size()) + self.assertEqual(md.size, state_dict["sharded"].size()) self.assertEqual(2, len(md.chunks)) + class TestStorageBase: - def __init__( - self, - fail_conf - ): + def __init__(self, fail_conf): self.fail_conf = fail_conf self.rank = 0 if not dist.is_initialized() else dist.get_rank() @@ -164,16 +175,16 @@ def _fail_rank_async(self, name, result=None): ranks = self._get_ranks(name) fut = Future() if ranks is not None and self.rank in ranks: - fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}")) + fut.set_exception( + ValueError(f"async rank fail {self.rank} for {name}") + ) else: fut.set_result(result) return fut + class FaultyStorageWriter(TestStorageBase, StorageWriter): - def __init__( - self, - fail_conf - ): + def __init__(self, fail_conf): super(FaultyStorageWriter, self).__init__(fail_conf) def init(self, is_coordinator: bool) -> None: @@ -188,23 +199,19 @@ def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: return plans def write_data( - self, - plan: SavePlan, - planner: SavePlanner + self, plan: SavePlan, planner: SavePlanner ) -> Future[List[WriteResult]]: self._fail_rank("fail_write_data") return self._fail_rank_async("fail_write_data_async", []) - def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + def finish( + self, metadata: Metadata, results: List[List[WriteResult]] + ) -> None: self._fail_rank("fail_finish") class FaultyStorageReader(TestStorageBase, StorageReader): - def __init__( - self, - metadata, - fail_conf - ): + def __init__(self, metadata, fail_conf): super(FaultyStorageReader, self).__init__(fail_conf) self.metadata = metadata @@ -219,11 +226,7 @@ def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: self._fail_rank("fail_prepare_global_plan") return plans - def read_data( - self, - plan: LoadPlan, - planner: LoadPlanner - ) -> Future[None]: + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: self._fail_rank("fail_read_data") return self._fail_rank_async("fail_read_data_async") @@ -231,13 +234,14 @@ def read_metadata(self) -> Metadata: self._fail_rank("fail_read_metadata") return self.metadata + class TestDistributedFailure(ShardedTensorTestBase): def get_spec(self): return ChunkShardingSpec( dim=0, placements=[ f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size()) - ] + ], ) @with_comms(init_rpc=False) @@ -245,9 +249,9 @@ def get_spec(self): @requires_nccl() def test_dummy_writer_works(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } save_state_dict(state_dict, FaultyStorageWriter({})) @@ -257,9 +261,9 @@ def test_dummy_writer_works(self) -> None: @requires_nccl() def test_dummy_reader_works(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } metadata = _create_default_local_metadata(state_dict) @@ -283,8 +287,10 @@ def _test_dist_failure(self, callback, kwargs): failed_ranks = e.failures.keys() for rank in bad_ranks: - self.assertTrue(rank in failed_ranks, msg=f"{rank} was supposed to fail was fine") - + self.assertTrue( + rank in failed_ranks, + msg=f"{rank} was supposed to fail was fine", + ) def _test_save(self, state_dict, coordinator=0, **kwargs): no_dist = not dist.is_initialized() @@ -296,6 +302,7 @@ def _save(): coordinator_rank=coordinator, no_dist=no_dist, ) + self._test_dist_failure(_save, kwargs) def _test_load(self, state_dict, coordinator=0, **kwargs): @@ -317,9 +324,9 @@ def _load(): @requires_nccl() def test_save_error_handling(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } self._test_save(state_dict, fail_init=[0]) @@ -334,10 +341,7 @@ def test_save_error_handling(self) -> None: self._test_save(state_dict, coordinator=1, fail_finish=[1]) def test_save_error_handling_no_dist(self) -> None: - state_dict = { - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] - } + state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} self.assertFalse(dist.is_initialized()) @@ -354,9 +358,9 @@ def test_save_error_handling_no_dist(self) -> None: @requires_nccl() def test_load_error_handling(self) -> None: state_dict = { - 'sharded': sharded_tensor.rand(self.get_spec(), 20, 20), - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] + "sharded": sharded_tensor.rand(self.get_spec(), 20, 20), + "replicated": torch.rand(10, 10), + "bytes": [1, 2, 3, 4], } self._test_load(state_dict) @@ -373,12 +377,8 @@ def test_load_error_handling(self) -> None: self._test_load(state_dict, coordinator=3, fail_read_data_async=[2]) self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1]) - def test_load_error_handling_no_dist(self) -> None: - state_dict = { - 'replicated': torch.rand(10, 10), - 'bytes': [1, 2, 3, 4] - } + state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]} self._test_load(state_dict) self._test_load(state_dict, fail_init=[0]) self._test_load(state_dict, fail_read_metadata=[0]) @@ -387,5 +387,6 @@ def test_load_error_handling_no_dist(self) -> None: self._test_load(state_dict, fail_read_data=[0]) self._test_load(state_dict, fail_read_data_async=[0]) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index e64fd358a305..3e9b967e0d11 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -5,7 +5,7 @@ import torch from torch import distributed as dist -from torch.distributed._shard.checkpoint import ( +from torch.distributed.checkpoint import ( FileSystemReader, FileSystemWriter, load_state_dict, diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index de94ffabf663..1d085f4d339e 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -59,9 +59,9 @@ def load_state_dict( >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() - >>> fs_storage_loader = torch.distributed._shard.checkpoint.FileSystemLoader("/checkpoint/1") + >>> fs_storage_loader = torch.distributed.checkpoint.FileSystemLoader("/checkpoint/1") - >>> torch.distributed._shard.checkpoint.load_state_dict( + >>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_loader, >>> ) diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index af18fd0c11dd..5e7fde10324c 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -59,8 +59,8 @@ def save_state_dict( >>> model_state_dict = my_model.state_dict() - >>> fs_storage_writer = torch.distributed._shard.checkpoint.FileSystemWriter("/checkpoint/1") - >>> torch.distributed._shard.checkpoint.save_state_dict( + >>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1") + >>> torch.distributed.checkpoint.save_state_dict( >>> state_dict=model_state_dict, >>> storage_writer=fs_stroage_writer, >>> ) From 5654fed23e7728eca717b23c97c1fca8c176112a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 18 Nov 2022 10:51:07 +0000 Subject: [PATCH 336/453] Export c10/[macros|util] headers to be used by internal inductor builds (#89249) Summary: Fixes package boundary violation that existed in previous implementation Test Plan: CI Differential Revision: D41391862 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89249 Approved by: https://github.com/izaitsevfb --- c10/macros/build.bzl | 9 +++++++++ c10/util/build.bzl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/c10/macros/build.bzl b/c10/macros/build.bzl index 932d0cabac4c..50f283560d7e 100644 --- a/c10/macros/build.bzl +++ b/c10/macros/build.bzl @@ -29,3 +29,12 @@ def define_targets(rules): "//conditions:default": [], }), ) + rules.filegroup( + name = "headers", + srcs = rules.glob( + ["*.h"], + exclude = [ + ], + ), + visibility = ["//:__pkg__"], + ) diff --git a/c10/util/build.bzl b/c10/util/build.bzl index b981eba67718..8d79a557477f 100644 --- a/c10/util/build.bzl +++ b/c10/util/build.bzl @@ -68,5 +68,5 @@ def define_targets(rules): exclude = [ ], ), - visibility = ["//c10:__pkg__"], + visibility = ["//c10:__pkg__", "//:__pkg__"], ) From 2e358cc98fab728aad8775de28596d589358b3b2 Mon Sep 17 00:00:00 2001 From: Jacob Hayes Date: Fri, 18 Nov 2022 14:09:21 +0000 Subject: [PATCH 337/453] Add platform markers for linux only extra_install_requires (#88826) Fixes #88049 https://github.com/pytorch/pytorch/pull/85097 added new extra dependencies on `nvidia-*`. They are linux (GPU) only packages, but were not marked as such, causing issues installing pytorch 1.13 via Poetry (and possibly other tools that follow PyPI's metadata API) on non-Linux systems. This "fixes" the issue by adding the `; platform_system = 'Linux'` marker on these dependencies, but the main problem of different metadata for different wheels is a [somewhat larger issue](https://github.com/pytorch/pytorch/issues/88049#issuecomment-1302555269). https://github.com/pytorch/pytorch/pull/85097 used `;` as a delimiter for splitting the different deps, but that is the delimiter used in markers, so I changed to split on `|`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88826 Approved by: https://github.com/neersighted, https://github.com/lalmei, https://github.com/malfet --- .github/scripts/generate_binary_build_matrix.py | 6 +++--- .../generated-linux-binary-manywheel-nightly.yml | 10 +++++----- setup.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 54949ff27bb1..4031ee9aacca 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -219,9 +219,9 @@ def generate_wheels_matrix(os: str, "container_image": WHEEL_CONTAINER_IMAGES[arch_version], "package_type": package_type, "pytorch_extra_install_requirements": - "nvidia-cuda-runtime-cu11;" - "nvidia-cudnn-cu11==8.5.0.96;" - "nvidia-cublas-cu11==11.10.3.66", + "nvidia-cuda-runtime-cu11; platform_system == 'Linux' | " + "nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | " + "nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux'", "build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn" .replace( diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index efe3e2c0d17c..ba9401d717a6 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -169,7 +169,7 @@ jobs: DESIRED_PYTHON: "3.7" build_name: manywheel-py3_7-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -697,7 +697,7 @@ jobs: DESIRED_PYTHON: "3.8" build_name: manywheel-py3_8-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1225,7 +1225,7 @@ jobs: DESIRED_PYTHON: "3.9" build_name: manywheel-py3_9-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -1753,7 +1753,7 @@ jobs: DESIRED_PYTHON: "3.10" build_name: manywheel-py3_10-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} @@ -2281,7 +2281,7 @@ jobs: DESIRED_PYTHON: "3.11" build_name: manywheel-py3_11-cuda11_7-with-pypi-cudnn build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11;nvidia-cudnn-cu11==8.5.0.96;nvidia-cublas-cu11==11.10.3.66 + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-runtime-cu11; platform_system == 'Linux' | nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' | nvidia-cublas-cu11==11.10.3.66; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/setup.py b/setup.py index bc8badb9b2e4..0aa27bef64d9 100644 --- a/setup.py +++ b/setup.py @@ -852,7 +852,7 @@ def configure_extension_build(): pytorch_extra_install_requirements = os.getenv("PYTORCH_EXTRA_INSTALL_REQUIREMENTS", "") if pytorch_extra_install_requirements: report(f"pytorch_extra_install_requirements: {pytorch_extra_install_requirements}") - extra_install_requires += pytorch_extra_install_requirements.split(";") + extra_install_requires += pytorch_extra_install_requirements.split("|") # Cross-compile for M1 From ce0e22a81a2383c7c951310c9c0aa7638748687b Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 18 Nov 2022 10:35:45 +0000 Subject: [PATCH 338/453] Fix names of some reference functions (#88115) The `__name__` field of some binary reference functions was wrong. We fix this to be consistent with unary reference functions. In the future, we should probably make the binary reference wrapper return a wrapper itself to avoid all those calls to `partial`. This change helps performing some homogeneous treatment of functions by their name. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88115 Approved by: https://github.com/mruberry --- torch/_prims_common/__init__.py | 21 ++ torch/_refs/__init__.py | 364 +++++++++++++------------- torch/_refs/nn/functional/__init__.py | 4 - torch/_refs/special/__init__.py | 16 +- 4 files changed, 215 insertions(+), 190 deletions(-) diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 128796dfa3d0..7752f1836141 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1577,6 +1577,27 @@ def mask_tensor(mask: TensorLikeType, t: TensorLikeType): return torch.where(mask, t, 0) +def get_aten_op(fn: Callable, name: str): + """ + Given the __module__ of reference and its name, it returns + (our best guess of) the ATen name of the associated operation + + Note: In ATen, the __name__ of a function within a module often + starts by the module name. E.g. linalg_eigh, or special_zeta + """ + module = fn.__module__ + prefix = "torch._refs" + assert(module.startswith(prefix)) + module = module[len(prefix):] + # We want to go from .special / .nn.functional + # to special and special_ / nn_functional_ + if module: + module = module[1:] + module = module.replace(".", "_") + module = module + "_" + return getattr(torch.ops.aten, f"{module}{name}") + + def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype: return dtype if dtype is not None else torch.get_default_dtype() diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 111c5c956f5d..25b6f2da37c8 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -394,18 +394,13 @@ def inner(prim: Callable): type_promotion_kind=type_promotion_kind, ) def _ref(a: TensorLikeType) -> TensorLikeType: - if not isinstance(a, TensorLike): - raise RuntimeError( - "Expected a tensor input for an elementwise unary operation!" - ) - if extra_meta is not None: extra_meta(a) return prim(a) if aten_op is infer_aten_op: - aten_op = getattr(torch.ops.aten, prim.__name__) + aten_op = utils.get_aten_op(prim, prim.__name__) if aten_op is not None: register_decomposition(aten_op)(_ref) @@ -860,54 +855,59 @@ def trunc(a): def _make_elementwise_binary_reference( - prim: Callable, - *, type_promotion_kind, aten_op=infer_aten_op, + name=None, has_out=True, supports_lhs_python_scalar=True, supports_rhs_python_scalar=True, supports_two_python_scalars=False, ) -> Callable: - @elementwise_type_promotion_wrapper( - type_promoting_args=("a", "b"), - type_promotion_kind=type_promotion_kind, - ) - def _ref( - a: Union[Tensor, NumberType], - b: Union[Tensor, NumberType], - ) -> Tensor: - if not supports_lhs_python_scalar and isinstance(a, Number): - raise ValueError( - "Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!" - ) + def inner(prim: Callable): + nonlocal aten_op, name + if name is None: + name = prim.__name__ - if not supports_rhs_python_scalar and isinstance(b, Number): - raise ValueError( - "Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!" + @wraps(prim) + @elementwise_type_promotion_wrapper( + type_promoting_args=("a", "b"), + type_promotion_kind=type_promotion_kind, + ) + def _ref( + a: Union[Tensor, NumberType], + b: Union[Tensor, NumberType], + ) -> Tensor: + check( + supports_lhs_python_scalar or not isinstance(a, Number), + lambda: "{name}: Received a lhs Python scalar to an elementwise binary operation that does not accept lhs scalars!", + ValueError, ) - - if ( - not supports_two_python_scalars - and isinstance(a, Number) - and isinstance(b, Number) - ): - raise ValueError( - f"Receive two Number inputs to an elementwise binary operation {prim}!" + check( + supports_rhs_python_scalar or not isinstance(b, Number), + lambda: "{name}: Received a rhs Python scalar to an elementwise binary operation that does not accept rhs scalars!", + ValueError, + ) + check( + supports_two_python_scalars + or not (isinstance(a, Number) and isinstance(b, Number)), + lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!", + ValueError, ) + a, b = _maybe_broadcast(a, b) + return prim(a, b) - a, b = _maybe_broadcast(a, b) - return prim(a, b) + if has_out: + _ref = out_wrapper()(_ref) - if has_out: - _ref = out_wrapper()(_ref) + _ref.__name__ = name + if aten_op is infer_aten_op: + aten_op = utils.get_aten_op(prim, name) + if aten_op is not None: + register_decomposition(aten_op)(_ref) - if aten_op is infer_aten_op: - aten_op = getattr(torch.ops.aten, prim.__name__.split(".")[0]) - if aten_op is not None: - register_decomposition(aten_op)(_ref) + return _ref - return _ref + return inner # Add has its own implementation because it has an alpha argument @@ -947,47 +947,61 @@ def add( # TODO: add docstring -atan2 = _make_elementwise_binary_reference( - prims.atan2, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def atan2(a, b): + return prims.atan2(a, b) + # TODO: add docstring -bitwise_and = _make_elementwise_binary_reference( - prims.bitwise_and, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_and(a, b) + # TODO: add docstring -bitwise_left_shift = _make_elementwise_binary_reference( - prims.shift_left, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.bitwise_left_shift, # prim/aten name mismatch ) +def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_left(a, b) + # TODO: add docstring -bitwise_or = _make_elementwise_binary_reference( - prims.bitwise_or, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_or(a, b) + # TODO: add docstring -bitwise_right_shift = _make_elementwise_binary_reference( - prims.shift_right_arithmetic, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.bitwise_right_shift, # prim/aten name mismatch ) +def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.shift_right_arithmetic(a, b) + # TODO: add docstring -bitwise_xor = _make_elementwise_binary_reference( - prims.bitwise_xor, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.bitwise_xor(a, b) -def _copysign( +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + supports_lhs_python_scalar=False, +) +def copysign( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): if isinstance(b, Number) and isinstance(a, Tensor): @@ -1000,14 +1014,6 @@ def _copysign( return where(signbit(b), neg(abs(a)), abs(a)) -# TODO: add docstring -copysign = _make_elementwise_binary_reference( - _copysign, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - supports_lhs_python_scalar=False, - aten_op=torch.ops.aten.copysign, -) - # TODO: add docstring # complex = _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT) @@ -1038,14 +1044,19 @@ def div( # TODO: add docstring -eq = _make_elementwise_binary_reference( - prims.eq, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.eq(a, b) -def _pow( +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, +) +def pow( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], ) -> TensorLikeType: @@ -1061,13 +1072,6 @@ def _pow( return prims.pow(a, b) -# TODO: add docstring -pow = _make_elementwise_binary_reference( - _pow, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG, - aten_op=torch.ops.aten.pow, -) - # TODO: add docstring # Float power has its own implementation because it has unique type promotion. # NB: aten_op not registered because CompositeExplicitAutograd @@ -1127,7 +1131,13 @@ def float_power( # # For reference, see CPython's implementation: # https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636 -def _floor_divide( + +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_two_python_scalars=True, +) +def floor_divide( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): # Wrap scalars because some references only accept tensor arguments. @@ -1194,66 +1204,69 @@ def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor: # TODO: add docstring -floor_divide = _make_elementwise_binary_reference( - _floor_divide, - type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.floor_divide, - supports_two_python_scalars=True, -) - - -# TODO: add docstring -fmax = _make_elementwise_binary_reference( - prims.fmax, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmax, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmax(a, b) + # TODO: add docstring -fmin = _make_elementwise_binary_reference( - prims.fmin, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmin, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmin(a, b) + # TODO: add docstring -fmod = _make_elementwise_binary_reference( - prims.fmod, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.fmod, supports_lhs_python_scalar=False, supports_rhs_python_scalar=True, ) +def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.fmod(a, b) + # TODO: add docstring -gcd = _make_elementwise_binary_reference( - prims.gcd, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.gcd, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gcd(a, b) + # TODO: add docstring -ge = _make_elementwise_binary_reference( - prims.ge, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ge(a, b) + # TODO: add docstring -gt = _make_elementwise_binary_reference( - prims.gt, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.gt(a, b) -def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: input_eq_zero = eq(input, 0) input_lt_zero = logical_or(lt(input, 0), isnan(input)) zeros_and_ones = where(input_lt_zero, 0, 1) @@ -1261,34 +1274,31 @@ def _heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType: return output -heaviside = _make_elementwise_binary_reference( - _heaviside, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, - supports_lhs_python_scalar=False, - supports_rhs_python_scalar=False, - aten_op=torch.ops.aten.heaviside, -) - -hypot = _make_elementwise_binary_reference( - prims.hypot, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.hypot(a, b) + -igamma = _make_elementwise_binary_reference( - prims.igamma, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igamma(a, b) -igammac = _make_elementwise_binary_reference( - prims.igammac, # type: ignore[has-type] + +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.igammac(a, b) def _check_close_args( @@ -1363,7 +1373,13 @@ def isclose( return result -def _lcm(a: TensorLikeType, b: TensorLikeType): +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + supports_lhs_python_scalar=False, + supports_rhs_python_scalar=False, +) +def lcm(a: TensorLikeType, b: TensorLikeType): dtype = a.dtype # promoting to int32 to maintain 100% consistency with C++ and to # prevent overflow in case of int8 and int16 @@ -1380,24 +1396,19 @@ def _lcm(a: TensorLikeType, b: TensorLikeType): # TODO: add docstring -lcm = _make_elementwise_binary_reference( - _lcm, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.lcm, +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, - supports_rhs_python_scalar=False, ) +def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.le(a, b) # TODO: add docstring -le = _make_elementwise_binary_reference( - prims.le, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - supports_lhs_python_scalar=False, ) - - -def _logical_and(a: TensorLikeType, b: TensorLikeType): +def logical_and(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1405,23 +1416,19 @@ def _logical_and(a: TensorLikeType, b: TensorLikeType): return a & b -logical_and = _make_elementwise_binary_reference( - _logical_and, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_and, -) - - -@_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, aten_op=torch.ops.aten.logical_not -) +# TODO: add docstring +@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL) def logical_not(a: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): return a == 0 return ~a -def _logical_or(a: TensorLikeType, b: TensorLikeType): +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, +) +def logical_or(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1429,14 +1436,12 @@ def _logical_or(a: TensorLikeType, b: TensorLikeType): return bitwise_or(a, b) -logical_or = _make_elementwise_binary_reference( - _logical_or, +# TODO: add docstring +# TODO: skip unnecessary conversion of long to float +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_or, ) - - -def _logical_xor(a: TensorLikeType, b: TensorLikeType): +def logical_xor(a: TensorLikeType, b: TensorLikeType): if not utils.is_boolean_dtype(a.dtype): a = a != 0 if not utils.is_boolean_dtype(b.dtype): @@ -1444,61 +1449,66 @@ def _logical_xor(a: TensorLikeType, b: TensorLikeType): return a ^ b -# TODO: skip unnecessary conversion of long to float -logical_xor = _make_elementwise_binary_reference( - _logical_xor, - type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, - aten_op=torch.ops.aten.logical_xor, -) - - # TODO: add docstring -lt = _make_elementwise_binary_reference( - prims.lt, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.lt(a, b) + # TODO: add docstring -maximum = _make_elementwise_binary_reference( - prims.maximum, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.maximum(a, b) + # TODO: add docstring -minimum = _make_elementwise_binary_reference( - prims.minimum, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) +def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.minimum(a, b) + # TODO: add docstring -mul = _make_elementwise_binary_reference( - prims.mul, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, supports_two_python_scalars=True, ) +def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.mul(a, b) + # TODO: add docstring -ne = _make_elementwise_binary_reference( - prims.ne, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL, supports_lhs_python_scalar=False, ) +def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.ne(a, b) + # TODO: add docstring -nextafter = _make_elementwise_binary_reference( - prims.nextafter, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, supports_lhs_python_scalar=False, supports_rhs_python_scalar=False, ) +def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.nextafter(a, b) + # TODO: add docstring -remainder = _make_elementwise_binary_reference( - prims.remainder, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=torch.ops.aten.remainder, ) +def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.remainder(a, b) + # reverse sub def rsub( @@ -1550,12 +1560,14 @@ def sub( # TODO: add docstring -true_divide = _make_elementwise_binary_reference( - prims.div, # type: ignore[has-type] +@_make_elementwise_binary_reference( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + name="true_divide", aten_op=None, # CompositeImplicitAutograd supports_two_python_scalars=True, ) +def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.div(a, b) @register_decomposition(torch.ops.aten.xlogy) @@ -1583,7 +1595,13 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT return torch.where(torch.isnan(b), float("nan"), rhs) -def _trunc_divide( +# TODO: add docstring +@_make_elementwise_binary_reference( + type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + aten_op=None, # CompositeImplicitAutograd + supports_two_python_scalars=True, +) +def trunc_divide( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType] ): dtype = utils.get_dtype(a) @@ -1593,14 +1611,6 @@ def _trunc_divide( return trunc(prims.div(a, b)) -# TODO: add docstring -trunc_divide = _make_elementwise_binary_reference( - _trunc_divide, - type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - aten_op=None, # CompositeImplicitAutograd - supports_two_python_scalars=True, -) - # # Elementwise Ternary References # diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index 12f44c4092a4..ab352c40a93a 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -20,10 +20,6 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) -from torch._refs import ( - _make_elementwise_binary_reference, - _make_elementwise_unary_reference, -) from torch._subclasses.fake_tensor import FakeTensor diff --git a/torch/_refs/special/__init__.py b/torch/_refs/special/__init__.py index 1227a2631475..498382324265 100644 --- a/torch/_refs/special/__init__.py +++ b/torch/_refs/special/__init__.py @@ -46,7 +46,6 @@ @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_bessel_j0, ) def bessel_j0(a: TensorLikeType) -> TensorLikeType: return prims.bessel_j0(a) @@ -54,7 +53,6 @@ def bessel_j0(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_bessel_j1, ) def bessel_j1(a: TensorLikeType) -> TensorLikeType: return prims.bessel_j1(a) @@ -89,21 +87,21 @@ def erfcx(a: TensorLikeType) -> TensorLikeType: @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i0e + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i0e(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i0e(a) @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1 + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i1(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i1(a) @_make_elementwise_unary_reference( - ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=torch.ops.aten.special_i1e + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) def i1e(a: TensorLikeType) -> TensorLikeType: return prims.bessel_i1e(a) @@ -223,14 +221,14 @@ def softmax( @_make_elementwise_unary_reference( ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_spherical_bessel_j0, ) def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType: return prims.spherical_bessel_j0(a) -zeta = _make_elementwise_binary_reference( - prims.zeta, # type: ignore[has-type] +# TODO: add docstring +@_make_elementwise_binary_reference( type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, - aten_op=torch.ops.aten.special_zeta, ) +def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: + return prims.zeta(a, b) From 6741443c7ceae0201fd76b5e6fc59ebd8cd6876a Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 18 Nov 2022 10:35:46 +0000 Subject: [PATCH 339/453] Simplify maybe_resize_out (#88116) The previous behaviour would call `resize_` on 0-sized elements even when their size was correct. This would make some test fail, as resize_ may be an in-place operation and it's not supported by some subsystems Pull Request resolved: https://github.com/pytorch/pytorch/pull/88116 Approved by: https://github.com/mruberry --- torch/_prims_common/wrappers.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 76886f886a72..349e450cf372 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -4,6 +4,7 @@ NumberType, TensorLike, TensorLikeType, + ShapeType, ELEMENTWISE_TYPE_PROMOTION_KIND, ) import torch._prims_common as utils @@ -11,8 +12,7 @@ from typing import Callable, Sequence, Union, Tuple, NamedTuple import inspect -from functools import wraps, reduce -import operator +from functools import wraps import warnings from itertools import chain @@ -129,25 +129,22 @@ def _fn(*args, **kwargs): # TODO: handle tuples of tensors -def _maybe_resize_out(out: TensorLikeType, shape): - if out.numel() == 0: - return out.resize_(shape) - - if out.numel() != reduce(operator.mul, shape, 1): - msg = ( - "An output with one or more elements was resized since it had shape {0} " - "which does not match the required output shape {1}. " - "This behavior is deprecated, and in a future PyTorch release outputs will not " - "be resized unless they have zero elements. " - "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0).".format( - str(out.shape), str(shape) +def _maybe_resize_out(out: TensorLikeType, shape: ShapeType): + # If the shapes are correct there's nothing to do + if utils.same_shape(out.shape, shape): + return out + else: + if out.numel() != 0: + msg = ( + f"An output with one or more elements was resized since it had shape {str(out.shape)} " + "which does not match the required output shape {str(shape)}. " + "This behavior is deprecated, and in a future PyTorch release outputs will not " + "be resized unless they have zero elements. " + "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." ) - ) - warnings.warn(msg) + warnings.warn(msg) return out.resize_(shape) - return out - def _safe_copy_out( *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False From 154e58c03285f3d399b8818dd17e973d486efefa Mon Sep 17 00:00:00 2001 From: lezcano Date: Fri, 18 Nov 2022 11:25:36 +0000 Subject: [PATCH 340/453] Add most in-place references/decompositions (#88117) We add most in-place references in a generic way. We also implement a wrapper to implement the annoying interface that `nn.functional` nonlinearities have. We fix along the way a couple decompositions for some non-linearities by extending the arguments that the references have. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88117 Approved by: https://github.com/mruberry --- test/functorch/test_aotdispatch.py | 9 -- test/test_meta.py | 11 +- test/test_ops.py | 7 +- test/test_proxy_tensor.py | 116 +++--------------- torch/_decomp/decompositions.py | 13 -- torch/_refs/__init__.py | 114 ++++++++++++++++- torch/_refs/nn/functional/__init__.py | 95 +++++++++++--- .../_internal/common_methods_invocations.py | 9 ++ 8 files changed, 231 insertions(+), 143 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 1dc5476158f9..de6d82960adc 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1078,7 +1078,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('linalg.tensorinv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('linalg.tensorsolve', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('linalg.vander', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('linalg.vector_norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition xfail('logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposition xfail('logcumsumexp', ''), # aten.logcumsumexp.default - couldn't find symbolic meta function/decomposition @@ -1105,9 +1104,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('mvlgamma', 'mvlgamma_p_1'), # aten.digamma_.default - couldn't find symbolic meta function/decom... - xfail('mvlgamma', 'mvlgamma_p_3'), # aten.digamma_.default - couldn't find symbolic meta function/decom... - xfail('mvlgamma', 'mvlgamma_p_5'), # aten.digamma_.default - couldn't find symbolic meta function/decom... # Deleting this in a followup xfail('nn.functional.poisson_nll_loss', ''), @@ -1121,7 +1117,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): skip('nn.functional.batch_norm', ''), # '0 is not tracked with proxy for threshold, out_grad, out_grad * z / (z + 1.0)) -@register_decomposition(aten.elu) -@pw_cast_for_opmath -def elu( - self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1 -) -> Tensor: - negcoef = alpha * scale - poscoef = scale - negiptcoef = input_scale - return torch.where( - self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef - ) - - @register_decomposition(aten.elu_backward) @pw_cast_for_opmath def elu_backward( diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 25b6f2da37c8..3355400db43c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -122,6 +122,7 @@ "bitwise_right_shift", "bitwise_xor", "clamp_min", + "clamp_max", "copysign", "div", "eq", @@ -422,6 +423,31 @@ def _fn(*args, **kwargs): return _fn +def _make_inplace(fn): + """ + Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant + See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, **kwargs): + return fn(a, *args, out=a, **kwargs) + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_decomposition(getattr(torch.ops.aten, inplace_name))(_fn) + + # We access the __all__ attribute of the module where fn is defined + # There may be a cleaner way of doing this... + from inspect import getmodule + + _all = getmodule(fn).__all__ # type: ignore[union-attr] + if inplace_name not in _all: + _all.append(inplace_name) + return _fn + + @_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT) def abs(a): return prims.abs(a) @@ -3419,7 +3445,6 @@ def index_select(x: TensorLike, dim: int, index: TensorLike): return x[idx] -# Note: although squeeze is documented as having the out= kwarg it doesn't @register_decomposition(torch.ops.aten.squeeze) def squeeze(a: TensorLikeType, dim: Optional[int] = None) -> TensorLikeType: if dim is not None: @@ -3843,6 +3868,7 @@ def cumsum( return sum(masked_a, dim=dim, keepdim=keepdim, dtype=dtype, out=out) +# Note: although squeeze is documented as having the out= kwarg it doesn't @register_decomposition(torch.ops.aten.unsqueeze) def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType: # Note that unsqueeze canonicalizes with rank + 1 because it allows @@ -5013,6 +5039,92 @@ def bucketize( return start.to(dtype=out_dtype) +# inplace +abs_ = _make_inplace(abs) +acos_ = _make_inplace(acos) +acosh_ = _make_inplace(acosh) +addcmul_ = _make_inplace(addcmul) +addcdiv_ = _make_inplace(addcdiv) +asin_ = _make_inplace(asin) +asinh_ = _make_inplace(asinh) +atan_ = _make_inplace(atan) +atanh_ = _make_inplace(atanh) +atan2_ = _make_inplace(atan2) +ceil_ = _make_inplace(ceil) +clamp_ = _make_inplace(clamp) +clamp_min_ = _make_inplace(clamp_min) +clamp_max_ = _make_inplace(clamp_max) +conj_physical_ = _make_inplace(conj_physical) +copysign_ = _make_inplace(copysign) +cos_ = _make_inplace(cos) +cosh_ = _make_inplace(cosh) +cumsum_ = _make_inplace(cumsum) +digamma_ = _make_inplace(digamma) +div_ = _make_inplace(div) +eq_ = _make_inplace(eq) +erf_ = _make_inplace(erf) +erfc_ = _make_inplace(erfc) +erfinv_ = _make_inplace(erfinv) +exp_ = _make_inplace(exp) +exp2_ = _make_inplace(exp2) +expm1_ = _make_inplace(expm1) +float_power_ = _make_inplace(float_power) +floor_ = _make_inplace(floor) +floor_divide_ = _make_inplace(floor_divide) +fmod_ = _make_inplace(fmod) +frac_ = _make_inplace(frac) +ge_ = _make_inplace(ge) +gt_ = _make_inplace(gt) +heaviside_ = _make_inplace(heaviside) +hypot_ = _make_inplace(hypot) +igamma_ = _make_inplace(igamma) +igammac_ = _make_inplace(igammac) +le_ = _make_inplace(le) +lerp_ = _make_inplace(lerp) +lgamma_ = _make_inplace(lgamma) +log10_ = _make_inplace(log10) +log1p_ = _make_inplace(log1p) +log2_ = _make_inplace(log2) +log_ = _make_inplace(log) +logical_and_ = _make_inplace(logical_and) +logical_or_ = _make_inplace(logical_or) +logical_xor_ = _make_inplace(logical_xor) +lt_ = _make_inplace(lt) +mvlgamma_ = _make_inplace(mvlgamma) +nan_to_num_ = _make_inplace(nan_to_num) +ne_ = _make_inplace(ne) +neg_ = _make_inplace(neg) +nextafter_ = _make_inplace(nextafter) +pow_ = _make_inplace(pow) +reciprocal_ = _make_inplace(reciprocal) +remainder_ = _make_inplace(remainder) +rsqrt_ = _make_inplace(rsqrt) +sgn_ = _make_inplace(sgn) +sigmoid_ = _make_inplace(sigmoid) +sign_ = _make_inplace(sign) +sin_ = _make_inplace(sin) +sinc_ = _make_inplace(sinc) +sinh_ = _make_inplace(sinh) +sqrt_ = _make_inplace(sqrt) +square_ = _make_inplace(square) +tan_ = _make_inplace(tan) +tanh_ = _make_inplace(tanh) +tril_ = _make_inplace(tril) +triu_ = _make_inplace(triu) +true_divide_ = _make_inplace(true_divide) +trunc_ = _make_inplace(trunc) +xlogy_ = _make_inplace(xlogy) + +# Views +# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function +# given that it does not reshape the input (it just copies the result into it) + +# squeeze_ = _make_inplace(squeeze) +# t_ = _make_inplace(t) +# transpose_ = _make_inplace(transpose) +# unsqueeze_ = _make_inplace(unsqueeze) + + import torch._refs._conversions import torch._refs.fft import torch._refs.linalg diff --git a/torch/_refs/nn/functional/__init__.py b/torch/_refs/nn/functional/__init__.py index ab352c40a93a..4ebe6e2b05d9 100644 --- a/torch/_refs/nn/functional/__init__.py +++ b/torch/_refs/nn/functional/__init__.py @@ -1,4 +1,5 @@ import math +from functools import wraps from typing import Callable, Optional, Union import torch @@ -20,6 +21,7 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) +from torch._refs import _make_inplace from torch._subclasses.fake_tensor import FakeTensor @@ -116,9 +118,31 @@ def alpha_dropout( return self * dropout_mask + b +def inplace_wrapper(fn): + """ + Given a nn.functional non-linearity, implements its `inplace: bool` argument + """ + + # nb. We use the name of the first argument used in the unary references + @wraps(fn) + def _fn(a, *args, inplace=False, **kwargs): + if inplace: + check( + "out" not in kwargs, + lambda: "Cannot set inplace=True and pass out= at the same time", + ) + return fn(a, *args, inplace=False, out=a, **kwargs) + else: + return fn(a, *args, inplace=False, **kwargs) + + return _fn + + # celu is implemented specially because it has an alpha argument # celu is very similar to elu @register_decomposition(torch.ops.aten.celu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -151,6 +175,8 @@ def celu( @register_decomposition(torch.ops.aten.dropout) +@inplace_wrapper +@out_wrapper() def dropout( a: TensorLikeType, p: float = 0.5, training: bool = True, inplace: bool = False ) -> TensorLikeType: @@ -178,14 +204,19 @@ def dropout( return a * dropout_mask * scale -# elu is implemented specially because it has an alpha argument -# This cannot be used as a decomposition because the aten op takes in 2 extra kwargs +@register_decomposition(torch.ops.aten.elu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) def elu( - a: TensorLikeType, alpha: Optional[NumberType] = None, inplace: bool = False + a: TensorLikeType, + alpha: NumberType = 1.0, + scale: NumberType = 1.0, + input_scale: NumberType = 1.0, + inplace: bool = False, ) -> TensorLikeType: """ Reference implementation of torch.nn.functional.elu @@ -193,24 +224,27 @@ def elu( if inplace: raise NotImplementedError - rhs: TensorLikeType - if alpha is not None: - python_type = utils.dtype_to_type(a.dtype) - if not utils.is_weakly_lesser_type(type(alpha), python_type): - msg = ( - "alpha argument of type {0} cannot be safely cast to type {1}!".format( - type(alpha), python_type - ) - ) - raise ValueError(msg) - rhs = alpha * torch.expm1(a) - else: - rhs = torch.expm1(a) + # nb. This should be factored out into a can_cast aux function + python_type = utils.dtype_to_type(a.dtype) + check( + utils.is_weakly_lesser_type(type(input_scale), python_type), + lambda: f"input_scale argument of type {type(input_scale)} cannot be safely cast to type {python_type}!", + ) + check( + utils.is_weakly_lesser_type(type(scale), python_type), + lambda: f"scale argument of type {type(scale)} cannot be safely cast to type {python_type}!", + ) + check( + utils.is_weakly_lesser_type(type(alpha), python_type), + lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!", + ) - return torch.where(a > 0, a, rhs) + return torch.where(a > 0, scale * a, (alpha * scale) * torch.expm1(a * input_scale)) @register_decomposition(torch.ops.aten.relu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -280,6 +314,8 @@ def layer_norm( @register_decomposition(torch.ops.aten.leaky_relu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -302,6 +338,8 @@ def leaky_relu( @register_decomposition(torch.ops.aten.mish) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -317,6 +355,8 @@ def mish(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: @register_decomposition(torch.ops.aten.selu) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -369,6 +409,7 @@ def softmin( # softplus is implemented specially because it has beta and threshold arguments @register_decomposition(torch.ops.aten.softplus) +@inplace_wrapper @out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), @@ -661,11 +702,11 @@ def _nll_loss_nd( @register_decomposition(torch.ops.aten.nll_loss) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("input",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def nll_loss( input: TensorLikeType, target: TensorLikeType, @@ -784,6 +825,8 @@ def tanhshrink(a: TensorLikeType) -> TensorLikeType: @register_decomposition(torch.ops.aten.threshold) +@inplace_wrapper +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, @@ -883,6 +926,8 @@ def _triplet_margin_with_distance_loss( @register_decomposition(torch.ops.aten.hardtanh) +@inplace_wrapper +@out_wrapper() @elementwise_unary_scalar_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("a"), @@ -1022,6 +1067,8 @@ def prelu(a: TensorLikeType, weight: TensorLikeType) -> TensorLikeType: @register_decomposition(torch.ops.aten.relu6) +@inplace_wrapper +@out_wrapper() def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: """ Reference implementation of torch.nn.functional.relu6 @@ -1036,11 +1083,11 @@ def relu6(a: TensorLikeType, inplace: bool = False) -> TensorLikeType: @register_decomposition(torch.ops.aten.glu) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def glu(a: TensorLikeType, dim: int = -1) -> TensorLikeType: dim = utils.canonicalize_dims(a.ndim, dim) check( @@ -1065,11 +1112,11 @@ def pairwise_distance( @register_decomposition(torch.ops.aten.pdist) +@out_wrapper() @elementwise_type_promotion_wrapper( type_promoting_args=("a",), type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, ) -@out_wrapper() def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: check(a.ndim == 2, lambda: f"pdist only supports 2D tensors, got: {a.ndim}D") check(p >= 0, lambda: "pdist only supports non-negative p values") @@ -1083,3 +1130,11 @@ def pdist(a: TensorLikeType, p: float = 2) -> TensorLikeType: t = torch.linalg.vector_norm(a.unsqueeze(1) - a, ord=p, dim=2) i = torch.triu_indices(t.shape[0], t.shape[1], offset=1, device=a.device) return t.flatten().index_select(0, i[0] * t.shape[0] + i[1]) + + +# Needed as aten.{celu_,elu_...} exist (even if they don't have the in-place kwarg) +celu_ = _make_inplace(celu) +elu_ = _make_inplace(elu) +mish_ = _make_inplace(mish) +selu_ = _make_inplace(selu) +threshold_ = _make_inplace(threshold) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 50732af6f857..6cff2f6a4749 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -17419,11 +17419,13 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.celu", torch_opinfo_name="nn.functional.celu", + supports_out=True, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.threshold", torch_opinfo_name="nn.functional.threshold", supports_nvfuser=False, + supports_out=True, ), PythonRefInfo( "_refs.nn.functional.dropout", @@ -17458,11 +17460,13 @@ def reference_flatten(input, start_dim=0, end_dim=-1): ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.elu", torch_opinfo_name="nn.functional.elu", + supports_out=True, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.hardtanh", torch_opinfo_name="nn.functional.hardtanh", supports_nvfuser=False, + supports_out=True, ), PythonRefInfo( # TODO: Port this to an UnaryOpInfo "_refs.nn.functional.gelu", @@ -17501,6 +17505,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): PythonRefInfo( "_refs.nn.functional.leaky_relu", torch_opinfo_name="nn.functional.leaky_relu", + supports_out=True, ), PythonRefInfo( "_refs.nn.functional.log_softmax", @@ -17526,18 +17531,22 @@ def reference_flatten(input, start_dim=0, end_dim=-1): "_refs.nn.functional.relu", torch_opinfo_name="nn.functional.relu", supports_nvfuser=False, + supports_out=True, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.relu6", torch_opinfo_name="nn.functional.relu6", + supports_out=True, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.mish", torch_opinfo_name="nn.functional.mish", + supports_out=True, ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.selu", torch_opinfo_name="nn.functional.selu", + supports_out=True, ), PythonRefInfo( "_refs.nn.functional.softmax", From 55e55d95ea9a6f64bba50cdc9e243808cb534202 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Fri, 18 Nov 2022 15:27:15 +0000 Subject: [PATCH 341/453] Update torch.distributed.DistBackendError type (#89235) Summary: Update torch.distributed.DistBackendError type based on https://fb.workplace.com/groups/pyreqa/posts/5753993921357059 Test Plan: Pyre tests should pass? let sandcastle run Reviewed By: markkm Differential Revision: D41384130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89235 Approved by: https://github.com/awgu --- test/distributed/test_c10d_nccl.py | 4 +++- torch/_C/__init__.pyi.in | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index cdc167bc4d1a..fb28e744b5ed 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1031,9 +1031,11 @@ def test_nccl_dist_backend_error(self): self._create_process_group_nccl(store, self.opts()) # Both rank 0 and 1 will use the same CUDA device resulting in ncclInvalidUsage - with self.assertRaises(dist.DistBackendError): + with self.assertRaises(dist.DistBackendError) as cm: dist.broadcast(torch.tensor([1, 2, 3]).cuda(), 0) + self.assertIsInstance(cm.exception, RuntimeError) + class DistributedDataParallelTest( test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase ): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 5833d7d7f2a4..d69cf1f3477e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1508,5 +1508,5 @@ def _current_graph_task_id() -> _int: ... class _OutOfMemoryError: pass -class _DistBackendError: +class _DistBackendError(RuntimeError): pass From 1f7c0ff6e799e7bde94975f7a5bbec39a69ab8f6 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 18 Nov 2022 13:41:51 +0000 Subject: [PATCH 342/453] [inductor] Temporarily disable functorch_dp_cifar10 test in TorchBench (#89281) Summary: The failure wasn't caught because of a land race. Skip the test for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89281 Approved by: https://github.com/Krovatkin --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index cad954f825b2..44b020413c9c 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -111,6 +111,7 @@ # *CI_SKIP_INDCUTOR_INFERENCE, # TorchBench "detectron2", + "functorch_dp_cifar10", "mobilenet_v3_large", "moco", "tacotron2", From 19fcb80551854431e7e05c422690751037a18488 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Fri, 18 Nov 2022 16:15:55 +0000 Subject: [PATCH 343/453] [inductor] Skip DALLE2_pytorch in torchbench (#89288) Summary: DALLE2_pytorch fails in eager as well. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89288 Approved by: https://github.com/Krovatkin --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 44b020413c9c..c4e9d62f0a7c 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -110,6 +110,7 @@ # *CI_SKIP_AOT_EAGER_TRAINING, # *CI_SKIP_INDCUTOR_INFERENCE, # TorchBench + "DALLE2_pytorch", "detectron2", "functorch_dp_cifar10", "mobilenet_v3_large", From 19e66fcec235fe46a23186a59446bcfe70ad4f6d Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Thu, 17 Nov 2022 12:47:33 -0800 Subject: [PATCH 344/453] [Quant] Allow setting fixed qparams for inner LSTM ops (#88456) Summary: In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is used as an observed custom module, which is responsible for inserting its own observers. By default, the user specifies a single QConfig for the custom module (either through QConfigMapping or by setting the "qconfig" attribute"), and all inner ops will [inherit this QConfig](https://github.com/pytorch/pytorch/blob/dc00bb51b8d370bf3891f0edb2c6e0c2914e329a/torch/ao/nn/quantizable/modules/rnn.py#L366-L378) and use the same observer/fake_quantize constructors. Today, users who wish to override this behavior must extend `torch.ao.nn.quantizable.LSTM` and write a lot of custom code to manually assign the QConfigs to the inner ops. This commit alleviates this burden on the user by providing a helper function to assign QConfigs with custom observers. An example use case of this is providing a reference implementation for a backend kernel that hardcodes qparams for efficiency. Example usage: ``` import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.fx.custom_config import ( PrepareCustomConfig, ConvertCustomConfig, ) class MyModel(torch.nn.Module): ... class UserLSTM(torch.ao.nn.quantizable.LSTM): @classmethod def from_float(cls, other): assert isinstance(other, cls._FLOAT_MODULE) linear_output_obs_ctr = FixedQParamsObserver.with_args( scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) sigmoid_obs_ctr = FixedQParamsObserver.with_args( scale=2 ** -16, zero_point=0, dtype=torch.qint32) tanh_obs_ctr = FixedQParamsObserver.with_args( scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) cell_state_obs_ctr = FixedQParamsObserver.with_args( scale=2 ** -11, zero_point=0, dtype=torch.qint32) hidden_state_obs_ctr = FixedQParamsObserver.with_args( scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) return torch.ao.quantization.utils._get_lstm_with_individually_observed_parts( float_lstm=other, linear_output_obs_ctr=linear_output_obs_ctr, sigmoid_obs_ctr=sigmoid_obs_ctr, tanh_obs_ctr=tanh_obs_ctr, cell_state_obs_ctr=cell_state_obs_ctr, hidden_state_obs_ctr=hidden_state_obs_ctr, ) qconfig_mapping = get_default_qconfig_mapping() example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) prepare_custom_config = PrepareCustomConfig() \ .set_float_to_observed_mapping(torch.nn.LSTM, UserLSTM) convert_custom_config = ConvertCustomConfig() \ .set_observed_to_quantized_mapping(UserLSTM, torch.ao.nn.quantized.LSTM) model = MyModel() model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) model(*example_inputs) # calibrate model = convert_fx(model, convert_custom_config=convert_custom_config) model(*example_inputs) ``` Test Plan: python test/test_quantization.py TestQuantizeFx.test_static_lstm_with_custom_fixed_qparams Reviewers: jerryzh168, vkuzo Subscribers: jerryzh168, vkuzo Pull Request resolved: https://github.com/pytorch/pytorch/pull/88456 Approved by: https://github.com/jerryzh168, https://github.com/vkuzo --- test/quantization/fx/test_quantize_fx.py | 77 ++++++++++++++++++++- torch/ao/nn/quantizable/modules/rnn.py | 24 +++++-- torch/ao/quantization/utils.py | 88 ++++++++++++++++++++++++ 3 files changed, 182 insertions(+), 7 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index b03b7fb0cf0e..2d91ba80b7e0 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -192,7 +192,7 @@ import operator import unittest import io -from typing import Callable, Optional, List +from typing import Callable, Optional, List, Tuple class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): @@ -4217,6 +4217,81 @@ def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): } self._test_static_lstm_helper(m2, prepare_node_occurrence, convert_node_occurrence2) + def test_static_lstm_with_custom_fixed_qparams(self): + """ + Test statically quantized LSTM with custom fixed qparams assigned to each of the + inner submodules. This flow requires users to extend `torch.ao.nn.quantizable.LSTM` + and use the child class in the custom module mapping. + """ + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.my_lstm = torch.nn.LSTM(50, 50, 1) + + def forward(self, inputs: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor): + x = self.my_lstm(inputs, (h0, c0)) + return x + + class UserLSTM(torch.ao.nn.quantizable.LSTM): + """ + Example of user provided LSTM implementation that has fixed qparams assigned + to the inner submodules. + """ + @classmethod + def from_float(cls, other): + assert isinstance(other, cls._FLOAT_MODULE) + # uint16, [-16, 16) + linear_output_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=2 ** 15, dtype=torch.qint32) + # uint16, [0, 1) + sigmoid_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -16, zero_point=0, dtype=torch.qint32) + # uint16, [-1, 1) + tanh_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -15, zero_point=2 ** 15, dtype=torch.qint32) + # int16, [-16, 16) + cell_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -11, zero_point=0, dtype=torch.qint32) + # uint8, [-1, 1) + hidden_state_obs_ctr = FixedQParamsObserver.with_args(scale=2 ** -7, zero_point=2 ** 7, dtype=torch.quint8) + return torch.ao.quantization.utils._get_lstm_with_individually_observed_parts( + float_lstm=other, + linear_output_obs_ctr=linear_output_obs_ctr, + sigmoid_obs_ctr=sigmoid_obs_ctr, + tanh_obs_ctr=tanh_obs_ctr, + cell_state_obs_ctr=cell_state_obs_ctr, + hidden_state_obs_ctr=hidden_state_obs_ctr, + ) + + # Prepare model + qconfig_mapping = get_default_qconfig_mapping() + example_inputs = (torch.rand(5, 3, 50), torch.rand(1, 3, 50), torch.randn(1, 3, 50)) + prepare_custom_config = PrepareCustomConfig() \ + .set_float_to_observed_mapping(torch.nn.LSTM, UserLSTM) + convert_custom_config = ConvertCustomConfig() \ + .set_observed_to_quantized_mapping(UserLSTM, torch.ao.nn.quantized.LSTM) + model = MyModel() + model = prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=prepare_custom_config) + + # Validate that the observers inserted to each inner module has the expected qparams + def validate_qparams(inner_module: torch.nn.Module, scale: float, zero_point: int, dtype: torch.dtype): + self.assertTrue(hasattr(inner_module, "activation_post_process")) + obs = inner_module.activation_post_process + self.assertTrue(isinstance(obs, FixedQParamsObserver)) + self.assertEqual(obs.scale, scale) + self.assertEqual(obs.zero_point, zero_point) + self.assertEqual(obs.dtype, dtype) + cell = model.my_lstm.layers[0].layer_fw.cell + validate_qparams(cell.igates, 2 ** -11, 2 ** 15, torch.qint32) + validate_qparams(cell.hgates, 2 ** -11, 2 ** 15, torch.qint32) + validate_qparams(cell.input_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.forget_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.cell_gate, 2 ** -15, 2 ** 15, torch.qint32) + validate_qparams(cell.output_gate, 2 ** -16, 0, torch.qint32) + validate_qparams(cell.fgate_cx_igate_cgate, 2 ** -11, 0, torch.qint32) + validate_qparams(cell.ogate_cy, 2 ** -7, 2 ** 7, torch.quint8) + + # Make sure the rest of the flow runs + model(*example_inputs) + model = convert_fx(model, convert_custom_config=convert_custom_config, _remove_qconfig=False) + model(*example_inputs) + def test_reroute_tuple_getitem_patterns(self): """ The following graph should redirect the output to `b`. After the transformation, diff --git a/torch/ao/nn/quantizable/modules/rnn.py b/torch/ao/nn/quantizable/modules/rnn.py index 59f23137097c..72156a7ba5fe 100644 --- a/torch/ao/nn/quantizable/modules/rnn.py +++ b/torch/ao/nn/quantizable/modules/rnn.py @@ -41,12 +41,22 @@ def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True, self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs) self.gates = torch.ao.nn.quantized.FloatFunctional() + self.input_gate = torch.nn.Sigmoid() + self.forget_gate = torch.nn.Sigmoid() + self.cell_gate = torch.nn.Tanh() + self.output_gate = torch.nn.Sigmoid() + self.fgate_cx = torch.ao.nn.quantized.FloatFunctional() self.igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional() self.ogate_cy = torch.ao.nn.quantized.FloatFunctional() + self.initial_hidden_state_qparams: Tuple[float, int] = (1.0, 0) + self.initial_cell_state_qparams: Tuple[float, int] = (1.0, 0) + self.hidden_state_dtype: torch.dtype = torch.quint8 + self.cell_state_dtype: torch.dtype = torch.quint8 + def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]: if hidden is None or hidden[0] is None or hidden[1] is None: hidden = self.initialize_hidden(x.shape[0], x.is_quantized) @@ -58,10 +68,10 @@ def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) - input_gate = torch.sigmoid(input_gate) - forget_gate = torch.sigmoid(forget_gate) - cell_gate = torch.tanh(cell_gate) - out_gate = torch.sigmoid(out_gate) + input_gate = self.input_gate(input_gate) + forget_gate = self.forget_gate(forget_gate) + cell_gate = self.cell_gate(cell_gate) + out_gate = self.output_gate(out_gate) fgate_cx = self.fgate_cx.mul(forget_gate, cx) igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) @@ -75,8 +85,10 @@ def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]: h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size)) if is_quantized: - h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8) - c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8) + (h_scale, h_zp) = self.initial_hidden_state_qparams + (c_scale, c_zp) = self.initial_cell_state_qparams + h = torch.quantize_per_tensor(h, scale=h_scale, zero_point=h_zp, dtype=self.hidden_state_dtype) + c = torch.quantize_per_tensor(c, scale=c_scale, zero_point=c_zp, dtype=self.cell_state_dtype) return h, c def _get_name(self): diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index 9f3dc712a9fe..662d0068fef4 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -546,6 +546,94 @@ def _patched_module_call(self, *args, **kwargs): torch.nn.Module.__call__ = orig_module_call return fqn_to_example_inputs +def _get_lstm_with_individually_observed_parts( + float_lstm: torch.nn.LSTM, + # Use Callable instead of _PartialWrapper here to avoid circular dependencies + linear_output_obs_ctr: Optional[Callable] = None, + sigmoid_obs_ctr: Optional[Callable] = None, + tanh_obs_ctr: Optional[Callable] = None, + cell_state_obs_ctr: Optional[Callable] = None, + hidden_state_obs_ctr: Optional[Callable] = None, +) -> torch.ao.nn.quantizable.LSTM: + """ + Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` + with specific observers or fake quantizes assigned to the inner ops or submodules. + + In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is + used as an observed custom module, which is responsible for inserting its own + observers. By default, all inner ops inherit the parent custom module's QConfig. + Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM` + and use this helper function to customize the observer insertion logic. + + Args: + `float_lstm`: The float LSTM module + `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b, + where W is the weight matrix, b is the bias, and x is either the inputs + or the hidden state from the previous layer (if any) + `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations + `tanh_obs_ctr`: observer or fake quantize for tanh activations + `cell_state_obs_ctr`: observer or fake quantize for the cell state + `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and + the output + + Return: + A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes + attached to the inner submodules. + """ + def make_qconfig(obs_ctr: Callable) -> torch.ao.quantization.QConfig: + """ + Make a QConfig with fixed qparams observers or fake quantizes. + """ + if isinstance(obs_ctr(), torch.ao.quantization.FakeQuantizeBase): + weight = torch.ao.quantization.default_weight_fake_quant + else: + weight = torch.ao.quantization.default_weight_observer + return torch.ao.quantization.QConfig(activation=obs_ctr, weight=weight) + + observed_lstm = torch.ao.nn.quantizable.LSTM( + float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias, + float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional) + + # Assign QConfigs with fixed qparams to all inner submodules + # Module hierarchy: LSTM > _LSTMLayer > _LSTMSingleLayer (forward or backward) > LSTMCell + for layer in observed_lstm.layers: + inner_layers = [layer.layer_fw] + if float_lstm.bidirectional: + inner_layers.append(layer.layer_bw) + for inner_layer in inner_layers: + cell = inner_layer.cell + if linear_output_obs_ctr is not None: + qconfig = make_qconfig(linear_output_obs_ctr) + cell.igates.qconfig = qconfig + cell.hgates.qconfig = qconfig + if sigmoid_obs_ctr is not None: + qconfig = make_qconfig(sigmoid_obs_ctr) + cell.input_gate.qconfig = qconfig + cell.forget_gate.qconfig = qconfig + cell.output_gate.qconfig = qconfig + if tanh_obs_ctr is not None: + cell.cell_gate.qconfig = make_qconfig(tanh_obs_ctr) + if cell_state_obs_ctr is not None: + cell.fgate_cx_igate_cgate.qconfig = make_qconfig(cell_state_obs_ctr) + obs = cell_state_obs_ctr() + if hasattr(obs, "scale") and hasattr(obs, "zero_point"): + cell.initial_cell_state_qparams = (obs.scale, obs.zero_point) + cell.cell_state_dtype = obs.dtype + if hidden_state_obs_ctr is not None: + cell.ogate_cy.qconfig = make_qconfig(hidden_state_obs_ctr) + obs = hidden_state_obs_ctr() + if hasattr(obs, "scale") and hasattr(obs, "zero_point"): + cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point) + cell.hidden_state_dtype = obs.dtype + + # Insert the observers based on the previously attached QConfigs + # Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden + torch.ao.quantization.add_observer_( + observed_lstm, + non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh] + ) + return observed_lstm + __all__ = [ "NodePattern", "Pattern", From 12a97444c3f5b640be54f3307895cd0e0c18085a Mon Sep 17 00:00:00 2001 From: Richard Howell Date: Fri, 18 Nov 2022 16:30:53 +0000 Subject: [PATCH 345/453] [xplat] remove -weak_framework (#89233) Summary: The `-weak_framework` flag is no longer necessary, Buck will weakly link frameworks depending on the `target_sdk_version` of the binary being linked. Test Plan: Compare IG load commands before and after change with P553208168 ``` load command difference in Instagram.app/Frameworks/InstagramXplatFramework.framework/InstagramXplatFramework --- /tmp/tmpvd97s2v0 2022-11-16 12:13:54.082910598 -0800 +++ /tmp/tmpj20r_4ca 2022-11-16 12:13:54.082910598 -0800 @@ -9,7 +9,7 @@ /System/Library/Frameworks/CoreHaptics.framework/CoreHaptics (compatibility version 1.0.0, current version 1.0.0, weak) /System/Library/Frameworks/CoreImage.framework/CoreImage (compatibility version 1.0.0, current version 5.0.0) /System/Library/Frameworks/CoreLocation.framework/CoreLocation (compatibility version 1.0.0, current version 2780.0.17) - /System/Library/Frameworks/CoreML.framework/CoreML (compatibility version 1.0.0, current version 1.0.0, weak) + /System/Library/Frameworks/CoreML.framework/CoreML (compatibility version 1.0.0, current version 1.0.0) /System/Library/Frameworks/CoreMedia.framework/CoreMedia (compatibility version 1.0.0, current version 1.0.0) /System/Library/Frameworks/CoreServices.framework/CoreServices (compatibility version 1.0.0, current version 1226.0.0) /System/Library/Frameworks/CoreTelephony.framework/CoreTelephony (compatibility version 1.0.0, current version 0.0.0) @@ -33,9 +33,9 @@ /System/Library/Frameworks/Security.framework/Security (compatibility version 1.0.0, current version 60420.40.34) /System/Library/Frameworks/SystemConfiguration.framework/SystemConfiguration (compatibility version 1.0.0, current version 1241.40.2) /System/Library/Frameworks/UIKit.framework/UIKit (compatibility version 1.0.0, current version 6109.1.108) - /System/Library/Frameworks/UserNotifications.framework/UserNotifications (compatibility version 1.0.0, current version 1.0.0, weak) + /System/Library/Frameworks/UserNotifications.framework/UserNotifications (compatibility version 1.0.0, current version 1.0.0) /System/Library/Frameworks/VideoToolbox.framework/VideoToolbox (compatibility version 1.0.0, current version 1.0.0) - /System/Library/Frameworks/WebKit.framework/WebKit (compatibility version 1.0.0, current version 614.2.9, weak) + /System/Library/Frameworks/WebKit.framework/WebKit (compatibility version 1.0.0, current version 614.2.9) /usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1319.0.0) /usr/lib/libbz2.1.0.dylib (compatibility version 1.0.0, current version 1.0.8) /usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 1300.32.0) ``` Both these changes are correct, WebKit is available from 8.0, UserNotifications from 10.0 and CoreML from 11.0. Instagram has a deployment target of 12.4. Reviewed By: ebgraham Differential Revision: D41348639 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89233 Approved by: https://github.com/malfet --- c2_defs.bzl | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/c2_defs.bzl b/c2_defs.bzl index 0a89bb88093d..fedbb4bca84b 100644 --- a/c2_defs.bzl +++ b/c2_defs.bzl @@ -221,26 +221,13 @@ def get_c2_fbobjc_ios_frameworks(): frameworks = [] if get_c2_mpscnn(): - frameworks.append( + frameworks.extend([ "$SDKROOT/System/Library/Frameworks/Metal.framework", - ) + "$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", + ]) return frameworks -def get_c2_fbobjc_linker_flags(): - flags = [] - - if get_c2_mpscnn(): - # Need linker flags as no platform_frameworks exist, and we can't - # use MPSCNN on x86_64. - # We use weak_framework as it's iOS 10 - flags = [ - "-L$SDKROOT/System/Library/Frameworks/MetalPerformanceShaders.framework", - "-weak_framework", - "MetalPerformanceShaders", - ] - return flags - def get_c2_fbobjc_exported_preprocessor_flags(): flags = [] @@ -311,12 +298,6 @@ def get_c2_default_cxx_args(): STATIC_LIBRARY_IOS_CONFIG, extra_target_config = C2_FBOBJC_EXTRA_TARGET_CONFIG, ), - fbobjc_exported_platform_linker_flags = [ - ( - "iphoneos", - get_c2_fbobjc_linker_flags(), - ), - ], fbobjc_exported_platform_preprocessor_flags = [ ( "iphoneos", From c219b55b5f8d5718d382735628e9eb8a46caee9f Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Thu, 17 Nov 2022 21:35:51 -0800 Subject: [PATCH 346/453] Use standard __func__ macro in symbolic shape. (#89264) Summary: I saw the following issue only on Windows build in PR #88767: ``` RuntimeError: AttributeError: 'SymNode' object has no attribute 'torch::impl::PythonSymNodeImpl::ge' ``` It's only on Windows because we get the attributes of SymNode in C++ with `__FUNCTION__` macro, which is not in C++ standard, therefore has platform specific behavior. In this case, MSVC will include a function's namespace and class name, which is not intended here. Instead we should use `__func__`. see: https://en.cppreference.com/w/cpp/language/function#Function_definition godbolt example to show the difference: https://godbolt.org/z/PGfvecxPx Test Plan: CI Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89264 Approved by: https://github.com/ezyang --- torch/csrc/utils/python_symnode.h | 38 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 3a9fa79d37d6..00bddfb9e4dc 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -94,78 +94,78 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { } c10::SymNode add(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode sub(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode mul(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode truediv(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode pow(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode floordiv(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode mod(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode eq(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode gt(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode lt(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode le(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode ge(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode min(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode max(const c10::SymNode& other) override { - return dispatch_common_(__FUNCTION__, other); + return dispatch_common_(__func__, other); } c10::SymNode ceil() override { - return dispatch_common_(__FUNCTION__); + return dispatch_common_(__func__); } c10::SymNode floor() override { - return dispatch_common_(__FUNCTION__); + return dispatch_common_(__func__); } c10::SymNode neg() override { - return dispatch_common_(__FUNCTION__); + return dispatch_common_(__func__); } c10::SymNode clone() override { - return dispatch_common_(__FUNCTION__); + return dispatch_common_(__func__); } c10::SymNode sym_float() override { - return dispatch_common_(__FUNCTION__); + return dispatch_common_(__func__); } py::handle getPyObj() { From 38ccd08f9b79bc2102050833948f5112aed2dfc4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 18 Nov 2022 00:15:45 -0800 Subject: [PATCH 347/453] [quant][fx][be] Refactor replace observer with q/dq op code (#89247) Summary: This is a refactor to prepare for future extensions, no functionality changes Test Plan: python test/test_quantization.py TestQuantizeFx Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89247 Approved by: https://github.com/vkuzo, https://github.com/andrewor14 --- torch/ao/quantization/fx/convert.py | 132 +++++++++++++++++++++------- 1 file changed, 98 insertions(+), 34 deletions(-) diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index ca6ae61a4c97..f09785679e37 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -53,12 +53,15 @@ _get_module, _is_custom_module_lstm, get_custom_module_class_keys, - get_quantize_node_info, create_getattr_from_value, collect_producer_nodes, graph_module_from_producer_nodes, node_arg_is_weight, ) +from torch.ao.quantization.utils import ( + is_per_channel, + to_underlying_dtype, +) from torch.ao.quantization.quantize import ( _remove_qconfig, ) @@ -107,51 +110,103 @@ def _replace_observer_with_quantize_dequantize_node( assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) - observer_module = modules[node.target] - maybe_quantize_node_info = get_quantize_node_info(observer_module, is_decomposed) + activation_post_process = modules[node.target] # Skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, node_name_to_qconfig) for n in list(node.args) + list(node.users.keys())]) - if skip_replacement or maybe_quantize_node_info is None: - # didn't find correponding quantize op and info for the observer_module + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find correponding quantize op and info for the activation_post_process # so we just remove the observer with graph.inserting_before(node): node.replace_all_uses_with(node.args[0]) graph.erase_node(node) - else: - # otherwise, we can convert the observer moduel call to quantize/dequantize node - node_type, quantize_op, qparams = maybe_quantize_node_info - # replace observer node with quant - dequant node - with graph.inserting_before(node): - input_node = node.args[0] - quantize_op_inputs = [input_node] - for key, value in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ['_scale_', '_zero_point_']: - # For scale and zero_point values we register them as buffers in the root module. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) - quantize_op_inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. - quantize_op_inputs.append(value) + return - quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # otherwise, we can convert the observer module call to quantize/dequantize node + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + dtype = activation_post_process.dtype # type: ignore[attr-defined] + compute_dtype = None + if hasattr(activation_post_process, "compute_dtype"): + compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] + quantize_op : Optional[Union[Callable, str]] = None + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ + not hasattr(activation_post_process, 'compute_dtype'): + node_type = "call_function" + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined] + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} if is_decomposed: - # use the same qparams from quantize op - dq_inputs = [quantized_node] + quantize_op_inputs[1:] - dequantized_node = graph.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor, - tuple(dq_inputs), - {} - ) + raise NotImplementedError("decomposed quantize_per_channel op not implemented yet") else: - dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) - node.replace_all_uses_with(dequantized_node) - graph.erase_node(node) + quantize_op = torch.quantize_per_channel + else: + scale = float(scale) + zero_point = int(zero_point) + if is_decomposed: + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min": quant_min, + "_quant_max": quant_max, + "_dtype_": dtype + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor + else: + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: + # TODO(future PR): switch compute_dtype to is_dynamic + # dynamic quantization + node_type = "call_function" + if is_decomposed: + raise NotImplementedError("decomposed quantize_per_tensor_dynamic op not implemented yet") + else: + quantize_op = torch.quantize_per_tensor_dynamic + # TODO: get reduce range from observer + # reduce_range = activation_post_process.reduce_range + reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") + qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} + elif dtype == torch.float16: + node_type = "call_method" + quantize_op = "to" + qparams = {"_dtype_": dtype} + + # 2. replace observer node with quant - dequant node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + if is_decomposed: + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantized_node = graph.call_function( + torch.ops.quantized_decomposed.dequantize_per_tensor, + tuple(dq_inputs), + {} + ) + else: + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) # this is a temporary hack for custom module, we may want to implement # this properly after the custom module class design is finalized @@ -166,6 +221,15 @@ def _replace_observer_or_dequant_stub_with_dequantize_node(node: Node, graph: Gr graph.erase_node(node) insert_dequantize_node(call_custom_module_node, graph) +def _is_conversion_supported(activation_post_process: torch.nn.Module) -> bool: + dtype = activation_post_process.dtype # type: ignore[attr-defined] + compute_dtype = None + if hasattr(activation_post_process, "compute_dtype"): + compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] + return (dtype in [torch.quint8, torch.qint8, torch.qint32] and compute_dtype is None) or \ + compute_dtype in [torch.quint8, torch.qint8, torch.float16] or \ + dtype == torch.float16 + def restore_state( observed: torch.nn.Module ) -> Tuple[Dict[str, Tuple[str, type]], From 8a419cbffb939ef00ce723bbdf5bf1b8c62a7d74 Mon Sep 17 00:00:00 2001 From: Horace He Date: Fri, 18 Nov 2022 10:56:03 +0000 Subject: [PATCH 348/453] Added partial decomposition of conv_backward and grad_bias computation (#89128) `convolution_backward` often just kicks off the `sum` as a separate kernel. Splitting it off in a decomp allows us to fuse it into other ops: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L2150 Improves `convnext_base` from 373 img/s => 383 img/s Not sure what other models use convolution with bias haha. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89128 Approved by: https://github.com/ezyang --- test/inductor/test_torchinductor.py | 72 +++++++++++++++++++++++++++++ torch/_inductor/decomposition.py | 33 +++++++++++++ 2 files changed, 105 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f2b1caeb32ea..fc0ae82a2598 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4101,6 +4101,78 @@ def fn(x): rtol=0.5, ) + def test_conv_backward(self): + def fn(rank4_inps, rank3_inps, rank5_inps): + + out1 = aten.convolution_backward( + *rank4_inps, + [C], + [1, 1], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + [True, True, True], + ) + out2 = aten.convolution_backward( + *rank4_inps, + [C], + [1, 1], + [0, 0], + [1, 1], + False, + [0, 0], + 1, + [True, False, False], + ) + out3 = aten.convolution_backward( + *rank3_inps, + [C], + [1], + [0], + [1], + False, + [0], + 1, + [True, True, True], + ) + out4 = aten.convolution_backward( + *rank5_inps, + [C], + [1, 1, 1], + [0, 0, 0], + [1, 1, 1], + False, + [0, 0, 0], + 1, + [True, True, True], + ) + return (out1, out2, out3, out4) + + B = 3 + C = 4 + H = 5 + grad_out = torch.randn(B, C, H - 2, H - 2, H - 2) + inp = torch.randn(B, C, H, H, H) + weight = torch.randn(C, C, 3, 3, 3) + + def shrink_rank(x, rank): + res = x + while res.dim() > rank: + res = torch.select(res, -1, 0) + return res.contiguous() + + rank4_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] + rank3_inps = [shrink_rank(x, 4) for x in [grad_out, inp, weight]] + rank5_inps = [shrink_rank(x, 5) for x in [grad_out, inp, weight]] + + with torch.backends.cudnn.flags(allow_tf32=False): + self.common( + fn, + [rank4_inps, rank3_inps, rank5_inps], + ) + @unittest.skip( """ FIXME: In the case of having equally max/min elements, our implementation returns diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3254f174b495..09ee53579345 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -317,6 +317,39 @@ def bmm_decomp(mat1, mat2): return NotImplemented # go directly to lowering +@register_decomposition([aten.convolution_backward]) +def convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, +): + if not output_mask[2] or grad_output.device.type != "cuda": + return NotImplemented + grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim()))) + grad_inp, grad_weight, _ = aten.convolution_backward( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + [output_mask[0], output_mask[1], False], + ) + return (grad_inp, grad_weight, grad_bias) + + @register_decomposition([aten.rsqrt]) def rsqrt(x): return torch.reciprocal(torch.sqrt(x)) From 81a4aeabdf9d550ceda52a5060f19568de61b265 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 18 Nov 2022 18:43:15 +0000 Subject: [PATCH 349/453] [Dynamo] Support Tensor.nelement & torch.cuda.is_available (#89164) Fix several errors in [7k github models](https://github.com/pytorch/torchdynamo/issues/1198). Pull Request resolved: https://github.com/pytorch/pytorch/pull/89164 Approved by: https://github.com/soumith --- test/dynamo/test_misc.py | 17 +++++++++++++++-- torch/_dynamo/variables/tensor.py | 2 +- torch/_dynamo/variables/torch.py | 1 + 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 2825b157bc68..e3274738fc21 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -419,10 +419,10 @@ def fn(x): def test_numel(self): def fn(a): - return a + a.numel() + torch.numel(a) + return (a + a.numel() + torch.numel(a), a + a.nelement()) return torch._dynamo.testing.standard_test( - self, fn=fn, nargs=1, expected_ops=2, expected_ops_dynamic=4 + self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6 ) def test_pair(self): @@ -2963,6 +2963,19 @@ def forward(self, x): res = opt_model(x) self.assertTrue(same(ref, res)) + def test_torch_cuda_is_available(self): + def fn(x): + if torch.cuda.is_available(): + return x + 1 + else: + return x - 1 + + x = torch.rand(4) + ref = fn(x) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + class CustomFunc(torch.autograd.Function): @staticmethod diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 8867f7e6cc93..ab94aaf537d2 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -203,7 +203,7 @@ def call_method( ), **options, ) - elif name == "numel" and self.size is not None: + elif name in ("numel", "nelement") and self.size is not None: constant_result = ConstantVariable(product(self.size), **options) elif name in ("ndimension", "dim") and self.ndim is not None: constant_result = ConstantVariable(self.ndim, **options) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 56e74503faca..651f80b5d77d 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -163,6 +163,7 @@ def can_constant_fold_through(self): torch.finfo, torch.iinfo, torch.is_floating_point, + torch.cuda.is_available, ): return True return getattr(self.value, "__module__", None) == "math" From 7ec8a4d2a26f717d0a4073e6005f9edfdd7ab641 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 18 Nov 2022 18:46:50 +0000 Subject: [PATCH 350/453] Vectorized horizontal flip implementation (#88989) When we benchmarked image processing transforms in torchvision : tensor vs pillow we saw that horizontal flip on uint8 data `(3, X, X)` is 2-3x slower. Due to the fact that output's first stride is negative, implementation does a simple data copy using [`basic_loop`](https://github.com/pytorch/pytorch/blob/8371bb8a3dddbead709bc1e9d26715818a34fa8a/aten/src/ATen/native/cpu/Loops.h#L286). In this PR, a vectorized path is added for horizontal flip op for dtypes: uint8, int, float32, long and double and there is a speed-up that reduces the gap between PIL and tensor ops ``` CPU capability usage: AVX2 [----------------------------------------------------------------- Horizontal flip -----------------------------------------------------------------] | torch (1.14.0a0+git2ed1d29) PR | Pillow (9.3.0) | torch (1.14.0.dev20221116+cu116) nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------ channels=3, size=256, dtype=torch.int64 | 101.307 (+-0.904) | | 111.364 (+-0.328) channels=3, size=520, dtype=torch.int64 | 462.369 (+-2.184) | | 505.602 (+-0.541) channels=3, size=712, dtype=torch.int64 | 1855.441 (+-6.528) | | 1828.370 (+-8.600) channels=1, size=256, dtype=torch.int32 | 22.282 (+-0.130) | 44.218 (+-0.936) | 34.651 (+-0.162) channels=1, size=520, dtype=torch.int32 | 72.180 (+-0.076) | 166.639 (+-1.180) | 118.820 (+-0.210) channels=1, size=712, dtype=torch.int32 | 129.621 (+-0.649) | 307.140 (+-2.221) | 216.104 (+-0.793) channels=3, size=256, dtype=torch.uint8 | 51.685 (+-0.200) | 44.171 (+-0.818) | 361.611 (+-0.276) channels=3, size=520, dtype=torch.uint8 | 223.320 (+-0.726) | 166.607 (+-2.256) | 1462.012 (+-4.917) channels=3, size=712, dtype=torch.uint8 | 423.298 (+-1.156) | 307.067 (+-1.999) | 2738.481 (+-1.715) channels=1, size=256, dtype=torch.float32 | 22.281 (+-0.056) | 44.149 (+-0.808) | 35.316 (+-0.028) channels=1, size=520, dtype=torch.float32 | 72.268 (+-0.106) | 166.631 (+-1.212) | 119.504 (+-0.340) channels=1, size=712, dtype=torch.float32 | 129.777 (+-0.632) | 307.078 (+-1.909) | 216.987 (+-0.185) channels=1, size=256, dtype=torch.float16 | 32.789 (+-0.081) | | 34.044 (+-0.039) channels=1, size=520, dtype=torch.float16 | 112.693 (+-0.478) | | 117.445 (+-0.125) channels=1, size=712, dtype=torch.float16 | 203.644 (+-0.791) | | 213.283 (+-0.397) channels=3, size=256, dtype=torch.float64 | 102.058 (+-0.333) | | 108.404 (+-0.346) channels=3, size=520, dtype=torch.float64 | 473.139 (+-1.327) | | 503.265 (+-0.365) channels=3, size=712, dtype=torch.float64 | 1854.489 (+-9.513) | | 1844.345 (+-1.371) channels=1, size=256, dtype=torch.int16 | 11.927 (+-0.056) | | 33.993 (+-0.037) channels=1, size=520, dtype=torch.int16 | 39.724 (+-0.148) | | 117.577 (+-0.153) channels=1, size=712, dtype=torch.int16 | 68.264 (+-0.133) | | 213.118 (+-0.157) Times are in microseconds (us). ``` ``` CPU capability usage: AVX512 [----------------------------------------------------------------- Horizontal flip ------------------------------------------------------------------] | torch (1.14.0a0+git2ed1d29) PR | Pillow (9.3.0) | torch (1.14.0.dev20221118+cu116) nightly 1 threads: ------------------------------------------------------------------------------------------------------------------------------------------- channels=3, size=256, dtype=torch.int64 | 131.244 (+-1.954) | | 135.649 (+-4.066) channels=3, size=520, dtype=torch.int64 | 522.032 (+-4.660) | | 539.822 (+-10.420) channels=3, size=712, dtype=torch.int64 | 1041.111 (+-53.575) | | 1322.411 (+-80.017) channels=1, size=256, dtype=torch.int32 | 10.108 (+-0.414) | 49.164 (+-1.000) | 34.606 (+-0.865) channels=1, size=520, dtype=torch.int32 | 93.218 (+-1.417) | 191.985 (+-5.047) | 133.664 (+-5.372) channels=1, size=712, dtype=torch.int32 | 167.919 (+-2.854) | 353.574 (+-6.568) | 246.162 (+-5.753) channels=3, size=256, dtype=torch.uint8 | 34.710 (+-0.541) | 49.005 (+-0.923) | 136.603 (+-2.339) channels=3, size=520, dtype=torch.uint8 | 154.873 (+-3.049) | 191.729 (+-4.997) | 534.329 (+-10.754) channels=3, size=712, dtype=torch.uint8 | 290.319 (+-4.819) | 351.619 (+-6.978) | 997.119 (+-33.086) channels=1, size=256, dtype=torch.float32 | 10.345 (+-0.338) | 49.105 (+-0.942) | 35.478 (+-0.733) channels=1, size=520, dtype=torch.float32 | 81.131 (+-5.281) | 191.697 (+-4.555) | 133.554 (+-4.193) channels=1, size=712, dtype=torch.float32 | 169.581 (+-3.476) | 352.995 (+-10.792) | 251.089 (+-7.485) channels=1, size=256, dtype=torch.float16 | 35.259 (+-0.612) | | 35.154 (+-0.924) channels=1, size=520, dtype=torch.float16 | 132.407 (+-1.980) | | 131.850 (+-5.611) channels=1, size=712, dtype=torch.float16 | 240.192 (+-5.479) | | 239.555 (+-7.273) channels=3, size=256, dtype=torch.float64 | 129.649 (+-2.349) | | 130.429 (+-6.240) channels=3, size=520, dtype=torch.float64 | 548.534 (+-5.179) | | 622.568 (+-25.720) channels=3, size=712, dtype=torch.float64 | 1208.091 (+-77.095) | | 1679.204 (+-316.292) channels=1, size=256, dtype=torch.int16 | 7.801 (+-0.115) | | 34.517 (+-0.482) channels=1, size=520, dtype=torch.int16 | 36.010 (+-0.855) | | 131.001 (+-1.686) channels=1, size=712, dtype=torch.int16 | 87.395 (+-1.355) | | 237.731 (+-4.181) Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/c0421f54c8aed655b042dd1ce4cb621e) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88989 Approved by: https://github.com/lezcano, https://github.com/datumbox, https://github.com/peterbell10, https://github.com/ngimel --- aten/src/ATen/cpu/vec/vec256/vec256.h | 45 ++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512.h | 50 +++++++++++++ aten/src/ATen/cpu/vec/vec_base.h | 12 ++++ aten/src/ATen/native/cpu/IndexKernel.cpp | 92 ++++++++++++++++++++++++ 4 files changed, 199 insertions(+) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 98ec588137ce..d0a8cb03604a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -222,6 +222,51 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_ps(v, mask_float); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template<> +inline Vectorized flip(const Vectorized & v) { + return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3) +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7); + return _mm256_permutevar8x32_epi32(v, mask_int32); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask = _mm256_set_epi8( + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14, + 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 + ); + auto reversed = _mm256_shuffle_epi8(v, mask); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m256i mask_int8 = _mm256_set_epi8( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + auto reversed = _mm256_shuffle_epi8(v, mask_int8); + return _mm256_permute2x128_si256(reversed, reversed, 1); +} + + #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h index 0c6f33fa08a0..dd1235e82ece 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512.h @@ -190,6 +190,56 @@ inline deinterleave2(const Vectorized& a, const Vectorized& _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b)); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_ps(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_pd(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7); + return _mm512_permutexvar_epi64(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15); + return _mm512_permutexvar_epi32(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask = _mm512_set_epi16( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 + ); + return _mm512_permutexvar_epi16(mask, v); +} + +template<> +inline Vectorized flip(const Vectorized & v) { + const __m512i mask1 = _mm512_set_epi8( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6); + auto reversed_vec = _mm512_shuffle_epi8(v, mask1); + return _mm512_permutexvar_epi64(mask2, reversed_vec); +} + #endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER) }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index f045437ac368..e9e87fa605f7 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1001,4 +1001,16 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { } } +template +inline Vectorized flip(const Vectorized & data) { + static constexpr int size = Vectorized::size(); + T output[size]; + T buffer[size]; + data.store(static_cast(buffer)); + for (const auto i : c10::irange(size)) { + output[i] = buffer[size - i - 1]; + } + return Vectorized::loadu(static_cast(output)); +} + }}} diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index be0dc3301a00..81e135d1e749 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -457,6 +457,75 @@ void masked_select_kernel(TensorIterator& iter, int64_t result_stride) { }); } + +template +void cpu_hflip_vec(at::TensorIterator& iter) { + + auto loop2d = [&](char** base, const int64_t *strides, int64_t size0, int64_t size1) { + + static constexpr int ntensors = 3; + std::array data_arr; + std::copy_n(base, ntensors, data_arr.data()); + const int64_t *outer_strides = &strides[ntensors]; + + using Vec = Vectorized; + + constexpr auto stride = sizeof(scalar_t); + TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]); + + for (const auto j C10_UNUSED : c10::irange(size1)) { + + // vectorized loop with negative stride for output + char** C10_RESTRICT data_ = data_arr.data(); + int64_t n = size0; + + char* C10_RESTRICT data[ntensors]; + for (const auto arg : c10::irange(ntensors)) { + data[arg] = data_[arg]; + } + + int64_t i = 0; + + // data[0] unaligned pre-pass + int64_t offset = (j * n + (n - i - Vec::size())) % 32; + offset = (offset >= n) ? n : offset; + for (; i < offset; i++) { + scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); + *out_ptr = *(scalar_t *)(data[1] + i * stride); + } + // Empirically found that it is faster to process 3 data items together vs 2 or 4 + for (; i <= n - 3 * Vec::size(); i += 3 * Vec::size()) { + auto out1 = Vec::loadu(data[1] + i * stride); + auto out2 = Vec::loadu(data[1] + (i + Vec::size()) * stride); + auto out3 = Vec::loadu(data[1] + (i + 2 * Vec::size()) * stride); + // flip the vector: 1234 -> 4321 + out1 = flip(out1); + out2 = flip(out2); + out3 = flip(out3); + out1.store(data[0] - (i + Vec::size() - 1) * stride); + out2.store(data[0] - (i + 2 * Vec::size() - 1) * stride); + out3.store(data[0] - (i + 3 * Vec::size() - 1) * stride); + } + if (i < n) { + for (; i < n; i++) { + scalar_t* out_ptr = (scalar_t*)(data[0] - i * stride); + *out_ptr = *(scalar_t *)(data[1] + i * stride); + } + } + + // advance: + for (const auto arg : c10::irange(data_arr.size())) { + data_arr[arg] += outer_strides[arg]; + } + } + }; + + int64_t grain_size = at::internal::GRAIN_SIZE; + iter.for_each(loop2d, grain_size); + iter.cast_outputs(); +} + + void flip_kernel(TensorIterator& iter, const bool quantized) { if (quantized) { AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(iter.dtype(), "flip_quantized_cpu", @@ -466,6 +535,29 @@ void flip_kernel(TensorIterator& iter, const bool quantized) { }); }); } else { + // Special case: horizontal flip with vectorization and input is contiguous + // Context: horizontal flip leads to strides[0] < 0 and + // thus is_contiguous condition is not satisfied and non-vectorized code path is taken. + auto output_strides = iter.strides(0); + auto input_strides = iter.strides(1); + if (iter.ndim() > 0 && output_strides[0] < 0 && input_strides[0] == iter.element_size(1)) { + auto iter_dtype = iter.dtype(); + if (iter_dtype == kByte) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kFloat) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kInt) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kShort) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kLong) { + return cpu_hflip_vec(iter); + } else if (iter_dtype == kDouble) { + return cpu_hflip_vec(iter); + } + // other dtypes are handled below with cpu_kernel_vec + } + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kHalf, kBFloat16, iter.dtype(), "flip_cpu", [&iter] { cpu_kernel_vec(iter, [](scalar_t a, scalar_t /*dummy input*/) -> scalar_t { From ee2ce3fef6d6bd073eb31303808618db88cec2e1 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Fri, 18 Nov 2022 18:55:33 +0000 Subject: [PATCH 351/453] Set make max load when building libtorch (#89237) The nccl build is still OOM sometimes when using `$(MAKE)`: ``` virtual memory exhausted: Cannot allocate memory Makefile:73: recipe for target '/var/lib/jenkins/cpp-build/caffe2/build/nccl/obj/collectives/device/devlink.o' failed make[5]: *** [/var/lib/jenkins/cpp-build/caffe2/build/nccl/obj/collectives/device/devlink.o] Error 1 make[5]: Leaving directory '/var/lib/jenkins/workspace/third_party/nccl/nccl/src/collectives/device' ``` * https://github.com/pytorch/pytorch/actions/runs/3476485191/jobs/5811758058 * https://github.com/pytorch/pytorch/actions/runs/3422228421/jobs/5702153639 So trying to set the same limit here as when building with ninja Pull Request resolved: https://github.com/pytorch/pytorch/pull/89237 Approved by: https://github.com/malfet --- cmake/External/nccl.cmake | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/cmake/External/nccl.cmake b/cmake/External/nccl.cmake index cb928baf3a59..160d2b648c05 100644 --- a/cmake/External/nccl.cmake +++ b/cmake/External/nccl.cmake @@ -15,23 +15,24 @@ if(NOT __NCCL_INCLUDED) # this second replacement is needed when there are multiple archs string(REPLACE ";-gencode" " -gencode" NVCC_GENCODE "${NVCC_GENCODE}") - if("${CMAKE_GENERATOR}" MATCHES "Make") - # Recursive make with jobserver for parallelism - set(MAKE_COMMAND "$(MAKE)") + if(DEFINED ENV{MAX_JOBS}) + set(MAX_JOBS "$ENV{MAX_JOBS}") else() - if(DEFINED ENV{MAX_JOBS}) - set(MAX_JOBS "$ENV{MAX_JOBS}") - else() - include(ProcessorCount) - ProcessorCount(NUM_HARDWARE_THREADS) - # Assume 2 hardware threads per cpu core - math(EXPR MAX_JOBS "${NUM_HARDWARE_THREADS} / 2") - # ProcessorCount might return 0, set to a positive number - if(MAX_JOBS LESS 2) - set(MAX_JOBS 2) - endif() + include(ProcessorCount) + ProcessorCount(NUM_HARDWARE_THREADS) + # Assume 2 hardware threads per cpu core + math(EXPR MAX_JOBS "${NUM_HARDWARE_THREADS} / 2") + # ProcessorCount might return 0, set to a positive number + if(MAX_JOBS LESS 2) + set(MAX_JOBS 2) endif() + endif() + if("${CMAKE_GENERATOR}" MATCHES "Make") + # Recursive make with jobserver for parallelism, and also put a load limit + # here to avoid flaky OOM, https://www.gnu.org/software/make/manual/html_node/Parallel.html + set(MAKE_COMMAND "$(MAKE)" "-l${MAX_JOBS}") + else() # Parallel build with CPU load limit to avoid oversubscription set(MAKE_COMMAND "make" "-j${MAX_JOBS}" "-l${MAX_JOBS}") endif() From 837ca8f344380f2356b01662f215ff561b09401f Mon Sep 17 00:00:00 2001 From: Zain Rizvi Date: Fri, 18 Nov 2022 19:36:09 +0000 Subject: [PATCH 352/453] Remove --retry-all-errors from environment with old curl (#89298) The version of curl on the `ubuntu-latest` box doesn't support the `--retry-all-errors` param and is breaking periodic builds Example: https://github.com/pytorch/pytorch/actions/runs/3495466804/jobs/5852265880 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89298 Approved by: https://github.com/huydhn --- scripts/buck_setup.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/buck_setup.sh b/scripts/buck_setup.sh index 331a29956416..f6152537435c 100644 --- a/scripts/buck_setup.sh +++ b/scripts/buck_setup.sh @@ -22,16 +22,16 @@ python3 generate-xnnpack-wrappers.py # bazel-skylib printf "\nDownloading bazel-skylib\n" rm -rf bazel-skylib; mkdir bazel-skylib -curl --retry 3 --retry-all-errors -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib +curl --retry 3 -L $PROXY https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz|tar zx -C bazel-skylib # glog printf "\nDownloading glog\n" rm -rf glog; mkdir glog -curl --retry 3 --retry-all-errors -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 +curl --retry 3 -L $PROXY https://github.com/google/glog/archive/v0.4.0.tar.gz | tar zx -C glog --strip-components 1 # ruy printf "\nDownloading ruy\n" -curl --retry 3 --retry-all-errors -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip +curl --retry 3 -L $PROXY -o /tmp/ruy.zip https://github.com/google/ruy/archive/a09683b8da7164b9c5704f88aef2dc65aa583e5d.zip unzip -q /tmp/ruy.zip -d /tmp/ rm -rf ruy/ mv /tmp/ruy-a09683b8da7164b9c5704f88aef2dc65aa583e5d ruy/ From e04dc35a6a1d1447f6e067db5f29f88adff91acf Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 18 Nov 2022 06:59:20 -0800 Subject: [PATCH 353/453] Symintify obeys_layout_contract (#89138) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89138 Approved by: https://github.com/bdhirsh --- torch/csrc/autograd/utils/grad_layout_contract.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/csrc/autograd/utils/grad_layout_contract.h b/torch/csrc/autograd/utils/grad_layout_contract.h index 2addde79c8ec..37dda0f9acaa 100644 --- a/torch/csrc/autograd/utils/grad_layout_contract.h +++ b/torch/csrc/autograd/utils/grad_layout_contract.h @@ -28,9 +28,9 @@ inline bool obeys_layout_contract( return false; } else if (variable.is_non_overlapping_and_dense()) { // Only look at stride for dimensions that are not of size 1. - const auto& grad_sizes = grad.sizes(); - const auto& grad_strides = grad.strides(); - const auto& variable_strides = variable.strides(); + const auto& grad_sizes = grad.sym_sizes(); + const auto& grad_strides = grad.sym_strides(); + const auto& variable_strides = variable.sym_strides(); for (const auto idx : c10::irange(grad_sizes.size())) { if (grad_sizes[idx] != 1) { if (grad_strides[idx] != variable_strides[idx]) { From ba605c3b0439fd5dfe062f42e60b990c88c061d4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 18 Nov 2022 06:59:21 -0800 Subject: [PATCH 354/453] Don't trace when we track_tensor_tree (#89139) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89139 Approved by: https://github.com/bdhirsh --- torch/fx/experimental/proxy_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 8a51294c5a8f..c3a5d706e3cc 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -426,7 +426,9 @@ def wrap_key(f, tensors, tracer): def wrapped(*proxies): flat_proxies, proxies_spec = pytree.tree_flatten(proxies) assert len(flat_proxies) == len(flat_tensors) - track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + assert isinstance(_get_current_dispatch_mode(), ProxyTorchDispatchMode) + with _pop_mode_temporarily(): + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) out = f(*tensors) out = pytree.tree_map_only( From 304b5de1b01213b18947ffcb6f5782f89fcd0b2e Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 17 Nov 2022 09:58:29 -0800 Subject: [PATCH 355/453] Re-enable test_hf_bert_fsdp (#89223) It looks like this failure was actually caused by https://github.com/pytorch/pytorch/pull/88629, see the revert message on that PR. It probably just looked like a flaky test on CI because of how quickly the PR was reverted. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89223 Approved by: https://github.com/voznesenskym --- test/distributed/test_dynamo_distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 21550a0120e4..b6bc16edb941 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -258,8 +258,6 @@ def test_fsdp_inductor(self): # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert @patch.object(torch._inductor.config.triton, "cudagraphs", False) @patch.object(torch._inductor.config, "fallback_random", True) - # TODO(voz): Flaky on CI failure, consistent failure on local master. - @unittest.skipIf(True, "Flaky on CI failure, consistent failure on local master") def test_hf_bert_fsdp(self): from transformers.models.bert.modeling_bert import BertLayer From bfffc8d8efc3247853d706148146a5fd62d5ef08 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Thu, 17 Nov 2022 23:06:09 +0000 Subject: [PATCH 356/453] [DDP][Docs] Add warning that `no_sync()` should include forward (#89244) The issue where the user only includes `loss.backward()` inside `no_sync()` but not the forward pass has arisen several times now. I think adding an explicit warning in the docs is worthwhile. Rendered doc: Screen Shot 2022-11-17 at 9 21 32 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/89244 Approved by: https://github.com/zhaojuanmao --- torch/nn/parallel/distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 47eb6bb2ebf1..b6673874eecc 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1001,6 +1001,10 @@ def no_sync(self): >>> for input in inputs: >>> ddp(input).backward() # no synchronization, accumulate grads >>> ddp(another_input).backward() # synchronize grads + + .. warning:: + The forward pass should be included inside the context manager, or + else gradients will still be synchronized. """ old_require_backward_grad_sync = self.require_backward_grad_sync self.require_backward_grad_sync = False From 35d5fc52f01f0314ab1bf1555ea27d6fedbb7d98 Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Thu, 17 Nov 2022 13:33:39 -0800 Subject: [PATCH 357/453] [Profiler] Don't raise SOFT_ASSERT in debug builds. (#89240) Enough people are hitting this issue that we need to turn off hard failures until the fire rate is zero in steady state. (via scuba logging.) Differential Revision: [D41382914](https://our.internmc.facebook.com/intern/diff/D41382914/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89240 Approved by: https://github.com/aaronenyeshi --- torch/csrc/profiler/util.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index f4fb4dd1eee1..6833e8abef70 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -94,13 +94,7 @@ void setSoftAssertRaises(c10::optional value) { } bool softAssertRaises() { - return soft_assert_raises_.value_or( -#ifdef NDEBUG - false -#else - true -#endif - ); + return soft_assert_raises_.value_or(false); } // ---------------------------------------------------------------------------- From 7551136b81251fef0505a935ab614a44dd355479 Mon Sep 17 00:00:00 2001 From: Bryce Long Date: Fri, 18 Nov 2022 22:36:05 +0000 Subject: [PATCH 358/453] Add NVTX markers that dump additional information for nvprim_nvfuser Dynamo graphs (#88259) dump information on graphs that NVFuser JIT compiles: - the markers show the list of ops, args, and inputs that make up the graph also dumps information on FX nodes that are not touched by NVFuser: - the markers show the op, name, and arg list of the node Pull Request resolved: https://github.com/pytorch/pytorch/pull/88259 Approved by: https://github.com/IvanYashchuk, https://github.com/jjsjann123, https://github.com/mruberry --- torch/_prims/nvfuser_executor.py | 47 ++++++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py index 0f4e7b49fa27..a155433231e1 100644 --- a/torch/_prims/nvfuser_executor.py +++ b/torch/_prims/nvfuser_executor.py @@ -28,6 +28,14 @@ else: DataType = None +import os + + +@lru_cache(None) +def get_nvprim_dump_nvtx(): + return os.getenv("PYTORCH_NVFUSER_DUMP_NVTX") + + DEFAULT_NVFUSER_PYTHON_CONFIG = MappingProxyType( { "use_python_fusion_cache": True, @@ -247,10 +255,30 @@ def nvfuser_execute(gm: GraphModule, *args, executor_parameters=None): arg for arg in flat_args if isinstance(arg, (torch.Tensor, Number)) ) - return tree_unflatten( + if get_nvprim_dump_nvtx(): + torch.cuda.nvtx.range_push( + "fusion: {0}, graph: {1}".format( + fusion.id(), + str( + [ + { + "op": n.op, + "name": n.name, + "args": n.args, + "kwargs": n.kwargs, + } + for n in gm.graph.nodes + ] + ), + ) + ) + result = tree_unflatten( fusion.execute(concrete_fusion_inputs), # type: ignore[has-type] unflatten_spec, # type: ignore[has-type] ) + if get_nvprim_dump_nvtx(): + torch.cuda.nvtx.range_pop() + return result else: warn( "nvfuser_executor is executed with non-cuda args, fallback to aten executor" @@ -421,6 +449,18 @@ def maybe_partition_graph( return gm, any_unsupported +class NVTXInterpreter(torch.fx.Interpreter): + def run_node(self, n): + torch.cuda.nvtx.range_push( + "name: {0}, args: {1}, op: {2}, kwargs: {3}".format( + n.name, n.args, n.op, n.kwargs + ) + ) + result = super().run_node(n) + torch.cuda.nvtx.range_pop() + return result + + def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None): executor_parameters = executor_parameters or DEFAULT_NVFUSER_PYTHON_CONFIG # maybe_partition_graph function is cached so we can't use non-hashable arguments @@ -440,6 +480,9 @@ def nvfuser_execute_partitioned(gm: GraphModule, *args, executor_parameters=None use_python_fusion_cache=use_python_fusion_cache, ) if is_partitioned: - return gm(*args) + if get_nvprim_dump_nvtx(): + return NVTXInterpreter(gm).run(*args) + else: + return gm(*args) else: return nvfuser_execute(gm, *args, executor_parameters=executor_parameters) From ecfb4e064ccedb42fd73d99f24cb749e05e28801 Mon Sep 17 00:00:00 2001 From: Wei Wang Date: Fri, 18 Nov 2022 23:05:50 +0000 Subject: [PATCH 359/453] [Inductor CI] Use string format for cuda-arch-list input to prevent 8.0/9.0/10.0 etc from being interpreted as 8/9/10 (#89279) Currently or in future whenever we change the cuda-arch-list to num.0, github action or some agent would pass just num to TORCH_CUDA_ARCH_LIST This num is not regex matched during cuda arch analysis phase. (here: https://github.com/pytorch/pytorch/blob/c5fafb4e1694f141d8a1a31142cce4049d9057ed/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake#L229) Example failure: https://github.com/weiwangmeta/pytorch/actions/runs/3495656108/jobs/5852735299 Unknown CUDA Architecture Name 8 in CUDA_SELECT_NVCC_ARCH_FLAGS This change reminds us to use e.g. '8.0', '9.0', '10.0' etc instead of 8.0, 9.0, 10.0 as GHA or some other agent may erroneously truncate it to pure numbers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89279 Approved by: https://github.com/desertfire, https://github.com/atalman --- .github/workflows/inductor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index eb953ff42321..9179b186e918 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -20,7 +20,7 @@ jobs: with: build-environment: linux-bionic-cuda11.6-py3.10-gcc7-sm86 docker-image-name: pytorch-linux-bionic-cuda11.6-cudnn8-py3-gcc7 - cuda-arch-list: 8.6 + cuda-arch-list: '8.6' test-matrix: | { include: [ { config: "inductor", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, From 0e1fcc8aa8790e54a85efdc81b959f46f089e3d3 Mon Sep 17 00:00:00 2001 From: Tran Le Date: Fri, 18 Nov 2022 23:19:14 +0000 Subject: [PATCH 360/453] [FX] Add type annotation to `getitem` node before `split_module` (#88510) Summary: Some nodes lost the type annotation during `split_module`, causing the submodels to be un-scriptable. This is because compiler always infer Tensor type, which is wrong for non-Tensor types. We attempt to infer type annotation for `getitem` node to improve scriptability. Test Plan: ``` buck2 test //caffe2/test:fx_experimental ``` Differential Revision: D41037819 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88510 Approved by: https://github.com/xush6528 --- torch/fx/passes/split_module.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index c6954c2cc717..0343bae94c31 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,4 +1,5 @@ import inspect +import operator from typing import Any, Callable, Dict, List, Optional import torch @@ -159,6 +160,25 @@ def record_cross_partition_use( # split nodes into parititons for node in m.graph.nodes: + # Annotations on local names within function are lost during FX transforms. + # Adding back known type annotation for getitem nodes for jit scriptability. + if node.target == operator.getitem: + sequence_node, index_node = node.args + # only support type Tuple for now + if ( + hasattr(sequence_node.type, "_name") + and sequence_node.type._name == "Tuple" + ): + parameterized_types = sequence_node.type.__args__ + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + orig_nodes[node.name] = node # TODO currently placeholders/parameters aren't put into random partitions, @@ -210,7 +230,10 @@ def record_cross_partition_use( for partition_name in sorted_partitions: partition = partitions[partition_name] for input in partition.inputs: - placeholder = partition.graph.placeholder(input) + placeholder = partition.graph.placeholder( + input, + type_expr=orig_nodes[input].type, + ) placeholder.meta = orig_nodes[input].meta.copy() partition.environment[orig_nodes[input]] = placeholder @@ -248,7 +271,11 @@ def record_cross_partition_use( assert isinstance(gathered_args, tuple) assert isinstance(gathered_kwargs, dict) new_node = partition.graph.create_node( - op=node.op, target=target, args=gathered_args, kwargs=gathered_kwargs + op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + type_expr=node.type, ) new_node.meta = node.meta.copy() partition.environment[node] = new_node From 885f8a56d445796100f3ab6f806633890662021a Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 18 Nov 2022 23:44:57 +0000 Subject: [PATCH 361/453] [BE] Print backtraces from coredumps (#89309) By simply invoking `gdb python core -ex "bt" -ex "q"` Test plan: See: [linux-focal-py3.7-gcc7 / test (default, 1, 2, linux.2xlarge)](https://github.com/pytorch/pytorch/actions/runs/3500498821/jobs/5863369649#step:14:39) Not sure why multiprocessing tests SEGFAULT, but they do Pull Request resolved: https://github.com/pytorch/pytorch/pull/89309 Approved by: https://github.com/clee2000, https://github.com/huydhn --- .github/workflows/_linux-test.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 16f25fed9121..454e558fbee4 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -192,6 +192,7 @@ jobs: -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" ) + echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}" docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}" - name: Get workflow job id @@ -216,6 +217,12 @@ jobs: with: file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} + - name: Collect backtraces from coredumps (if any) + if: always() + run: | + # shellcheck disable=SC2156 + find . -iname "core.[1-9]*" -exec docker exec "${DOCKER_CONTAINER_ID}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; + - name: Store Core dumps on S3 uses: seemethere/upload-artifact-s3@v5 if: failure() From 94b5c807fdb1fdf62bc2ab5f0161936f564b140c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Fri, 18 Nov 2022 13:14:40 -0800 Subject: [PATCH 362/453] Detach fake tensors into val, so they aren't affected by metadata mutation (#89140) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89140 Approved by: https://github.com/bdhirsh --- test/test_proxy_tensor.py | 19 +++++++++++++++---- torch/fx/experimental/proxy_tensor.py | 17 +++++++++++++++-- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index aa12b5b74d1c..e174a1483791 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -401,6 +401,19 @@ def f(x): ) ) + def test_val_metadata_mutation(self): + def f(x): + y = x.clone() + y.unsqueeze_(0) + return y + + traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) + self.assertEqual([ + tuple(node.meta['val'].shape) + for node in traced.graph.nodes + if 'val' in node.meta + ], [(3,), (3,), (1, 3)]) + def test_make_fx_overloads(self): def f(x): return x.cos() + torch.randn(x.shape) @@ -847,8 +860,7 @@ def forward(self, a_1): sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None mul = sym_size * 2; sym_size = None empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None - detach = torch.ops.aten.detach.default(empty); empty = None - return detach""") + return empty""") def test_neg_shape(self): @@ -862,8 +874,7 @@ def forward(self, a_1): neg = -sym_size; sym_size = None add = neg + 10; neg = None empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None - detach = torch.ops.aten.detach.default(empty); empty = None - return detach""") + return empty""") def test_sqrt_size(self): def f(a): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index c3a5d706e3cc..daa17f94b7bb 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -106,20 +106,33 @@ def get_proxy(obj): def has_proxy(obj): return get_proxy(obj) is not None +def snapshot_fake(val): + return val.detach() + +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) def set_meta(proxy, val): if isinstance(val, FakeTensor): - proxy.node.meta['val'] = val + proxy.node.meta['val'] = snapshot_fake(val) proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) elif isinstance(val, py_sym_types): proxy.node.meta['val'] = val elif isinstance(val, list) or isinstance(val, tuple): if all(isinstance(x, FakeTensor) for x in val): - proxy.node.meta['val'] = val + proxy.node.meta['val'] = [snapshot_fake(x) for x in val] elif isinstance(val, torch.Tensor): if not val.is_sparse: proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(val) # NB: Kinda hacky, but we should try to get val as the metadata # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) with fake_tensor_mode: proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype) From c3938bb97ab2bf0942bee2a97d30051733e839ca Mon Sep 17 00:00:00 2001 From: zhxchen17 Date: Sat, 19 Nov 2022 00:19:47 +0000 Subject: [PATCH 363/453] [functorch] introduce an experimental map() op. (#88767) Summary: We want to introduce an experimental control flow op: map() to export some models as FX graphs correctly. Some calrification on basic requirements we have in mind: 1. This op can nest cond() and other control flow primitives internally. 2. We don't necessarily need loop carried dependencies for the models we've seen. 3. This map() op can handle dynamically shaped tensor as input and return dynamically shaped output based on input shapes. 4. We should be able to pass through additional arguments to the loop body as extra arguments. In this diff we introduce a new control flow op `map()` which has the following semantics: ``` def map(f: Callable, xs: Tensor, *args): # one possible implementation: return torch.stack([f(x, *args) for x in xs]) ``` Test Plan: pytest functorch/test_control_flow.py CI Differential Revision: D41165796 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88767 Approved by: https://github.com/zou3519 --- functorch/experimental/__init__.py | 4 +- functorch/experimental/_map.py | 105 +++++++++++++++++++++ functorch/experimental/cond.py | 48 ++++++++-- functorch/experimental/control_flow.py | 1 + test/functorch/test_control_flow.py | 122 ++++++++++++++++++++++++- torch/csrc/utils/python_dispatch.cpp | 29 ++++-- 6 files changed, 290 insertions(+), 19 deletions(-) create mode 100644 functorch/experimental/_map.py create mode 100644 functorch/experimental/control_flow.py diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index ea874acafc42..3a4c92ffbe7a 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,5 @@ -from .batch_norm_replacement import replace_all_batch_norm_modules_ # PyTorch forward-mode is not mature yet -from .._src.eager_transforms import jvp, jacfwd, hessian +from .._src.eager_transforms import hessian, jacfwd, jvp from .._src.vmap import chunk_vmap +from .batch_norm_replacement import replace_all_batch_norm_modules_ from functorch import functionalize diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py new file mode 100644 index 000000000000..d681526da4b3 --- /dev/null +++ b/functorch/experimental/_map.py @@ -0,0 +1,105 @@ +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard +from torch._ops import PyOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_slot, + make_fx, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode, + _pop_mode_temporarily, +) +from torch.utils._pytree import tree_flatten + + +map = PyOperator("map") + + +def trace_map(proxy_mode, func_overload, f, xs, *args): + def _unwrap_proxy(e): + if not isinstance(e, (torch.Tensor, torch.SymInt, torch.SymFloat)): + return e + return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) + + + if not isinstance(xs, torch.Tensor): + raise ValueError("map() must loop over a tensor") + if len(xs.shape) == 0 or xs.shape[0] == 0: + raise ValueError("map() cannot be traced with scalar tensors or zero dimension tensors") + if not all(isinstance(o, (torch.Tensor, torch.nn.Module)) for o in args): + raise ValueError("map() operands must be a list of tensors or modules") + + with disable_proxy_modes_tracing(): + body_graph = make_fx(f)(xs[0], *args) + + next_name = None + i = 0 + while not next_name: + candidate = f"body_graph_{i}" + if hasattr(proxy_mode.tracer.root, candidate): + i += 1 + else: + next_name = candidate + + proxy_mode.tracer.root.register_module(next_name, body_graph) + node_args = (body_graph, xs, *args) + proxy_args = pytree.tree_map(_unwrap_proxy, node_args) + out_proxy = proxy_mode.tracer.create_proxy('call_function', func_overload, proxy_args, {}, + name="map") + outs = [body_graph(x, *args) for x in xs] + # Implementation notes: we need to use new_empty() + copy_() here instead of stack() directly + # because stack([...]) takes a fixed size list which will specialize dynamic shape here. + # Meanwhile we want to preserve the looped over dimension as symbolic shape, such that: + # ys: Tensor[s0, ...] = map(xs: Tensor[s0, ...], *args) + out = xs.new_empty([xs.shape[0], *outs[0].shape]) + out.copy_(torch.stack(outs)) + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@map.py_impl(DispatchKey.CPU) +def map_cpu(f, xs, *args): + mode = _get_current_dispatch_mode() + assert (mode is None), "Mode should never be enabled for CPU key" + return torch.stack([f(x, *args) for x in xs]) + + +@map.py_impl(DispatchKey.AutogradCPU) +def map_autograd(f, xs, *args): + # TODO: support autograd + flat_operands, _ = tree_flatten([f, xs, args]) + assert all([not f.requires_grad for f in flat_operands + if isinstance(f, torch.Tensor)]) + + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)) + return map(f, xs, *args) + + +@map.py_impl(ProxyTorchDispatchMode) +def map_proxy_torch_dispatch_mode(f, xs, *args): + mode = _get_current_dispatch_mode() + assert (mode is not None), "Mode should always be enabled for python fallback key" + with _pop_mode_temporarily() as mode: + res = trace_map(mode, map, f, xs, *args) + return res + + +@map.py_impl(FakeTensorMode) +def map_fake_tensor_mode(f, xs, *args): + return torch.stack([f(x, *args) for x in xs]) + +# We cannot directly call fallthrough here due to issue #89037. +@map.py_impl(DispatchKey.PythonDispatcher) +def map_python_dispatcher(*args): + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher)) + return map(*args) + + +# TODO(voz) Make this automatic for keys, this is very ugly atm +map.fallthrough(DispatchKey.PythonTLSSnapshot) +map.fallthrough(DispatchKey.ADInplaceOrView) +map.fallthrough(DispatchKey.BackendSelect) diff --git a/functorch/experimental/cond.py b/functorch/experimental/cond.py index e620dbadeccb..bc6f776d073f 100644 --- a/functorch/experimental/cond.py +++ b/functorch/experimental/cond.py @@ -1,19 +1,31 @@ +# TODO(zhxchen17) Expose API through functorhc.experimental.control_flow +# and rename this file to _cond.py. import torch + +import torch.utils._pytree as pytree + from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard from torch._ops import PyOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + get_isolated_graphmodule, + get_proxy_slot, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils._python_dispatch import ( + _get_current_dispatch_mode, + _pop_mode_temporarily, +) from torch.utils._pytree import tree_flatten -from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot -import torch.utils._pytree as pytree -from torch.utils._python_dispatch import _get_current_dispatch_mode, _pop_mode_temporarily -from torch.fx.experimental.proxy_tensor import track_tensor_tree -from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode """ We're going to define a `cond` operation. In order to do this, we need implementations for each of the dispatch keys. """ -cond = PyOperator('cond') +cond = PyOperator("cond") def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): @@ -115,6 +127,30 @@ def inner(pred, true_fn, false_fn, operands): return res +@cond.py_impl(FakeTensorMode) +def cond_fake_tensor_mode(pred, true_fn, false_fn, operands): + true_outs = true_fn(*operands) + flat_true_outs, _ = pytree.tree_flatten(true_outs) + flat_false_outs, _ = pytree.tree_flatten(false_fn(*operands)) + if len(flat_true_outs) != len(flat_false_outs): + raise RuntimeError("Unmatched number of outputs from cond() branches.") + + for true_out, false_out in zip(flat_true_outs, flat_false_outs): + true_meta = _extract_tensor_metadata(true_out) + false_meta = _extract_tensor_metadata(false_out) + if true_meta != false_meta: + raise RuntimeError( + f"Unmatched tensor metadata from cond() branches.\ntrue branch: {true_meta}, false branch: {false_meta}") + return true_outs + + +# We cannot directly call fallthrough here due to issue #89037. +@cond.py_impl(DispatchKey.PythonDispatcher) +def cond_python_dispatcher(*args): + _ = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.PythonDispatcher)) + return cond(*args) + + # TODO(voz): Make this automatic for keys, this is very ugly atm cond.fallthrough(DispatchKey.PythonTLSSnapshot) cond.fallthrough(DispatchKey.ADInplaceOrView) diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py new file mode 100644 index 000000000000..c46c83fd005d --- /dev/null +++ b/functorch/experimental/control_flow.py @@ -0,0 +1 @@ +from ._map import map # noqa: F401 diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 5c3cb2dd72ad..39e1967d1b27 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -1,10 +1,11 @@ # Owner(s): ["module: functorch"] import torch - -from torch.testing._internal.common_utils import TestCase, run_tests from functorch.experimental.cond import cond +from functorch.experimental import control_flow from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_utils import run_tests, TestCase + class TestControlFlow(TestCase): def test_cond_no_trace(self): def true_fn(x): @@ -345,5 +346,122 @@ def f(x, y): with self.assertRaises(AssertionError): make_fx(f, tracing_mode="fake")(x, torch.tensor(False)) + def check_map_graph(self, gm, key): + i = 0 + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.map: + i += 1 + self.assertEqual( + node.meta[key].shape[0], node.args[1].meta[key].shape[0] + ) + self.assertEqual(i, 1) + + def test_map_real(self): + def f(x, y): + return x + y + + def g(xs, y): + return control_flow.map(f, xs, y) + + gm = make_fx(g, tracing_mode="real")(torch.ones(3, 2, 2), torch.ones(2)) + x = torch.randn(3, 2, 2) + y = torch.randn(2) + res = gm(x, y) + self.assertEqual(res, g(x, y)) + self.check_map_graph(gm, "tensor_meta") + + def test_map_symbolic(self): + def f(x, y): + return x + y + + def g(xs, y): + return control_flow.map(f, xs, y) + + gm = make_fx(g, tracing_mode="symbolic")(torch.ones(3, 2, 4), torch.ones(4)) + x = torch.randn(3, 2, 2) + y = torch.randn(2) + res = gm(x, y) + self.assertEqual(res, g(x, y)) + self.check_map_graph(gm, "val") + + def test_nested_map_cond_real(self): + def true_fn(x, y): + return x * y + + def false_fn(x, y): + return x + y + + def f(x, pred, y): + return cond(pred, true_fn, false_fn, [x, y]) + + def g(pred, xs, y): + return control_flow.map(f, xs, pred, y) + + gm = make_fx(g, tracing_mode="real")( + torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) + ) + pred = torch.tensor(False) + x = torch.randn(3, 2, 2) + y = torch.randn(2) + res = gm(pred, x, y) + self.assertEqual(res, g(pred, x, y)) + self.check_map_graph(gm, "tensor_meta") + + def test_nested_map_cond_symbolic(self): + def true_fn(x, y): + return x * y + + def false_fn(x, y): + return x + y + + def f(x, pred, y): + return cond(pred, true_fn, false_fn, [x, y]) + + def g(pred, xs, y): + return control_flow.map(f, xs, pred, y) + + gm = make_fx(g, tracing_mode="symbolic")( + torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) + ) + pred = torch.tensor(False) + x = torch.randn(3, 2, 2) + y = torch.randn(2) + res = gm(pred, x, y) + self.assertEqual(res, g(pred, x, y)) + self.check_map_graph(gm, "val") + + def test_nested_cond_map_cond_symbolic(self): + + def true_fn(x, y): + return x * y + + def false_fn(x, y): + return x + y + + def f(x, pred, y): + return cond(pred, true_fn, false_fn, [x, y]) + + def g(pred, xs, y): + return control_flow.map(f, xs, pred, y) + + def main_true_fn(pred, xs, y): + return g(pred, xs, y) * 2 + + def main_false_fn(pred, xs, y): + return g(pred, xs, y) + 1 + + def main(p, pred, xs, y): + return cond(p, main_true_fn, main_false_fn, [pred, xs, y]) + + gm = make_fx(main, tracing_mode="symbolic")( + torch.tensor(True), torch.tensor(True), torch.ones(3, 2, 4), torch.ones(4) + ) + p = torch.tensor(False) + pred = torch.tensor(False) + xs = torch.randn(3, 2, 2) + y = torch.randn(2) + res = gm(p, pred, xs, y) + self.assertEqual(res, main(p, pred, xs, y)) + if __name__ == '__main__': run_tests() diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 662ab9981a1d..381e82e1fcdb 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -479,14 +479,23 @@ void initDispatchBindings(PyObject* module) { #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) - py::enum_(m, "DispatchKey") DEF_ONE(Undefined) DEF_ONE( - CompositeExplicitAutogradNonFunctional) DEF_ONE(CompositeExplicitAutograd) + py::enum_(m, "DispatchKey") + // clang-format off + DEF_ONE(Undefined) + DEF_ONE(CompositeExplicitAutogradNonFunctional) + DEF_ONE(CompositeExplicitAutograd) DEF_ONE(CompositeImplicitAutogradNestedTensor) - DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther) - DEF_ONE(Autograd) DEF_ONE(BackendSelect) DEF_ONE(ADInplaceOrView) - DEF_ONE(PythonTLSSnapshot) DEF_ONE(Python) - DEF_ONE(FuncTorchDynamicLayerFrontMode) - DEF_ONE(FuncTorchDynamicLayerBackMode) + DEF_ONE(CompositeImplicitAutograd) + DEF_ONE(AutogradOther) + DEF_ONE(Autograd) + DEF_ONE(BackendSelect) + DEF_ONE(ADInplaceOrView) + DEF_ONE(PythonTLSSnapshot) + DEF_ONE(Python) + DEF_ONE(FuncTorchDynamicLayerFrontMode) + DEF_ONE(FuncTorchDynamicLayerBackMode) + DEF_ONE(PythonDispatcher) + // clang-format on #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) #define DEF_MULTIPLE(fullname, prefix) \ @@ -495,11 +504,13 @@ void initDispatchBindings(PyObject* module) { C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ DEF_SINGLE(, EndOf##fullname##Backends) - C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) + // clang-format off + C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) + // clang-format on #undef DEF_MULTIPLE #undef DEF_SINGLE - ; + ; py::class_(m, "DispatchKeySet") .def(py::init()) From ee907375fa085fbc61bd087f7d459fd62fd1ae4f Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Sat, 19 Nov 2022 00:21:11 +0000 Subject: [PATCH 364/453] [small] Update error message (#89294) Summary: `RuntimeError: Invalid function argument. Expected parameter "tensor_list" to be of type List[torch.Tensor].` to `RuntimeError: Invalid function argument. Expected parameter "input_tensor_list" to be of type List[torch.Tensor].` Test Plan: sandcastle Differential Revision: D41405238 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89294 Approved by: https://github.com/awgu --- torch/distributed/distributed_c10d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index f46aaaef94ef..4d343bcffec3 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2517,7 +2517,7 @@ def all_gather_coalesced( if _rank_not_in_group(group): _warn_not_in_group("all_gather_coalesced") return - _check_tensor_list(input_tensor_list, "tensor_list") + _check_tensor_list(input_tensor_list, "input_tensor_list") _ensure_all_tensors_same_dtype(input_tensor_list) if not isinstance(output_tensor_lists, list): raise RuntimeError( From cad5772c2c2e2c719664765119172610eed9c590 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sat, 19 Nov 2022 00:22:43 +0000 Subject: [PATCH 365/453] =?UTF-8?q?[dashboard][huggingface]=20skip=20accur?= =?UTF-8?q?acy=20checks=20for=20really=20large=20models=E2=80=A6=20(#89273?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/pytorch/pull/89273 Approved by: https://github.com/desertfire --- benchmarks/dynamo/huggingface.py | 17 ++++++++++++++++- benchmarks/dynamo/torchbench.py | 1 + torch/_dynamo/testing.py | 2 +- torch/_dynamo/utils.py | 4 +++- 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 489fcd69df94..bf127deaa43a 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -89,6 +89,8 @@ def pip_install(package): SKIP = { + # Difficult to setup accuracy test because .eval() not supported + "Reformer", # Fails deepcopy "BlenderbotForConditionalGeneration", "GPTNeoForCausalLM", @@ -124,7 +126,7 @@ def pip_install(package): "GPT2ForSequenceClassification": 2, # "GPTJForCausalLM" : 2, # "GPTJForQuestionAnswering" : 2, - # "GPTNeoForCausalLM" : 2, + # "GPTNeoForCausalLM" : 32, # "GPTNeoForSequenceClassification" : 2, "GoogleFnet": 2, "LayoutLMForMaskedLM": 2, @@ -153,6 +155,13 @@ def pip_install(package): "YituTechConvBert": 2, } +SKIP_ACCURACY_CHECK_MODELS = { + # Models too large to have eager, dynamo and fp64_numbers simultaneosuly + # even for 40 GB machine. + "DebertaV2ForMaskedLM", + "BlenderbotForCausalLM", +} + def get_module_cls_by_model_name(model_cls_name): _module_by_model_name = { @@ -445,6 +454,12 @@ def iter_model_names(self, args): continue yield model_name + @property + def skip_accuracy_checks_large_models_dashboard(self): + if self.args.dashboard or self.args.accuracy: + return SKIP_ACCURACY_CHECK_MODELS + return set() + def pick_grad(self, name, is_training): if is_training: return torch.enable_grad() diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index cec284ebcd8c..b7d4a3be7933 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -50,6 +50,7 @@ def setup_torchbench_cwd(): # size to test the accuracy. USE_SMALL_BATCH_SIZE = { "demucs": 4, + "dlrm": 1024, "densenet121": 4, "hf_Reformer": 4, "timm_efficientdet": 1, diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index eea4c26a171c..55186931988b 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -236,7 +236,7 @@ def rand_strided(size, stride, dtype=torch.float32, device="cpu"): if dtype.is_floating_point: buffer = torch.randn(needed_size, dtype=dtype, device=device) else: - buffer = torch.ones(size=[needed_size], dtype=dtype, device=device) + buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device) return torch.as_strided(buffer, size, stride) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index f426ef691307..e4b92a73aacf 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -816,7 +816,9 @@ def same( res = res.to_dense() assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}" if exact_dtype: - assert ref.dtype == res.dtype, f"dtype mismatch {ref.dtype}, {res.dtype}" + if ref.dtype != res.dtype: + log.error(f"dtype mismatch {ref.dtype}, {res.dtype}") + return False if ref.dtype == torch.bool: # triton stores bool as int8, so add this for more accurate checking return torch.allclose( From ea58955dda6452307ce43a5beef0a466b49f1bef Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 19 Nov 2022 01:13:08 +0000 Subject: [PATCH 366/453] Move bazel to c++17 (#89297) Splitting out various smaller pieces from https://github.com/pytorch/pytorch/pull/85969 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89297 Approved by: https://github.com/huydhn --- .bazelrc | 2 +- third_party/gloo.BUILD | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.bazelrc b/.bazelrc index ce8406b58aaa..f8ff2215f2d6 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,4 +1,4 @@ -build --cxxopt=--std=c++14 +build --cxxopt=--std=c++17 build --copt=-I. # Bazel does not support including its cc_library targets as system # headers. We work around this for generated code diff --git a/third_party/gloo.BUILD b/third_party/gloo.BUILD index 3f623e54e6ad..e9deaa13fc63 100644 --- a/third_party/gloo.BUILD +++ b/third_party/gloo.BUILD @@ -75,8 +75,7 @@ cc_library( ] ) + if_cuda(glob(["gloo/cuda*.cc"])), copts = [ - "-std=gnu++11", - "-std=c++11", + "-std=c++17", ], visibility = ["//visibility:public"], deps = [":gloo_headers"] + if_cuda( From 85a87e635c677e1c6d992bb9ea21c634e8b1d58f Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Sat, 19 Nov 2022 01:47:45 +0000 Subject: [PATCH 367/453] [dynamo] mutable local caching to make dynamo faster at tracing mutation (#89170) Make mutation faster to speed up tracing optimizers, helps with https://github.com/pytorch/torchdynamo/issues/1803 `replace_all` no longer iterates over the entire variable tracker data structure every time a mutation is performed Each variable tracker internally keeps a set of contained mutable variable trackers, to provide a hint to `replace_all`. This is populated with a call to `apply` from `__post_init__` in the base `VariableTracker` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89170 Approved by: https://github.com/jansel --- torch/_dynamo/side_effects.py | 6 +-- torch/_dynamo/symbolic_convert.py | 13 +++++-- torch/_dynamo/variables/base.py | 61 +++++++++++++++++++++++++------ torch/_dynamo/variables/dicts.py | 41 +++++++++++++++++---- torch/_dynamo/variables/lists.py | 20 +++++++--- torch/_dynamo/variables/misc.py | 10 +++-- torch/_dynamo/variables/torch.py | 2 +- 7 files changed, 119 insertions(+), 34 deletions(-) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 1f8675ae1c9e..55e6e9f927e8 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -82,16 +82,16 @@ def clone(self): keepalive=list(self.keepalive), ) - def apply(self, fn, cache=None): + def apply(self, fn, cache=None, skip_fn=lambda _: False): if cache is None: cache = dict() self.id_to_variable = collections.OrderedDict( - (k, VariableTracker.apply(fn, v, cache)) + (k, VariableTracker.apply(fn, v, cache, skip_fn)) for k, v in self.id_to_variable.items() ) self.store_attr_mutations = collections.OrderedDict( - (k, VariableTracker.apply(fn, v, cache)) + (k, VariableTracker.apply(fn, v, cache, skip_fn)) for k, v in self.store_attr_mutations.items() ) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index d2bc5332719c..7a16b6b982a0 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -400,11 +400,18 @@ def repl(v: VariableTracker): return newvar return v + def skip(v: VariableTracker): + return oldvar.mutable_local not in v.recursively_contains + cache = dict() - self.output.side_effects.apply(repl, cache) - self.stack = [VariableTracker.apply(repl, x, cache) for x in self.stack] + self.output.side_effects.apply(repl, cache, skip_fn=skip) + self.stack = [ + VariableTracker.apply(repl, x, cache, skip_fn=skip) for x in self.stack + ] for k, x in self.symbolic_locals.items(): - self.symbolic_locals[k] = VariableTracker.apply(repl, x, cache) + self.symbolic_locals[k] = VariableTracker.apply( + repl, x, cache, skip_fn=skip + ) def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker): if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects): diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 62cddfff0cb2..4c5aee344061 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -21,7 +21,15 @@ def __eq__(self, other): return self is other -class VariableTracker: +# metaclass to call post_init +class HasPostInit(type): + def __call__(cls, *args, **kwargs): + obj = type.__call__(cls, *args, **kwargs) + obj.__post_init__(*args, **kwargs) + return obj + + +class VariableTracker(object, metaclass=HasPostInit): """ Base class for tracked locals and stack values @@ -70,7 +78,11 @@ def copy(cls, value): @classmethod def apply( - cls, fn: Callable[["VariableTracker"], "VariableTracker"], value, cache=None + cls, + fn: Callable[["VariableTracker"], "VariableTracker"], + value, + cache=None, + skip_fn=lambda _: False, # Whether we should skip applying to this var ): """ Walk this object and call fn on all the VariableTracker @@ -84,21 +96,29 @@ def apply( return cache[idx][0] if isinstance(value, VariableTracker): - updated_dict = dict(value.__dict__) - for key in updated_dict.keys(): - if key not in value._nonvar_fields: - updated_dict[key] = cls.apply(fn, updated_dict[key], cache) - result = fn(value.clone(**updated_dict)) + if not skip_fn(value): + updated_dict = dict(value.__dict__) + for key in updated_dict.keys(): + if key not in value._nonvar_fields: + updated_dict[key] = cls.apply( + fn, updated_dict[key], cache, skip_fn + ) + result = fn(value.clone(**updated_dict)) + else: + result = fn(value) + elif istype(value, list): - result = [cls.apply(fn, v, cache) for v in value] + result = [cls.apply(fn, v, cache, skip_fn) for v in value] elif istype(value, tuple): - result = tuple(cls.apply(fn, v, cache) for v in value) + result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value) elif istype(value, collections.OrderedDict): result = collections.OrderedDict( - cls.apply(fn, v, cache) for v in value.items() + cls.apply(fn, v, cache, skip_fn) for v in value.items() ) elif istype(value, dict): - result = {k: cls.apply(fn, v, cache) for k, v in list(value.items())} + result = { + k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) + } else: result = value @@ -244,11 +264,30 @@ def __init__( guards: Optional[Set] = None, source: Source = None, mutable_local: MutableLocal = None, + recursively_contains: Optional[Set] = None, ): super(VariableTracker, self).__init__() self.guards = guards or set() self.source = source self.mutable_local = mutable_local + self.recursively_contains = ( + recursively_contains # provides hint to replace_all when replacing vars + ) + + def __post_init__(self, *args, **kwargs): + if self.recursively_contains is None: + self.recursively_contains = set() + + def aggregate_mutables(var): + self.recursively_contains.update(var.recursively_contains) + if var.mutable_local is not None: + self.recursively_contains.add(var.mutable_local) + + return var + + VariableTracker.apply( + aggregate_mutables, self, skip_fn=lambda var: var is not self + ) def typestr(*objs): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 30df18f6d6e9..f28efc713db4 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -16,8 +16,10 @@ class ConstDictVariable(VariableTracker): - def __init__(self, items, user_cls, **kwargs): - super(ConstDictVariable, self).__init__(**kwargs) + def __init__(self, items, user_cls, recursively_contains=None, **kwargs): + super(ConstDictVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) self.items = items self.user_cls = user_cls @@ -112,7 +114,17 @@ def call_method( tx.store_dict_key(global_key_name(k), k) newval = collections.OrderedDict(val) newval[k] = args[1] - return tx.replace_all(self, self.modifed(newval, **options)) + + new_rec_contains = self.recursively_contains.union( + args[1].recursively_contains + ) + if args[1].mutable_local is not None: + new_rec_contains.add(args[1].mutable_local) + + return tx.replace_all( + self, + self.modifed(newval, new_rec_contains, **options), + ) elif ( name in ("pop", "get") and args @@ -130,7 +142,7 @@ def call_method( ): newval = collections.OrderedDict(val) result = newval.pop(ConstDictVariable.get_key(args[0])) - tx.replace_all(self, self.modifed(newval, **options)) + tx.replace_all(self, self.modifed(newval, None, **options)) return result.add_options(options) elif ( name == "update" @@ -140,7 +152,12 @@ def call_method( ): newval = collections.OrderedDict(val) newval.update(args[0].items) - result = self.modifed(newval, **options) + new_rec_contains = self.recursively_contains.union( + args[0].recursively_contains + ) + result = self.modifed( + newval, recursively_contains=new_rec_contains, **options + ) return tx.replace_all(self, result) elif ( name in ("get", "__getattr__") @@ -159,9 +176,11 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) - def modifed(self, items, **options): + def modifed(self, items, recursively_contains, **options): """a copy of self with different items""" - return self.clone(items=items, **options) + return self.clone( + items=items, recursively_contains=recursively_contains, **options + ) def unpack_var_sequence(self, tx): options = VariableTracker.propagate([self]) @@ -237,7 +256,13 @@ def call_method( f"defaultdict with default_factory = {self.default_factory}" ) new_val[k] = default_var - tx.replace_all(self, self.modifed(new_val, **options)) + new_rec_contains = self.recursively_contains.union( + default_var.recursively_contains + ) + new_rec_contains.add(default_var.mutable_local) + tx.replace_all( + self, self.modifed(new_val, new_rec_contains, **options) + ) return default_var else: return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 70c6da07adb5..553c9ca1e664 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -23,8 +23,12 @@ def cls_for(obj): tuple: TupleVariable, }[obj] - def __init__(self, items: List[VariableTracker], **kwargs): - super(BaseListVariable, self).__init__(**kwargs) + def __init__( + self, items: List[VariableTracker], recursively_contains=None, **kwargs + ): + super(BaseListVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) assert isinstance(items, list) assert all(isinstance(x, VariableTracker) for x in items) self.items: List[VariableTracker] = items @@ -145,9 +149,13 @@ def call_method( if name == "append" and self.mutable_local: assert not kwargs (arg,) = args + new_rec_contains = self.recursively_contains.union(arg.recursively_contains) + new_rec_contains.add(arg.mutable_local) tx.replace_all( self, - ListVariable(self.items + [arg], **options), + ListVariable( + self.items + [arg], recursively_contains=new_rec_contains, **options + ), ) return ConstantVariable(None) elif ( @@ -454,8 +462,10 @@ def var_getattr(self, tx, name): class ListIteratorVariable(VariableTracker): - def __init__(self, items, index: int = 0, **kwargs): - super(ListIteratorVariable, self).__init__(**kwargs) + def __init__(self, items, index: int = 0, recursively_contains=None, **kwargs): + super(ListIteratorVariable, self).__init__( + recursively_contains=recursively_contains, **kwargs + ) assert isinstance(items, list) # Removing this check as it slows things down too much # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492 diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 952cbd2c6424..f8975f70fcfb 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -107,6 +107,10 @@ def __init__(self, target_values, initial_values=None, **kwargs): super(ContextWrappingVariable, self).__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values + self.recursively_contains = ( + set() + ) # This var doesn't contain any child vars and doesn't support clone() properly, + # so don't populate this automatically def enter(self, tx): self._call_func(tx, self.target_values) @@ -294,7 +298,7 @@ def fn_name(self): class AutocastModeVariable(ContextWrappingVariable): @staticmethod - def create(tx, target_values, kwargs): + def create(target_values, kwargs): values = target_values # device_type : str, # dtype : Optional[_dtype] = None, @@ -322,10 +326,10 @@ def create(tx, target_values, kwargs): else: values.append(variables.ConstantVariable(None)) - var = AutocastModeVariable(tx, target_values, initial_values=None, **kwargs) + var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var - def __init__(self, tx, target_values, initial_values=None, **kwargs): + def __init__(self, target_values, initial_values=None, **kwargs): super(AutocastModeVariable, self).__init__( target_values=target_values, initial_values=initial_values, **kwargs ) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 651f80b5d77d..4c4681b75622 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -293,7 +293,7 @@ def call_function( tensor_with_tf_override.subclass_type, ) elif self.value is torch.amp.autocast_mode.autocast: - return AutocastModeVariable.create(tx, target_values=args, kwargs=kwargs) + return AutocastModeVariable.create(target_values=args, kwargs=kwargs) elif self.value in ( torch.profiler.profile, torch.profiler.record_function, From 2e72ec79823111e8dd8c5e82c5d1b56197cd52d3 Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Sat, 19 Nov 2022 02:06:24 +0000 Subject: [PATCH 368/453] Update sdp dispatch logic to enable fused backward (#89154) # Summary Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 52 ++--- .../cuda/NestedTensorTransformerFunctions.cpp | 100 ++++++--- .../ATen/native/transformers/attention.cpp | 65 ++++-- .../native/transformers/cuda/attention.cu | 46 ++--- .../transformers/cuda/attention_backward.cu | 40 +++- .../transformers/cuda/flash_attn/fmha_api.cpp | 7 +- .../transformers/cuda/flash_attn/fmha_api.h | 2 +- .../ATen/native/transformers/cuda/sdp_utils.h | 34 +++- benchmarks/transformer/sdp_backwards.py | 189 ++++++++++++++++++ .../check_forward_backward_compatibility.py | 3 + test/functorch/test_ops.py | 8 +- test/test_meta.py | 1 - test/test_transformers.py | 74 +++++-- tools/autograd/derivatives.yaml | 6 +- .../_internal/common_methods_invocations.py | 5 + 15 files changed, 497 insertions(+), 135 deletions(-) create mode 100644 benchmarks/transformer/sdp_backwards.py diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f625c9faff41..8c759cd09c48 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13252,18 +13252,39 @@ CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda -# Register the math kernel for cpu -- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) +- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function + +- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor) dispatch: - CUDA: _scaled_dot_product_attention_forward_cuda - CPU: _scaled_dot_product_attention_forward_math - NestedTensorCUDA: _scaled_dot_product_attention_forward_nested - NestedTensorCPU: _scaled_dot_product_attention_forward_math - Meta: _scaled_dot_product_attention_forward_math + CUDA: _scaled_dot_product_flash_attention_cuda + NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda -- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_cuda + NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + +- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_backward_cuda + +# Returns ouput, softmax_logsumexp, softmax +- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor) variants: function + dispatch: + CUDA: _flash_attention_forward + +# Returns ouput, logsumexp if compute_logsumexp +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_forward + +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function @@ -13290,21 +13311,6 @@ structured: True variants: function -- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor - variants: function - dispatch: - CUDA: flash_scaled_dot_product_attention - -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_forward - -- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_backward - - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index c2bf4e08ce04..9c72454560d3 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -214,26 +214,6 @@ Tensor NestedTensor_to_padded_tensor_cuda( return NestedTensor_to_padded_tensor_generic(t, padding, output_size); } -std::tuple _scaled_dot_product_attention_forward_nested( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - // TODO: enable flash attention kernel - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention"); - return std::make_tuple(Tensor(), Tensor()); - } -} namespace{ /** @@ -340,19 +320,80 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { } } // namespace -std::tuple mem_efficient_helper_nested_unpacked( + +std::tuple _scaled_dot_product_flash_attention_nestedtensor_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { + TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.") // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) const int64_t num_heads = query.size(1); const int64_t head_dim = query.size(3); + // Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head) + // Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + // Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + Tensor q_t = query.transpose(1, 2).contiguous(); + Tensor k_t = key.transpose(1, 2).contiguous(); + Tensor v_t = value.transpose(1, 2).contiguous(); + + // K and V have to have the same Nnz, should probably torch_check + // assume in order to not iterate over v + + auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t); + auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t); + + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); + Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k); + + const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); + const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k); + + const int64_t Nnz_q = cumulative_sequence_length_q[-1].item(); + const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item(); + + auto query_buffer_reshaped = + get_buffer(q_t).view({Nnz_q, num_heads, head_dim}); + auto key_buffer_reshaped = + get_buffer(k_t).view({Nnz_kv, num_heads, head_dim}); + auto value_buffer_reshaped = + get_buffer(v_t).view({Nnz_kv, num_heads, head_dim}); + + auto attention_and_lse_and_softmax = + at::_flash_attention_forward( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + return_softmax, + dropout_p, + is_causal); + // Reshape output to convert nnz to batch_size and seq_len + Tensor attention = std::get<0>(attention_and_lse_and_softmax); + attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2); + return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax)); +} + +std::tuple _scaled_dot_product_efficient_attention_nestedtensor_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { + // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) + // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + const int64_t num_heads = query.size(1); + const int64_t head_dim = query.size(3); + Tensor q_t = query.transpose(1, 2); Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); @@ -432,7 +473,7 @@ std::tuple mem_efficient_helper_nested_unpacked( {Nnz_kv, num_heads, head_dim}, {nnz_v_stride, head_v_stride, head_dim_stride}, value_impl->get_storage_offsets()[0]); - std::tuple attention_and_weights = + std::tuple attention_and_logsumexp= at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), @@ -440,14 +481,14 @@ std::tuple mem_efficient_helper_nested_unpacked( cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, - false, - false); + compute_log_sumexp, + is_causal); // Reshape output to convert nnz to batch_size and seq_len - Tensor attention = std::get<0>(attention_and_weights); + Tensor attention = std::get<0>(attention_and_logsumexp); attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()) .transpose(1, 2); - return std::tie(attention, std::get<1>(attention_and_weights)); + return std::tie(attention, std::get<1>(attention_and_logsumexp)); } Tensor flash_attention_helper( @@ -492,7 +533,7 @@ Tensor flash_attention_helper( // If we are passing in query, key, value all the same tensors then we have // packed them into one tensor and need to slice for flash attention Tensor attention = - at::_flash_scaled_dot_product_attention( + std::get<0>(at::_flash_attention_forward( q, k, v, @@ -500,8 +541,9 @@ Tensor flash_attention_helper( cumulative_sequence_length_q, max_seqlen_batch_q, max_seqlen_batch_q, + false /*return_softmax*/, dropout_p, - is_causal); + is_causal)); // Output of flash_attention is a regular tensor lets wrap it back up to // form a nested tensor diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 89a0e4691018..9c5be12ef24d 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -678,20 +678,6 @@ std::tuple native_decoder_only_multi_head_attent // L: Target sequence length // E: Embedding dimension std::tuple _scaled_dot_product_attention( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){ - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - } - return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); -} - -int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ - return static_cast(sdp::SDPBackend::math); -} - -std::tuple _scaled_dot_product_attention_forward_math( const Tensor& query_, const Tensor& key, const Tensor& value, @@ -699,14 +685,49 @@ std::tuple _scaled_dot_product_attention_forward_math( double dropout_p, bool need_attn_weights, bool is_causal) { - return at::_scaled_dot_product_attention_math( - query_, - key, - value, - attn_mask_, - dropout_p, - need_attn_weights, - is_causal); + // TODO: The second return is the attention weights if the math kernel is + // used. The fused kernels do not return this Tensor so for the fused kernels + // The second return SHOULD always be an empty Tensor, unless need_attn_weights + // is true (in which case the fused kernels would not be called). This blows up + // op_info tests. + int64_t choice_int = at::_fused_sdp_choice( + query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + sdp::SDPBackend backend = static_cast(choice_int); + switch (backend) { + case sdp::SDPBackend::flash_attention: { + auto out_lse_softmax = at::_scaled_dot_product_flash_attention( + query_, key, value, dropout_p, need_attn_weights, is_causal); + return std::make_tuple( + std::move(std::get<0>(out_lse_softmax)), + std::move(std::get<2>(out_lse_softmax))); + } + case sdp::SDPBackend::efficient_attention: { + bool compute_logsumexp = + (query_.requires_grad() || key.requires_grad() || + value.requires_grad()); + return at::_scaled_dot_product_efficient_attention( + query_, key, value, compute_logsumexp, is_causal); + } + case sdp::SDPBackend::math: + return at::_scaled_dot_product_attention_math( + query_, + key, + value, + attn_mask_, + dropout_p, + need_attn_weights, + is_causal); + default: + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found."); + return std::make_tuple(Tensor(), Tensor()); + } +} + +int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + return static_cast(sdp::SDPBackend::math); } std::tuple _scaled_dot_product_attention_math( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 602cf319f74a..8dcb99b3380d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -678,12 +678,12 @@ std::tuple native_multi_head_attention_cuda( return std::make_tuple(std::move(proj), std::move(qkt)); } -std::tuple flash_attention_helper_dense_unpacked( +std::tuple _scaled_dot_product_flash_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -726,8 +726,9 @@ std::tuple flash_attention_helper_dense_unpacked( Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim}); Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim}); - Tensor attention = - at::_flash_scaled_dot_product_attention( + Tensor attention, log_sumexp, softmax; + std::tie(attention, log_sumexp, softmax) = + at::_flash_attention_forward( query_reshaped, key_reshaped, value_reshaped, @@ -735,15 +736,17 @@ std::tuple flash_attention_helper_dense_unpacked( cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, + return_softmax, dropout_p, is_causal); // Reshape output to convert nnz to batch_size and seq_len attention = attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2); - return std::tuple(attention, Tensor()); + return std::make_tuple(attention, log_sumexp, softmax); } -std::tuple mem_eff_helper( + +std::tuple _scaled_dot_product_efficient_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, @@ -767,26 +770,7 @@ std::tuple mem_eff_helper( compute_log_sumexp, is_causal); attention = attention.transpose(1,2); - return std::make_tuple(std::move(attention), Tensor()); -} - -std::tuple _scaled_dot_product_attention_forward_cuda( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found."); - return std::make_tuple(Tensor(), Tensor()); - } + return std::make_tuple(std::move(attention), std::move(log_sumexp)); } int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, @@ -802,7 +786,7 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te return static_cast(backend); } -Tensor flash_scaled_dot_product_attention( +std::tuple _flash_attention_forward( const Tensor& query, const Tensor& key, const Tensor& value, @@ -810,11 +794,12 @@ Tensor flash_scaled_dot_product_attention( const Tensor& cumulative_sequence_length_k, const int64_t max_seqlen_batch_q, const int64_t max_seqlen_batch_k, + bool return_softmax, double dropout_p, bool is_causal) { #if defined(USE_FLASH_ATTENTION) auto softmax_scale = std::pow(query.size(-1), -0.5); - std::vector output = fmha::mha_fwd( + return fmha::mha_fwd( query, key, value, @@ -826,12 +811,11 @@ Tensor flash_scaled_dot_product_attention( softmax_scale, false, is_causal, - false, + return_softmax, c10::nullopt); - return output[0]; #endif TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return Tensor(); + return std::make_tuple(Tensor(), Tensor(), Tensor()); } std::tuple _efficient_attention_forward( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index af005b2669b2..a063aacb901e 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -10,6 +10,7 @@ #include #include +#include #ifdef USE_FLASH_ATTENTION #include #endif @@ -73,14 +74,14 @@ std::tuple _efficient_attention_backward( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - const at::Tensor& logsumexp, const at::Tensor& out, + const at::Tensor& logsumexp, bool causal) { #if defined(USE_FLASH_ATTENTION) if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } - // ndim + // ndim TORCH_CHECK(query.dim() == grad_out_.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); @@ -128,6 +129,7 @@ std::tuple _efficient_attention_backward( // initialized bool grad_kv_needs_init = causal && N > M; at::Tensor grad_q, grad_k, grad_v; + int8_t gQKV_strideM_multiplier = 1; if (!grad_kv_needs_init && query.size(1) == key.size(1) && query.size(3) == value.size(3) && query.storage().is_alias_of(key.storage()) && @@ -141,10 +143,13 @@ std::tuple _efficient_attention_backward( grad_q = chunk.select(2, 0); grad_k = chunk.select(2, 1); grad_v = chunk.select(2, 2); + gQKV_strideM_multiplier=3; } else { - grad_q = at::empty_like(query); - grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + grad_q = at::empty(query.sizes(), query.options()); + grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options()) + : at::empty(key.sizes(), key.options()); + grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options()) + : at::empty(value.sizes(), value.options()); } auto launchKernel = [&](auto _k, int computeCapability) { @@ -198,7 +203,7 @@ std::tuple _efficient_attention_backward( ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; + p.gQKV_strideM_multiplier = gQKV_strideM_multiplier; TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); @@ -257,5 +262,28 @@ std::tuple _efficient_attention_backward( return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } + +std::tuple _scaled_dot_product_efficient_attention_backward_cuda( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + bool causal){ + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto grad_out = grad_out_.transpose(1, 2); + auto out_t = out.transpose(1, 2); + auto q_t = query.transpose(1, 2); + auto k_t = key.transpose(1, 2); + auto v_t = value.transpose(1, 2); + + Tensor grad_q, grad_k, grad_v; + std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal); + return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index aaf7d833fe83..7cc0c250664e 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -26,6 +26,7 @@ * ******************************************************************************/ +#include #ifdef USE_FLASH_ATTENTION #include #include @@ -115,7 +116,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.is_causal = is_causal; } -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -241,9 +242,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fprop(launch_params, /*configure=*/false); - std::vector result = {o, softmax_lse}; - if (return_softmax) {result.push_back(s);} - return result; + return std::make_tuple(o, softmax_lse, s); } } // namespace fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h index 226d4ddd2b55..b0555463be04 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -7,7 +7,7 @@ namespace fmha { TORCH_API -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 5d62a6cbd0dc..55e9aeb184a2 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -91,6 +91,31 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { return true; } +inline bool check_for_nested_inputs(sdp_params params, bool debug){ + if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { + TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors."); + return false; + } + return true; +} + +inline bool check_requires_grad(sdp_params params, bool debug) { + if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { + TORCH_CHECK(!debug, "Flash Attention does not currently support training."); + return false; + } + return true; +} + +inline bool check_requires_grad_and_nested(sdp_params params, bool debug) { + // If we fail both checks then we return false + if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){ + TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs."); + return false; + } + return true; +} + inline bool check_for_attn_mask(sdp_params params, bool debug) { if (params.has_attn_mask) { TORCH_CHECK(!debug, "Flash Attention does not support attention mask."); @@ -198,13 +223,15 @@ inline bool use_flash_attention(sdp_params params, bool debug) { return false; #endif // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints {{ + constexpr std::array constraints {{ check_runtime_disabled_flash, + check_requires_grad, check_tensor_shapes, check_for_attn_weights, check_for_attn_mask, check_head_dim_size, check_gpu_sm75_or_greater, + check_for_nested_inputs, check_for_seq_len_1_nested_tensor}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { @@ -232,14 +259,15 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { at::kHalf, at::kFloat, at::kBFloat16}; // Define gate functions that determine if a flash kernel can be ran - std::vector> constraints{ + constexpr std::array constraints{{ check_gpu_sm50_or_greater, check_runtime_disabled_mem_efficient, + check_requires_grad_and_nested, check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, check_for_seq_len_1_nested_tensor, - check_for_non_zero_dropout}; + check_for_non_zero_dropout}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/benchmarks/transformer/sdp_backwards.py b/benchmarks/transformer/sdp_backwards.py new file mode 100644 index 000000000000..2f745e157b28 --- /dev/null +++ b/benchmarks/transformer/sdp_backwards.py @@ -0,0 +1,189 @@ +import torch +import numpy as np +import random +import torch.utils.benchmark as benchmark +from torch.profiler import profile, record_function, ProfilerActivity + + +class CompositeMHA(torch.nn.Module): + def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): + super().__init__() + self.in_proj_weight = in_proj_weight + self.in_proj_bias = in_proj_bias + self.out_proj = out_proj + self.num_heads = num_heads + + def forward(self, query, key, value, mask): + if not (query is key and key is value): + raise NotImplementedError( + "query, key and value must be the same Tensor for now." + ) + if mask is not None: + raise NotImplementedError("mask is currently not supported.") + + query_projected = torch.nn.functional.linear( + query, self.in_proj_weight, self.in_proj_bias + ) + + batch_size = query_projected.size(0) + embed_dim = query_projected.size(2) + head_dim = embed_dim // (self.num_heads * 3) + + query, key, value = query_projected.chunk(3, -1) + + query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + attn, _ = torch.nn.functional._scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + need_attn_weights=False, + is_causal=False, + ) + + attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) + # Match return signature of nn.MHA + return self.out_proj(attn) + + +def build_composite_mha_from_nn_mha(pt): + assert pt._qkv_same_embed_dim + in_proj_weight = pt.in_proj_weight + assert in_proj_weight is not None + assert pt.batch_first + return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) + + +def forw_back(model, input, upward): + output = model(*input) + output.backward(upward) + + +# Context manger not working in timer + + +def forw_back_fused(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + output = model(*input) + output.backward(upward) + + +def forw_back_eager(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + output = model(*input) + output.backward(upward) + + +def run_timing( + min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype +): + dropout_p = 0.0 + mask = None + + pt = torch.nn.MultiheadAttention( + embed_dim=embed_dimension, + num_heads=num_heads, + batch_first=True, + dropout=dropout_p, + ) + npt = pt.cuda().to(dtype) + cpt = build_composite_mha_from_nn_mha(npt) + x = torch.randn( + batch_size, + max_sequence_len, + embed_dimension, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + rand_fused_upward = cpt(x, x, x, mask).clone().detach() + + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + rand_eager_upward = cpt(x, x, x, mask).clone().detach() + + t0 = benchmark.Timer( + stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)", + globals={ + "forw_back_fused": forw_back_fused, + "cpt": cpt, + "x": x, + "rand_fused_upward": rand_fused_upward, + "mask": mask, + }, + label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + t1 = benchmark.Timer( + stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)", + globals={ + "forw_back_eager": forw_back_eager, + "cpt": cpt, + "x": x, + "rand_eager_upward": rand_eager_upward, + "mask": mask, + }, + label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + m0 = t0.blocked_autorange(min_run_time=min_run_time) + m1 = t1.blocked_autorange(min_run_time=min_run_time) + + print(m0) + print(m1) + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print("Profile for Fused".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_fused_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + print("Profile for eager".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_eager_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + +def main(): + seed = 123 + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + min_run_time = 10 + batch_size = 64 + num_heads = 32 + max_seq_len = 256 + embed_dim = 1024 + dtype = torch.bfloat16 + + print( + f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} " + f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}" + ) + run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype) + + +if __name__ == "__main__": + main() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 90080ab0934f..853f5206969b 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -317,6 +317,9 @@ ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), + ("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)), + ("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)), + ("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ] diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 91ea2443777b..f276b739f81d 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -401,6 +401,7 @@ def wrapped_fn(*args, **kwargs): skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool3d'), # fails everywhere except on mac xfail("native_batch_norm"), + xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented })) @@ -522,6 +523,7 @@ def f(inp, *args, **kwargs): xfail('nn.functional.ctc_loss'), # Not Implemented xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other' xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) @@ -616,7 +618,7 @@ def fn(inp, *args, **kwargs): skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. @@ -1093,6 +1095,7 @@ def test(): skip('nn.functional.rrelu'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to_sparse', ''), # non-dense output @@ -1216,6 +1219,7 @@ def get_vjp(cotangents, *primals): xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides @@ -1336,7 +1340,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.dropout2d'), # calls random op xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout'), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op diff --git a/test/test_meta.py b/test/test_meta.py index 6d21d5c7bd75..0e3cfb6ef140 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -294,7 +294,6 @@ def test_tensor_outlives_converter(self): aten._fft_c2r.default, aten._fft_r2c.default, aten._linalg_svd.default, - aten._scaled_dot_product_attention_forward.default, aten.binary_cross_entropy.default, aten.complex.default, aten.copysign.Tensor, diff --git a/test/test_transformers.py b/test/test_transformers.py index abb4c71ec19a..f6bc0cc2d663 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1059,6 +1059,11 @@ def rand_tensor(shape): if fused_kernel == "flash": with sdp_kernel(enable_mem_efficient=False, enable_math=False): + # TODO Flash for the nested path is currently not working due to cuda memory issues + if type == "nested": + self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)) + return actual = torch.nn.functional._scaled_dot_product_attention( query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) elif fused_kernel == "mem_efficient": @@ -1097,28 +1102,73 @@ def rand_tensor(shape): @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) - def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): + def test_sdp_math_gradcheck(self, contiguous_inputs: bool): batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) query, key, value = qkv.chunk(3, dim=-1) - query = query.view(batch_size, -1, num_heads, head_dim) - key = key.view(batch_size, -1, num_heads, head_dim) - value = value.view(batch_size, -1, num_heads, head_dim) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) if contiguous_inputs: query = query.contiguous() key = key.contiguous() value = value.contiguous() - # Normally we would transpose the inputs but the fused kernels expect - # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel - # in fp32 - assert gradcheck(lambda *args, **kwargs: - wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), - (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (query, key, value, None, 0.0, False, False) + ) + + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool): + batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) + + qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) + qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() + + query, key, value = qkv.chunk(3, dim=-1) + query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + query_lp = query_lp.contiguous() + key_lp = key_lp.contiguous() + value_lp = value_lp.contiguous() + + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False) + + with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False): + out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, None, 0.0, False, False) + + rand_upward = torch.rand_like(out) + rand_upward_lp = rand_upward.to(torch.float32) + + out.backward(rand_upward) + out_lp.backward(rand_upward_lp) + + # Cast up and compare + self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) @parametrize("type", ["dense", "nested"]) def test_fused_sdp_choice(self, type: str): @@ -1144,7 +1194,7 @@ def test_fused_sdp_choice(self, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - if SM80OrLater: + if SM80OrLater and not type == "nested": assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION else: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a0892b32a835..52c0f76bf070 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2613,9 +2613,13 @@ nested_strides: non_differentiable # Transformers +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal) + - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) output_differentiability: [True, False] - query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) + query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal) # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6cff2f6a4749..cf68a68cf629 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12008,16 +12008,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # This is only failing on Linux Bionic 3.10 Cuda 11.6 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', + device_type='cuda', dtypes=(torch.float32,)), # AssertionError: JIT Test does not execute any logic DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), # Doesn't support autocasting DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensorNonErroring', 'test_fake_autocast', device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), + # Forward works for dtype=float64 which is the math path + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), # No meta function DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', device_type='cuda'), DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),), ), UnaryUfuncInfo( From 8c0515dbff04f03cae584e10100715e05f7cb32e Mon Sep 17 00:00:00 2001 From: Nikolay Korovaiko Date: Sat, 19 Nov 2022 02:18:03 +0000 Subject: [PATCH 369/453] cast C++ py-bound SymNode to SymInt correctly (#89295) Unfortunately, it's a bit hard to test purely on the Pytorch core side, but it passes the XLA tests which are currently disabled. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89295 Approved by: https://github.com/ezyang --- torch/csrc/utils/pybind.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/csrc/utils/pybind.cpp b/torch/csrc/utils/pybind.cpp index 4cd148fdfa91..1b9d1e3a2f73 100644 --- a/torch/csrc/utils/pybind.cpp +++ b/torch/csrc/utils/pybind.cpp @@ -7,8 +7,14 @@ namespace detail { bool type_caster::load(py::handle src, bool) { if (torch::is_symint(src)) { + auto node = src.attr("node"); + if (py::isinstance(node)) { + value = c10::SymInt(py::cast(node)); + return true; + } + value = c10::SymInt(static_cast( - c10::make_intrusive(src.attr("node")))); + c10::make_intrusive(node))); return true; } From e6996ea172b01fa6c6586379ccb26746c32e2b2c Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Sat, 19 Nov 2022 02:24:18 +0000 Subject: [PATCH 370/453] Don't redefine __STDC_FORMAT_MACROS (#89310) Similar to https://github.com/pytorch/pytorch/pull/39608 and https://github.com/pytorch/pytorch/pull/6676 This causes a compile error in our internal build. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89310 Approved by: https://github.com/kit1980 --- torch/csrc/cuda/Tensor.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/csrc/cuda/Tensor.cpp b/torch/csrc/cuda/Tensor.cpp index beb81f187a6e..f9486164358d 100644 --- a/torch/csrc/cuda/Tensor.cpp +++ b/torch/csrc/cuda/Tensor.cpp @@ -1,4 +1,6 @@ +#ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS +#endif // Order of these includes matters, which should be fixed. // clang-format off From 631baecbcd821124296cfe40e5c297cf2def410c Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Sat, 19 Nov 2022 03:35:07 +0000 Subject: [PATCH 371/453] Add --explain flag to bench (#89316) TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 time python benchmarks/dynamo/torchbench.py --accuracy --explain --backend aot_eager --train --only BERT_pytorch Dynamo produced 76 graphs with 75 graph break and 198 ops Pull Request resolved: https://github.com/pytorch/pytorch/pull/89316 Approved by: https://github.com/ezyang --- benchmarks/dynamo/common.py | 10 ++++++++++ torch/_dynamo/eval_frame.py | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index c4e9d62f0a7c..f4d1bfad37d7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1318,6 +1318,7 @@ def run_one_model( experiment, diff=False, branch=None, + explain=False, ): if diff: self.compare_branches( @@ -1337,6 +1338,8 @@ def run_one_model( name, model, example_inputs, optimize_ctx, experiment ) print(status) + if explain: + print(torch._dynamo.explain(model, *example_inputs)[0]) def help(fn): @@ -1515,6 +1518,12 @@ def get_example_inputs(self): help="Delta this branch against main. In the future, we may add support for picking the branch.", ) + parser.add_argument( + "--explain", + action="store_true", + help="run .explain() on the graph at the end of the run.", + ) + parser.add_argument( "--cold_start_latency", action="store_true", @@ -1982,6 +1991,7 @@ def run(runner, args, original_dir=None): optimize_ctx, experiment, diff=args.diff_main, + explain=args.explain, ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 538f6131d62b..31fb479906e1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -429,6 +429,7 @@ def toy_example(a, b): ) +# TODO(voz): Consider making "explain" output alongside a run / part of a run @patch("torch._dynamo.symbolic_convert.explain", True) def explain(f, *args, **kwargs): # TODO(voz): Do we want a decorator for this? @@ -487,15 +488,23 @@ def guard_export_print(guards): msg = f"{break_reason.reason}\n{formatted_stack}" formatted_list += f"{idx + 1}. {msg} \n" - explanation = f"Dynamo produced {graph_count} graphs" + explanation = f"Dynamo produced {graph_count} graphs " explanation += f"with {graph_count - 1} graph break and {op_count} ops" - explanation += f"\n Break reasons: \n\n{formatted_list}" + explanation_verbose = explanation + explanation_verbose += f"\n Break reasons: \n\n{formatted_list}" - explanation += compile_times() + explanation_verbose += compile_times() # TODO(voz): Do we want a decorator for this? reset() - return explanation, out_guards, graphs, ops_per_graph, break_reasons + return ( + explanation, + out_guards, + graphs, + ops_per_graph, + break_reasons, + explanation_verbose, + ) def export( From 7a2930b357a4e62bb0bab53bb0d23c607b6ede38 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 19 Nov 2022 04:09:29 +0000 Subject: [PATCH 372/453] add jvp test with non-contig inputs (#89131) Ref: https://github.com/pytorch/functorch/issues/1029 We update `test_jvp` to do contiguous and non-contiguous testing in a single test. Prev time for `test_jvp` : ~28s New time for `test_jvp`: ~45s Pull Request resolved: https://github.com/pytorch/pytorch/pull/89131 Approved by: https://github.com/zou3519 --- test/functorch/test_ops.py | 53 +++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index f276b739f81d..e9451b596b4a 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -10,7 +10,7 @@ import unittest from torch.testing._internal.common_utils import TestCase, run_tests, is_iterable_of_tensors, IS_MACOS, \ - IS_ARM64, parametrize, TEST_WITH_ASAN + IS_ARM64, IS_X86, parametrize, TEST_WITH_ASAN, noncontiguous_like import torch from torch import Tensor import functools @@ -403,15 +403,31 @@ def wrapped_fn(*args, **kwargs): xfail("native_batch_norm"), xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), - xfail('nn.functional.rrelu') # in-place test errors out with no formula implemented + xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented + + # --- Non-Contiguous Failures! --- + # This is expected to fail as the operator + # expects last dim to have stride=1 + xfail('view_as_complex'), + # BUG + # AssertionError: Tensor-likes are not close! + xfail('as_strided'), + decorate('linalg.det', 'singular', + decorator=unittest.skipIf(IS_MACOS and IS_X86, "Fails on x86 MacOS CI")), })) @opsToleranceOverride('TestOperators', 'test_jvp', ( tol1('nn.functional.conv_transpose3d', {torch.float32: tol(atol=1e-04, rtol=1.3e-06)}, device_type='cuda'), + tol1('linalg.tensorsolve', + {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}, device_type='cuda'), tol1('nn.functional.binary_cross_entropy_with_logits', {torch.float32: tol(atol=4e-04, rtol=4e-04)}), tol1('nn.functional.batch_norm', {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1('nn.functional.conv2d', + {torch.float32: tol(atol=4e-05, rtol=5e-05)}), + tol1('pca_lowrank', + {torch.float32: tol(atol=5e-05, rtol=5e-05)}), )) def test_jvp(self, device, dtype, op): # TODO: get rid of vjp_decomp when we add decomposition support to @@ -435,28 +451,38 @@ def test_jvp(self, device, dtype, op): inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None for sample in samples: - args = (sample.input,) + sample.args - kwargs = sample.kwargs if outplace_variant: - self.jvp_opinfo_test(outplace_variant, args, kwargs, + self.jvp_opinfo_test(outplace_variant, sample, sample.output_process_fn_grad, clone_inputs=False, fixme_ref_jvp_local=fixme_ref_jvp_local) if is_valid_inplace_sample_input(sample, op, inplace_variant): - self.jvp_opinfo_test(inplace_variant, args, kwargs, + self.jvp_opinfo_test(inplace_variant, sample, sample.output_process_fn_grad, clone_inputs=True, fixme_ref_jvp_local=fixme_ref_jvp_local) - def jvp_opinfo_test(self, fn, args, kwargs, output_process_fn, + + def jvp_opinfo_test(self, fn, sample, output_process_fn, clone_inputs, fixme_ref_jvp_local): # NB: we used requires_grad=True to determine where the primals are, # but don't need that information otherwise - fn, primals = normalize_op_input_output2( + args = (sample.input,) + sample.args + kwargs = sample.kwargs + contig_fn, primals = normalize_op_input_output2( fn, args, kwargs, output_process_fn, requires_grad=True) orig_primals = tree_map(lambda x: x.detach(), primals) orig_tangents = tree_map(lambda x: torch.randn_like(x), primals) + noncontig_sample = sample.noncontiguous() + noncontig_args = (noncontig_sample.input,) + noncontig_sample.args + noncontig_kwargs = sample.kwargs + noncontig_fn, primals = normalize_op_input_output2( + fn, noncontig_args, noncontig_kwargs, + output_process_fn, requires_grad=True) + noncontig_primals = tree_map(lambda x: x.detach(), primals) + noncontig_tangents = tree_map(lambda x: noncontiguous_like(x), orig_tangents) + def maybe_clone_inputs(): if clone_inputs: primals = tree_map(torch.clone, orig_primals) @@ -466,14 +492,21 @@ def maybe_clone_inputs(): primals, tangents = maybe_clone_inputs() expected_primal_outs, expected_tangent_outs = \ - fixme_ref_jvp_local(fn, primals, tangents) + fixme_ref_jvp_local(contig_fn, primals, tangents) primals, tangents = maybe_clone_inputs() - primal_outs, tangent_outs = jvp(fn, primals, tangents) + primal_outs, tangent_outs = jvp(contig_fn, primals, tangents) + + noncontig_primal_outs, noncontig_tangent_outs = jvp(noncontig_fn, + noncontig_primals, + noncontig_tangents) self.assertEqual(primal_outs, expected_primal_outs) self.assertEqual(tangent_outs, expected_tangent_outs) + self.assertEqual(noncontig_primal_outs, expected_primal_outs) + self.assertEqual(noncontig_tangent_outs, expected_tangent_outs) + @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vjp', vjp_fail.union({ xfail('sparse.sampled_addmm', ''), From 419ef2cdcfe84442de5232739284c6a51a18632f Mon Sep 17 00:00:00 2001 From: Horace He Date: Fri, 18 Nov 2022 21:39:11 +0000 Subject: [PATCH 373/453] Added utility to count memory reads/written in Inductor (#89203) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89203 Approved by: https://github.com/jansel, https://github.com/ngimel --- test/inductor/test_perf.py | 434 ++++++++++++++++++++++++++++++++ torch/_inductor/compile_fx.py | 28 ++- torch/_inductor/dependencies.py | 39 ++- torch/_inductor/graph.py | 43 +++- torch/_inductor/metrics.py | 5 + torch/_inductor/scheduler.py | 2 +- torch/_inductor/utils.py | 5 + torch/_inductor/virtualized.py | 2 +- 8 files changed, 545 insertions(+), 13 deletions(-) create mode 100644 test/inductor/test_perf.py diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py new file mode 100644 index 000000000000..d473ff4b7449 --- /dev/null +++ b/test/inductor/test_perf.py @@ -0,0 +1,434 @@ +# Owner(s): ["module: inductor"] +import contextlib +from unittest.mock import patch + +import torch._dynamo +import torch._inductor.config as config +from torch._dynamo.optimizations.backends import register_backend +from torch._inductor import metrics +from torch._inductor.compile_fx import compile_fx, count_bytes_inner +from torch.testing._internal.common_utils import ( + TEST_WITH_ROCM, + TestCase as TorchTestCase, +) +from torch.testing._internal.inductor_utils import HAS_CUDA + +aten = torch.ops.aten + + +@register_backend +def count_bytes_inductor(gm, example_inputs): + return compile_fx(gm, example_inputs, inner_compile=count_bytes_inner) + + +@torch._dynamo.optimize("count_bytes_inductor") +def f(x): + return torch.cat([x, x.cos()]) + + +def count_numel(f, *args): + """ + Assumes all inputs are fp32 + """ + metrics.reset() + torch._dynamo.optimize("count_bytes_inductor")(f)(*args) + print(metrics.nodes_num_elem) + return str(metrics.num_bytes_accessed // 4) + + +DEVICE = "cuda" + + +def T(*size, dtype=torch.float32, device=DEVICE): + return torch.randn(size, dtype=dtype, device=device) + + +def TI(*size, mx=10, dtype=torch.int32, device=DEVICE): + return torch.randint(0, mx, size, dtype=dtype, device=device) + + +class TestCase(TorchTestCase): + device = DEVICE + pass + + +class NumBytesMetricTests(TestCase): + """ + Primarily used for sanity testing that the num_bytes_accessed metrics is correct. + """ + + def test_pointwise(self): + def f(x): + return x.cos() + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """20""") + + def f(x, y): + return x + y + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + def f(x, y): + return x + y + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def f(x): + return x + x + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """20""") + + def f(x): + return x + x.t() + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def f(a, b, c): + return a.cos(), b.sin() + c.sin() + + inp = (T(10), T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """50""") + + def test_reduction(self): + def f(x): + return x.sum(dim=1) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def f(x): + return x.sum(dim=0) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def test_extern(self): + def f(x): + return torch.mm(x, x) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def f(a, b): + return torch.mm(a, b) + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def f(x): + x = x.cos() + x = torch.mm(x, x) + x = x.cos() + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """600""") + + def f(x): + a = x.cos() + b = x.sin() + x = torch.mm(a, b) + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """600""") + + def test_cat(self): + def f(a, b): + return torch.cat([a.sin(), b.sin()]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a, b): + return torch.cat([a, b]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a, b): + return torch.cat([a.cos(), b]) + + inp = (T(10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """40""") + + def f(a): + return torch.cat([a.cos(), a.sin()]) + + inp = (T(10),) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + def test_index(self): + def f(a, b): + return a[b] + + inp = (T(10), TI(10, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """30""") + + +class FusionTests(TestCase): + """ + Tests that things can be fused into a single kernel + """ + + def test_horizontal_reduction_pointwise(self): + def f(a): + b = a.sum(dim=1) + c = a.cos() + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def test_horizontal_reduction_reduction(self): + def f(a): + b = a.sum(dim=1) + c = a.amax(dim=1) + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_reduction_pointwise2(self): + def f(a, b): + c = a.sum(dim=1) + b = b.cos() + return b + c + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_reduction_outer_pointwise(self): + def f(a, b): + c = a.sum(dim=0) + b = b.cos() + return b + c + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """120""") + + def test_horizontal_sum_pw_broadcast(self): + def f(a, b): + a = a.sum(dim=1, keepdim=True) + b = b.cos() + return a * b + + inp = (T(10, 10), T(10)) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + def test_vertical_sum_pw(self): + def f(a): + a = a.cos() + a = a.sum(dim=1) + return a.cos() + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """110""") + + def test_norm_chain(self): + def f(a): + b = a.sum(dim=1, keepdim=True) + a = a * b + b = a.sum(dim=1, keepdim=True) + a = a * b + b = a.sum(dim=1, keepdim=True) + a = a * b + return a + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_softmax_inner(self): + def f(a): + return torch.softmax(a, dim=1) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_layer_norm(self): + # TODO: Suboptimal! We shouldn't need to save normalization stats. + mod = torch.nn.LayerNorm(10, device=self.device) + + def f(x): + return mod(x) + + inp = (T(10, 10),) + with torch.no_grad(): + self.assertExpectedInline(count_numel(f, *inp), """220""") + + def test_double_softmax(self): + def f(x): + x = torch.softmax(x, dim=1) + x = torch.softmax(x, dim=1) + return x + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + def test_softmax_backward(self): + def f(grad_out, out): + return aten._softmax_backward_data(grad_out, out, 1, torch.float32) + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def test_neighbor(self): + def f(a, b): + return ((a - b) ** 2).sum(dim=-1).amax(dim=1) + + inp = (T(10, 1, 4), T(1, 10, 4)) + self.assertExpectedInline(count_numel(f, *inp), """90""") + + def test_factory_reduction(self): + def f(): + a = torch.ones(10, device=self.device) + b = torch.ones(10, 10, device=self.device) + return (a + b).sum(dim=-1) + + inp = () + self.assertExpectedInline(count_numel(f, *inp), """10""") + + def test_index_pointwise(self): + def f(a, b): + return a[b].cos() + + inp = (T(10, 10), TI(20, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """320""") + + def test_index_reduction(self): + def f(a, b): + return a[b].cos().sum(dim=1) + + inp = (T(10, 10), TI(20, mx=10)) + self.assertExpectedInline(count_numel(f, *inp), """140""") + + +class SchedulerFusionTests(TestCase): + """ + Testing the fusion group creation heuristic (i.e. cases where we can't fuse + everything into a single kernel) + Disables inductor rematerialization for easier reasoning of tests. + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._stack = contextlib.ExitStack() + cls._stack.enter_context(patch.object(config, "realize_bytes_threshold", 0)) + + @classmethod + def tearDownClass(cls): + cls._stack.close() + super().tearDownClass() + + def test_fusion_choice1(self): + # Doesn't matter where we break fusion group here + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c.cos() + return d + e + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """700""") + + def test_fusion_choice2(self): + # We should materialize e (it's smaller!) + # [c, e]: 210, [f]: 210, [d]: 200 + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c.sum(dim=1) + f = d + e + return f + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """620""") + + def test_fusion_choice3(self): + # We should materialize e. + # [c, e]: 300, [f]: 300, [d]: 200 + def f(a): + c = a.cos() + d = torch.mm(c, c) + e = c + a + f = d + e + return f, e + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """800""") + + +class TilingTests(TestCase): + def test_tiling_simple(self): + def f(a, b): + return a + b.t() + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def f(a, b): + return a.t() + b + + inp = (T(10, 10), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """300""") + + def test_tiling_three(self): + def f(a, b, c): + return a + b.permute(1, 2, 0) + c.permute(2, 0, 1) + + inp = (T(10, 10, 10), T(10, 10, 10), T(10, 10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """4000""") + + +# Test cases where we don't do the right thing yet. +class WouldBeNiceIfItWorked: + def test_horizontal(self): + def f(a): + b = a.sum(dim=0) + c = a.cos() + return b, c + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """210""") + + # TODO: We aren't fusing outer dim softmaxes + def test_softmax_outer(self): + def f(a): + return torch.softmax(a, dim=0) + + inp = (T(10, 10),) + self.assertExpectedInline(count_numel(f, *inp), """200""") + + # TODO: The greedy fusion strategy results in suboptimal grouping + @patch.object(config, "realize_bytes_threshold", 0) + def test_fusion_choice4(self): + def f(a, b, b2): + c = a + b + d = torch.mm(c, c) + e = c + b + b2 + f = d + e + b2 + return f, e + + inp = (T(10, 10), T(10, 10, dtype=torch.float16), T(10, 10)) + self.assertExpectedInline(count_numel(f, *inp), """1000""") + + # TODO: We materialize the intermediate if we don't unroll the reduction + def test_neighbor(self): + def f(a, b): + return ((a - b) ** 2).sum(dim=-1).amax(dim=1) + + inp = (T(10, 1, 8), T(1, 10, 8)) + self.assertExpectedInline(count_numel(f, *inp), """170""") + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + if HAS_CUDA and not TEST_WITH_ROCM: + run_tests(needs="filelock") diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 813daee1252f..c482e55a954d 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -12,7 +12,7 @@ import torch.fx from torch._subclasses.fake_tensor import FakeTensor -from . import config, overrides +from . import config, metrics, overrides from .debug import DebugContext from .decomposition import select_decomp_table from .graph import GraphLowering @@ -83,6 +83,22 @@ def _step_logger(): return dynamo_logging.get_step_logger(log) +@DebugContext.wrap +def count_bytes_inner(gm, example_inputs, num_fixed=0, **kwargs): + shape_env = None + for inp in example_inputs: + if isinstance(inp, FakeTensor) and inp.fake_mode.shape_env is not None: + shape_env = inp.fake_mode.shape_env + + graph = GraphLowering(gm, shape_env=shape_env, num_static_inputs=num_fixed) + with V.set_graph_handler(graph): + graph.run(*example_inputs) + num_bytes, nodes_num_elem = graph.count_bytes() + metrics.num_bytes_accessed += num_bytes + metrics.nodes_num_elem += nodes_num_elem + return make_boxed_func(gm.forward) + + @DebugContext.wrap @torch.utils._python_dispatch._disable_current_modes() def compile_fx_inner( @@ -326,7 +342,11 @@ def is_not_gradout(x): _graph_counter = itertools.count(0) -def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor]): +def compile_fx( + model_: torch.fx.GraphModule, + example_inputs_: List[torch.Tensor], + inner_compile=compile_fx_inner, +): """Main entrypoint to a compile given FX graph""" if not is_aot_autograd_safe_to_run(model_, example_inputs_): @@ -348,7 +368,7 @@ def compile_fx(model_: torch.fx.GraphModule, example_inputs_: List[torch.Tensor] @dynamo_utils.dynamo_timed def fw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = len(example_inputs) - num_example_inputs - return compile_fx_inner( + return inner_compile( model, example_inputs, num_fixed=fixed, @@ -359,7 +379,7 @@ def fw_compiler(model: torch.fx.GraphModule, example_inputs): @dynamo_utils.dynamo_timed def bw_compiler(model: torch.fx.GraphModule, example_inputs): fixed = count_tangents(model) - return compile_fx_inner( + return inner_compile( model, example_inputs, num_fixed=fixed, diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 27c92f82c07c..5434d7addfa9 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -9,7 +9,14 @@ from . import config from .codegen.common import index_prevent_reordering -from .utils import sympy_product, sympy_str, sympy_subs, sympy_symbol, VarRanges +from .utils import ( + get_dtype_size, + sympy_product, + sympy_str, + sympy_subs, + sympy_symbol, + VarRanges, +) from .virtualized import V log = logging.getLogger(__name__) @@ -69,11 +76,18 @@ def rename(self, renames: Dict[str, str]) -> "MemoryDep": return MemoryDep(renames[self.name], self.index, self.size) return self - def numel_hint(self): + def numbytes_hint(self): vars = set(self.index.free_symbols) + size_vars_used = [] + for var in vars: + if var.name.startswith(canonicalization_prefix()): + # Sometimes with indirect indexing we have very weird symbol names + assert " " not in var.name + size_vars_used.append(int(var.name[len(canonicalization_prefix()) :])) + return V.graph.sizevars.size_hint( - sympy_product([s for s in self.size if s in vars]) - ) + sympy_product([self.size[i] for i in size_vars_used]) + ) * get_dtype_size(V.graph.get_dtype(self.name)) def is_contiguous(self) -> bool: return isinstance(self.index, (sympy.Symbol, sympy.Integer)) @@ -88,8 +102,21 @@ def rename(self, renames: Dict[str, str]) -> "StarDep": return StarDep(renames[self.name]) return self - def numel_hint(self): - return 1 + def numbytes_hint(self): + from .ir import MultiOutputLayout + + if self.name in V.graph.name_to_buffer: + buf = V.graph.name_to_buffer[self.name] + elif self.name in V.graph.graph_inputs: + buf = V.graph.graph_inputs[self.name] + else: + return 1 + if hasattr(buf, "layout") and isinstance(buf.layout, MultiOutputLayout): + # NB: Too annoying to acquire, should only be used for instrumentation + return 1 + return V.graph.sizevars.size_hint( + sympy_product(buf.get_size()) + ) * get_dtype_size(buf.get_dtype()) def is_contiguous(self) -> bool: return False diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 5114ffa76111..a47d9c1a02e1 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -27,7 +27,7 @@ needs_realized_inputs, ) from .sizevars import SizeVarAllocator -from .utils import dynamo_utils, gather_origins +from .utils import dynamo_utils, gather_origins, get_dtype_size, sympy_product from .virtualized import V log = logging.getLogger(__name__) @@ -356,6 +356,47 @@ def codegen(self): self.scheduler.codegen() return self.wrapper_code.generate() + def count_bytes(self): + from .scheduler import FusedSchedulerNode, NopKernelSchedulerNode, Scheduler + + scheduler = Scheduler(self.buffers) + + def get_read_write_buffers_sizes(node): + if isinstance(node, NopKernelSchedulerNode): + return 0 + reads = set(dep.name for dep in node.read_writes.reads) + writes = set(dep.name for dep in node.read_writes.writes) + + def is_materialized(buf): + buf_uses = set( + [user.node for user in scheduler.name_to_node[buf].users] + ) + return len(buf_uses - set(node.snodes)) > 0 + + if isinstance(node, FusedSchedulerNode): + writes = set([dep for dep in writes if is_materialized(dep)]) + node_bytes = 0 + for buf in reads | writes: + if buf in self.name_to_buffer: + buf = self.name_to_buffer[buf] + elif buf in self.graph_inputs: + buf = self.graph_inputs[buf] + else: + continue + + node_bytes += V.graph.sizevars.size_hint( + sympy_product(buf.get_size()) + ) * get_dtype_size(buf.get_dtype()) + return node_bytes + + total_bytes = 0 + node_counts = [] + for node in scheduler.nodes: + num_bytes = get_read_write_buffers_sizes(node) + node_counts.append((node, num_bytes // 4)) + total_bytes += num_bytes + return total_bytes, node_counts + @dynamo_utils.dynamo_timed def compile_to_module(self): from .codecache import PyCodeCache diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 582c5aca7f88..f7e05288c9a5 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -1,12 +1,17 @@ # counter for tracking how many kernels have been generated generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 +num_bytes_accessed = 0 +nodes_num_elem = [] # reset all counters def reset(): global generated_kernel_count global generated_cpp_vec_kernel_count + global num_bytes_accessed, nodes_num_elem generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 + num_bytes_accessed = 0 + nodes_num_elem.clear() diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index cb71a4443804..8609617897bf 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -977,7 +977,7 @@ def score_fusion_memory(self, node1, node2): common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( node2.read_writes.reads | node2.read_writes.writes ) - return sum(dep.numel_hint() for dep in common_memory_deps) + return sum(dep.numbytes_hint() for dep in common_memory_deps) def score_fusion_key(self, nodes): """ diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 08e95b9b5cc3..62357be8bcf3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -375,3 +375,8 @@ def fresh_inductor_cache(cache_entries=None): def argsort(seq): return sorted(range(len(seq)), key=seq.__getitem__) + + +@functools.lru_cache(8) +def get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 5d40d05f751f..27e60b1daf1d 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -74,7 +74,7 @@ def masked(cls, mask, body, other): @staticmethod def indirect_indexing(index_var): - return sympy_symbol(str(index_var)) + return sympy_symbol(f"({str(index_var)})") @classmethod def _init_cls(cls): From 808bdbab89e875abbbe9652bde675b4402eed532 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Sat, 19 Nov 2022 07:16:29 +0000 Subject: [PATCH 374/453] Fix try/except flow where DataDependentOutputException is getting wrapped in a RuntimeError (#89314) Repro fixed ``` def fn(a): return a.repeat_interleave(14, dim=0).repeat_interleave(14, dim=1) x = torch.ones(14, 14).to(dtype=torch.int64) opt_fn = torch._dynamo.optimize("eager")(fn) opt_fn(x) ``` Fixes [#1886](https://github.com/pytorch/torchdynamo/issues/1886) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89314 Approved by: https://github.com/anijain2305, https://github.com/eellison --- torch/_dynamo/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index e4b92a73aacf..889bb5683b6b 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1055,13 +1055,20 @@ def visit(n: torch.fx.Node): except Unsupported: raise except RuntimeError as e: - if isinstance(e, torch._subclasses.fake_tensor.DataDependentOutputException): + cause = e + if e.__cause__ is not None: + cause = e.__cause__ + if isinstance( + cause, torch._subclasses.fake_tensor.DataDependentOutputException + ): if config.capture_scalar_outputs and node.target == "item": return torch.zeros(size=(), dtype=args[0].dtype).item() else: - unimplemented(f"data dependent operator: {e.func}") - elif isinstance(e, torch._subclasses.fake_tensor.DynamicOutputShapeException): - unimplemented(f"dynamic shape operator: {e.func}") + unimplemented(f"data dependent operator: {cause.func}") + elif isinstance( + cause, torch._subclasses.fake_tensor.DynamicOutputShapeException + ): + unimplemented(f"dynamic shape operator: {cause.func}") raise TorchRuntimeError() from e From 940959ebbfa54204b3cd45f918c5ee65b5efc3d0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 18 Nov 2022 22:46:47 -0800 Subject: [PATCH 375/453] [quant][fix] Add quant_min/quant_max for default dynamic quantization observer (#89267) Summary: This is needed for choose qparams, but previously it is not configurable, and in the reference quantization flow with decomposed Tensor, we are making this explicit Test Plan: tested in future PR Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89267 Approved by: https://github.com/vkuzo --- torch/ao/quantization/observer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 3156b4245a12..42962fe7c29a 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1316,6 +1316,8 @@ class PlaceholderObserver(ObserverBase): Args: dtype: dtype argument to the `quantize` node needed to implement the reference model spec. + quant_min: minimum value in quantized domain (TODO: align behavior with other observers) + quant_min: maximum value in quantized domain custom_op_name: (temporary) specify this observer for an operator that doesn't require any observation (Can be used in Graph Mode Passes for special case ops). compute_dtype: if set, marks the future quantize function to use @@ -1325,12 +1327,15 @@ class PlaceholderObserver(ObserverBase): """ def __init__( - self, dtype=torch.float32, custom_op_name="", compute_dtype=None + self, dtype=torch.float32, custom_op_name="", compute_dtype=None, + quant_min=None, quant_max=None, ) -> None: - super(PlaceholderObserver, self).__init__(dtype=dtype) + super().__init__(dtype=dtype) # dtype of input of the target operator, e.g. for dynamic quantization # ops, the dtype will be float32 self.dtype = dtype + self.quant_min = quant_min + self.quant_max = quant_max self.custom_op = custom_op_name # used for configuration of computation type for dynamic quantization # TODO(future PR): replace this with `is_dynamic` @@ -1551,7 +1556,7 @@ def load_observer_state_dict(mod, obs_dict): """ default_dynamic_quant_observer = PlaceholderObserver.with_args( - dtype=torch.quint8, compute_dtype=torch.quint8 + dtype=torch.quint8, compute_dtype=torch.quint8, quant_min=0, quant_max=255 ) """ Default observer for dynamic quantization. From 6daf60be5abe4184121bc41e69e336015a268d6a Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Sat, 19 Nov 2022 02:56:14 +0000 Subject: [PATCH 376/453] [ONNX] Add setType from user into InferredType and Reliable in ConstantValueMap (#88622) `setType` API is not respected in current exporter because the graph-level shape type inference simply overrides every NOT ONNX Op shape we had from node-level shape type inference. To address this issue, this PR (1) makes custom Op with `setType` **reliable** in ConstantValueMap to secure its shape/type information in pass: _C._jit_pass_onnx. (2) If an invalid Op with shape/type in pass: _C._jit_pass_onnx_graph_shape_type_inference(graph-level), we recognize it as reliable. 1. In #62856, The refactor in onnx.cpp made regression on custom Op, as that was the step we should update custom Op shape/type information into ConstantValueMap for remaining Ops. 2. Add another condition besides IsValidONNXNode for custom Op setType in shape_type_inference.cpp. If all the node output has shape (not all dynamic), we say it's custom set type. 3. ~However, this PR won't solve the [issue](https://github.com/pytorch/pytorch/issues/87738#issuecomment-1292831219) that in the node-level shape type inference, exporter invokes the warning in terms of the unknow custom Op, since we process its symbolic_fn after this warning, but it would have shape/type if setType is used correctly. And that will be left for another issue to solve. #84661~ Add `no_type_warning` in UpdateReliable() and it only warns if non ONNX node with no given type appears. Fixes #81693 Fixes #87738 NOTE: not confident of this not breaking anything. Please share your thoughts if there is a robust test on your mind. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88622 Approved by: https://github.com/BowenBao --- test/onnx/internal/test_diagnostics.py | 5 +- .../onnx/test_pytorch_onnx_shape_inference.py | 171 +++++++++++++++++- torch/csrc/jit/passes/onnx.cpp | 17 +- .../jit/passes/onnx/shape_type_inference.cpp | 85 ++++++--- .../jit/passes/onnx/shape_type_inference.h | 7 +- 5 files changed, 250 insertions(+), 35 deletions(-) diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index 884b7cb1c388..49402204e9d2 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -215,10 +215,11 @@ def test_diagnostics_records_cpp_call_stack(self): assert stack is not None # for mypy self.assertGreater(len(stack.frames), 0) frame_messages = [frame.location.message for frame in stack.frames] + # node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx) + # after node-level shape type inference and processed symbolic_fn output type self.assertTrue( any( - isinstance(message, str) - and "torch::jit::ONNXShapeTypeInference" in message + isinstance(message, str) and "torch::jit::NodeToONNX" in message for message in frame_messages ) ) diff --git a/test/onnx/test_pytorch_onnx_shape_inference.py b/test/onnx/test_pytorch_onnx_shape_inference.py index cf9ef2fd893e..915677279d01 100644 --- a/test/onnx/test_pytorch_onnx_shape_inference.py +++ b/test/onnx/test_pytorch_onnx_shape_inference.py @@ -1,8 +1,10 @@ # Owner(s): ["module: onnx"] +import io + import numpy as np +import onnx import pytorch_test_common - import torch from pytorch_test_common import skipIfUnsupportedMinOpsetVersion from torch.onnx import _constants, symbolic_helper @@ -284,5 +286,172 @@ def test_reduce_prod_without_axes(self): self.run_test(g, reduce_prod.node(), expect_tensor("Long", shape=(1,))) +class TestONNXCustomOpShapeInference(pytorch_test_common.ExportTestCase): + def setUp(self): + super().setUp() + self.opset_version = _constants.ONNX_MAX_OPSET + + def test_setType_maintains_output_shape_for_single_custom_op(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType(self.type()) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + for dim, rank in zip(dims, x.size()): + self.assertEqual(dim.dim_value, rank) + + def test_no_setType_for_single_custom_op(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_no_settype(g, self): + return g.op("com.microsoft::Inverse", self) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_no_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_param")) + + def test_setType_maintains_output_shape_for_single_custom_op_with_dynamic_axes( + self, + ): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x): + return torch.inverse(x) + x + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType( + self.type().with_dtype(torch.float).with_sizes([None, 3, 3]) + ) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 3) + f = io.BytesIO() + torch.onnx.export( + model, + (x,), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + input_names=["x"], + dynamic_axes={"x": {0: "batch"}}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + dims = model_value_info[0].type.tensor_type.shape.dim + # The first axe should be dynamic as we defined when exporting + self.assertTrue(dims[0].HasField("dim_param")) + for i in range(1, len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + self.assertEqual(dims[i].dim_value, x.size()[i]) + + def test_setType_maintains_output_shape_for_single_custom_op_with_onnx_ops(self): + + self.addCleanup(torch.onnx.unregister_custom_op_symbolic, "::linalg_inv", 9) + + class CustomInverse(torch.nn.Module): + def forward(self, x, y, z): + x = torch.inverse(x) + return x + y + z + + def linalg_inv_settype(g, self): + return g.op("com.microsoft::Inverse", self).setType( + self.type().with_dtype(torch.float).with_sizes([2, 3, 10, 10]) + ) + + torch.onnx.register_custom_op_symbolic("::linalg_inv", linalg_inv_settype, 9) + model = CustomInverse() + x = torch.randn(2, 3, 10, 10) + y = torch.randn(2, 3, 10, 10) + z = torch.randn(2, 3, 10, 10) + f = io.BytesIO() + torch.onnx.export( + model, + (x, y, z), + f, + opset_version=self.opset_version, + custom_opsets={"com.microsoft": 1}, + ) + + model_proto = onnx.load(io.BytesIO(f.getvalue())) + # To validate the shape of inverse Op, we need to find inverse output name, + # and then use it to identify its value_info for the shape. + output_name = "" + for node in model_proto.graph.node: + if node.op_type == "Inverse": + output_name = node.output[0] + break + assert output_name + model_value_info = model_proto.graph.value_info + self.assertIsNotNone(model_value_info) + assert model_value_info + for value_info in model_value_info: + assert value_info.name + if value_info.name == output_name: + dims = value_info.type.tensor_type.shape.dim + for i in range(len(dims)): + # If node output has shape info, it should have dim_value + # Otherwise, it has dim_params with dynamic shape + self.assertTrue(dims[i].HasField("dim_value")) + for dim, rank in zip(dims, x.size()): + self.assertEqual(dim.dim_value, rank) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 607f2ce61ada..75e2d754aa50 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -14,7 +14,6 @@ #include #include #include - namespace torch { namespace jit { @@ -326,10 +325,20 @@ void NodeToONNX( ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); env[old] = const_node->output(); } else { - // ConstantValueMap has been set in shape inference, - // set_constant_value_map = false here to avoid redundancy. + // An update in ConstantValueMap is also needed here, since + // the user setType can be only accessed in this step, and it + // should be reliable. MergeInferredTypeAndSetMap( - outputs[i], old->type(), outputs[i]->type(), false); + outputs[i], old->type(), outputs[i]->type()); + // non ONNX node with no type given will throw out the warnings here. + UpdateReliable( + outputs[i], + AreInputsReliableOrStatic(outputs[i]->node()), + /*no_type_warning=*/true); + // For the node type that does not have ComputeConstant logic, it may + // have reliable shape but its shape is not in ConstantValueMap. So we + // need to update ConstantValueMap. + UpdateShapeConstantIfReliable(outputs[i]); // Copy over source location and scope information to all nodes // created by the symbolic diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 8baa439bdb58..a9087508e6ad 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -76,16 +76,13 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, TypePtr existing_type, - TypePtr inferred_type, - bool set_constant_value_map) { + TypePtr inferred_type) { TypePtr mergedType; bool inferred; std::tie(mergedType, inferred) = MergeInferredType(existing_type, inferred_type); dest_v->setType(mergedType); - if (set_constant_value_map) { - ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); - } + ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); } namespace { @@ -232,6 +229,28 @@ bool IsValidONNXNode(const Node* n) { return true; } +bool CustomSettype(Node* node) { + // This is a helper function to decide if the non-ONNX node actually has + // custom setType from user + // Go through every symbolic_sizes and if any one of them is static, we say + // this is set by user. On the other hand, if all of them are * (dynamic), we + // take this node does not have given type, since unreliable nodes have * + // shape anyway. + auto all_output_has_type = [](Value* output) { + if (auto output_type = output->type()->cast()) { + if (auto sizes = output_type->symbolic_sizes().sizes()) { + return std::any_of(std::begin(*sizes), std::end(*sizes), [](auto size) { + return size.is_static(); + }); + } + } + return false; + }; + + return std::all_of( + node->outputs().begin(), node->outputs().end(), all_output_has_type); +} + Value* CloneValueFromListConstruct( Value* v, std::shared_ptr n_graph, @@ -1879,7 +1898,8 @@ static std::unordered_set nodeTypeReliableForTracer = { void UpdateReliable( torch::jit::Value* output, - const std::pair& inferred_type_reliable) { + const std::pair& inferred_type_reliable, + bool no_type_warning) { auto inferred = ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false); auto isTypeReliableForTracer = @@ -1887,7 +1907,9 @@ void UpdateReliable( output->node()->kind().toDisplayString()) != nodeTypeReliableForTracer.end(); if (!inferred && !isTypeReliableForTracer && - !output->node()->kind().is_onnx()) { + !output->node()->kind().is_onnx() && no_type_warning) { + // TODO(84661): This warning comes before setType in symbolic_fn. + // tracked in #84661 TORCH_WARN( "The shape inference of ", output->node()->kind().toDisplayString(), @@ -1949,6 +1971,7 @@ void ONNXShapeTypeInference( SetGraphInputTypeReliable(n->owningGraph()); GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); + if (IsValidONNXNode(n)) { // Create a Graph containing only the single node n. // This graph is later converted to ONNX to run shape inference. @@ -2041,6 +2064,15 @@ void ONNXShapeTypeInference( GRAPH_DEBUG( "ONNX graph after shape inference: ", prettyPrint(*model_proto)); } + } else if (CustomSettype(n)) { + // If the node is not ONNX standard, go through every output to check if + // they all have shape. If they all do, this should be reliable even if the + // Op is not from ONNX. + for (auto node_output : n->outputs()) { + // Custom setType output should get in here if it's set correctly. They + // will be updated to inferred for later updatereliable function. + ConstantValueMap::SetUseInferredType(node_output->debugName(), true); + } } SpecialPostProcess(n); @@ -2082,20 +2114,7 @@ void ONNXShapeTypeInference( // reliable shape but its shape is not in ConstantValueMap. So we need this // logic to update ConstantValueMap. for (auto node_output : n->outputs()) { - if (ConstantValueMap::HasTypeReliable(node_output->debugName())) { - auto reliable = - ConstantValueMap::GetTypeReliable(node_output->debugName()) - .value_or(false); - if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) { - // TODO: ListType case - if (auto output_tensor_type = node_output->type()->cast()) { - if (output_tensor_type->dim()) { - auto symbolic_sizes = output_tensor_type->symbolic_sizes(); - UpdateShapeConstantValueMap(node_output, symbolic_sizes); - } - } - } - } + UpdateShapeConstantIfReliable(node_output); } GRAPH_DEBUG( @@ -2280,10 +2299,10 @@ size_t ONNXAssignOutputShape( // Tracing: // Ignore None, since it is not captured in IR graph as output. // Scripting: - // Ignore None, if observing a fixed `None` node in IR graph. Because it - // is meaningless to include it as graph output as it carries no - // data/information. Plus that static `None` is not supported in ONNX IR. - // Otherwise, the output should have type `Optional`, and should be + // Ignore None, if observing a fixed `None` node in IR graph. Because + // it is meaningless to include it as graph output as it carries no + // data/information. Plus that static `None` is not supported in ONNX + // IR. Otherwise, the output should have type `Optional`, and should be // converted to ONNX `Optional`. // More context: @@ -2343,5 +2362,21 @@ void ONNXShapeTypeInference( ConstantValueMap::ClearMaps(); } +void UpdateShapeConstantIfReliable(torch::jit::Value* node_output) { + if (ConstantValueMap::HasTypeReliable(node_output->debugName())) { + auto reliable = ConstantValueMap::GetTypeReliable(node_output->debugName()) + .value_or(false); + if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) { + // TODO: ListType case + if (auto output_tensor_type = node_output->type()->cast()) { + if (output_tensor_type->dim()) { + auto symbolic_sizes = output_tensor_type->symbolic_sizes(); + UpdateShapeConstantValueMap(node_output, symbolic_sizes); + } + } + } + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index afda5b176537..39350ed273d4 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -34,8 +34,7 @@ std::pair MergeInferredType( void MergeInferredTypeAndSetMap( Value* dest_v, TypePtr existing_type, - TypePtr inferred_type, - bool set_constant_value_map = true); + TypePtr inferred_type); // Update graph input types with dynamic axes info. // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. @@ -80,9 +79,11 @@ TORCH_API void ONNXShapeTypeInference( std::pair AreInputsReliableOrStatic(Node* n); void UpdateReliable( torch::jit::Value* output, - const std::pair& input_reliable); + const std::pair& input_reliable, + bool no_type_warning = false); void UpdateReliable(torch::jit::Node* n); +void UpdateShapeConstantIfReliable(torch::jit::Value* output); } // namespace jit } // namespace torch From 6b8c1b19b513ec3d82d588961f8a2b4a86e08f99 Mon Sep 17 00:00:00 2001 From: Michael Voznesensky Date: Sat, 19 Nov 2022 17:49:39 +0000 Subject: [PATCH 377/453] RM expectedFailure UnspecReproTests.test_batch_norm_act_unspec (#89340) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89340 Approved by: https://github.com/bertmaher --- test/dynamo/test_unspec.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index e46d79208de0..fd5396981b74 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -50,8 +50,6 @@ class UnspecTest(cls): UnspecReproTests = make_unspec_cls(test_repros.ReproTests) UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) -unittest.expectedFailure(UnspecReproTests.test_batch_norm_act_unspec) - @patch.object(torch._dynamo.config, "specialize_int_float", False) class UnspecTests(torch._dynamo.test_case.TestCase): From 6afe341276f9ffa660446c5fa15b68558791869a Mon Sep 17 00:00:00 2001 From: fduwjj Date: Sat, 19 Nov 2022 18:01:25 +0000 Subject: [PATCH 378/453] [PT-D][1/N] Sync TP Beta change to prod (#89242) This is part of TP Beta Release efforts. ref: https://github.com/pytorch/tau/issues/576 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89242 Approved by: https://github.com/wanchaol --- .../_tensor/parallel/test_parallelize_api.py | 136 ++++++++++++ .../_tensor/parallel/test_tp_examples.py | 78 +------ .../_tensor/parallel/test_tp_style.py | 192 +++++++++++++++++ .../distributed/_tensor/parallel/__init__.py | 12 ++ torch/distributed/_tensor/parallel/api.py | 112 ++++++++++ torch/distributed/_tensor/parallel/style.py | 197 ++++++++++++++++++ torch/distributed/_tensor/parallel/utils.py | 149 +++++++++++++ 7 files changed, 805 insertions(+), 71 deletions(-) create mode 100644 test/distributed/_tensor/parallel/test_parallelize_api.py create mode 100644 test/distributed/_tensor/parallel/test_tp_style.py create mode 100644 torch/distributed/_tensor/parallel/style.py create mode 100644 torch/distributed/_tensor/parallel/utils.py diff --git a/test/distributed/_tensor/parallel/test_parallelize_api.py b/test/distributed/_tensor/parallel/test_parallelize_api.py new file mode 100644 index 000000000000..fb3e8f4721c8 --- /dev/null +++ b/test/distributed/_tensor/parallel/test_parallelize_api.py @@ -0,0 +1,136 @@ +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms +from torch.distributed._tensor import distribute_tensor, DeviceMesh, Shard, Replicate +from torch.distributed._tensor.parallel import PairwiseParallel, ParallelStyle +from torch.distributed._tensor.parallel.api import _parallelize_mlp +from torch.distributed._tensor.parallel.utils import _create_1d_device_mesh +from torch.distributed._tensor.parallel.style import ( + make_input_replicate_1d, + make_output_replicate_1d, +) + + +class MLPModule(torch.nn.Module): + def __init__(self, device): + super(MLPModule, self).__init__() + torch.manual_seed(5) + self.net1 = torch.nn.Linear(10, 16, device=device) + self.relu = torch.nn.ReLU() + self.net2 = torch.nn.Linear(16, 12, device=device) + + def forward(self, x): + return self.net2(self.relu(self.net1(x))) + + +class TensorParallelAPITests(DTensorTestBase): + @property + def world_size(self): + gpu_num = torch.cuda.device_count() + return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 + + @with_comms + def test_creat_1d_device_mesh(self): + dim_one_size = 2 + mesh_shape = ( + torch.arange(self.world_size) + .reshape( + self.world_size // dim_one_size, + dim_one_size, + ) + .to(torch.int) + ) + mesh = DeviceMesh(self.device_type, mesh_shape) + # When 1D dim is 1. + one_dimention_mesh_shape = mesh_shape[self.rank // dim_one_size, :] + pg = mesh.get_dim_groups()[1] + new_mesh = _create_1d_device_mesh(mesh, 1) + expected_mesh = DeviceMesh( + self.device_type, one_dimention_mesh_shape, [pg] + ) + self.assertEqual(new_mesh.mesh, expected_mesh.mesh) + self.assertEqual(new_mesh.device_type, expected_mesh.device_type) + # When 1D dim is 0. + one_dimention_mesh_shape = mesh_shape[:, self.rank % dim_one_size] + pg = mesh.get_dim_groups()[0] + new_mesh = _create_1d_device_mesh(mesh, 0) + expected_mesh = DeviceMesh( + self.device_type, one_dimention_mesh_shape, [pg] + ) + self.assertEqual(new_mesh.mesh, expected_mesh.mesh) + self.assertEqual(new_mesh.device_type, expected_mesh.device_type) + + @with_comms + def test_creat_1d_device_mesh_error(self): + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + with self.assertRaisesRegex( + AssertionError, + "Expect tp_mesh_dim within range \\[-1, 1\\), but found 3.", + ): + _create_1d_device_mesh(mesh, 3) + + @with_comms + def test_parallelize_mlp(self): + model = MLPModule(self.device_type) + model_tp = MLPModule(self.device_type) + + # Ensure model are initialized the same way. + self.assertEqual(model.net1.weight, model_tp.net1.weight) + self.assertEqual(model.net1.bias, model_tp.net1.bias) + self.assertEqual(model.net2.weight, model_tp.net2.weight) + self.assertEqual(model.net2.bias, model_tp.net2.bias) + + # Parallelize module. + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size) + ) + _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) + + # Ensure the parameter is properly distributed. + self.assertEqual( + distribute_tensor(model.net1.weight, device_mesh, [Shard(0)]), + model_tp.net1.weight, + ) + self.assertEqual( + distribute_tensor(model.net1.bias, device_mesh, [Shard(0)]), + model_tp.net1.bias, + ) + self.assertEqual( + distribute_tensor(model.net2.weight, device_mesh, [Shard(1)]), + model_tp.net2.weight, + ) + self.assertEqual( + distribute_tensor(model.net2.bias, device_mesh, [Replicate()]), + model_tp.net2.bias, + ) + + @with_comms + def test_parallelize_mlp_error(self): + class DummyParallel(ParallelStyle): + def __init__(self) -> None: + super().__init__( + make_input_replicate_1d, make_output_replicate_1d + ) + + model_tp = MLPModule(self.device_type) + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size) + ) + with self.assertRaisesRegex( + NotImplementedError, + "Only support PairwiseParallel for MLP parallelization.", + ): + _parallelize_mlp(model_tp, device_mesh, DummyParallel()) + + with self.assertRaisesRegex( + RuntimeError, "We only support even number of Linear for MLP." + ): + _parallelize_mlp( + torch.nn.Linear(10, 5), device_mesh, PairwiseParallel() + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tensor/parallel/test_tp_examples.py b/test/distributed/_tensor/parallel/test_tp_examples.py index 582108ea7599..696171e4ca88 100644 --- a/test/distributed/_tensor/parallel/test_tp_examples.py +++ b/test/distributed/_tensor/parallel/test_tp_examples.py @@ -11,11 +11,8 @@ skip_unless_torch_gpu, ) from torch.distributed._tensor import ( - distribute_tensor, distribute_module, DeviceMesh, - DTensor, - Shard, Replicate, ) from torch.distributed._tensor.parallel import ( @@ -24,6 +21,8 @@ replicate_input, replicate_output, ) +from torch.distributed._tensor.parallel import PairwiseParallel +from torch.distributed._tensor.parallel.api import _parallelize_mlp class MLPModule(torch.nn.Module): @@ -38,73 +37,6 @@ def forward(self, x): return self.net2(self.relu(self.net1(x))) -def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module: - def hook_func(_module, _input, output): - if isinstance(output, DTensor): - replica_placement = [Replicate()] * device_mesh.ndim - return ( - output.redistribute(output.device_mesh, replica_placement) - .contiguous() - .to_local() - ) - - module.register_forward_hook(hook_func) - return module - - -def shard_mlp(m, device_type, tp_size): - start_idx = 0 - device_mesh = DeviceMesh( - device_type, - list(range(start_idx, start_idx + tp_size)), - ) - col_wise_sharding = [Shard(0)] - row_wise_sharding = [Shard(1)] - replicate = [Replicate()] * device_mesh.ndim - - def shard_params(name, module, device_mesh): - if isinstance(module, nn.Linear): - if name == "net1": - sharded_weight = nn.Parameter( - distribute_tensor( - module.weight, device_mesh, col_wise_sharding - ) - ) - sharded_bias = nn.Parameter( - distribute_tensor( - module.bias, device_mesh, col_wise_sharding - ) - ) - module.register_parameter("weight", sharded_weight) - module.register_parameter("bias", sharded_bias) - elif name == "net2": - sharded_weight = nn.Parameter( - distribute_tensor( - module.weight, device_mesh, row_wise_sharding - ) - ) - replicated_bias = nn.Parameter( - distribute_tensor(module.bias, device_mesh, replicate) - ) - module.register_parameter("weight", sharded_weight) - module.register_parameter("bias", replicated_bias) - - def aggregate_output(outputs, device_mesh): - assert isinstance(outputs, DTensor) - return ( - outputs.redistribute(device_mesh, replicate).contiguous().to_local() - ) - - dist_mod = distribute_module( - m, - device_mesh, - partition_fn=shard_params, - input_fn=replicate_input, - output_fn=aggregate_output, - ) - return dist_mod - - class MultiheadAttnWrap(nn.Module): def __init__(self, embed_dim, num_heads, add_bias_kv=False, device=None): super().__init__() @@ -134,7 +66,11 @@ def test_mlp_megatron_e2e(self): # Shard module and initialize optimizer. LR = 0.25 - shard_mlp(model_tp, self.device_type, NUM_DEVICES) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(0, NUM_DEVICES), + ) + _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) optim = torch.optim.SGD(model.parameters(), lr=LR) optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) diff --git a/test/distributed/_tensor/parallel/test_tp_style.py b/test/distributed/_tensor/parallel/test_tp_style.py new file mode 100644 index 000000000000..314fe470955b --- /dev/null +++ b/test/distributed/_tensor/parallel/test_tp_style.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# Owner(s): ["oncall: distributed"] + +import torch +from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms +from torch.distributed._tensor import distribute_tensor, DeviceMesh, Shard, Replicate +from torch.distributed._tensor.parallel.style import ( + RowwiseParallel, + ColwiseParallel, + make_input_shard_1d, + make_input_replicate_1d, + make_output_shard_1d, + make_output_replicate_1d, + make_output_tensor, +) + + +class TensorParallelStyleTest(DTensorTestBase): + @property + def world_size(self): + gpu_num = torch.cuda.device_count() + return gpu_num if gpu_num % 2 == 0 and gpu_num > 4 else 4 + + def _1d_input_func_check( + self, input_local_tensor, expected_local_tensor, func + ) -> None: + with self.assertRaisesRegex( + RuntimeError, "device_mesh is not passed nor can be inferred" + ): + dtensor = func(input_local_tensor) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + with self.assertRaisesRegex( + RuntimeError, + "device_mesh has dims [0-9]+ but expcted to be 1 for input.", + ): + dtensor = func(input_local_tensor, device_mesh) + + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + # test 1: replicate local tensor + dtensor = func(input_local_tensor, device_mesh) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 2: replicate DTensor + dtensor = func(dtensor) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + # test 3: replicate DTensor with DeviceMesh passed + dtensor = func(dtensor, device_mesh) + self.assertEqual(expected_local_tensor, dtensor.to_local()) + + @with_comms + def test_make_input_replicate_1d(self): + tensor = torch.rand(8, 16, device=self.device_type) + self._1d_input_func_check(tensor, tensor, make_input_replicate_1d) + + @with_comms + def test_make_input_shard_1d(self): + tensor = torch.rand(8, 16, device=self.device_type) + self._1d_input_func_check(tensor, tensor, make_input_shard_1d) + + # Common logic for testing prepare output funcs + def _test_prepare_output( + self, func, spec, dim=None, device_mesh_input_none=False + ): + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + tensor = torch.rand(8, 16, device=self.device_type) + dtensor = distribute_tensor(tensor, device_mesh, spec) + device_mesh_input = None if device_mesh_input_none else device_mesh + if dim is not None: + output = func(dtensor, device_mesh_input, dim) + else: + output = func(dtensor, device_mesh_input) + return output, dtensor, device_mesh + + @with_comms + def test_make_output_shard_1d(self): + # test when output is sharded. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Shard(0)], 1 + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(1)])) + # test when output is replicated. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Replicate()], 0 + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(0)])) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_shard_1d, [Shard(0)], 1, True + ) + self.assertEqual(output, dtensor.redistribute(device_mesh, [Shard(1)])) + + @with_comms + def test_make_output_replicate_1d(self): + output, dtensor, device_mesh = self._test_prepare_output( + make_output_replicate_1d, [Shard(0)] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]) + ) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_replicate_1d, [Shard(0)], None, True + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]) + ) + + @with_comms + def test_make_output_tensor(self): + # test when output is sharded. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Shard(0)] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + # test when output is replicated. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Replicate()] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + make_output_tensor, [Shard(0)], None, True + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]).to_local() + ) + + # Common logic for testing prepare output funcs errors. + def _test_prepare_output_error(self, func): + tensor = torch.rand(8, 16, device=self.device_type) + device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) + output = [dtensor] + with self.assertRaisesRegex( + AssertionError, + f"Expect output of Tensor Parallel to be a DTensor, but found {type(output)}.", + ): + func(output, device_mesh) + device_mesh = DeviceMesh( + self.device_type, + torch.arange(self.world_size).reshape(self.world_size // 2, 2), + ) + with self.assertRaisesRegex( + AssertionError, + "device_mesh has dims 2 but expcted to be 1 for output.", + ): + func(dtensor, device_mesh) + + @with_comms + def test_prepare_output_error(self): + self._test_prepare_output_error(make_output_shard_1d) + self._test_prepare_output_error(make_output_replicate_1d) + self._test_prepare_output_error(make_output_tensor) + + @with_comms + def test_rowwise_parallel_style(self): + tensor = torch.rand(8, 16, device=self.device_type) + rs = RowwiseParallel() + self._1d_input_func_check(tensor, tensor, rs._prepare_input) + # TODO: change output test + output, dtensor, device_mesh = self._test_prepare_output( + rs._prepare_input, [Shard(0)] + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]) + ) + # test when input device_mesh is None. + output, dtensor, device_mesh = self._test_prepare_output( + rs._prepare_input, [Shard(0)], None, True + ) + self.assertEqual( + output, dtensor.redistribute(device_mesh, [Replicate()]) + ) + self._test_prepare_output_error(rs._prepare_output) + + @with_comms + def test_colwise_parallel_style(self): + tensor = torch.rand(8, 16, device=self.device_type) + cs = ColwiseParallel() + self._1d_input_func_check(tensor, tensor, cs._prepare_input) + self.assertEqual(None, cs._prepare_output) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/_tensor/parallel/__init__.py b/torch/distributed/_tensor/parallel/__init__.py index 5725c5077d4b..0ef0e8ff0b9e 100644 --- a/torch/distributed/_tensor/parallel/__init__.py +++ b/torch/distributed/_tensor/parallel/__init__.py @@ -8,3 +8,15 @@ replicate_input, replicate_output, ) + +from torch.distributed._tensor.parallel.style import ( + ParallelStyle, + PairwiseParallel, + RowwiseParallel, + ColwiseParallel, + make_input_shard_1d, + make_input_replicate_1d, + make_output_shard_1d, + make_output_replicate_1d, + make_output_tensor, +) diff --git a/torch/distributed/_tensor/parallel/api.py b/torch/distributed/_tensor/parallel/api.py index 7ab3ad2199f2..68d444882c4c 100644 --- a/torch/distributed/_tensor/parallel/api.py +++ b/torch/distributed/_tensor/parallel/api.py @@ -3,6 +3,7 @@ import torch.nn as nn from typing import Sequence, Tuple from torch.distributed._tensor import ( + distribute_module, distribute_tensor, DTensor, Shard, @@ -11,6 +12,8 @@ Placement, ) from torch.distributed._tensor.parallel import TensorParallelMultiheadAttention +from torch.distributed._tensor.parallel.style import ParallelStyle, PairwiseParallel +from torch.distributed._tensor.parallel.utils import _create_1d_device_mesh def replicate_input( @@ -84,3 +87,112 @@ def _shard_self_attn_params(name: str, module: nn.Module) -> None: ) tp_multi_head_attention.copy(m) module.register_module(n, tp_multi_head_attention) + + +def _has_even_num_linears(module: nn.Module) -> bool: + """ + We traverse through all the children of the given module and count the + number of Linear module. If the number is even, we return True. + + Args: + module (nn.Module): + :class:``nn.Module`` object to be traversed and counted. + + Return: + A boolean object which specifies whether the module contains + event-number of Linears in its children. + + .. warning:: + The traversal is not recursive for now. + """ + linear_submodules = list( + filter(lambda x: isinstance(x, nn.Linear), module.children()) + ) + return len(linear_submodules) > 0 and len(linear_submodules) % 2 == 0 + + +def _parallelize_mlp( + module: nn.Module, + device_mesh: DeviceMesh, + parallel_style: ParallelStyle = PairwiseParallel(), + tp_mesh_dim: int = 0, +) -> None: + """ + This function assumes the input module is a sequence of nn.Linear + and we parallelize the module based on the given parallel style. + We don't change the FQN of each sub-module and replace each parameter + in place. + + Args: + module (nn.Module): + :class:``nn.Module`` object to be parallelized. + device_mesh (DeviceMesh): + :class:``DeviceMesh`` object which describes the mesh topology + of devices for the DTensor. + parallel_style (ParallelStyle): + :class:``ParallelStyle`` object which contains how + we prepare input/output for Tensor Parallelism. + tp_mesh_dim (int): + the dimension of ``device_mesh`` where we perform + Tensor Parallelism on. + + Return: + None + + .. warning:: + We only support ``PairwiseParallel`` right now. + """ + + # Define partition functions needed. + def _rowwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] + for name, param in module.named_parameters(): + dist_spec = ( + [Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item] + ) + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, dist_spec) + ) + module.register_parameter(name, dist_param) + + def _colwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Shard(0)]) + ) + module.register_parameter(name, dist_param) + + if not isinstance(parallel_style, PairwiseParallel): + raise NotImplementedError( + "Only support PairwiseParallel for MLP parallelization." + ) + + if not _has_even_num_linears(module): + raise RuntimeError("We only support even number of Linear for MLP.") + + if device_mesh.ndim > 1: + device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) + + linear_submodules = list( + filter(lambda x: isinstance(x, nn.Linear), module.children()) + ) + for i, m in enumerate(linear_submodules): + if i % 2 == 0: + # Col-wise Parallelize the linear layer + distribute_module( + m, + device_mesh, + _colwise_parallelize_fn, + input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6] + if i == 0 + else None, + ) + else: + # Row-wise Parallelize the linear layer + distribute_module( + m, + device_mesh, + _rowwise_parallelize_fn, + output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6] + if i == (len(linear_submodules) - 1) + else None, + ) diff --git a/torch/distributed/_tensor/parallel/style.py b/torch/distributed/_tensor/parallel/style.py new file mode 100644 index 000000000000..5ea96434118a --- /dev/null +++ b/torch/distributed/_tensor/parallel/style.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from abc import abstractmethod +import torch +from abc import ABC +from typing import Union, Optional +from torch.distributed._tensor import DTensor, Shard, Replicate, DeviceMesh +from torch.distributed._tensor.parallel.utils import ( + _Prepare_Input_Func_Type, + _Prepare_Output_Func_Type, + _prepare_input_validate, + _prepare_output_validate, +) + + +class ParallelStyle(ABC): + """ + The parallel style user wants the module or submodule to be parallelized. + Users can extend this class to build their own parallel style with customized input/output preparations. + """ + + _prepare_input: _Prepare_Input_Func_Type + _prepare_output: _Prepare_Output_Func_Type + + @abstractmethod + def __init__(self, _prepare_input, _prepare_output) -> None: + self._prepare_input = _prepare_input # type: ignore[assignment, misc] + self._prepare_output = _prepare_output # type: ignore[assignment, misc] + + +class PairwiseParallel(ParallelStyle): + """ + PairwiseParallel concatenate colwise and rowwise styles as a fixed + pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. + We assume both input and output needs to a replicate DTensor. + + .. warning:: + PairwiseParallel only supports ``nn.Multihead Attention``, + ``nn.Transformer`` or even-number-layer MLP for now. + """ + + def __init__(self) -> None: + super().__init__(make_input_replicate_1d, make_output_tensor) + + +class RowwiseParallel(ParallelStyle): + """ + Partitioning the row of a module. + We assume the input to be a sharded :class:``DTensor`` and output to be a replicated :class:``DTensor``. + """ + + def __init__(self) -> None: + super().__init__(make_input_shard_1d, make_output_replicate_1d) + + +class ColwiseParallel(ParallelStyle): + """ + Partitioning the column of a tensor or module. + We assume the input to be a replicated :class:``DTensor`` and output to be a sharded :class:``DTensor``. + """ + + def __init__(self) -> None: + super().__init__(make_input_replicate_1d, None) + + +@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] +def make_input_shard_1d( + input: Union[torch.Tensor, DTensor], + device_mesh: Optional[DeviceMesh] = None, + dim: int = 0, +) -> DTensor: + """ + Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle. + + Args: + input (Union[Tensor, DTensor]): + This single tensor will be sharded on dimension ``dim`` + over the 1-D :class:``DeviceMesh``. + device_mesh (DeviceMesh, optional): + The 1-D device mesh where ``input`` will be sharded. + If no :class:``DeviceMesh`` is passed and ``input`` is a :class:``DTensor``, + `input.device_mesh` will be used. + If :class:``DeviceMesh`` is not 1-D, an exception will be thrown. + Default: ``None`` + dim (int, optional): The sharding dimension of ``input`` tensor. + Default: 0 + + Returns: + A :class:``DTensor`` sharded on dimension ``dim`` over ``device_mesh``. + """ + shard_spec = [Shard(dim)] + if isinstance(input, DTensor): + return input.redistribute(device_mesh, shard_spec) + elif isinstance(input, torch.Tensor): + return DTensor.from_local( + input, device_mesh, shard_spec, run_check=False + ) + else: + raise RuntimeError( + f"Tensor parallel module expects torch.Tensor or DTensor input but received {type(input)}!" + ) + + +@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] +def make_input_replicate_1d( + input: Union[torch.Tensor, DTensor], + device_mesh: Optional[DeviceMesh] = None, +) -> DTensor: + """ + Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle. + + Args: + input (Union[Tensor, DTensor]): + This single tensor will be replicated over the 1-D :class:``DeviceMesh``. + device_mesh (DeviceMesh, optional): + The 1-D device mesh where ``input`` will be replicated. + If no :class:``DeviceMesh`` is passed and ``input`` is a :class:``DTensor``, + ``input.device_mesh`` will be used. + If :class:``DeviceMesh`` is not 1-D, an exception will be thrown. + Default: ``None`` + + Returns: + A :class:``DTensor`` replicated over ``device_mesh``. + """ + replicate = [Replicate()] + if isinstance(input, DTensor): + return input.redistribute(device_mesh, replicate) + elif isinstance(input, torch.Tensor): + return DTensor.from_local( + input, device_mesh, replicate, run_check=False + ) + else: + raise RuntimeError( + f"Tensor parallel module expects torch.Tensor or DTensor input but received {type(input)}!" + ) + + +@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] +def make_output_shard_1d( + output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0 +) -> DTensor: + """ + Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle. + Args: + output (DTensor): output of module to be converted. + device_mesh (Optional[DeviceMesh]): :class:``DeviceMesh`` object needed to + shard the output and it needs to be a 1D ``device_mesh`` and we will throw + exceptions if a non-1D ``device_mesh`` is passed in. If no ``device_mesh`` + is passed in, we will reuse the one from output. + Default: ``None`` + dim (int): Sharding dim for output. Default: 0 + Return: + A :class:``DTensor`` object sharded on the given dim. + """ + + return output.redistribute(device_mesh, [Shard(dim)]) + + +@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] +def make_output_replicate_1d( + output: DTensor, device_mesh: Optional[DeviceMesh] = None +) -> DTensor: + """ + Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle. + Args: + output (DTensor): output of module to be converted. + device_mesh (Optional[DeviceMesh]): :class:``DeviceMesh`` object needed to + replicate the output and it needs to be a 1D ``device_mesh`` and we will + throw exceptions if a non-1D ``device_mesh`` is passed in. If no + ``device_mesh`` is passed in, we will reuse the one from output. + Default: ``None`` + Return: + A :class:``DTensor`` object made replicate. + """ + + return output.redistribute(device_mesh, [Replicate()]) + + +@_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] +def make_output_tensor( + output: DTensor, device_mesh: Optional[DeviceMesh] = None +) -> torch.Tensor: + """ + Convert Output DTensor to a replicated DTensor first and then convert it to Tensor. + Args: + output (DTensor): output of module to be converted. + device_mesh (Optional[DeviceMesh]): :class:``DeviceMesh`` object needed to + replicate the output and it needs to be a 1D ``device_mesh`` and we will + throw exceptions if a non-1D ``device_mesh`` is passed in. If no + ``device_mesh`` is passed in, we will reuse the one from output. + Default: ``None`` + Return: + A :class:``torch.Tensor`` object converted from output DTensor. + """ + + return make_output_replicate_1d( # type: ignore[attr-defined] + output, device_mesh + ).to_local() # type: ignore[call-arg] diff --git a/torch/distributed/_tensor/parallel/utils.py b/torch/distributed/_tensor/parallel/utils.py new file mode 100644 index 000000000000..2680ae41ffbe --- /dev/null +++ b/torch/distributed/_tensor/parallel/utils.py @@ -0,0 +1,149 @@ +import functools + +import torch +from torch.distributed._tensor import DeviceMesh, DTensor +from typing import Callable, Optional, Union + +_Prepare_Input_Func_Type = Callable[ + [Union[torch.Tensor, DTensor], Optional[DeviceMesh], Optional[int]], DTensor +] + +_Prepare_Output_Func_Type = Callable[ + [DTensor, Optional[DeviceMesh], Optional[int]], Union[torch.Tensor, DTensor] +] + + +def _prepare_input_validate( + _prepare_input_func: _Prepare_Input_Func_Type, +) -> _Prepare_Input_Func_Type: + """ + Inject common validation logics for `_prepare_input` funcs via this + decorator, including verifying that input needs to be either + a :class:`Tensor` or :class:`DTensor` and only 1D :class:`DeviceMesh` + is passed in. + + Args: + _prepare_input_func (Callable): The func we want to inject the + validation into. + + Returns: + func (Callable): Same input function with validation logic added. + + Example:: + >>> @_prepare_input_validate + >>> def make_input_shard_1d(args, kwargs): + >>> ... + >>> + >>> input = torch.rand(...) + >>> dtensor = make_input_shard_1d(input, device_mesh, 1) + >>> # This will call '_prepare_input_validate' first + """ + + @functools.wraps(_prepare_input_func) + def wrapper(*args, **kwargs): # pyre-ignore[2, 3] + assert len(args) >= 1, "_prepare_input needs at least one arg." + input = args[0] + if isinstance(input, list) or isinstance(input, tuple): + input = input[0] + args = (input, *args[1:]) + device_mesh = None if len(args) < 2 else args[1] + + if device_mesh is None: + if isinstance(input, DTensor): + device_mesh = input.device_mesh + args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60] + else: + raise RuntimeError( + "device_mesh is not passed nor can be inferred" + ) + if device_mesh.ndim != 1: + raise RuntimeError( + f"device_mesh has dims {device_mesh.ndim} but expcted to be 1 for input." + ) + return _prepare_input_func(*args, **kwargs) + + return wrapper + + +def _prepare_output_validate( + _prepare_output_func: _Prepare_Output_Func_Type, +) -> _Prepare_Output_Func_Type: + """ + Inject common validation logics for _prepare_output funcs via this + decorator, including verifying that output needs to be a DTensor + and only 1D Device Mesh is passed in. + Example:: + >>> @_prepare_output_validate + >>> def make_output_shard_1d(args, kwargs): + >>> ... + >>> + >>> dt = distribute(tensor, device_mesh, [Shard(0)]) + >>> make_output_shard_1d(dt, device_mesh, 1) + >>> # This will call '_prepare_output_validate' first + Args: + _prepare_output_func (Callable): The func we want to inject the + validation into. + Return: + func (Callable): Same input func with validation logic added. + """ + + @functools.wraps(_prepare_output_func) + def wrapper(*args, **kwargs): # pyre-ignore[2, 3] + assert len(args) >= 1, "_prepare_output needs at least one arg." + output = args[0] + assert isinstance( + output, DTensor + ), f"Expect output of Tensor Parallel to be a DTensor, but found {type(output)}." + if len(args) < 2 or args[1] is None: + device_mesh = output.device_mesh + args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60] + else: + device_mesh = args[1] + + assert ( + device_mesh.ndim == 1 + ), f"device_mesh has dims {device_mesh.ndim} but expcted to be 1 for output." + return _prepare_output_func(*args, **kwargs) + + return wrapper + + +def _create_1d_device_mesh( + device_mesh: DeviceMesh, tp_mesh_dim: int = 0 +) -> DeviceMesh: + """ + This function converts a N-D ``device_mesh`` into a 1D ``device_mesh`` + for 1D Tensor Parallelism. + + Args: + device_mesh (DeviceMesh): + :class:``DeviceMesh`` object which describes the mesh topology + of devices for the DTensor. + tp_mesh_dim (int): + the dimension of ``device_mesh`` where we perform + Tensor Parallelism on. + + Return: + device_mesh (DeviceMesh): 1-D :class:``DeviceMesh`` object that + Tensor Parallelism operates on. + """ + assert ( + tp_mesh_dim < device_mesh.ndim and tp_mesh_dim >= -device_mesh.ndim + ), ( + f"Expect tp_mesh_dim within range [{-device_mesh.ndim}, {device_mesh.ndim})" + f", but found {tp_mesh_dim}." + ) + + if device_mesh.ndim == 1: + return device_mesh + + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + cur_rank = device_mesh.get_rank() + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, tp_mesh_dim).reshape( + -1, device_mesh.mesh.size(tp_mesh_dim) + ) + dim_mesh_1d = pg_ranks_by_dim[torch.any(pg_ranks_by_dim == cur_rank, 1), :] + + sub_pg = device_mesh.get_dim_groups()[tp_mesh_dim] + return DeviceMesh(device_mesh.device_type, dim_mesh_1d.squeeze(), [sub_pg]) From 5582001bd5e5c66dcab8859ecb84cbaa42524fd4 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 19 Nov 2022 12:51:53 -0500 Subject: [PATCH 379/453] Reland 2 "Towards unifying symbolic and non symbolic fake tensor (#89038) (#89143)" (#89346) This reverts commit 8e4c9828f4c990f439179912159086aaed790493. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89346 Approved by: https://github.com/wconstab --- aten/src/ATen/native/TensorFactories.cpp | 6 --- test/functorch/test_aotdispatch.py | 3 -- test/test_proxy_tensor.py | 19 +++----- torch/_meta_registrations.py | 39 +++++++++++++++- torch/_ops.py | 1 + torch/_prims/__init__.py | 5 +- torch/_prims_common/__init__.py | 3 ++ torch/_subclasses/fake_tensor.py | 58 +++++++++--------------- 8 files changed, 71 insertions(+), 63 deletions(-) diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 9d1c6d8a3633..7245cb77b1c5 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -325,12 +325,6 @@ Tensor empty_like( // See [Note: hacky wrapper removal for TensorOptions] TensorOptions options_ = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); - - TORCH_CHECK( - !(options_.has_memory_format() && optional_memory_format.has_value()), - "Cannot set memory_format both in TensorOptions and explicit argument; please delete " - "the redundant setter."); - TensorOptions options = self.options() .merge_in(options_) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index de6d82960adc..e03fe1e15385 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1011,10 +1011,8 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumsum', ''), # aten.cumsum.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.fft2', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('fft.fft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides @@ -1168,7 +1166,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('put', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition - xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition xfail('repeat_interleave', ''), # aten.repeat_interleave.Te... xfail('reshape_as', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index e174a1483791..fa04c57d9426 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1164,9 +1164,7 @@ def f(a, b, c, d, e): xfail('cummin', ''), # aten.cummin.default - couldn't find symbolic meta function/decomposition xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition - xfail('deg2rad', ''), # aten.deg2rad.default - couldn't find symbolic meta function/decomposition xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition - xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1248,8 +1246,6 @@ def f(a, b, c, d, e): xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition - xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 - xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... @@ -1291,7 +1287,6 @@ def f(a, b, c, d, e): xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta function/decompos... xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... - xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco... xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition @@ -1307,10 +1302,8 @@ def f(a, b, c, d, e): xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition - xfail('put', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition - xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition @@ -1356,14 +1349,18 @@ def f(a, b, c, d, e): symbolic_tensor_failures.update(symbolic_tensor_segfaults) +outplace_symbolic_tensor_failures = { + xfail('masked_fill', ''), # expected predicate to be bool, got torch.float32 + xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition + xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition +} + inplace_symbolic_tensor_failures = { # bugs xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double # decomp not implemented - xfail('addbmm', ''), xfail('addmm', ''), xfail('addmm', 'decomposed'), - xfail('logit', ''), xfail('nn.functional.hardsigmoid', ''), xfail('round', ''), # ref missing a kwarg xfail('round', 'decimals_0'), # ref missing a kwarg @@ -1373,10 +1370,8 @@ def f(a, b, c, d, e): # in-place has a different signature than out-of-place xfail('uniform', ''), # Views - xfail('squeeze', ''), xfail('t', ''), xfail('transpose', ''), - xfail('nn.functional.dropout3d', ''), # calls unsqueeze_ } # Copies inputs to inplace operations to avoid inplace modifications @@ -1452,7 +1447,7 @@ def test_make_fx_fake_exhaustive(self, device, dtype, op): @skipIfNoSympy @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', - make_fx_failures | fake_tensor_failures | symbolic_tensor_failures) + make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) def test_make_fx_symbolic_exhaustive(self, device, dtype, op): _test_make_fx_helper(self, device, dtype, op, "symbolic") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 4fa3ab09d275..9849df0a58af 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1513,7 +1513,6 @@ def full(size, fill_value, *args, **kwargs): aten.randn_like.default, aten.rand_like.default, aten.full_like.default, - aten.zeros_like.default, aten.ones_like.default, ] ) @@ -1521,6 +1520,44 @@ def meta_like(self, *args, **kwargs): return aten.empty_like.default(self, **kwargs) +# zeros_like is special cased to work for sparse +@register_meta(aten.zeros_like.default) +def zeros_like( + self, dtype=None, layout=None, device=None, pin_memory=None, memory_format=None +): + if layout == torch.sparse_coo: + check( + memory_format is None, + lambda: "memory format option is only supported by strided tensors", + ) + + res = torch.empty( + 0, + dtype=self.dtype if dtype is None else dtype, + layout=layout, + device=self.device if device is None else device, + pin_memory=pin_memory, + ) + + if self.is_sparse: + res.sparse_resize_and_clear_( + self.size(), self.sparse_dim(), self.dense_dim() + ) + else: + res.sparse_resize_and_clear_(self.size(), self.dim(), 0) + + res._coalesced_(True) + return res + return aten.empty_like.default( + self, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + memory_format=memory_format, + ) + + # hacky: Please remove after math.ceil works with arange @register_meta(aten.arange.default) def arange(end, **kwargs): diff --git a/torch/_ops.py b/torch/_ops.py index 9163932144d0..b20398a7f3ab 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -365,6 +365,7 @@ def handler(*args, **kwargs): return handler final_key = resolve_key(self, key) + # print(self, key, final_key) r = self.py_kernels.get(final_key, final_key) self._dispatch_cache[key] = r return r diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 22917ec048eb..67e16ca102ac 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -1150,9 +1150,6 @@ def _minimum_aten( # # View operations -# -# TODO: model view relationships -# TODO: model storage def _as_strided_meta( a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int ) -> TensorLikeType: @@ -1170,7 +1167,7 @@ def _as_strided_meta( a._typed_storage(), size, stride, storage_offset ) - return TensorMeta(a, shape=size, strides=stride) + return torch.as_strided(a, size, stride, storage_offset) def _as_strided_aten( diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 7752f1836141..6df72f6c158d 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -291,6 +291,9 @@ def is_non_overlapping_and_dense(a: Tensor) -> bool: its dimensions that is contiguous. """ + if a.is_sparse: + return False + # Short-circuits if the tensor is already contiguous or channels-last contiguous if is_contiguous(a) or is_channels_last_contiguous(a): return True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 5d3d3a0e32fe..9a0ac050e6b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1,7 +1,6 @@ import contextlib import functools import itertools -import sys import weakref from dataclasses import dataclass from functools import partial @@ -297,8 +296,9 @@ def constructors(fake_mode, func, *args, **kwargs): out_device = new_kwargs.pop("device", None) out_device = out_device if out_device is not None else default_device new_kwargs["device"] = torch.device("meta") - # Not in_kernel_invocation_manager as no fake tensor inputs - with no_dispatch(): + # _like constructors have fake tensor inputs (maybe this causes the non-like + # to fail? hmmm) + with in_kernel_invocation_manager(fake_mode): r = func(*args, **new_kwargs) return FakeTensor(fake_mode, r, out_device) @@ -821,40 +821,30 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # is written to must be invalidated self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) - from torch._decomp import decomposition_table - - with self: - # Decomposes CompositeImplicitAutograd ops - r = func.decompose(*args, **kwargs) - if r is not NotImplemented: - return r + # If there's a Python meta, prefer that over the decomposition + from torch._decomp import meta_table as meta_table - # IDK: feels bad man, sym_numel on as_strided infinite loops otherwise - if has_symbolic_sizes and not self.cpp_meta_supports_symint(func): - from torch._decomp import meta_table as meta_table + if func not in meta_table and not self.cpp_meta_supports_symint(func): + from torch._decomp import decomposition_table - if func == aten.size.default: - sys.stderr.write( - "Trying to call aten.size on a tensor with symbolic shapes. " - "It's likely that this is from calling tensor.shape in C++" + # Prefer Python decompositions over C++ ones + if func in decomposition_table and ( + has_symbolic_sizes + or ( + # TODO: Remove these exclusions, so that we can remove + # this leg entirely + torch_decomp_decompositions(func) + and all(not e.is_sparse for e in flat_arg_fake_tensors) ) - # We do this to allow for better error localization with `TORCH_SHOW_CPP_STACKTRACES=1` - return None - - with self: - if func in meta_table: - r = meta_table[func](*args, **kwargs) - return r - if func in decomposition_table: + ): + with self: return decomposition_table[func](*args, **kwargs) - if ( - func in decomposition_table - and torch_decomp_decompositions(func) - and all(not e.is_sparse for e in flat_arg_fake_tensors) - ): with self: - return decomposition_table[func](*args, **kwargs) + # Decomposes CompositeImplicitAutograd ops + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r # prims already wrap FakeTensor inputs to FakeTensor outputs # and do device logic, we dont need do anything but run them @@ -865,12 +855,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self: return func.prim_meta_impl(*args, **kwargs) - if has_symbolic_sizes: - if not self.cpp_meta_supports_symint(func): - raise RuntimeError( - f"{func} - couldn't find symbolic meta function/decomposition" - ) - # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors # and then afterwards wrapping them to a FakeTensor From 8ac58bc2e3449760bef7f36f600a40c96d5bc5ba Mon Sep 17 00:00:00 2001 From: kvathupo Date: Sat, 19 Nov 2022 21:40:07 +0000 Subject: [PATCH 380/453] Add nullptr_t overload to c10::intrusive_ptr (#89196) __What?__ Fixes #82413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89196 Approved by: https://github.com/ezyang --- c10/test/util/intrusive_ptr_test.cpp | 5 +++++ c10/util/intrusive_ptr.h | 3 +++ 2 files changed, 8 insertions(+) diff --git a/c10/test/util/intrusive_ptr_test.cpp b/c10/test/util/intrusive_ptr_test.cpp index 7ed1c292841d..632fe7fc2f20 100644 --- a/c10/test/util/intrusive_ptr_test.cpp +++ b/c10/test/util/intrusive_ptr_test.cpp @@ -146,6 +146,11 @@ TEST(IntrusivePtrTest, givenInvalidPtr_whenCallingGet_thenReturnsNullptr) { EXPECT_EQ(nullptr, obj.get()); } +TEST(IntrusivePtrTest, givenNullptr_whenCallingGet_thenReturnsNullptr) { + intrusive_ptr obj(nullptr); + EXPECT_EQ(nullptr, obj.get()); +} + TEST(IntrusivePtrTest, givenValidPtr_whenDereferencing_thenReturnsObject) { intrusive_ptr obj = make_intrusive(5); diff --git a/c10/util/intrusive_ptr.h b/c10/util/intrusive_ptr.h index c87305b08be5..e75c1980fdfa 100644 --- a/c10/util/intrusive_ptr.h +++ b/c10/util/intrusive_ptr.h @@ -326,6 +326,9 @@ class intrusive_ptr final { intrusive_ptr() noexcept : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + intrusive_ptr(std::nullptr_t) noexcept + : intrusive_ptr(NullType::singleton(), raw::DontIncreaseRefcount{}) {} + // This constructor will not increase the ref counter for you. // We use the tagged dispatch mechanism to explicitly mark this constructor // to not increase the refcount From 8ad39536d741d9fc8c5d33f1344d23bd56f1c050 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 19 Nov 2022 21:47:55 +0000 Subject: [PATCH 381/453] Revert "Symintify numel(), infer_size, prims.elementwise_meta (#88956)" This reverts commit ce2f8700bafcf44850402a39188ec121ba8b5486. Reverted https://github.com/pytorch/pytorch/pull/88956 on behalf of https://github.com/ezyang due to somehow breaks torch.numel --- aten/src/ATen/ExpandUtils.cpp | 10 ++++---- aten/src/ATen/ExpandUtils.h | 2 -- test/test_proxy_tensor.py | 25 +++---------------- torch/_prims/__init__.py | 16 +++--------- torch/_refs/__init__.py | 4 +-- torch/_subclasses/fake_tensor.py | 6 ++++- torch/csrc/autograd/input_metadata.h | 4 +-- .../python_torch_functions_manual.cpp | 2 +- torch/fx/experimental/symbolic_shapes.py | 3 +++ torch/fx/traceback.py | 2 +- 10 files changed, 26 insertions(+), 48 deletions(-) diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp index ee846c9b82e3..a44005a2ef81 100644 --- a/aten/src/ATen/ExpandUtils.cpp +++ b/aten/src/ATen/ExpandUtils.cpp @@ -13,8 +13,8 @@ TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) { namespace { // NOTE: are_expandable did a similar check, please keep them sync if change is needed -template -Container infer_size_impl(ArrayType a, ArrayType b) { +template +Container infer_size_impl(IntArrayRef a, IntArrayRef b) { size_t dimsA = a.size(); size_t dimsB = b.size(); size_t ndim = dimsA > dimsB ? dimsA : dimsB; @@ -25,8 +25,8 @@ Container infer_size_impl(ArrayType a, ArrayType b) { ptrdiff_t offset = ndim - 1 - i; ptrdiff_t dimA = dimsA - 1 - offset; ptrdiff_t dimB = dimsB - 1 - offset; - auto sizeA = (dimA >= 0) ? a[dimA] : 1; - auto sizeB = (dimB >= 0) ? b[dimB] : 1; + int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; + int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; TORCH_CHECK( sizeA == sizeB || sizeA == 1 || sizeB == 1, @@ -35,7 +35,7 @@ Container infer_size_impl(ArrayType a, ArrayType b) { ") at non-singleton dimension ", i); // 1s map to the other size (even 0). - expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA); + expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; } return expandedSizes; diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 9e48421e540f..786cbf132cd7 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -21,8 +21,6 @@ namespace at { TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); -TORCH_API SymDimVector -infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); // Named type instead of a pair/tuple so that we can be sure to // construct the vectors in place and get NRVO. diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fa04c57d9426..34edc5cfac94 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -13,7 +13,6 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch._decomp import decomposition_table -from torch.fx.experimental.symbolic_shapes import sym_float from torch.testing._internal.common_device_type import ops from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy @@ -733,6 +732,7 @@ def deco(cls): @skipIfNoSympy @xfail_inherited_tests([ + "test_mode_tracing_factory_function", "test_make_fx_overloads", "test_trace_subclasses", ]) @@ -972,27 +972,8 @@ def f(x): # happened afterwards self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3) - def test_elementwise_meta_with_sym_numbers(self): - def f(x, offset, as_sym_float=False): - x0 = x.size()[0] - if as_sym_float: - x0 = sym_float(x0) - return torch.add(x0, offset) - - fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) - meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) - self.assertEqual(meta_add.meta['val'].shape, ()) - self.assertEqual(meta_add.meta['val'].dtype, torch.float32) - - fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) - meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) - self.assertEqual(meta_add.meta['val'].shape, ()) - self.assertEqual(meta_add.meta['val'].dtype, torch.int64) - - fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) - meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) - self.assertEqual(meta_add.meta['val'].shape, ()) - self.assertEqual(meta_add.meta['val'].dtype, torch.float32) + + def test_return_symint(self): def f(x): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index 67e16ca102ac..a867a44f72e3 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -31,7 +31,6 @@ ) from torch._prims_common.wrappers import backwards_not_supported from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.fx.experimental.symbolic_shapes import sym_float from torch.overrides import handle_torch_function, has_torch_function from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -391,18 +390,11 @@ def _elementwise_meta( return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype) # Number case + # NOTE: this case is not currently exercised # TODO: fix number type promotion (bool, complex->float) - - # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) - seen_float = False - if isinstance(number, (torch.SymInt, torch.SymFloat)): - for a in args: - assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" - seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) - if seen_float: - number = sym_float(number) - - return TensorMeta(number) # type: ignore[arg-type] + assert not isinstance(number, torch.SymInt), "NYI" + assert not isinstance(number, torch.SymFloat), "NYI" + return TensorMeta(number) def _complex_only_elementwise_meta(*args, **kwargs): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 3355400db43c..8ea1390a4449 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -744,10 +744,10 @@ def nan_to_num( nan = 0.0 if posinf is None: - posinf = torch.finfo(a.dtype).max + posinf = prims.maximum_value(a.dtype) if neginf is None: - neginf = torch.finfo(a.dtype).min + neginf = prims.minimum_value(a.dtype) result = where(isnan(a), nan, a) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9a0ac050e6b9..f52bec927b11 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -851,7 +851,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # and ensure that Meta kernels are dispatched to (see) # Fake Tensor Dispatch Keys # TODO - we should be use the prim aten impl - if "prims::" in func._schema.name and hasattr(func, "prim_meta_impl"): + if ( + "prims::" in func._schema.name + and len(flat_arg_fake_tensors) != 0 + and hasattr(func, "prim_meta_impl") + ): with self: return func.prim_meta_impl(*args, **kwargs) diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 8060c11ac457..7cb9e8aedb19 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -125,13 +125,13 @@ struct InputMetadata { if (grad.is_nested()) { ss << at::native::get_nested_size_tensor(grad); } else { - ss << grad.sym_sizes(); + ss << grad.sizes(); } ss << " but expected shape compatible with "; if (is_nested_tensor()) { ss << shape_as_tensor(); } else { - ss << shape_as_dim_vector(); + ss << c10::asIntArrayRefSlow(shape_as_dim_vector()); } return ss; } diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 2c4999c971ea..562f5a427d38 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -692,7 +692,7 @@ static PyObject* THPVariable_numel( } if (r.idx == 0) { - return wrap(r.tensor(0).sym_numel()); + return wrap(r.tensor(0).numel()); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bd52760502c6..ae4427e2320e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -261,6 +261,9 @@ def eval(cls, base, divisor): 'floordiv': lambda a, b: FloorDiv(a, b), } +def _nyi(): + raise NotImplementedError() + magic_methods = { **reflectable_magic_methods, 'eq': lambda a, b: sympy.Eq(a, b), diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index cee7626e5c83..a07b36b997bd 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -54,7 +54,7 @@ def format_stack() -> List[str]: return current_stack.copy() else: # fallback to traceback.format_stack() - return traceback.format_list(traceback.extract_stack()[:-1]) + return traceback.format_stack() @compatibility(is_backward_compatible=False) From 7c811efab70a3546f997e37178c93d1de24e0444 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 19 Nov 2022 12:52:39 -0500 Subject: [PATCH 382/453] Add support for dynamic kwarg to torch._dynamo.optimize (#89290) This is an easier way to enable dynamic shapes for a region. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89290 Approved by: https://github.com/soumith, https://github.com/jansel, https://github.com/voznesenskym --- test/dynamo/test_subgraphs.py | 12 ++++++++++ torch/_dynamo/eval_frame.py | 43 +++++++++++++++++++++++++++++++---- 2 files changed, 50 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 3a38561f16d2..27f73026435c 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -367,6 +367,18 @@ def fn(a, b): # just one graph now rather than 10 self.assertEqual(cnt_dynamic.frame_count, 1) + def test_dynamic_kwarg(self): + def fn(a, b): + return a - b * 10 + + torch._dynamo.reset() + cnt_dynamic = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt_dynamic, dynamic=True)(fn) + for i in range(10): + opt_fn(torch.randn(i), torch.randn(i)) + # just one graph + self.assertEqual(cnt_dynamic.frame_count, 1) + @patch.object(torch._dynamo.config, "capture_scalar_outputs", True) def test_no_graph_break_on_item(self): def fn(a, b): diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 31fb479906e1..65e8af4883ab 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -100,6 +100,17 @@ def innermost_fn(fn): return unaltered_fn +@contextlib.contextmanager +def enable_dynamic(enable: bool = True): + if not enable: + yield + return + with patch("torch._dynamo.config.dynamic_shapes", True), patch( + "functorch._src.config.use_dynamic_shapes", True + ): + yield + + class _TorchDynamoContext: def __init__( self, @@ -108,6 +119,8 @@ def __init__( backend_ctx_ctor=null_context, patch_fn=nothing, first_ctx=False, + *, + dynamic=False, ): super().__init__() assert callable(callback) or callback is False or callback is None @@ -116,6 +129,7 @@ def __init__( self.on_enter = on_enter self.extra_ctx_ctor = backend_ctx_ctor self.first_ctx = first_ctx + self.dynamic = dynamic patch_fn() def __enter__(self): @@ -129,10 +143,14 @@ def __enter__(self): self.prior = set_eval_frame(self.callback) self.backend_ctx = self.extra_ctx_ctor() self.backend_ctx.__enter__() + self.dynamic_ctx = enable_dynamic(self.dynamic) + self.dynamic_ctx.__enter__() def __exit__(self, exc_type, exc_val, exc_tb): set_eval_frame(self.prior) self.prior = unset + # TODO: This is totally not the right way to chain contexts manually + self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb) self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) def __call__(self, fn): @@ -170,10 +188,13 @@ def _fn(*args, **kwargs): prior = set_eval_frame(callback) backend_ctx = backend_ctx_ctor() backend_ctx.__enter__() + dynamic_ctx = enable_dynamic(self.dynamic) + dynamic_ctx.__enter__() try: return fn(*args, **kwargs) finally: set_eval_frame(prior) + dynamic_ctx.__exit__(None, None, None) backend_ctx.__exit__(None, None, None) # hooks to properly handle inlining @@ -229,7 +250,7 @@ def _fn(*args, **kwargs): class OptimizeContext(_TorchDynamoContext): - def __init__(self, callback, backend_ctx_ctor, first_ctx=False): + def __init__(self, callback, backend_ctx_ctor, first_ctx=False, *, dynamic=False): def on_enter(): global most_recent_backend if ( @@ -247,6 +268,7 @@ def on_enter(): backend_ctx_ctor=backend_ctx_ctor, patch_fn=TorchPatcher.patch, first_ctx=first_ctx, + dynamic=dynamic, ) @@ -289,11 +311,12 @@ def catch_errors(frame, cache_size): return catch_errors -def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): +def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context, dynamic=False): return OptimizeContext( catch_errors_wrapper(compile_fn), backend_ctx_ctor=backend_ctx_ctor, first_ctx=True, + dynamic=dynamic, ) @@ -375,7 +398,12 @@ def __call__(self, fn): def optimize( - backend="inductor", *, nopython=False, guard_export_fn=None, disable=False + backend="inductor", + *, + nopython=False, + guard_export_fn=None, + disable=False, + dynamic=False, ): """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -393,6 +421,7 @@ def optimize( nopython: If True, graph breaks will be errors and there will be a single whole-program graph. disable: If True, turn this decorator into a no-op + dynamic: If True, turn on dynamic shapes support Example Usage: @@ -422,10 +451,13 @@ def toy_example(a, b): backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) if nopython: - return optimize_assert(backend, guard_export_fn=guard_export_fn) + return optimize_assert( + backend, guard_export_fn=guard_export_fn, dynamic=dynamic + ) return _optimize_catch_errors( convert_frame.convert_frame(backend, guard_export_fn=guard_export_fn), backend_ctx_ctor, + dynamic=dynamic, ) @@ -655,7 +687,7 @@ def assume_constant_result(fn): return fn -def optimize_assert(backend, *, guard_export_fn=None, export=False): +def optimize_assert(backend, *, guard_export_fn=None, export=False, dynamic=False): """ The same as `torch._dynamo.optimize(backend, nopython=True)` """ @@ -667,6 +699,7 @@ def optimize_assert(backend, *, guard_export_fn=None, export=False): return _optimize_catch_errors( convert_frame.convert_frame_assert(backend, guard_export_fn, export=export), backend_ctx_ctor, + dynamic=dynamic, ) From caf3d5319f15e47363fe36856326f5e4ab3303e1 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Sat, 19 Nov 2022 23:10:34 +0000 Subject: [PATCH 383/453] Symintify numel(), infer_size, prims.elementwise_meta (#88956) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88956 Approved by: https://github.com/ezyang --- aten/src/ATen/ExpandUtils.cpp | 10 ++++---- aten/src/ATen/ExpandUtils.h | 2 ++ test/test_dynamic_shapes.py | 11 ++++++++ test/test_proxy_tensor.py | 25 ++++++++++++++++--- torch/_prims/__init__.py | 16 +++++++++--- torch/_refs/__init__.py | 4 +-- torch/_subclasses/fake_tensor.py | 6 +---- torch/csrc/autograd/input_metadata.h | 4 +-- .../python_torch_functions_manual.cpp | 2 +- torch/fx/experimental/symbolic_shapes.py | 3 --- torch/fx/traceback.py | 2 +- 11 files changed, 59 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp index a44005a2ef81..ee846c9b82e3 100644 --- a/aten/src/ATen/ExpandUtils.cpp +++ b/aten/src/ATen/ExpandUtils.cpp @@ -13,8 +13,8 @@ TensorBase expand_slow_path(const TensorBase &self, IntArrayRef size) { namespace { // NOTE: are_expandable did a similar check, please keep them sync if change is needed -template -Container infer_size_impl(IntArrayRef a, IntArrayRef b) { +template +Container infer_size_impl(ArrayType a, ArrayType b) { size_t dimsA = a.size(); size_t dimsB = b.size(); size_t ndim = dimsA > dimsB ? dimsA : dimsB; @@ -25,8 +25,8 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ptrdiff_t offset = ndim - 1 - i; ptrdiff_t dimA = dimsA - 1 - offset; ptrdiff_t dimB = dimsB - 1 - offset; - int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; - int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; + auto sizeA = (dimA >= 0) ? a[dimA] : 1; + auto sizeB = (dimB >= 0) ? b[dimB] : 1; TORCH_CHECK( sizeA == sizeB || sizeA == 1 || sizeB == 1, @@ -35,7 +35,7 @@ Container infer_size_impl(IntArrayRef a, IntArrayRef b) { ") at non-singleton dimension ", i); // 1s map to the other size (even 0). - expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + expandedSizes[i] = sizeA == 1 ? std::move(sizeB) : std::move(sizeA); } return expandedSizes; diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 786cbf132cd7..9e48421e540f 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -21,6 +21,8 @@ namespace at { TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); +TORCH_API SymDimVector +infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); // Named type instead of a pair/tuple so that we can be sure to // construct the vectors in place and get NRVO. diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 3a8e31151bf3..953b6d9a53f6 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -291,6 +291,17 @@ def test_size_expressions(self): self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) + @skipIfNoSympy + def test_numel(self): + shape_env = ShapeEnv() + x = create_symbolic_tensor("x", torch.randn(5), shape_env) + self.assertIsInstance(x.numel(), torch.SymInt) + self.assertIsInstance(torch.numel(x), torch.SymInt) + + x = torch.rand(3, 3) + self.assertIsInstance(x.numel(), int) + self.assertIsInstance(torch.numel(x), int) + @skipIfNoSympy def test_int_to_float(self): shape_env = ShapeEnv() diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 34edc5cfac94..fa04c57d9426 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -13,6 +13,7 @@ from torch._subclasses.fake_tensor import DynamicOutputShapeException from torch._decomp import decomposition_table +from torch.fx.experimental.symbolic_shapes import sym_float from torch.testing._internal.common_device_type import ops from torch._C import _disabled_torch_function_impl from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule, has_proxy @@ -732,7 +733,6 @@ def deco(cls): @skipIfNoSympy @xfail_inherited_tests([ - "test_mode_tracing_factory_function", "test_make_fx_overloads", "test_trace_subclasses", ]) @@ -972,8 +972,27 @@ def f(x): # happened afterwards self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3) - - + def test_elementwise_meta_with_sym_numbers(self): + def f(x, offset, as_sym_float=False): + x0 = x.size()[0] + if as_sym_float: + x0 = sym_float(x0) + return torch.add(x0, offset) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.int64) + + fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) + meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) + self.assertEqual(meta_add.meta['val'].shape, ()) + self.assertEqual(meta_add.meta['val'].dtype, torch.float32) def test_return_symint(self): def f(x): diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index a867a44f72e3..67e16ca102ac 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -31,6 +31,7 @@ ) from torch._prims_common.wrappers import backwards_not_supported from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.symbolic_shapes import sym_float from torch.overrides import handle_torch_function, has_torch_function from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -390,11 +391,18 @@ def _elementwise_meta( return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype) # Number case - # NOTE: this case is not currently exercised # TODO: fix number type promotion (bool, complex->float) - assert not isinstance(number, torch.SymInt), "NYI" - assert not isinstance(number, torch.SymFloat), "NYI" - return TensorMeta(number) + + # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat) + seen_float = False + if isinstance(number, (torch.SymInt, torch.SymFloat)): + for a in args: + assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI" + seen_float = seen_float or isinstance(a, (float, torch.SymFloat)) + if seen_float: + number = sym_float(number) + + return TensorMeta(number) # type: ignore[arg-type] def _complex_only_elementwise_meta(*args, **kwargs): diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 8ea1390a4449..3355400db43c 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -744,10 +744,10 @@ def nan_to_num( nan = 0.0 if posinf is None: - posinf = prims.maximum_value(a.dtype) + posinf = torch.finfo(a.dtype).max if neginf is None: - neginf = prims.minimum_value(a.dtype) + neginf = torch.finfo(a.dtype).min result = where(isnan(a), nan, a) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index f52bec927b11..9a0ac050e6b9 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -851,11 +851,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): # and ensure that Meta kernels are dispatched to (see) # Fake Tensor Dispatch Keys # TODO - we should be use the prim aten impl - if ( - "prims::" in func._schema.name - and len(flat_arg_fake_tensors) != 0 - and hasattr(func, "prim_meta_impl") - ): + if "prims::" in func._schema.name and hasattr(func, "prim_meta_impl"): with self: return func.prim_meta_impl(*args, **kwargs) diff --git a/torch/csrc/autograd/input_metadata.h b/torch/csrc/autograd/input_metadata.h index 7cb9e8aedb19..8060c11ac457 100644 --- a/torch/csrc/autograd/input_metadata.h +++ b/torch/csrc/autograd/input_metadata.h @@ -125,13 +125,13 @@ struct InputMetadata { if (grad.is_nested()) { ss << at::native::get_nested_size_tensor(grad); } else { - ss << grad.sizes(); + ss << grad.sym_sizes(); } ss << " but expected shape compatible with "; if (is_nested_tensor()) { ss << shape_as_tensor(); } else { - ss << c10::asIntArrayRefSlow(shape_as_dim_vector()); + ss << shape_as_dim_vector(); } return ss; } diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 562f5a427d38..bd969f6a26fb 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -692,7 +692,7 @@ static PyObject* THPVariable_numel( } if (r.idx == 0) { - return wrap(r.tensor(0).numel()); + return py::cast(r.tensor(0).sym_numel()).release().ptr(); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index ae4427e2320e..bd52760502c6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -261,9 +261,6 @@ def eval(cls, base, divisor): 'floordiv': lambda a, b: FloorDiv(a, b), } -def _nyi(): - raise NotImplementedError() - magic_methods = { **reflectable_magic_methods, 'eq': lambda a, b: sympy.Eq(a, b), diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a07b36b997bd..cee7626e5c83 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -54,7 +54,7 @@ def format_stack() -> List[str]: return current_stack.copy() else: # fallback to traceback.format_stack() - return traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) @compatibility(is_backward_compatible=False) From dbeacf11820e336e803bb719b7aaaf2125ae4d9c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 19 Nov 2022 19:44:18 -0500 Subject: [PATCH 384/453] Fix cat striding in PrimTorch (#89332) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89332 Approved by: https://github.com/ngimel --- test/jit/test_symbolic_shape_analysis.py | 7 +++++- test/test_mps.py | 2 -- torch/_refs/__init__.py | 24 +++++++++++++++++-- torch/testing/_creation.py | 8 ++++++- .../_internal/common_methods_invocations.py | 5 ++++ 5 files changed, 40 insertions(+), 6 deletions(-) diff --git a/test/jit/test_symbolic_shape_analysis.py b/test/jit/test_symbolic_shape_analysis.py index 1c4e359662bd..3e3cb3ffed73 100644 --- a/test/jit/test_symbolic_shape_analysis.py +++ b/test/jit/test_symbolic_shape_analysis.py @@ -319,7 +319,12 @@ def forward(self, x, y): mod = torch.jit.script(CatMod(**inp.kwargs).eval()) args = inp.input - self.assertTrue(len(args) == 2) + + # This test is hard-coded only to work with two sample inputs + # but the OpInfo may have more/less + if len(args) != 2: + continue + out_size = mod(*args).size() inps = list(mod.graph.inputs()) inps[1].setType(inps[1].type().with_sizes(args[0].size())) diff --git a/test/test_mps.py b/test/test_mps.py index 8e40a5cce293..745fa09df5e0 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7326,7 +7326,6 @@ class TestConsistency(TestCase): 'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'], 'bmm': ['f32'], 'broadcast_shapes': ['f32'], - 'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'ceil': ['f32', 'int32', 'int64', 'f16'], 'char': ['b8', 'u8'], 'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -7558,7 +7557,6 @@ class TestConsistency(TestCase): 'block_diag': ['f16', 'f32'], 'bmm': ['f32'], 'broadcast_shapes': ['f32'], - 'cat': ['f16', 'f32'], 'ceil': ['f32'], 'chunk': ['f16', 'f32'], 'clone': ['f16', 'f32'], diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 3355400db43c..fda73cf0bc60 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -2530,6 +2530,18 @@ def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType: type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH, ) def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: + def cat_compute_output_memory_format(inputs): + format = None + for t in inputs: + f = utils.suggest_memory_format(t) + if f == torch.contiguous_format: + return f + if format is not None and format != f: + return torch.contiguous_format + format = f + assert format is not None + return format + if len(tensors) == 0: msg = "cat expects at least one tensor, but received zero!" raise ValueError(msg) @@ -2547,6 +2559,8 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: utils.validate_idx(t.ndim, dim) break + memory_format = cat_compute_output_memory_format(tensors) + # Filters tensors with one dimension of length zero filtered = tuple(x for x in tensors if not (x.ndim == 1 and x.numel() == 0)) if len(filtered) == 0: @@ -2558,9 +2572,15 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType: except Exception: requires_grad = False - return empty((0,), dtype=t.dtype, device=t.device, requires_grad=requires_grad) + return empty( + (0,), + dtype=t.dtype, + device=t.device, + requires_grad=requires_grad, + memory_format=memory_format, + ) - return prims.cat(filtered, dim) + return prims.cat(filtered, dim).clone(memory_format=memory_format) # CompositeImplicitAutograd - don't register decomp diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index b8f41f04743c..33b9739a7f36 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -31,7 +31,8 @@ def make_tensor( high: Optional[float] = None, requires_grad: bool = False, noncontiguous: bool = False, - exclude_zero: bool = False + exclude_zero: bool = False, + memory_format: Optional[torch.memory_format] = None, ) -> torch.Tensor: r"""Creates a tensor with the given :attr:`shape`, :attr:`device`, and :attr:`dtype`, and filled with values uniformly drawn from ``[low, high)``. @@ -74,6 +75,8 @@ def make_tensor( :attr:`dtype`'s :func:`~torch.finfo` object), and for complex types it is replaced with a complex number whose real and imaginary parts are both the smallest positive normal number representable by the complex type. Default ``False``. + memory_format (Optional[torch.memory_format]): The memory format of the returned tensor. Incompatible + with :attr:`noncontiguous`. Raises: ValueError: if ``requires_grad=True`` is passed for integral `dtype` @@ -152,9 +155,12 @@ def clamp(a, l, h): raise TypeError(f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()." " To request support, file an issue at: https://github.com/pytorch/pytorch/issues") + assert not (noncontiguous and memory_format is not None) if noncontiguous and result.numel() > 1: result = torch.repeat_interleave(result, 2, dim=-1) result = result[..., ::2] + elif memory_format is not None: + result = result.clone(memory_format=memory_format) if exclude_zero: if dtype in _integral_types or dtype is torch.bool: diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cf68a68cf629..4edba78ec4ae 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1873,6 +1873,9 @@ def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs): for input_shape1, input_shape2, kwargs in cases: yield SampleInput([make_arg(input_shape1), make_arg(input_shape2)], kwargs=kwargs) + # from coat_lite_mini + yield SampleInput([make_arg((2, 2, 2, 2), memory_format=torch.channels_last)], args=(1,),) + def error_inputs_cat(op_info, device, **kwargs): make_arg = partial(make_tensor, device=device, dtype=torch.float32) @@ -15016,6 +15019,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=True, assert_autodiffed=True, skips=( + # https://github.com/pytorch/pytorch/issues/89353 + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), # RuntimeError: Arguments for call not valid. # Expected a value of type 'List[Tensor]' for argument # 'tensors' but instead found type 'Tensor (inferred)'. From 7b0d577c226fae78f377b26feab4122c4203ad59 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sat, 19 Nov 2022 22:31:24 -0500 Subject: [PATCH 385/453] Set INTERFACE_LINK_DIRECTORIES on caffe2::mkl (#89359) This ensures that subsequent link commands involving mkl libraries know where to find the libraries if they are in a non-standard location (which is the case if you installed mkl via conda, which is what our standard instructions recommend.) This is kind of a hack, because the MKL libraries are not actually guaranteed to be in $MKL_ROOT/lib (they are for the conda install though). The real fix is to properly use the MKL targets from FindMKL.cmake but thats its own can of fish. See https://github.com/pytorch/pytorch/issues/73008 This fixes https://github.com/pytorch/audio/issues/2784 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89359 Approved by: https://github.com/soumith --- cmake/public/mkl.cmake | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cmake/public/mkl.cmake b/cmake/public/mkl.cmake index 9515a4ae9681..57c404299136 100644 --- a/cmake/public/mkl.cmake +++ b/cmake/public/mkl.cmake @@ -10,3 +10,8 @@ set_property( set_property( TARGET caffe2::mkl PROPERTY INTERFACE_LINK_LIBRARIES ${MKL_LIBRARIES}) +# TODO: This is a hack, it will not pick up architecture dependent +# MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008 +set_property( + TARGET caffe2::mkl PROPERTY INTERFACE_LINK_DIRECTORIES + ${MKL_ROOT}/lib) From c09929659ce8ba2f1b7b2f6e50084ccbf854d44b Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Sun, 20 Nov 2022 09:13:30 -0500 Subject: [PATCH 386/453] Also include MKL_THREAD_LIB in link libraries for caffe2::mkl (#89378) Actually fixes https://github.com/pytorch/audio/issues/2784 for real; in my previous testing I didn't check if I could import torchaudio; now torchaudio successfully imports. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89378 Approved by: https://github.com/soumith --- cmake/public/mkl.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/public/mkl.cmake b/cmake/public/mkl.cmake index 57c404299136..f4ab1ffa9d0f 100644 --- a/cmake/public/mkl.cmake +++ b/cmake/public/mkl.cmake @@ -9,7 +9,7 @@ set_property( ${MKL_INCLUDE_DIR}) set_property( TARGET caffe2::mkl PROPERTY INTERFACE_LINK_LIBRARIES - ${MKL_LIBRARIES}) + ${MKL_LIBRARIES} ${MKL_THREAD_LIB}) # TODO: This is a hack, it will not pick up architecture dependent # MKL libraries correctly; see https://github.com/pytorch/pytorch/issues/73008 set_property( From e1d58b1928a9427f05e3f44ab9b8119000bced09 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 20 Nov 2022 22:14:38 +0000 Subject: [PATCH 387/453] Revert "Update sdp dispatch logic to enable fused backward (#89154)" This reverts commit 2e72ec79823111e8dd8c5e82c5d1b56197cd52d3. Reverted https://github.com/pytorch/pytorch/pull/89154 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but the new test_sdp_math_gradcheck test breaks periodic slow gradcheck, i.e. https://hud.pytorch.org/pytorch/pytorch/commit/419ef2cdcfe84442de5232739284c6a51a18632f --- aten/src/ATen/native/native_functions.yaml | 52 +++-- .../cuda/NestedTensorTransformerFunctions.cpp | 100 +++------ .../ATen/native/transformers/attention.cpp | 65 ++---- .../native/transformers/cuda/attention.cu | 46 +++-- .../transformers/cuda/attention_backward.cu | 40 +--- .../transformers/cuda/flash_attn/fmha_api.cpp | 7 +- .../transformers/cuda/flash_attn/fmha_api.h | 2 +- .../ATen/native/transformers/cuda/sdp_utils.h | 34 +--- benchmarks/transformer/sdp_backwards.py | 189 ------------------ .../check_forward_backward_compatibility.py | 3 - test/functorch/test_ops.py | 8 +- test/test_meta.py | 1 + test/test_transformers.py | 74 ++----- tools/autograd/derivatives.yaml | 6 +- .../_internal/common_methods_invocations.py | 5 - 15 files changed, 135 insertions(+), 497 deletions(-) delete mode 100644 benchmarks/transformer/sdp_backwards.py diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8c759cd09c48..f625c9faff41 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13252,39 +13252,18 @@ CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda -- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) +# Register the math kernel for cpu +- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function - -- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor) - dispatch: - CUDA: _scaled_dot_product_flash_attention_cuda - NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda - -- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) dispatch: - CUDA: _scaled_dot_product_efficient_attention_cuda - NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + CUDA: _scaled_dot_product_attention_forward_cuda + CPU: _scaled_dot_product_attention_forward_math + NestedTensorCUDA: _scaled_dot_product_attention_forward_nested + NestedTensorCPU: _scaled_dot_product_attention_forward_math + Meta: _scaled_dot_product_attention_forward_math -- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) - dispatch: - CUDA: _scaled_dot_product_efficient_attention_backward_cuda - -# Returns ouput, softmax_logsumexp, softmax -- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor) - variants: function - dispatch: - CUDA: _flash_attention_forward - -# Returns ouput, logsumexp if compute_logsumexp -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) +- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function - dispatch: - CUDA: _efficient_attention_forward - -- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_backward - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function @@ -13311,6 +13290,21 @@ structured: True variants: function +- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor + variants: function + dispatch: + CUDA: flash_scaled_dot_product_attention + +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_forward + +- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward + - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 9c72454560d3..c2bf4e08ce04 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -214,6 +214,26 @@ Tensor NestedTensor_to_padded_tensor_cuda( return NestedTensor_to_padded_tensor_generic(t, padding, output_size); } +std::tuple _scaled_dot_product_attention_forward_nested( + const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { + + // Determine which efficient kernel to use + sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; + auto backend = select_sdp_backend(kernel_params); + switch(backend){ + case sdp::SDPBackend::flash_attention: + // TODO: enable flash attention kernel + return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); + case sdp::SDPBackend::efficient_attention: + return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); + case sdp::SDPBackend::math: + return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + default: + TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention"); + return std::make_tuple(Tensor(), Tensor()); + } +} namespace{ /** @@ -320,80 +340,19 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { } } // namespace - -std::tuple _scaled_dot_product_flash_attention_nestedtensor_cuda( +std::tuple mem_efficient_helper_nested_unpacked( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool return_softmax, + bool need_atten_weights, bool is_causal) { - TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.") // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) const int64_t num_heads = query.size(1); const int64_t head_dim = query.size(3); - // Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head) - // Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head) - // Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head) - Tensor q_t = query.transpose(1, 2).contiguous(); - Tensor k_t = key.transpose(1, 2).contiguous(); - Tensor v_t = value.transpose(1, 2).contiguous(); - - // K and V have to have the same Nnz, should probably torch_check - // assume in order to not iterate over v - - auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t); - auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t); - - Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); - Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k); - - const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); - const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k); - - const int64_t Nnz_q = cumulative_sequence_length_q[-1].item(); - const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item(); - - auto query_buffer_reshaped = - get_buffer(q_t).view({Nnz_q, num_heads, head_dim}); - auto key_buffer_reshaped = - get_buffer(k_t).view({Nnz_kv, num_heads, head_dim}); - auto value_buffer_reshaped = - get_buffer(v_t).view({Nnz_kv, num_heads, head_dim}); - - auto attention_and_lse_and_softmax = - at::_flash_attention_forward( - query_buffer_reshaped, - key_buffer_reshaped, - value_buffer_reshaped, - cumulative_sequence_length_q, - cumulative_sequence_length_k, - max_seqlen_batch_q, - max_seqlen_batch_k, - return_softmax, - dropout_p, - is_causal); - // Reshape output to convert nnz to batch_size and seq_len - Tensor attention = std::get<0>(attention_and_lse_and_softmax); - attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2); - return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax)); -} - -std::tuple _scaled_dot_product_efficient_attention_nestedtensor_cuda( - const Tensor& query, - const Tensor& key, - const Tensor& value, - bool compute_log_sumexp, - bool is_causal) { - // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) - // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) - // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) - const int64_t num_heads = query.size(1); - const int64_t head_dim = query.size(3); - Tensor q_t = query.transpose(1, 2); Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); @@ -473,7 +432,7 @@ std::tuple _scaled_dot_product_efficient_attention_nestedtensor_ {Nnz_kv, num_heads, head_dim}, {nnz_v_stride, head_v_stride, head_dim_stride}, value_impl->get_storage_offsets()[0]); - std::tuple attention_and_logsumexp= + std::tuple attention_and_weights = at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), @@ -481,14 +440,14 @@ std::tuple _scaled_dot_product_efficient_attention_nestedtensor_ cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, - compute_log_sumexp, - is_causal); + false, + false); // Reshape output to convert nnz to batch_size and seq_len - Tensor attention = std::get<0>(attention_and_logsumexp); + Tensor attention = std::get<0>(attention_and_weights); attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()) .transpose(1, 2); - return std::tie(attention, std::get<1>(attention_and_logsumexp)); + return std::tie(attention, std::get<1>(attention_and_weights)); } Tensor flash_attention_helper( @@ -533,7 +492,7 @@ Tensor flash_attention_helper( // If we are passing in query, key, value all the same tensors then we have // packed them into one tensor and need to slice for flash attention Tensor attention = - std::get<0>(at::_flash_attention_forward( + at::_flash_scaled_dot_product_attention( q, k, v, @@ -541,9 +500,8 @@ Tensor flash_attention_helper( cumulative_sequence_length_q, max_seqlen_batch_q, max_seqlen_batch_q, - false /*return_softmax*/, dropout_p, - is_causal)); + is_causal); // Output of flash_attention is a regular tensor lets wrap it back up to // form a nested tensor diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 9c5be12ef24d..89a0e4691018 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -678,6 +678,20 @@ std::tuple native_decoder_only_multi_head_attent // L: Target sequence length // E: Embedding dimension std::tuple _scaled_dot_product_attention( + const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { + if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){ + return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + } + return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); +} + +int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + return static_cast(sdp::SDPBackend::math); +} + +std::tuple _scaled_dot_product_attention_forward_math( const Tensor& query_, const Tensor& key, const Tensor& value, @@ -685,49 +699,14 @@ std::tuple _scaled_dot_product_attention( double dropout_p, bool need_attn_weights, bool is_causal) { - // TODO: The second return is the attention weights if the math kernel is - // used. The fused kernels do not return this Tensor so for the fused kernels - // The second return SHOULD always be an empty Tensor, unless need_attn_weights - // is true (in which case the fused kernels would not be called). This blows up - // op_info tests. - int64_t choice_int = at::_fused_sdp_choice( - query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - sdp::SDPBackend backend = static_cast(choice_int); - switch (backend) { - case sdp::SDPBackend::flash_attention: { - auto out_lse_softmax = at::_scaled_dot_product_flash_attention( - query_, key, value, dropout_p, need_attn_weights, is_causal); - return std::make_tuple( - std::move(std::get<0>(out_lse_softmax)), - std::move(std::get<2>(out_lse_softmax))); - } - case sdp::SDPBackend::efficient_attention: { - bool compute_logsumexp = - (query_.requires_grad() || key.requires_grad() || - value.requires_grad()); - return at::_scaled_dot_product_efficient_attention( - query_, key, value, compute_logsumexp, is_causal); - } - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math( - query_, - key, - value, - attn_mask_, - dropout_p, - need_attn_weights, - is_causal); - default: - TORCH_CHECK( - false, - "No viable backend for scaled_dot_product_attention was found."); - return std::make_tuple(Tensor(), Tensor()); - } -} - -int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ - return static_cast(sdp::SDPBackend::math); + return at::_scaled_dot_product_attention_math( + query_, + key, + value, + attn_mask_, + dropout_p, + need_attn_weights, + is_causal); } std::tuple _scaled_dot_product_attention_math( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 8dcb99b3380d..602cf319f74a 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -678,12 +678,12 @@ std::tuple native_multi_head_attention_cuda( return std::make_tuple(std::move(proj), std::move(qkt)); } -std::tuple _scaled_dot_product_flash_attention_cuda( +std::tuple flash_attention_helper_dense_unpacked( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool return_softmax, + bool need_atten_weights, bool is_causal) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -726,9 +726,8 @@ std::tuple _scaled_dot_product_flash_attention_cuda( Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim}); Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim}); - Tensor attention, log_sumexp, softmax; - std::tie(attention, log_sumexp, softmax) = - at::_flash_attention_forward( + Tensor attention = + at::_flash_scaled_dot_product_attention( query_reshaped, key_reshaped, value_reshaped, @@ -736,17 +735,15 @@ std::tuple _scaled_dot_product_flash_attention_cuda( cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, - return_softmax, dropout_p, is_causal); // Reshape output to convert nnz to batch_size and seq_len attention = attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2); - return std::make_tuple(attention, log_sumexp, softmax); + return std::tuple(attention, Tensor()); } - -std::tuple _scaled_dot_product_efficient_attention_cuda( +std::tuple mem_eff_helper( const Tensor& query, const Tensor& key, const Tensor& value, @@ -770,7 +767,26 @@ std::tuple _scaled_dot_product_efficient_attention_cuda( compute_log_sumexp, is_causal); attention = attention.transpose(1,2); - return std::make_tuple(std::move(attention), std::move(log_sumexp)); + return std::make_tuple(std::move(attention), Tensor()); +} + +std::tuple _scaled_dot_product_attention_forward_cuda( + const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { + // Determine which efficient kernel to use + sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; + auto backend = select_sdp_backend(kernel_params); + switch(backend){ + case sdp::SDPBackend::flash_attention: + return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); + case sdp::SDPBackend::efficient_attention: + return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); + case sdp::SDPBackend::math: + return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + default: + TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found."); + return std::make_tuple(Tensor(), Tensor()); + } } int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, @@ -786,7 +802,7 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te return static_cast(backend); } -std::tuple _flash_attention_forward( +Tensor flash_scaled_dot_product_attention( const Tensor& query, const Tensor& key, const Tensor& value, @@ -794,12 +810,11 @@ std::tuple _flash_attention_forward( const Tensor& cumulative_sequence_length_k, const int64_t max_seqlen_batch_q, const int64_t max_seqlen_batch_k, - bool return_softmax, double dropout_p, bool is_causal) { #if defined(USE_FLASH_ATTENTION) auto softmax_scale = std::pow(query.size(-1), -0.5); - return fmha::mha_fwd( + std::vector output = fmha::mha_fwd( query, key, value, @@ -811,11 +826,12 @@ std::tuple _flash_attention_forward( softmax_scale, false, is_causal, - return_softmax, + false, c10::nullopt); + return output[0]; #endif TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return std::make_tuple(Tensor(), Tensor(), Tensor()); + return Tensor(); } std::tuple _efficient_attention_forward( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index a063aacb901e..af005b2669b2 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -10,7 +10,6 @@ #include #include -#include #ifdef USE_FLASH_ATTENTION #include #endif @@ -74,14 +73,14 @@ std::tuple _efficient_attention_backward( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - const at::Tensor& out, const at::Tensor& logsumexp, + const at::Tensor& out, bool causal) { #if defined(USE_FLASH_ATTENTION) if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } - // ndim + // ndim TORCH_CHECK(query.dim() == grad_out_.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); @@ -129,7 +128,6 @@ std::tuple _efficient_attention_backward( // initialized bool grad_kv_needs_init = causal && N > M; at::Tensor grad_q, grad_k, grad_v; - int8_t gQKV_strideM_multiplier = 1; if (!grad_kv_needs_init && query.size(1) == key.size(1) && query.size(3) == value.size(3) && query.storage().is_alias_of(key.storage()) && @@ -143,13 +141,10 @@ std::tuple _efficient_attention_backward( grad_q = chunk.select(2, 0); grad_k = chunk.select(2, 1); grad_v = chunk.select(2, 2); - gQKV_strideM_multiplier=3; } else { - grad_q = at::empty(query.sizes(), query.options()); - grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options()) - : at::empty(key.sizes(), key.options()); - grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options()) - : at::empty(value.sizes(), value.options()); + grad_q = at::empty_like(query); + grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); + grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); } auto launchKernel = [&](auto _k, int computeCapability) { @@ -203,7 +198,7 @@ std::tuple _efficient_attention_backward( ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - p.gQKV_strideM_multiplier = gQKV_strideM_multiplier; + p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); @@ -262,28 +257,5 @@ std::tuple _efficient_attention_backward( return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } - -std::tuple _scaled_dot_product_efficient_attention_backward_cuda( - const at::Tensor& grad_out_, - const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& out, - const at::Tensor& logsumexp, - bool causal){ - if (!grad_out_.defined()) { - return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); - } - auto grad_out = grad_out_.transpose(1, 2); - auto out_t = out.transpose(1, 2); - auto q_t = query.transpose(1, 2); - auto k_t = key.transpose(1, 2); - auto v_t = value.transpose(1, 2); - - Tensor grad_q, grad_k, grad_v; - std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal); - return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index 7cc0c250664e..aaf7d833fe83 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -26,7 +26,6 @@ * ******************************************************************************/ -#include #ifdef USE_FLASH_ATTENTION #include #include @@ -116,7 +115,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.is_causal = is_causal; } -std::tuple +std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -242,7 +241,9 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fprop(launch_params, /*configure=*/false); - return std::make_tuple(o, softmax_lse, s); + std::vector result = {o, softmax_lse}; + if (return_softmax) {result.push_back(s);} + return result; } } // namespace fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h index b0555463be04..226d4ddd2b55 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -7,7 +7,7 @@ namespace fmha { TORCH_API -std::tuple +std::vector mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 55e9aeb184a2..5d62a6cbd0dc 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -91,31 +91,6 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { return true; } -inline bool check_for_nested_inputs(sdp_params params, bool debug){ - if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { - TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors."); - return false; - } - return true; -} - -inline bool check_requires_grad(sdp_params params, bool debug) { - if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { - TORCH_CHECK(!debug, "Flash Attention does not currently support training."); - return false; - } - return true; -} - -inline bool check_requires_grad_and_nested(sdp_params params, bool debug) { - // If we fail both checks then we return false - if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){ - TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs."); - return false; - } - return true; -} - inline bool check_for_attn_mask(sdp_params params, bool debug) { if (params.has_attn_mask) { TORCH_CHECK(!debug, "Flash Attention does not support attention mask."); @@ -223,15 +198,13 @@ inline bool use_flash_attention(sdp_params params, bool debug) { return false; #endif // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints {{ + constexpr std::array constraints {{ check_runtime_disabled_flash, - check_requires_grad, check_tensor_shapes, check_for_attn_weights, check_for_attn_mask, check_head_dim_size, check_gpu_sm75_or_greater, - check_for_nested_inputs, check_for_seq_len_1_nested_tensor}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { @@ -259,15 +232,14 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { at::kHalf, at::kFloat, at::kBFloat16}; // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints{{ + std::vector> constraints{ check_gpu_sm50_or_greater, check_runtime_disabled_mem_efficient, - check_requires_grad_and_nested, check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, check_for_seq_len_1_nested_tensor, - check_for_non_zero_dropout}}; + check_for_non_zero_dropout}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/benchmarks/transformer/sdp_backwards.py b/benchmarks/transformer/sdp_backwards.py deleted file mode 100644 index 2f745e157b28..000000000000 --- a/benchmarks/transformer/sdp_backwards.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import numpy as np -import random -import torch.utils.benchmark as benchmark -from torch.profiler import profile, record_function, ProfilerActivity - - -class CompositeMHA(torch.nn.Module): - def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): - super().__init__() - self.in_proj_weight = in_proj_weight - self.in_proj_bias = in_proj_bias - self.out_proj = out_proj - self.num_heads = num_heads - - def forward(self, query, key, value, mask): - if not (query is key and key is value): - raise NotImplementedError( - "query, key and value must be the same Tensor for now." - ) - if mask is not None: - raise NotImplementedError("mask is currently not supported.") - - query_projected = torch.nn.functional.linear( - query, self.in_proj_weight, self.in_proj_bias - ) - - batch_size = query_projected.size(0) - embed_dim = query_projected.size(2) - head_dim = embed_dim // (self.num_heads * 3) - - query, key, value = query_projected.chunk(3, -1) - - query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - attn, _ = torch.nn.functional._scaled_dot_product_attention( - query, - key, - value, - attn_mask=None, - dropout_p=0.0, - need_attn_weights=False, - is_causal=False, - ) - - attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) - # Match return signature of nn.MHA - return self.out_proj(attn) - - -def build_composite_mha_from_nn_mha(pt): - assert pt._qkv_same_embed_dim - in_proj_weight = pt.in_proj_weight - assert in_proj_weight is not None - assert pt.batch_first - return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) - - -def forw_back(model, input, upward): - output = model(*input) - output.backward(upward) - - -# Context manger not working in timer - - -def forw_back_fused(model, input, upward): - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): - output = model(*input) - output.backward(upward) - - -def forw_back_eager(model, input, upward): - with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): - output = model(*input) - output.backward(upward) - - -def run_timing( - min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype -): - dropout_p = 0.0 - mask = None - - pt = torch.nn.MultiheadAttention( - embed_dim=embed_dimension, - num_heads=num_heads, - batch_first=True, - dropout=dropout_p, - ) - npt = pt.cuda().to(dtype) - cpt = build_composite_mha_from_nn_mha(npt) - x = torch.randn( - batch_size, - max_sequence_len, - embed_dimension, - dtype=dtype, - device="cuda", - requires_grad=True, - ) - - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): - rand_fused_upward = cpt(x, x, x, mask).clone().detach() - - with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): - rand_eager_upward = cpt(x, x, x, mask).clone().detach() - - t0 = benchmark.Timer( - stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)", - globals={ - "forw_back_fused": forw_back_fused, - "cpt": cpt, - "x": x, - "rand_fused_upward": rand_fused_upward, - "mask": mask, - }, - label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " - f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", - num_threads=torch.get_num_threads(), - ) - - t1 = benchmark.Timer( - stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)", - globals={ - "forw_back_eager": forw_back_eager, - "cpt": cpt, - "x": x, - "rand_eager_upward": rand_eager_upward, - "mask": mask, - }, - label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " - f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", - num_threads=torch.get_num_threads(), - ) - - m0 = t0.blocked_autorange(min_run_time=min_run_time) - m1 = t1.blocked_autorange(min_run_time=min_run_time) - - print(m0) - print(m1) - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - print("Profile for Fused".center(200, "-")) - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): - with profile( - activities=activities, record_shapes=False, with_stack=True - ) as prof: - with record_function("Fused SDP forward and backward"): - for _ in range(20): - forw_back(cpt, (x, x, x, mask), rand_fused_upward) - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) - - print("Profile for eager".center(200, "-")) - with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): - with profile( - activities=activities, record_shapes=False, with_stack=True - ) as prof: - with record_function("Fused SDP forward and backward"): - for _ in range(20): - forw_back(cpt, (x, x, x, mask), rand_eager_upward) - print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) - - -def main(): - seed = 123 - np.random.seed(seed) - torch.manual_seed(seed) - random.seed(seed) - - min_run_time = 10 - batch_size = 64 - num_heads = 32 - max_seq_len = 256 - embed_dim = 1024 - dtype = torch.bfloat16 - - print( - f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} " - f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}" - ) - run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype) - - -if __name__ == "__main__": - main() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 853f5206969b..90080ab0934f 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -317,9 +317,6 @@ ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), - ("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)), - ("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)), - ("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ] diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index e9451b596b4a..5e3aa1ff898f 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -401,7 +401,6 @@ def wrapped_fn(*args, **kwargs): skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool3d'), # fails everywhere except on mac xfail("native_batch_norm"), - xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented @@ -556,7 +555,6 @@ def f(inp, *args, **kwargs): xfail('nn.functional.ctc_loss'), # Not Implemented xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other' xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides - skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) @@ -651,7 +649,7 @@ def fn(inp, *args, **kwargs): skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op - xfail('nn.functional._scaled_dot_product_attention'), # randomness + skip('nn.functional._scaled_dot_product_attention'), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. @@ -1128,7 +1126,6 @@ def test(): skip('nn.functional.rrelu'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness - skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to_sparse', ''), # non-dense output @@ -1252,7 +1249,6 @@ def get_vjp(cotangents, *primals): xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward - skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides @@ -1373,7 +1369,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.dropout2d'), # calls random op xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout'), # calls random op - xfail('nn.functional._scaled_dot_product_attention'), # randomness + skip('nn.functional._scaled_dot_product_attention'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op diff --git a/test/test_meta.py b/test/test_meta.py index 0e3cfb6ef140..6d21d5c7bd75 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -294,6 +294,7 @@ def test_tensor_outlives_converter(self): aten._fft_c2r.default, aten._fft_r2c.default, aten._linalg_svd.default, + aten._scaled_dot_product_attention_forward.default, aten.binary_cross_entropy.default, aten.complex.default, aten.copysign.Tensor, diff --git a/test/test_transformers.py b/test/test_transformers.py index f6bc0cc2d663..abb4c71ec19a 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1059,11 +1059,6 @@ def rand_tensor(shape): if fused_kernel == "flash": with sdp_kernel(enable_mem_efficient=False, enable_math=False): - # TODO Flash for the nested path is currently not working due to cuda memory issues - if type == "nested": - self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( - query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)) - return actual = torch.nn.functional._scaled_dot_product_attention( query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) elif fused_kernel == "mem_efficient": @@ -1102,73 +1097,28 @@ def rand_tensor(shape): @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) - def test_sdp_math_gradcheck(self, contiguous_inputs: bool): + def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True) qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) query, key, value = qkv.chunk(3, dim=-1) - - query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + query = query.view(batch_size, -1, num_heads, head_dim) + key = key.view(batch_size, -1, num_heads, head_dim) + value = value.view(batch_size, -1, num_heads, head_dim) if contiguous_inputs: query = query.contiguous() key = key.contiguous() value = value.contiguous() - with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): - assert gradcheck(lambda *args, **kwargs: - wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), - (query, key, value, None, 0.0, False, False) - ) - - @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") - @parametrize("contiguous_inputs", [True, False]) - def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool): - batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) - - qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) - qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() - - query, key, value = qkv.chunk(3, dim=-1) - query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) - - query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - - query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - - if contiguous_inputs: - query = query.contiguous() - key = key.contiguous() - value = value.contiguous() - - query_lp = query_lp.contiguous() - key_lp = key_lp.contiguous() - value_lp = value_lp.contiguous() - - with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): - out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False) - - with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False): - out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention( - query_lp, key_lp, value_lp, None, 0.0, False, False) - - rand_upward = torch.rand_like(out) - rand_upward_lp = rand_upward.to(torch.float32) - - out.backward(rand_upward) - out_lp.backward(rand_upward_lp) - - # Cast up and compare - self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) + # Normally we would transpose the inputs but the fused kernels expect + # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel + # in fp32 + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), + (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) @parametrize("type", ["dense", "nested"]) def test_fused_sdp_choice(self, type: str): @@ -1194,7 +1144,7 @@ def test_fused_sdp_choice(self, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - if SM80OrLater and not type == "nested": + if SM80OrLater: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION else: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 52c0f76bf070..a0892b32a835 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2613,13 +2613,9 @@ nested_strides: non_differentiable # Transformers -- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) - output_differentiability: [True, False] - query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal) - - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) output_differentiability: [True, False] - query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal) + query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4edba78ec4ae..4ccb5ef3840f 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12011,21 +12011,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # This is only failing on Linux Bionic 3.10 Cuda 11.6 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', - device_type='cuda', dtypes=(torch.float32,)), # AssertionError: JIT Test does not execute any logic DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), # Doesn't support autocasting DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensorNonErroring', 'test_fake_autocast', device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), - # Forward works for dtype=float64 which is the math path - DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), # No meta function DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), - DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', device_type='cuda'), DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),), ), UnaryUfuncInfo( From 1856fa5df7fda9950da26eff2ef885e845bf6b6c Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sun, 20 Nov 2022 23:36:47 +0000 Subject: [PATCH 388/453] Temporary increase ASAN shard 5 to 4xlarge (#89387) ASAN shard 5 also see OOM now https://hud.pytorch.org/pytorch/pytorch/commit/7b0d577c226fae78f377b26feab4122c4203ad59, may be we should increase all 5 of them to 4xlarge until https://github.com/pytorch/pytorch/issues/88309 is resolved Pull Request resolved: https://github.com/pytorch/pytorch/pull/89387 Approved by: https://github.com/kit1980 --- .github/workflows/pull.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 3208cb198bb4..3642c7fc1769 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -78,7 +78,7 @@ jobs: { config: "default", shard: 2, num_shards: 5, runner: "linux.2xlarge" }, { config: "default", shard: 3, num_shards: 5, runner: "linux.2xlarge" }, { config: "default", shard: 4, num_shards: 5, runner: "linux.4xlarge" }, - { config: "default", shard: 5, num_shards: 5, runner: "linux.2xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "linux.4xlarge" }, { config: "functorch", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, ]} From 51e961dd7bb9abaf999e6028208b2778a57c32b2 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Mon, 21 Nov 2022 00:58:03 +0000 Subject: [PATCH 389/453] use std/libdevice erf in inductor (#89388) By itself, libdevice version of erf has the same perf as our decomposition, but in real workloads it leads to better fusion groups (due to fewer ops in the fused kernel). Bonus: a few fp64 test skips removed, because our decomposition wasn't accurate enough for fp64, but libdevice version is. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89388 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor_opinfo.py | 4 ---- torch/_inductor/codegen/cpp.py | 4 ++++ torch/_inductor/codegen/triton.py | 4 ++++ torch/_inductor/decomposition.py | 24 ---------------------- torch/_inductor/lowering.py | 7 +++++++ 5 files changed, 15 insertions(+), 28 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 67b64c73a8ef..188fcd8b67dc 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -165,7 +165,6 @@ def process(device_type): "corrcoef": {f32, f64, i32, i64}, "cov": {f32, f64, i32, i64}, "equal": {b8, f16, f32, f64, i32, i64}, - "erf": {b8, f64}, "fft.fft": {f32, f64}, "fft.fft2": {b8, f32, f64, i32, i64}, "fft.fftn": {b8, f32, f64, i32, i64}, @@ -214,7 +213,6 @@ def process(device_type): "nn.functional.adaptive_avg_pool2d": {f16, f64}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, - "nn.functional.gelu": {f64}, "nn.functional.local_response_norm": {i64}, "nn.functional.one_hot": {i64}, "nn.functional.pairwise_distance": {f16}, @@ -346,8 +344,6 @@ def process(device_type): "unique_consecutive": {b8, f16, f32, f64, i32, i64}, "view_as_complex": {f16, f32, f64}, # AssertionError: Tensor-likes are not close! - "erf": {b8, f64}, - "nn.functional.gelu": {f64}, "nn.functional.triplet_margin_loss": {f16}, } diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 65a9335d6cbf..cf8e6616d677 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -356,6 +356,10 @@ def exp(x): # return f"Sleef_expf_u10({x})" return f"std::exp({x})" + @staticmethod + def erf(x): + return f"std::erf({x})" + @staticmethod def sqrt(x): return f"std::sqrt({x})" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index b79b03232a8a..2504bd2dcf8c 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -218,6 +218,10 @@ def masked(mask, body, other): def lgamma(x): return f"tl.libdevice.lgamma({x})" + @staticmethod + def erf(x): + return f"tl.libdevice.erf({x})" + @staticmethod def logical_and(a, b): return f"{a} & {b}" diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 09ee53579345..188072b3d489 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -366,30 +366,6 @@ def round_dec(x, decimals=0): return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals) -@register_decomposition([aten.special_erf, aten.erf]) -def special_erf(x): - # TODO(jansel): this might be crazy slow. Triton doesn't have the - # cuda ::erf() builtin. I've made a feature request for this, - # so it may be coming soon. - - # from https://www.johndcook.com/blog/2009/01/19/stand-alone-error-function-erf/ - a1 = 0.254829592 - a2 = -0.284496736 - a3 = 1.421413741 - a4 = -1.453152027 - a5 = 1.061405429 - p = 0.3275911 - - sign = torch.sign(x) - x = torch.abs(x) - - # A & S 7.1.26 - t = 1.0 / (1.0 + p * x) - y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * torch.exp(-x * x) - - return sign * y - - @register_decomposition([aten.rsub.Tensor, aten.rsub.Scalar]) def rsub(a, b): if isinstance(b, numbers.Number): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5168f37cd392..a76a9baea953 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3505,6 +3505,13 @@ def sum_(x, axis=None, keepdims=False, *, dtype=None): register_pointwise( aten.lgamma, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) +erf = register_pointwise( + aten.erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +) +register_lowering( + aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT +)(erf) + register_pointwise( aten.log, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, From 1db5ce095fb0e721c92304bceca7798456929e73 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Nov 2022 03:08:31 +0000 Subject: [PATCH 390/453] [vision hash update] update the pinned vision hash (#89287) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89287 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index cc0724ac842d..80fe47b2cee2 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -d710f3d1edc06afa244468cb96603ba6dbd4d9d5 +5b4f79d9ba8cbeeb8d6f0fbba3ba5757b718888b From e0251de42f56c8de0bd9b2783bfa2ae67e4813c5 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Sun, 20 Nov 2022 22:54:45 +0000 Subject: [PATCH 391/453] [Easy] Use prepend arg to register forward hooks in quantize.py (#89391) Differential Revision: [D41431110](https://our.internmc.facebook.com/intern/diff/D41431110) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89391 Approved by: https://github.com/awgu --- torch/ao/quantization/quantize.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 51eb2c1c1ec6..8b149b44ad3d 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -143,11 +143,13 @@ def register_activation_post_process_hook(module, pre_hook=False): assert hasattr(module, 'activation_post_process'), \ 'Expect activation_post_process attribute already attached to the module' if pre_hook: - handle = module.register_forward_pre_hook(_observer_forward_pre_hook) - module._forward_pre_hooks.move_to_end(handle.id, last=False) + handle = module.register_forward_pre_hook( + _observer_forward_pre_hook, prepend=True + ) else: - handle = module.register_forward_hook(_observer_forward_hook) - module._forward_hooks.move_to_end(handle.id, last=False) + handle = module.register_forward_hook( + _observer_forward_hook, prepend=True + ) def add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None): From 79770d3636626b2130e58d5acdf1d6a56953329d Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 20 Nov 2022 20:46:02 -0500 Subject: [PATCH 392/453] TorchDynamo: enable conv+relu6 fusion (#89265) This PR is about enabled conv+relu6 which improves mobilenet'e performance. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89265 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 1 + torch/_inductor/overrides.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fc0ae82a2598..d5c3cd673aca 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -93,6 +93,7 @@ def has_bf16_support(): torch.nn.Hardtanh(min_val=-0.5, max_val=4, inplace=False), torch.nn.GELU(approximate="none"), torch.nn.GELU(approximate="tanh"), + torch.nn.ReLU6(), ] diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index 9a8bc6266ac0..cff3f6f47023 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -59,6 +59,8 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): super(UnaryAttr, self).__init__() def __call__(self, unary_module: nn.Module): + if type(unary_module) is nn.ReLU6: + unary_module = nn.Hardtanh(min_val=0, max_val=6) assert all(hasattr(unary_module, item) for item in self.scalars_attr) scalars = [getattr(unary_module, item) for item in self.scalars_attr] @@ -983,6 +985,7 @@ def rand_like(x, **kwargs): nn.LeakyReLU: UnaryAttr("leaky_relu", scalars_attr=["negative_slope"]), nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), + nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), } From bc716383a6a3063b35cedfe8d163c61a4ff8f301 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Mon, 21 Nov 2022 03:31:50 +0000 Subject: [PATCH 393/453] Redefine the simdlen semantic (#89263) This PR is targeting to automatically enable vectorization optimization for TorchInductor. It refined the semantics of `config.cpp.simdlen`. Originally, `None` means to disable vectorization while a specific value means the number of elements to be vectorized once time. But it depends on the data. Regarding 256bit SVE/SIMD ISA for ARM and X86, the `simdlen` should be 16 for Float while 32 for BFloat. Hence, this PR defined the `simdlen` as the bit width. The detailed semantics are as follows. - **_simdlen = None_**: Automatically determine the SIMD bit width. Detect HW information and pick the proper vectorization ISA. Specific for X86, the priority of AVX512 is higher than AVX2. - **_simdlen <=1_**: Explicitly disable SIMD - **_simdlen > 1_**: Explicitly specify the SIMD bit width. It equals the disabled semantic if the bit width does not match the ISA width. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89263 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_minifier.py | 4 +- test/inductor/test_torchinductor.py | 94 +++++++++++- torch/_inductor/codecache.py | 215 +++++++++++++++++++++------- torch/_inductor/codegen/common.py | 6 + torch/_inductor/codegen/cpp.py | 92 +++++++++--- 5 files changed, 330 insertions(+), 81 deletions(-) diff --git a/test/inductor/test_minifier.py b/test/inductor/test_minifier.py index 55c0a1b6bb05..18c5e5f33cad 100644 --- a/test/inductor/test_minifier.py +++ b/test/inductor/test_minifier.py @@ -24,7 +24,7 @@ def cpp_runtime_error(x): CPP_ACCURACY_ERROR = """\ def cpp_accuracy_error(x): - return f"{x} + 1" + return f"{x} + decltype({x})(1)" """ TRITON_COMPILE_ERROR = """\ @@ -60,8 +60,10 @@ def _gen_codegen_fn_patch_code(self, old_fn_name, new_fn_code, device): patch_code = f"""\ import torch._inductor.codegen.{"cpp" if device == "cpu" else "triton"} as codegen overrides = codegen.{"CppOverrides" if device == "cpu" else "TritonOverrides"} +vec_overrides = codegen.{"CppVecOverrides" if device == "cpu" else "TritonOverrides"} {new_fn_code} overrides.{old_fn_name} = staticmethod({new_fn_name}) +vec_overrides.{old_fn_name} = staticmethod({new_fn_name}) """ return f"""\ {patch_code} diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d5c3cd673aca..3b47cd867c73 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4526,7 +4526,11 @@ def fn(x): v = torch.randn(10) result = fn(v) - assert same(result, mod(v)) + # TODO: OMP parallel reduction order is not deterministic. + # Hence, the accurarcy might vary up and down. For short term, + # we increase the tolerance and will fix it later by using + # aten parallel. + assert same(result, mod(v), tol=5e-1) def test_inplace_add_alpha(self): def fn(x, y): @@ -4596,7 +4600,79 @@ def test_complex_memory_overlap(self): self.assertFalse(complex_memory_overlap(gathered.t())) @unittest.skipIf( - not codecache.get_cpu_proc_info(), "Does not support vectorization" + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch.object(config, "dynamic_shapes", True) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + @patch.object(functorch_config, "use_dynamic_shapes", True) + def test_vec_dynamic_shapes(self): + def fn(x): + return torch.softmax(x, -1) + + value = torch.randn((2, 10)) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value) + + real_out = fn(value) + compiled_out = opt_fn(value) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count < 1 + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_simd(self): + vec_avx512 = codecache.supported_vec_isa_list[0] + vec_avx2 = codecache.supported_vec_isa_list[1] + self.assertTrue(vec_avx512.bit_width() == 512) + self.assertTrue(vec_avx2.bit_width() == 256) + self.assertTrue(vec_avx512.nelements() == 16) + self.assertTrue(vec_avx2.nelements() == 8) + self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) + self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) + + with patch.object(config.cpp, "simdlen", None): + isa = codecache.pick_vec_isa() + if vec_avx512 in codecache.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with patch.object(config.cpp, "simdlen", 0): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 1): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 257): + isa = codecache.pick_vec_isa() + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 513): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + self.assertFalse(isa) + + with patch.object(config.cpp, "simdlen", 512): + isa_list = codecache.valid_vec_isa_list() + if vec_avx512 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx512) + + with patch.object(config.cpp, "simdlen", 256): + isa_list = codecache.valid_vec_isa_list() + if vec_avx2 in isa_list: + isa = codecache.pick_vec_isa() + self.assertTrue(isa == vec_avx2) + + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_sign_cpu_only(self): @@ -4607,7 +4683,7 @@ def fn(x): x[0, 0] = torch.nan x[1, -1] = torch.nan - with patch.object(config.cpp, "simdlen", 8): + with patch.object(config.cpp, "simdlen", None): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x) @@ -4620,7 +4696,7 @@ def fn(x): # other platforms support, we just need to add the ISA info to the supported_vector_isa # and include proper aten vectorization head file. @unittest.skipIf( - not codecache.get_cpu_proc_info(), "Does not support vectorization" + not codecache.valid_vec_isa_list(), "Does not support vectorization" ) @patch("torch.cuda.is_available", lambda: False) def test_vec_kernel_cpu_only(self): @@ -4659,7 +4735,15 @@ def fn(x1, x2): x1 = torch.randn((10, 20)) x2 = torch.randn((10, 20)) - with patch.object(config.cpp, "simdlen", 8): + with patch.object(config.cpp, "simdlen", 1): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x1, x2) + compiled = compile_fx_inner(traced, [x1, x2]) + assert same(fn(x1, x2)[0], compiled([x1, x2])[0], equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count == 0 + + with patch.object(config.cpp, "simdlen", None): torch._dynamo.reset() metrics.reset() traced = make_fx(fn)(x1, x2) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 2826f3599912..232a611b06c6 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1,5 +1,5 @@ import base64 -import enum +import dataclasses import functools import getpass import hashlib @@ -18,7 +18,7 @@ from ctypes import cdll from threading import Thread from time import sleep, time -from typing import Any, Dict +from typing import Any, Callable, Dict, List import torch from torch.utils import cpp_extension @@ -147,79 +147,181 @@ def is_gcc(): return re.search(r"(gcc|g\+\+)", cpp_compiler()) -class _SupportedVecIsa(enum.Enum): - AVX512 = 1 - AVX2 = 2 - INVALID = -1 +class VecISA(object): + _bit_width: int + _macro: str + _arch_flags: str + _dtype_nelements: Dict[torch.dtype, int] + + # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions + # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions + # like exp, pow, sin, cos and etc. + # But PyTorch and TorchInductor might use different compilers to build code. If + # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so + # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass + # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest + # gcc/g++ compiler by default while it could support the AVX512 compilation. + # Therefore, there would be a conflict sleef version between PyTorch and + # TorchInductor. Hence, we dry-compile the following code to check whether current + # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM + # also needs the logic + _avx_code = """ +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) +#include +#include +#endif + +__attribute__((aligned(64))) float in_out_ptr0[16] = {0.0}; + +extern "C" void __avx_chk_kernel() { + auto tmp0 = at::vec::Vectorized(1); + auto tmp1 = tmp0.exp(); + tmp1.store(in_out_ptr0); +} +""" + + _avx_py_load = """ +import torch +from ctypes import cdll +cdll.LoadLibrary("__lib_path__") +""" + + def bit_width(self): + return self._bit_width + + def nelements(self, dtype: torch.dtype = torch.float): + return self._dtype_nelements[dtype] + def build_macro(self): + return self._macro + + def build_arch_flags(self): + return self._arch_flags + + def __hash__(self) -> int: + return hash(str(self)) + + @functools.lru_cache(None) def __bool__(self): - return self != _SupportedVecIsa.INVALID + key, input_path = write(VecISA._avx_code, "cpp", extra="") + from filelock import FileLock + + lock_dir = get_lock_dir() + lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) + with lock: + output_path = input_path[:-3] + "so" + build_cmd = cpp_compile_command( + input_path, output_path, warning_all=False, vec_isa=self + ).split(" ") + try: + # Check build result + subprocess.check_output(build_cmd, stderr=subprocess.STDOUT) + subprocess.check_call( + [ + "python", + "-c", + VecISA._avx_py_load.replace("__lib_path__", output_path), + ], + stderr=subprocess.DEVNULL, + ) + except Exception as e: + return False - @staticmethod - def isa_str(supported_isa: enum.Enum): - if supported_isa == _SupportedVecIsa.AVX512: - return "avx512" - elif supported_isa == _SupportedVecIsa.AVX2: - return "avx2" - else: - return "" + return True + + +@dataclasses.dataclass +class VecAVX512(VecISA): + _bit_width = 512 + _macro = "CPU_CAPABILITY_AVX512" + _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma" + _dtype_nelements = {torch.float: 16, torch.bfloat16: 32} + + def __str__(self) -> str: + return "avx512" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ - @staticmethod - def vec_macro(supported_isa: enum.Enum): - if supported_isa == _SupportedVecIsa.AVX512: - return "CPU_CAPABILITY_AVX512" - elif supported_isa == _SupportedVecIsa.AVX2: - return "CPU_CAPABILITY_AVX2" - else: - return "" + +@dataclasses.dataclass +class VecAVX2(VecISA): + _bit_width = 256 + _macro = "CPU_CAPABILITY_AVX2" + _arch_flags = "-mavx2 -mfma" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16} + + def __str__(self) -> str: + return "avx2" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +class InvalidVecISA(VecISA): + _bit_width = 0 + _macro = "" + _arch_flags = "" + _dtype_nelements = {} + + def __str__(self) -> str: + return "INVALID_VEC_ISA" + + def __bool__(self): + return False + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + +invalid_vec_isa = InvalidVecISA() +supported_vec_isa_list = [VecAVX512(), VecAVX2()] # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. -@functools.lru_cache(1) -def get_cpu_proc_info(): +@functools.lru_cache(None) +def valid_vec_isa_list(): if sys.platform != "linux": return [] - isa_info = [] + isa_list = [] with open("/proc/cpuinfo") as _cpu_info: _cpu_info_content = _cpu_info.read() - if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX512) in _cpu_info_content: - isa_info.append(_SupportedVecIsa.AVX512) - - if _SupportedVecIsa.isa_str(_SupportedVecIsa.AVX2) in _cpu_info_content: - isa_info.append(_SupportedVecIsa.AVX2) + for isa in supported_vec_isa_list: + if str(isa) in _cpu_info_content and isa: + isa_list.append(isa) + return isa_list - return isa_info +def pick_vec_isa(): + _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() + if not _valid_vec_isa_list: + return invalid_vec_isa -def supported_vector_isa(): - # TODO: Add ARM Vec here. - # Dict(k: isa, v: number of float element) - vec_isa_info = { - _SupportedVecIsa.AVX512: 16, - _SupportedVecIsa.AVX2: 8, - } + # If the simdlen is None, it indicates determin the vectroization length automatically + if config.cpp.simdlen is None: + assert _valid_vec_isa_list + return _valid_vec_isa_list[0] - if config.cpp.simdlen is None or config.cpp.simdlen <= 1: - return _SupportedVecIsa.INVALID - - cpu_info_content = get_cpu_proc_info() - for isa in vec_isa_info.keys(): - if isa in cpu_info_content and config.cpp.simdlen == vec_isa_info[isa]: + for isa in _valid_vec_isa_list: + if config.cpp.simdlen == isa.bit_width(): return isa - return _SupportedVecIsa.INVALID + return invalid_vec_isa -def cpp_compile_command(input, output, include_pytorch=False): - valid_isa = supported_vector_isa() - if include_pytorch or valid_isa: +def cpp_compile_command( + input, + output, + warning_all=True, + shared=True, + include_pytorch=False, + vec_isa: VecISA = invalid_vec_isa, +): + if include_pytorch or vec_isa != invalid_vec_isa: ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")] lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")] libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"] - macros = _SupportedVecIsa.vec_macro(valid_isa) + macros = vec_isa.build_macro() if macros: macros = f"-D{macros}" else: @@ -235,11 +337,13 @@ def cpp_compile_command(input, output, include_pytorch=False): lpaths = " ".join(["-L" + p for p in lpaths]) libs = " ".join(["-l" + p for p in libs]) + shared_lib = "-shared -fPIC" if shared else "" + warning_all_flag = "-Wall" if warning_all else "" return re.sub( r"[ \n]+", " ", f""" - {cpp_compiler()} {input} -shared -fPIC -Wall -std=c++14 -Wno-unused-variable + {cpp_compiler()} {input} {shared_lib} {warning_all_flag} -std=c++14 -Wno-unused-variable {ipaths} {lpaths} {libs} {macros} -march=native -O3 -ffast-math -fno-finite-math-only -fopenmp -D C10_USING_CUSTOM_GENERATED_MACROS @@ -266,7 +370,12 @@ def _load_library(path): @classmethod def load(cls, source_code): - key, input_path = write(source_code, "cpp", extra=cpp_compile_command("i", "o")) + picked_vec_isa = pick_vec_isa() + key, input_path = write( + source_code, + "cpp", + extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa), + ) if key not in cls.cache: from filelock import FileLock @@ -276,7 +385,7 @@ def load(cls, source_code): output_path = input_path[:-3] + "so" if not os.path.exists(output_path): cmd = cpp_compile_command( - input=input_path, output=output_path + input=input_path, output=output_path, vec_isa=picked_vec_isa ).split(" ") try: subprocess.check_output(cmd, stderr=subprocess.STDOUT) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 2803970295cc..cf98833964ca 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -417,6 +417,12 @@ def __init__(self, name): def __str__(self): return self.name + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other) -> bool: + return type(other) == type(self) and other.name == self.name + def update_on_args(self, args, kwargs): pass diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index cf8e6616d677..f82591ddff36 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -620,7 +620,7 @@ def codegen_loops(self, code, worksharing): ) reductions.mark_reduction(self.reduction_vars) - if config.cpp.simdlen: + if codecache.pick_vec_isa(): # TODO(jansel): detect stride-1 dimension and vectorize that if reductions: reductions.loops[-1].simd = True @@ -711,7 +711,8 @@ class CppVecKernel(CppKernel): def __init__(self, args, num_threads): super(CppVecKernel, self).__init__(args, num_threads) - self.simd_len = config.cpp.simdlen + assert codecache.pick_vec_isa() + self.simd_nelements = codecache.pick_vec_isa().nelements() self.reduction_omp_dec: Dict[str, str] = {} metrics.generated_cpp_vec_kernel_count += 1 @@ -727,10 +728,10 @@ def is_var_irrevelant(self, var: sympy.Symbol, index: sympy.Expr): def transform_index(self, index: sympy.Expr): expanded_index = sympy.expand(index) - assert self.simd_len - assert self.simd_len > 0 + assert self.simd_nelements + assert self.simd_nelements >= 1 most_inner_var = self.itervars[-1] - replacement = {most_inner_var: most_inner_var * self.simd_len} + replacement = {most_inner_var: most_inner_var * self.simd_nelements} new_index = sympy_subs(expanded_index, replacement) return new_index @@ -951,21 +952,24 @@ def __init__(self, args=None, num_threads=None): super(CppKernelProxy, self).__init__(args, num_threads) self.simd_vec_kernel: CppVecKernel = None self.simd_omp_kernel: CppKernel = None + self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() - def vectorize_most_inner_loop(self, loop_nest): - loop_nest.split_most_inner_loop(config.cpp.simdlen) + def vectorize_most_inner_loop(self, loop_nest, dtype=torch.float): + assert self.picked_vec_isa + nelements = self.picked_vec_isa.nelements(dtype) + loop_nest.split_most_inner_loop(nelements) loop_with_tail = loop_nest.loops[-1] assert isinstance(loop_with_tail, LoopLevelWithTail) loop_with_tail.main_loop.simd_vec = True loop_with_tail.tail_loop.simd_omp = True - # We chope the loop into two cubes by the config.cpp.simdlen - main loop and tail loop. + # We chope the loop into two cubes by the nelements - main loop and tail loop. # Regarding the main loop, it is straightforward that it could be vectorized with - # config.cpp.simdlen. But for the tail loop, it still could be vectorized. For example, - # if the config.cpp.simdlen is 8(256bits), then the tail loop still could be vectorized + # nelements. But for the tail loop, it still could be vectorized. For example, + # if the nelements is 8(256bits), then the tail loop still could be vectorized # as 4(128bits). - loop_with_tail.tail_loop.simd_len = int(config.cpp.simdlen / 2) + loop_with_tail.tail_loop.simd_nelements = int(nelements / 2) loop_with_tail.tail_loop.simd_vec = False loop_with_tail.main_loop_body = self.simd_vec_kernel @@ -975,7 +979,7 @@ def vectorize_most_inner_loop(self, loop_nest): def codegen_loops(self, code, worksharing): threads = parallel_num_threads() - if self.simd_vec_kernel is None: + if self.simd_vec_kernel is None or not self.picked_vec_isa: assert self.simd_omp_kernel return self.simd_omp_kernel.codegen_loops(code, worksharing) @@ -997,12 +1001,52 @@ def codegen_loops(self, code, worksharing): ), LoopNest(loops[reduction_depth:]) loops_nest_reduce.mark_reduction(self.simd_vec_kernel.reduction_vars) - if config.cpp.simdlen: - # TODO(jansel): detect stride-1 dimension and vectorize that - if loops_nest_reduce: - loops_nest_reduce.loops[-1].simd = True - elif loops_nest_non_reduce: - loops_nest_non_reduce.loops[-1].simd = True + assert self.picked_vec_isa + # Do not apply vectorization since the range of most inner is too small. Meanwhile, + # If the range of the most inner is less then the codecache.pick_vec_isa().nelements(), + # the generated code for some reduction will be as follows that leads to incrrect result. + # + # LINE01: float tmp1 = 0; + # LINE02: auto tmp1_vec = at::vec::Vectorized(tmp1); + # LINE03: for(long i1=0; i1<2; i1+=1) + # LINE04: { + # LINE05: for(long i2=0; i2<0; i2+=1) + # LINE06: { + # LINE07: auto tmp0 = at::vec::Vectorized::loadu(in_ptr0 + (8*i0) + (16*i2) + (32*i1)); + # LINE08: tmp1_vec += tmp0; + # LINE09: } + # LINE10: tmp1 = vec_reduce_all([](Vectorized& x, Vectorized&y) {return x + y;}, tmp1_vec); + # LINE11: #pragma omp simd simdlen(8) reduction(+:tmp1) + # LINE12: for(long i2=0; i2<8; i2+=1) + # LINE13: { + # LINE14: auto tmp0 = in_ptr0[i2 + (8*i0) + (32*i1)]; + # LINE15: tmp1 += tmp0; + # LINE16: } + # LINE17: } + # LINE18: out_ptr3[i0] = tmp1; + # + # tmp1_vec(LINE02) will always be zero as it is initialized with tmp1 value and the range(LINE05) + # is 0. Hence, the LINE10 will always reset tmp1 to 0. But tmp1(LINE01) is global value. So the result + # will be incorrect. We skip thie case. + most_inner_loop = ( + loops_nest_reduce.loops[-1] + if loops_nest_reduce + else loops_nest_non_reduce.loops[-1] + ) + main_loop_range = ir.IndexingDiv( + most_inner_loop.size, self.picked_vec_isa.nelements() + ) + loop_interval = sympy.simplify(main_loop_range) + # TODO(Eikan): To support dynamic shape. + if not loop_interval.is_integer or loop_interval <= 0: + metrics.generated_cpp_vec_kernel_count -= 1 + return self.simd_omp_kernel.codegen_loops(code, worksharing) + + # TODO(jansel): detect stride-1 dimension and vectorize that + if loops_nest_reduce: + loops_nest_reduce.loops[-1].simd = True + elif loops_nest_non_reduce: + loops_nest_non_reduce.loops[-1].simd = True par_depth = 0 reduction_par_depth = 0 @@ -1142,8 +1186,7 @@ def can_fuse_vertical(cls, node1, node2): return cls.can_fuse_horizontal(node1, node2) and not node1.is_reduction() def can_vec(self, nodes): - # TODO: Query cpu arch and vec length from aten - if not codecache.supported_vector_isa(): + if not codecache.pick_vec_isa(): return False _, (group, reduction_group) = max( @@ -1353,7 +1396,8 @@ class LoopLevel: steps: sympy.Expr = sympy.Integer(1) parallel: int = 0 simd_omp: bool = False - simd_len: int = config.cpp.simdlen + picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa() + simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0 simd_vec: bool = False collapsed: bool = False reduction_vars: Dict[str, str] = None @@ -1367,7 +1411,11 @@ def lines(self): ) else: reduction = "" - simd = f"simd simdlen({self.simd_len}) " if self.simd_omp else "" + simd = ( + f"simd simdlen({self.simd_nelements}) " + if self.simd_omp and self.simd_nelements > 1 + else "" + ) if self.parallel: # TODO(jansel): look into chunk size and other schedules line1 = f"#pragma omp for{reduction} " From 31708a731076b7feed3051b81d309a9babb4efc0 Mon Sep 17 00:00:00 2001 From: XiaobingSuper Date: Sun, 20 Nov 2022 20:46:03 -0500 Subject: [PATCH 394/453] TorchDynamo: enable conv+silu fusion (#89278) This PR will improve the tf_efficientnet_b0 performance by fusing conv+silu. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89278 Approved by: https://github.com/jgong5, https://github.com/jansel --- aten/src/ATen/native/mkldnn/Utils.cpp | 1 + test/inductor/test_torchinductor.py | 1 + torch/_inductor/overrides.py | 1 + 3 files changed, 3 insertions(+) diff --git a/aten/src/ATen/native/mkldnn/Utils.cpp b/aten/src/ATen/native/mkldnn/Utils.cpp index 5db6e0b07ff1..2c626884d8f0 100644 --- a/aten/src/ATen/native/mkldnn/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/Utils.cpp @@ -132,6 +132,7 @@ const std::map& fusion_unary_attr_map() { {"relu", ATTR_FUNC(relu)}, {"sigmoid", ATTR_FUNC(sigmoid)}, {"tanh", ATTR_FUNC(tanh)}, + {"swish", ATTR_FUNC(swish)}, {"hardswish", ATTR_FUNC(hardswish)}, {"leaky_relu", attr_func_leaky_relu}, {"hardtanh", attr_func_hardtanh}, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 3b47cd867c73..399032890ca8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -94,6 +94,7 @@ def has_bf16_support(): torch.nn.GELU(approximate="none"), torch.nn.GELU(approximate="tanh"), torch.nn.ReLU6(), + torch.nn.SiLU(), ] diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index cff3f6f47023..5bd97cd5009a 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -986,6 +986,7 @@ def rand_like(x, **kwargs): nn.Hardtanh: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), nn.GELU: UnaryAttr("gelu", algorithm_attr="approximate"), nn.ReLU6: UnaryAttr("hardtanh", scalars_attr=["min_val", "max_val"]), + nn.SiLU: UnaryAttr("swish"), } From a80e5e78137fb8adea6e7d638be483f866fe26e8 Mon Sep 17 00:00:00 2001 From: yanbing-j Date: Mon, 21 Nov 2022 09:52:34 +0000 Subject: [PATCH 395/453] Update ideep for future performance improvement (#87966) **Summary** The update includes API changes and optimzations to reduce framework overhead, which will benefit all mkldnn (onednn) ops in JIT mode and inductor CPU backend, etc. These benefits will be seen after switching to new ideep API by future PRs. **Test plan** For correctness, all UTs that call mkldnn ops, including test_ops.py, test_mkldnn*.py, test_quantization.py, etc. For performance, TorchBench has been run and no regression is found. Results are shown below. - Intel (R) Xeon (R) IceLake with 40 cores - Use multi-instance - Using tcmalloc & Intel OMP ![image](https://user-images.githubusercontent.com/12522207/201631004-bb77468d-953b-4757-a001-94d44615b5f6.png) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87966 Approved by: https://github.com/jgong5, https://github.com/XiaobingSuper --- third_party/ideep | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/ideep b/third_party/ideep index ececd0a4f53c..5ddc65efe042 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit ececd0a4f53c39f2d91caaddee0de1cd214f5b99 +Subproject commit 5ddc65efe0428bbce2942b3ce5e3ce15239abe2f From c2cf0bde1f4e9bed642648f299db0f6d5ecb5996 Mon Sep 17 00:00:00 2001 From: lezcano Date: Mon, 21 Nov 2022 10:54:32 +0000 Subject: [PATCH 396/453] Move the OpInfo same-storage error to the autograd test (#88306) This check was previously located at the `non_contiguous` test (quite and odd location). Even more, at https://github.com/pytorch/pytorch/pull/86378#discussion_r993658395, Kshiteej found that this assert was not doing anything really. We move it to the autograd test and make it a proper `self.assert`. We also disallow returning 1-tuples from sample_input functions, as they were breaking this assert. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88306 Approved by: https://github.com/mruberry --- test/test_ops.py | 5 --- test/test_testing.py | 17 +++------- .../_internal/common_methods_invocations.py | 31 ++++++++----------- torch/testing/_internal/common_utils.py | 9 ++++++ 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 5f9ad6ff4317..7e0a9952389c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -500,11 +500,6 @@ def test_noncontiguous_samples(self, device, dtype, op): noncontig_sample.kwargs, ) - # Verifies sample input tensors should have no grad or history - sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0] - assert sample_tensor.grad is None - assert sample_tensor.grad_fn is None - # validates forward expected = op(t_inp, *t_args, **t_kwargs) actual = op(n_inp, *n_args, **n_kwargs) diff --git a/test/test_testing.py b/test/test_testing.py index f05883919f17..6dc06a8a2aeb 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,7 +12,7 @@ import subprocess import sys import unittest.mock -from typing import Any, Callable, Iterator, List, Tuple, Generator, Sequence +from typing import Any, Callable, Iterator, List, Tuple, Generator import torch @@ -1923,32 +1923,23 @@ def test_sample_input_metadata(self) -> None: # Tests that validate the various sample generating functions on each OpInfo. class TestOpInfoSampleFunctions(TestCase): - def _assert_is_generator_or_singleton(self, item, property_name): - if isinstance(item, Sequence): - msg = ( - "{property_name} may only return lists for single items" - ", please use a coroutine which yields items instead") - self.assertTrue(len(item) <= 1, msg=msg) - else: - self.assertIsInstance(item, Generator) - @ops(op_db, dtypes=OpDTypes.any_one) def test_opinfo_sample_generators(self, device, dtype, op): # Test op.sample_inputs doesn't generate multiple samples when called samples = op.sample_inputs(device, dtype) - self._assert_is_generator_or_singleton(samples, "sample_inputs_func") + self.assertIsInstance(samples, Generator) @ops([op for op in op_db if op.reference_inputs_func is not None], dtypes=OpDTypes.any_one) def test_opinfo_reference_generators(self, device, dtype, op): # Test op.reference_inputs doesn't generate multiple samples when called samples = op.reference_inputs(device, dtype) - self._assert_is_generator_or_singleton(samples, "reference_inputs_func") + self.assertIsInstance(samples, Generator) @ops([op for op in op_db if op.error_inputs_func is not None], dtypes=OpDTypes.none) def test_opinfo_error_generators(self, device, op): # Test op.error_inputs doesn't generate multiple inputs when called samples = op.error_inputs(device) - self._assert_is_generator_or_singleton(samples, "error_inputs_func") + self.assertIsInstance(samples, Generator) instantiate_device_type_tests(TestOpInfoSampleFunctions, globals()) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4ccb5ef3840f..0f845f765829 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -958,11 +958,11 @@ def sample_inputs_sparse_sampled_addmm(op_info, device, dtype, requires_grad, ** def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) - return (SampleInput(make_arg(S, M), make_arg(M)),) + yield SampleInput(make_arg(S, M), make_arg(M)) def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) - return (SampleInput(make_arg(M, S, M), make_arg(M, M, S)),) + yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -1569,9 +1569,9 @@ def sample_inputs_logcumsumexp(self, device, dtype, requires_grad, **kwargs): yield SampleInput(t, dim) def sample_inputs_trace(self, device, dtype, requires_grad, **kwargs): - return (SampleInput((make_tensor((S, S), dtype=dtype, device=device, - low=None, high=None, - requires_grad=requires_grad))),) + yield SampleInput((make_tensor((S, S), dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad))) def error_inputs_trace(op, device): @@ -5281,6 +5281,11 @@ def error_inputs_complex(op_info, device, is_ref=False, **kwargs): out=make_arg(M, S, dtype=torch.complex64)), error_type=RuntimeError, error_regex=error_out) +def sample_inputs_logaddexp(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + shape = (S, S) + yield SampleInput(make_arg(shape), make_arg(shape)) + def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): def make_arg(shape): # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck @@ -5322,7 +5327,7 @@ def error_inputs_neg(op_info, device, **kwargs): msg = ("Negation, the `\\-` operator, on a bool tensor is not supported." " If you are trying to invert a mask, use the `\\~` or" " `logical_not\\(\\)` operator instead.") - return (ErrorInput(si, error_regex=msg),) + yield ErrorInput(si, error_regex=msg) def sample_inputs_diag(op_info, device, dtype, requires_grad, **kwargs): make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad, low=None, high=None) @@ -8318,7 +8323,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), @@ -9928,7 +9932,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_linspace, skips=( # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), @@ -9956,7 +9959,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_logpace, skips=( # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), @@ -10060,18 +10062,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfROCM=floating_types_and(torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, - sample_inputs_func=lambda op_info, device, dtype, requires_grad=False, **kwargs: - (SampleInput(make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad), - args=(make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad),)),)), + sample_inputs_func=sample_inputs_logaddexp), OpInfo('logaddexp2', dtypes=floating_types_and(torch.bfloat16), dtypesIfCUDA=floating_types_and(torch.bfloat16), dtypesIfROCM=floating_types_and(torch.bfloat16), supports_forward_ad=True, supports_fwgrad_bwgrad=True, - sample_inputs_func=lambda op_info, device, dtype, requires_grad=False, **kwargs: - (SampleInput(make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad), - args=(make_tensor((S, S), dtype=dtype, device=device, requires_grad=requires_grad),)),)), + sample_inputs_func=sample_inputs_logaddexp), UnaryUfuncInfo('logical_not', ref=np.logical_not, decorators=(precisionOverride({torch.bfloat16: 7e-1, @@ -14573,7 +14571,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_ones_zeros, skips=( # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), @@ -14594,7 +14591,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_ones_zeros, skips=( # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), @@ -14615,7 +14611,6 @@ def reference_flatten(input, start_dim=0, end_dim=-1): sample_inputs_func=sample_inputs_full, skips=( # Tests that assume input is a tensor or sequence of tensors - DecorateInfo(unittest.expectedFailure, "TestCommon", "test_noncontiguous_samples"), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 35ec53381c1f..e53887a5fdbb 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -3749,6 +3749,15 @@ def is_inplace(variant): all_args = tuple(chain((sample.input,), sample.args, sample.kwargs.values())) gradcheck_args = tuple(x for x in all_args if (isinstance(x, torch.Tensor) and x.requires_grad)) + # Verifies sample input tensors should have no grad + # This may happen if the same tensor is used in two different SampleInputs + for t in gradcheck_args: + self.assertIsNone(t.grad, + "A sampled input has a gradient before running autograd. " + "This usually means that (at least) one input tensor is reused " + "across different SampleInputs. " + "Please create a new tensor for each SampleInput.") + def _input_recomposition_helper(inputs, inp, input_idx): if is_iterable_of_tensors(inp): tensor_list = [] From 6796979ee1063890fd04bbf21f298f669129df8f Mon Sep 17 00:00:00 2001 From: Jiong Gong Date: Mon, 21 Nov 2022 14:20:33 +0000 Subject: [PATCH 397/453] [Inductor] Limit the number of compile threads to the available cpu cores (#89377) `config.compile_threads` gets the number of compile threads via `min(32,os.cpu_count())` while `os.cpu_count()` is the total number of cpu cores in the system, not the available ones. This would cause compile thread contention when the available cpu cores are less than `min(32,os.cpu_count())`, e.g., available cpu cores are limited with numactl or taskset, making the compilation very slow. This PR tries to use `len(os.sched_getaffinity(0))` if `os.sched_getaffinity` is available which returns the available number of cpu cores. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89377 Approved by: https://github.com/soumith --- torch/_inductor/config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c552101c1cae..a0062c4fe4e2 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -59,7 +59,16 @@ comment_origin = False -compile_threads = min(32, os.cpu_count()) if sys.platform != "win32" else 1 +compile_threads = ( + min( + 32, + len(os.sched_getaffinity(0)) + if hasattr(os, "sched_getaffinity") + else os.cpu_count(), + ) + if sys.platform != "win32" + else 1 +) # If kernel is fused, the name is generated from the origin node op names # for larger kernels limit this From f3db03612f9c6fb8717e1e13a9295da3c9c05193 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Nov 2022 16:38:20 +0000 Subject: [PATCH 398/453] Revert "[ao] maintain BC for is_activation_post_process (#89260)" This reverts commit c5fafb4e1694f141d8a1a31142cce4049d9057ed. Reverted https://github.com/pytorch/pytorch/pull/89260 on behalf of https://github.com/DanilBaibak due to breaking internal builds --- torch/ao/quantization/quantize.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index 8b149b44ad3d..ae080ccaa2ca 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -27,12 +27,7 @@ float_qparams_weight_only_qconfig_4bit, _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations - -from torch.ao.quantization.observer import ( # noqa: F401 - _is_activation_post_process, - _is_activation_post_process as is_activation_post_process, - # TODO remove this once problems from name change are resolved -) +from torch.ao.quantization.observer import _is_activation_post_process __all__ = [ "get_default_custom_config_dict", From 9d209e78348ee5c3e1ead700d240fb476b3bc4de Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 21 Nov 2022 16:48:26 +0000 Subject: [PATCH 399/453] Revert "[ao] making _is_activation_post_process private (#87520)" This reverts commit 45c62a337756ff9db97cd64d2d42d9e65dda0a85. Reverted https://github.com/pytorch/pytorch/pull/87520 on behalf of https://github.com/bigfootjon due to Diff reverted internally --- test/allowlist_for_publicAPI.json | 4 ++-- test/quantization/ao_migration/test_ao_migration.py | 2 +- test/quantization/ao_migration/test_quantization.py | 2 +- test/quantization/fx/test_quantize_fx.py | 6 +++--- torch/ao/ns/fx/graph_passes.py | 4 ++-- torch/ao/ns/fx/utils.py | 8 ++++---- torch/ao/quantization/__init__.py | 1 + torch/ao/quantization/fx/_model_report/detector.py | 4 ++-- torch/ao/quantization/fx/convert.py | 6 +++--- torch/ao/quantization/fx/prepare.py | 4 ++-- torch/ao/quantization/fx/qconfig_mapping_utils.py | 6 +++--- torch/ao/quantization/fx/utils.py | 6 +++--- torch/ao/quantization/observer.py | 2 +- torch/ao/quantization/quantize.py | 9 +++++++-- torch/quantization/quantize.py | 2 +- 15 files changed, 36 insertions(+), 30 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 2e1394a72e17..94ff57700af6 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -786,7 +786,7 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", - "_is_activation_post_process", + "is_activation_post_process", "load_observer_state_dict", "no_observer_set", "prepare", @@ -894,7 +894,7 @@ "convert", "get_observer_dict", "get_unique_devices_", - "_is_activation_post_process", + "is_activation_post_process", "prepare", "prepare_qat", "propagate_qconfig_", diff --git a/test/quantization/ao_migration/test_ao_migration.py b/test/quantization/ao_migration/test_ao_migration.py index 260ab32056f6..accb13da0dcb 100644 --- a/test/quantization/ao_migration/test_ao_migration.py +++ b/test/quantization/ao_migration/test_ao_migration.py @@ -19,7 +19,7 @@ def test_function_import_quantize(self): 'convert', 'get_observer_dict', 'get_unique_devices_', - '_is_activation_post_process', + 'is_activation_post_process', 'prepare', 'prepare_qat', 'propagate_qconfig_', diff --git a/test/quantization/ao_migration/test_quantization.py b/test/quantization/ao_migration/test_quantization.py index 95c5c7bd6015..9c246e1b7cd8 100644 --- a/test/quantization/ao_migration/test_quantization.py +++ b/test/quantization/ao_migration/test_quantization.py @@ -22,7 +22,7 @@ def test_function_import_quantize(self): 'convert', 'get_observer_dict', 'get_unique_devices_', - '_is_activation_post_process', + 'is_activation_post_process', 'prepare', 'prepare_qat', 'propagate_qconfig_', diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 2d91ba80b7e0..bab4467894e2 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -55,6 +55,7 @@ get_default_qat_qconfig, get_default_qconfig_mapping, get_default_qat_qconfig_mapping, + is_activation_post_process, fuse_modules, fuse_modules_qat, prepare, @@ -147,7 +148,6 @@ default_fixed_qparams_range_0to1_observer, default_fixed_qparams_range_neg1to1_observer, MinMaxObserver, - _is_activation_post_process, ) # test utils @@ -3249,7 +3249,7 @@ def _check_node_not_observed(model, arg_node, node): _check_node_not_observed(model, new_node, node) elif arg_node.op == "call_module": self.assertTrue( - not _is_activation_post_process(getattr(model, arg_node.target)), + not is_activation_post_process(getattr(model, arg_node.target)), "Arg: {0} of node: {1} is observed but is not a float tensor".format( arg_node, node ), @@ -5008,7 +5008,7 @@ def forward(self, x): qconfig_dict = func(backend) m = prepare_fx(m, qconfig_dict, example_inputs=(torch.randn(1, 1, 1, 1))) for name, mod in m.named_modules(): - if _is_activation_post_process(mod) and mod.dtype == torch.quint8: + if is_activation_post_process(mod) and mod.dtype == torch.quint8: if backend == "fbgemm": lower_bnd = 0 upper_bnd = 127 diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py index 3f4e15685902..c78b19d2701b 100644 --- a/torch/ao/ns/fx/graph_passes.py +++ b/torch/ao/ns/fx/graph_passes.py @@ -24,7 +24,7 @@ from torch.ao.ns.fx.mappings import ( get_node_type_to_io_type_map, ) -from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.quantize import is_activation_post_process from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set @@ -38,7 +38,7 @@ def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]: if node.op == 'call_module': assert isinstance(node.target, str) module = getattr_from_fqn(gm, node.target) - if _is_activation_post_process(module): + if is_activation_post_process(module): node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0) fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index] return fqn # type: ignore[return-value] diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py index 90574dc20248..2993764b8a12 100644 --- a/torch/ao/ns/fx/utils.py +++ b/torch/ao/ns/fx/utils.py @@ -13,10 +13,10 @@ from torch.fx.graph import Node from torch.ao.quantization import ( ObserverBase, - FakeQuantizeBase + FakeQuantizeBase, ) -from torch.ao.quantization.observer import _is_activation_post_process from torch.ao.quantization.utils import getattr_from_fqn +from torch.ao.quantization.quantize import is_activation_post_process from .ns_types import NSNodeTargetType, NSResultsType @@ -256,14 +256,14 @@ def return_first_non_observer_node( """ if node.op == "call_module": node_obj = getattr_from_fqn(gm, node.target) # type: ignore[arg-type] - if _is_activation_post_process(node_obj): + if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] # code duplication intended, not worth refactoring assert isinstance(node.target, str) node_obj = getattr_from_fqn(gm, node.target) - if _is_activation_post_process(node_obj): + if is_activation_post_process(node_obj): assert len(node.args) == 1 assert isinstance(node.args[0], Node) node = node.args[0] diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index bc8403f32af8..1ba2a60ed3d1 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -114,6 +114,7 @@ "get_quantized_operator", "get_static_quant_module_class", "get_unique_devices_", + "is_activation_post_process", "load_observer_state_dict", "no_observer_set", "per_channel_weight_observer_range_neg_127_to_127", diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index d398819ddcdd..c92733bbc1c3 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -23,7 +23,7 @@ default_equalization_qconfig, EqualizationQConfig, ) -from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.quantize import is_activation_post_process # Names for observer insert keys DETECTOR_TARGET_NODE_KEY = "target_node" @@ -1273,7 +1273,7 @@ def _supports_insertion(self, module: nn.Module) -> bool: # case for insertion of module # check if the module has any children and isn't observer num_children = len(list(module.children())) - return num_children == 0 and not _is_activation_post_process(module) + return num_children == 0 and not is_activation_post_process(module) def get_qconfig_info(self, model) -> Dict[str, DetectorQConfigInfo]: r""" Returns the DetectorQConfigInfo for each module_fqn relavent diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index f09785679e37..faa267c492c6 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -64,6 +64,7 @@ ) from torch.ao.quantization.quantize import ( _remove_qconfig, + is_activation_post_process, ) from torch.ao.quantization.stubs import DeQuantStub from .custom_config import ( @@ -73,7 +74,6 @@ from .lower_to_fbgemm import lower_to_fbgemm # importing the lib so that the quantized_decomposed ops are registered from ._decomposed import quantized_decomposed_lib # noqa: F401 -from torch.ao.quantization.observer import _is_activation_post_process # TODO: revisit this list. Many helper methods shouldn't be public @@ -359,7 +359,7 @@ def maybe_get_observer_for_node( for maybe_obs_node, _ in node.users.items(): if maybe_obs_node.op == 'call_module': maybe_obs = modules[str(maybe_obs_node.target)] - if _is_activation_post_process(maybe_obs): + if is_activation_post_process(maybe_obs): return maybe_obs return None @@ -787,7 +787,7 @@ def convert( elif node.op == "call_module": mod = _get_module(node, modules) assert mod is not None - if _is_activation_post_process(mod): + if is_activation_post_process(mod): observed_node = node.args[0] if observed_node in statically_quantized_custom_module_nodes: _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 005a9cef45e3..c908e3f3b764 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -16,7 +16,6 @@ ) from ..observer import ( ObserverBase, - _is_activation_post_process ) from ..qconfig import ( _is_reuse_input_qconfig, @@ -79,6 +78,7 @@ ) from torch.ao.quantization.quantize import ( + is_activation_post_process, convert ) @@ -148,7 +148,7 @@ def is_activation_post_process_node(node: Node, modules: Dict[str, torch.nn.Module]) -> bool: return isinstance(node, torch.fx.Node) and node.op == "call_module" and \ - _is_activation_post_process(modules[str(node.target)]) + is_activation_post_process(modules[str(node.target)]) def is_input_arg_dtype_supported_by_backend( arg: Argument, diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 26c7effd44db..0b0407c0b106 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -3,8 +3,8 @@ from typing import Callable, Any, Dict, Tuple, Set, List from torch.ao.quantization import QConfig from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals -from torch.ao.quantization.observer import ( - _is_activation_post_process, +from torch.ao.quantization.quantize import ( + is_activation_post_process, ) from torch.ao.quantization.backend_config import ( DTypeConfig, @@ -158,7 +158,7 @@ def generate_node_name_to_qconfig( elif node.op == 'call_module': # if the node is an observer, just continue - don't add it to the qconfig_map - if _is_activation_post_process(modules[node.target]): + if is_activation_post_process(modules[node.target]): continue qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index b8bfa4c9d053..73fdb0700144 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -30,7 +30,7 @@ is_per_channel, to_underlying_dtype, ) -from torch.ao.quantization.observer import _is_activation_post_process +from torch.ao.quantization.quantize import is_activation_post_process from torch.fx import GraphModule, map_arg @@ -447,7 +447,7 @@ def all_node_args_have_no_tensors(node: Node, modules: Dict[str, torch.nn.Module result = False elif node.op == 'call_module': assert isinstance(node.target, str) - if _is_activation_post_process(modules[node.target]): + if is_activation_post_process(modules[node.target]): result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] elif node.op == 'call_module': result = False @@ -1040,7 +1040,7 @@ def _activation_post_process_satisfies_dtype_config_constraints( satisfies_constraints = True if activation_post_process_ctr is not None: activation_post_process = activation_post_process_ctr() - assert _is_activation_post_process(activation_post_process) + assert is_activation_post_process(activation_post_process) # If dtypes don't match, don't check the activation_post_process and return True early if activation_post_process.dtype != dtype_with_constraints.dtype: return True diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 42962fe7c29a..ea2a26bf3896 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -1442,7 +1442,7 @@ def _is_observer_script_module(mod, obs_type_name): def _is_activation_post_process(module): return ( isinstance(module, torch.ao.quantization.ObserverBase) - or isinstance(module, torch.ao.quantization.FakeQuantizeBase) + or isinstance(module, torch.ao.quantization.FakeQuantize) or _is_observer_script_module(module, "quantization.observer") ) diff --git a/torch/ao/quantization/quantize.py b/torch/ao/quantization/quantize.py index ae080ccaa2ca..d18f93987465 100644 --- a/torch/ao/quantization/quantize.py +++ b/torch/ao/quantization/quantize.py @@ -27,10 +27,10 @@ float_qparams_weight_only_qconfig_4bit, _activation_is_memoryless) from torch.nn.utils.parametrize import type_before_parametrizations -from torch.ao.quantization.observer import _is_activation_post_process __all__ = [ "get_default_custom_config_dict", + "is_activation_post_process", "propagate_qconfig_", "register_activation_post_process_hook", "add_observer_", @@ -62,6 +62,11 @@ def get_default_custom_config_dict(): """ return _DEFAULT_CUSTOM_CONFIG_DICT +def is_activation_post_process(module): + return (isinstance(module, torch.ao.quantization.ObserverBase) or + isinstance(module, torch.ao.quantization.FakeQuantizeBase)) + + def _propagate_qconfig_helper(module, qconfig_dict, qconfig_parent=None, prefix='', prepare_custom_config_dict=None): r"""This is a helper function for `propagate_qconfig_` @@ -319,7 +324,7 @@ def _remove_activation_post_process(module): # TODO: maybe we should change activation_post_process to _activation_post_process # to prevent it from being used by user if hasattr(module, 'activation_post_process') and \ - _is_activation_post_process(module.activation_post_process): + is_activation_post_process(module.activation_post_process): delattr(module, 'activation_post_process') # remove activation_post_proceess pre and post hooks diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index 24d7049ec50e..d9fcf1d04d8b 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -17,7 +17,7 @@ from torch.ao.quantization.quantize import convert from torch.ao.quantization.quantize import get_observer_dict from torch.ao.quantization.quantize import get_unique_devices_ -from torch.ao.quantization.quantize import _is_activation_post_process +from torch.ao.quantization.quantize import is_activation_post_process from torch.ao.quantization.quantize import prepare from torch.ao.quantization.quantize import prepare_qat from torch.ao.quantization.quantize import propagate_qconfig_ From e4d9dbd7d236e86fac0055feb7dd8f64516d375e Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 21 Nov 2022 17:25:28 +0000 Subject: [PATCH 400/453] Port torchdynamo's torchbench script to userbenchmark (#89239) Summary: This Diff ports the torchbench.py script from torchdynamo to torchbench to support the development of internal models. Currently, only works with the `--only` option, and can only test one model at a time. Note that the noisy logs are from upstream model code, not the benchmark code. In the internal environment, `torch._dynamo.config.base_dir` is not writable, so we add an option to specify the output directory. Test Plan: ``` $ buck2 run mode/opt //caffe2/benchmarks/dynamo:torchbench -- --performance --only ads_dhen_5x --part over --output-directory /tmp/tb-test/ cuda eval ads_dhen_5x 1/ 1 +0 frames 2s 1 graphs 1 graph calls 412/ 411 = 100% ops 100% time ``` ``` $ buck2 run mode/opt //caffe2/benchmarks/dynamo:torchbench -- --performance --only cmf_10x --part over --output-directory /tmp/tb-test/ cuda eval cmf_10x 1/ 1 +0 frames 1s 1 graphs 1 graph calls 306/ 305 = 100% ops 100% time ``` Reviewed By: jansel Differential Revision: D41294311 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89239 Approved by: https://github.com/jansel --- benchmarks/dynamo/common.py | 39 ++++++++++++++++++++++++++++----- benchmarks/dynamo/torchbench.py | 21 +++++++++++++++--- 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index f4d1bfad37d7..8731d545c456 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1505,6 +1505,15 @@ def get_example_inputs(self): "--output", help="Overrides the output filename", ) + parser.add_argument( + "--output-directory", + help="Overrides the directory to place output files.", + ) + parser.add_argument( + "--part", + default=None, + help="Specify the part of the model to run.", + ) parser.add_argument( "--export-profiler-trace", action="store_true", @@ -1918,7 +1927,12 @@ def run(runner, args, original_dir=None): output_filename = args.output if output_filename: - output_filename = os.path.join(torch._dynamo.config.base_dir, output_filename) + if args.output_directory: + output_filename = os.path.join(args.output_directory, output_filename) + else: + output_filename = os.path.join( + torch._dynamo.config.base_dir, output_filename + ) if args.find_batch_sizes and args.only: for device in args.devices: @@ -1955,11 +1969,24 @@ def run(runner, args, original_dir=None): example_inputs = tree_map(lambda x: x.to(device=device), example_inputs) else: try: - device, name, model, example_inputs, batch_size = runner.load_model( - device, - model_name, - batch_size=batch_size, - ) + if args.part: + ( + device, + name, + model, + example_inputs, + batch_size, + ) = runner.load_model( + device, model_name, batch_size=batch_size, part=args.part + ) + else: + ( + device, + name, + model, + example_inputs, + batch_size, + ) = runner.load_model(device, model_name, batch_size=batch_size) except NotImplementedError as e: print(e) import traceback diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index b7d4a3be7933..24a049f14ba2 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -227,12 +227,16 @@ def load_model( device, model_name, batch_size=None, + part=None, ): is_training = self.args.training use_eval_mode = self.args.use_eval_mode dynamic_shapes = self.args.dynamic_shapes - module = importlib.import_module(f"torchbenchmark.models.{model_name}") + try: + module = importlib.import_module(f"torchbenchmark.models.{model_name}") + except ModuleNotFoundError: + module = importlib.import_module(f"torchbenchmark.models.fb.{model_name}") benchmark_cls = getattr(module, "Model", None) if not hasattr(benchmark_cls, "name"): benchmark_cls.name = model_name @@ -248,13 +252,24 @@ def load_model( # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" torch.backends.__allow_nonbracketed_mutation_flag = True + extra_args = [] + if part: + extra_args = ["--part", part] if is_training: benchmark = benchmark_cls( - test="train", device=device, jit=False, batch_size=batch_size + test="train", + device=device, + jit=False, + batch_size=batch_size, + extra_args=extra_args, ) else: benchmark = benchmark_cls( - test="eval", device=device, jit=False, batch_size=batch_size + test="eval", + device=device, + jit=False, + batch_size=batch_size, + extra_args=extra_args, ) if dynamic_shapes: if not hasattr(benchmark, "get_dynamic_shapes_module"): From cf9476554fce9a9c909eebd7439f4b3f4d208f6c Mon Sep 17 00:00:00 2001 From: Taylor Robie Date: Mon, 21 Nov 2022 09:23:16 -0800 Subject: [PATCH 401/453] update kineto pinned commit (#89435) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89435 Approved by: https://github.com/malfet --- third_party/kineto | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/kineto b/third_party/kineto index 0703c7899906..6c1629809068 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 0703c78999061b8329dfab7ec5046fc5764a5573 +Subproject commit 6c1629809068efd78a8d56b4aa479c7ec49ae562 From 1d9e1fca97a2a01ea75b0938e38feee1d5288ebd Mon Sep 17 00:00:00 2001 From: Driss Guessous Date: Mon, 21 Nov 2022 20:02:09 +0000 Subject: [PATCH 402/453] Update sdp dispatch logic to enable fused backward (#89154) # Summary Reorganizes how the sdp dispatch logic is down in order to enable backwards for fused kernels Pull Request resolved: https://github.com/pytorch/pytorch/pull/89154 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 52 ++--- .../cuda/NestedTensorTransformerFunctions.cpp | 100 ++++++--- .../ATen/native/transformers/attention.cpp | 65 ++++-- .../native/transformers/cuda/attention.cu | 46 ++--- .../transformers/cuda/attention_backward.cu | 40 +++- .../transformers/cuda/flash_attn/fmha_api.cpp | 7 +- .../transformers/cuda/flash_attn/fmha_api.h | 2 +- .../ATen/native/transformers/cuda/sdp_utils.h | 34 +++- benchmarks/transformer/sdp_backwards.py | 189 ++++++++++++++++++ .../check_forward_backward_compatibility.py | 3 + test/functorch/test_ops.py | 8 +- test/test_meta.py | 1 - test/test_transformers.py | 76 +++++-- tools/autograd/derivatives.yaml | 6 +- .../_internal/common_methods_invocations.py | 5 + 15 files changed, 498 insertions(+), 136 deletions(-) create mode 100644 benchmarks/transformer/sdp_backwards.py diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index f625c9faff41..8c759cd09c48 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -13252,18 +13252,39 @@ CPU, NestedTensorCPU, Meta: _fused_sdp_choice_cpp CUDA, NestedTensorCUDA: _fused_sdp_choice_cuda -# Register the math kernel for cpu -- func: _scaled_dot_product_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) +- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) variants: function + +- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool return_softmax=False, bool is_causal=False) -> (Tensor, Tensor, Tensor) dispatch: - CUDA: _scaled_dot_product_attention_forward_cuda - CPU: _scaled_dot_product_attention_forward_math - NestedTensorCUDA: _scaled_dot_product_attention_forward_nested - NestedTensorCPU: _scaled_dot_product_attention_forward_math - Meta: _scaled_dot_product_attention_forward_math + CUDA: _scaled_dot_product_flash_attention_cuda + NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda -- func: _scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool need_attn_weights=False, bool is_causal=False) -> (Tensor, Tensor) +- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_cuda + NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda + +- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: _scaled_dot_product_efficient_attention_backward_cuda + +# Returns ouput, softmax_logsumexp, softmax +- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, bool return_softmax, float dropout_p, bool is_causal) -> (Tensor, Tensor, Tensor) variants: function + dispatch: + CUDA: _flash_attention_forward + +# Returns ouput, logsumexp if compute_logsumexp +- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_forward + +- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False) -> (Tensor, Tensor, Tensor) + variants: function + dispatch: + CUDA: _efficient_attention_backward - func: _triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor variants: function @@ -13290,21 +13311,6 @@ structured: True variants: function -- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal) -> Tensor - variants: function - dispatch: - CUDA: flash_scaled_dot_product_attention - -- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_forward - -- func: _efficient_attention_backward(Tensor grad, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor out, bool is_causal=False) -> (Tensor, Tensor, Tensor) - variants: function - dispatch: - CUDA: _efficient_attention_backward - - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) variants: function dispatch: diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index c2bf4e08ce04..9c72454560d3 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -214,26 +214,6 @@ Tensor NestedTensor_to_padded_tensor_cuda( return NestedTensor_to_padded_tensor_generic(t, padding, output_size); } -std::tuple _scaled_dot_product_attention_forward_nested( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - // TODO: enable flash attention kernel - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_efficient_helper_nested_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "Unsupported backend for scaled_dot_product_attention"); - return std::make_tuple(Tensor(), Tensor()); - } -} namespace{ /** @@ -340,19 +320,80 @@ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { } } // namespace -std::tuple mem_efficient_helper_nested_unpacked( + +std::tuple _scaled_dot_product_flash_attention_nestedtensor_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { + TORCH_CHECK(false, "There are currently cuda memory errors being returned from this path.") // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) const int64_t num_heads = query.size(1); const int64_t head_dim = query.size(3); + // Query -> Query (Batch x {Q_seq_len} x Num_heads x Dim_per_head) + // Key -> Key (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + // Value -> Value (Batch x {KV_seq_len} x Num_heads x Dim_per_head) + Tensor q_t = query.transpose(1, 2).contiguous(); + Tensor k_t = key.transpose(1, 2).contiguous(); + Tensor v_t = value.transpose(1, 2).contiguous(); + + // K and V have to have the same Nnz, should probably torch_check + // assume in order to not iterate over v + + auto cumulative_and_max_q = cumulative_and_max_seq_len(q_t); + auto cumulative_and_max_k = cumulative_and_max_seq_len(k_t); + + Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q); + Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k); + + const int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q); + const int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k); + + const int64_t Nnz_q = cumulative_sequence_length_q[-1].item(); + const int64_t Nnz_kv = cumulative_sequence_length_k[-1].item(); + + auto query_buffer_reshaped = + get_buffer(q_t).view({Nnz_q, num_heads, head_dim}); + auto key_buffer_reshaped = + get_buffer(k_t).view({Nnz_kv, num_heads, head_dim}); + auto value_buffer_reshaped = + get_buffer(v_t).view({Nnz_kv, num_heads, head_dim}); + + auto attention_and_lse_and_softmax = + at::_flash_attention_forward( + query_buffer_reshaped, + key_buffer_reshaped, + value_buffer_reshaped, + cumulative_sequence_length_q, + cumulative_sequence_length_k, + max_seqlen_batch_q, + max_seqlen_batch_k, + return_softmax, + dropout_p, + is_causal); + // Reshape output to convert nnz to batch_size and seq_len + Tensor attention = std::get<0>(attention_and_lse_and_softmax); + attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()).transpose(1,2); + return std::tie(attention, std::get<1>(attention_and_lse_and_softmax), std::get<2>(attention_and_lse_and_softmax)); +} + +std::tuple _scaled_dot_product_efficient_attention_nestedtensor_cuda( + const Tensor& query, + const Tensor& key, + const Tensor& value, + bool compute_log_sumexp, + bool is_causal) { + // Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head) + // Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + // Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head) + const int64_t num_heads = query.size(1); + const int64_t head_dim = query.size(3); + Tensor q_t = query.transpose(1, 2); Tensor k_t = key.transpose(1, 2); Tensor v_t = value.transpose(1, 2); @@ -432,7 +473,7 @@ std::tuple mem_efficient_helper_nested_unpacked( {Nnz_kv, num_heads, head_dim}, {nnz_v_stride, head_v_stride, head_dim_stride}, value_impl->get_storage_offsets()[0]); - std::tuple attention_and_weights = + std::tuple attention_and_logsumexp= at::_efficient_attention_forward( query_buffer_reshaped.unsqueeze(0), key_buffer_reshaped.unsqueeze(0), @@ -440,14 +481,14 @@ std::tuple mem_efficient_helper_nested_unpacked( cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, - false, - false); + compute_log_sumexp, + is_causal); // Reshape output to convert nnz to batch_size and seq_len - Tensor attention = std::get<0>(attention_and_weights); + Tensor attention = std::get<0>(attention_and_logsumexp); attention = wrap_buffer(attention.view(-1), get_nested_size_tensor(q_t).clone()) .transpose(1, 2); - return std::tie(attention, std::get<1>(attention_and_weights)); + return std::tie(attention, std::get<1>(attention_and_logsumexp)); } Tensor flash_attention_helper( @@ -492,7 +533,7 @@ Tensor flash_attention_helper( // If we are passing in query, key, value all the same tensors then we have // packed them into one tensor and need to slice for flash attention Tensor attention = - at::_flash_scaled_dot_product_attention( + std::get<0>(at::_flash_attention_forward( q, k, v, @@ -500,8 +541,9 @@ Tensor flash_attention_helper( cumulative_sequence_length_q, max_seqlen_batch_q, max_seqlen_batch_q, + false /*return_softmax*/, dropout_p, - is_causal); + is_causal)); // Output of flash_attention is a regular tensor lets wrap it back up to // form a nested tensor diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 89a0e4691018..9c5be12ef24d 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -678,20 +678,6 @@ std::tuple native_decoder_only_multi_head_attent // L: Target sequence length // E: Embedding dimension std::tuple _scaled_dot_product_attention( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - if (query_.requires_grad() || key.requires_grad() || value.requires_grad()){ - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - } - return at::_scaled_dot_product_attention_forward(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); -} - -int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ - return static_cast(sdp::SDPBackend::math); -} - -std::tuple _scaled_dot_product_attention_forward_math( const Tensor& query_, const Tensor& key, const Tensor& value, @@ -699,14 +685,49 @@ std::tuple _scaled_dot_product_attention_forward_math( double dropout_p, bool need_attn_weights, bool is_causal) { - return at::_scaled_dot_product_attention_math( - query_, - key, - value, - attn_mask_, - dropout_p, - need_attn_weights, - is_causal); + // TODO: The second return is the attention weights if the math kernel is + // used. The fused kernels do not return this Tensor so for the fused kernels + // The second return SHOULD always be an empty Tensor, unless need_attn_weights + // is true (in which case the fused kernels would not be called). This blows up + // op_info tests. + int64_t choice_int = at::_fused_sdp_choice( + query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); + sdp::SDPBackend backend = static_cast(choice_int); + switch (backend) { + case sdp::SDPBackend::flash_attention: { + auto out_lse_softmax = at::_scaled_dot_product_flash_attention( + query_, key, value, dropout_p, need_attn_weights, is_causal); + return std::make_tuple( + std::move(std::get<0>(out_lse_softmax)), + std::move(std::get<2>(out_lse_softmax))); + } + case sdp::SDPBackend::efficient_attention: { + bool compute_logsumexp = + (query_.requires_grad() || key.requires_grad() || + value.requires_grad()); + return at::_scaled_dot_product_efficient_attention( + query_, key, value, compute_logsumexp, is_causal); + } + case sdp::SDPBackend::math: + return at::_scaled_dot_product_attention_math( + query_, + key, + value, + attn_mask_, + dropout_p, + need_attn_weights, + is_causal); + default: + TORCH_CHECK( + false, + "No viable backend for scaled_dot_product_attention was found."); + return std::make_tuple(Tensor(), Tensor()); + } +} + +int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, + const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal){ + return static_cast(sdp::SDPBackend::math); } std::tuple _scaled_dot_product_attention_math( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 602cf319f74a..8dcb99b3380d 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -678,12 +678,12 @@ std::tuple native_multi_head_attention_cuda( return std::make_tuple(std::move(proj), std::move(qkt)); } -std::tuple flash_attention_helper_dense_unpacked( +std::tuple _scaled_dot_product_flash_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool return_softmax, bool is_causal) { // Query (Batch x Num_heads x Q_seq_len x Dim_per_head) // Key (Batch x Num_heads x KV_seq_len x Dim_per_head) @@ -726,8 +726,9 @@ std::tuple flash_attention_helper_dense_unpacked( Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim}); Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim}); - Tensor attention = - at::_flash_scaled_dot_product_attention( + Tensor attention, log_sumexp, softmax; + std::tie(attention, log_sumexp, softmax) = + at::_flash_attention_forward( query_reshaped, key_reshaped, value_reshaped, @@ -735,15 +736,17 @@ std::tuple flash_attention_helper_dense_unpacked( cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, + return_softmax, dropout_p, is_causal); // Reshape output to convert nnz to batch_size and seq_len attention = attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2); - return std::tuple(attention, Tensor()); + return std::make_tuple(attention, log_sumexp, softmax); } -std::tuple mem_eff_helper( + +std::tuple _scaled_dot_product_efficient_attention_cuda( const Tensor& query, const Tensor& key, const Tensor& value, @@ -767,26 +770,7 @@ std::tuple mem_eff_helper( compute_log_sumexp, is_causal); attention = attention.transpose(1,2); - return std::make_tuple(std::move(attention), Tensor()); -} - -std::tuple _scaled_dot_product_attention_forward_cuda( - const Tensor& query_, const Tensor& key, const Tensor& value, - const c10::optional& attn_mask_, double dropout_p, bool need_attn_weights, bool is_causal) { - // Determine which efficient kernel to use - sdp::sdp_params kernel_params{query_, key, value, attn_mask_.has_value(), dropout_p, need_attn_weights, is_causal}; - auto backend = select_sdp_backend(kernel_params); - switch(backend){ - case sdp::SDPBackend::flash_attention: - return flash_attention_helper_dense_unpacked(query_, key, value, dropout_p, need_attn_weights, is_causal); - case sdp::SDPBackend::efficient_attention: - return mem_eff_helper(query_, key , value, need_attn_weights, is_causal); - case sdp::SDPBackend::math: - return at::_scaled_dot_product_attention_math(query_, key, value, attn_mask_, dropout_p, need_attn_weights, is_causal); - default: - TORCH_CHECK(false, "No viable backend for scaled_dot_product_attention was found."); - return std::make_tuple(Tensor(), Tensor()); - } + return std::make_tuple(std::move(attention), std::move(log_sumexp)); } int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, @@ -802,7 +786,7 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te return static_cast(backend); } -Tensor flash_scaled_dot_product_attention( +std::tuple _flash_attention_forward( const Tensor& query, const Tensor& key, const Tensor& value, @@ -810,11 +794,12 @@ Tensor flash_scaled_dot_product_attention( const Tensor& cumulative_sequence_length_k, const int64_t max_seqlen_batch_q, const int64_t max_seqlen_batch_k, + bool return_softmax, double dropout_p, bool is_causal) { #if defined(USE_FLASH_ATTENTION) auto softmax_scale = std::pow(query.size(-1), -0.5); - std::vector output = fmha::mha_fwd( + return fmha::mha_fwd( query, key, value, @@ -826,12 +811,11 @@ Tensor flash_scaled_dot_product_attention( softmax_scale, false, is_causal, - false, + return_softmax, c10::nullopt); - return output[0]; #endif TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.") - return Tensor(); + return std::make_tuple(Tensor(), Tensor(), Tensor()); } std::tuple _efficient_attention_forward( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index af005b2669b2..a063aacb901e 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -10,6 +10,7 @@ #include #include +#include #ifdef USE_FLASH_ATTENTION #include #endif @@ -73,14 +74,14 @@ std::tuple _efficient_attention_backward( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, - const at::Tensor& logsumexp, const at::Tensor& out, + const at::Tensor& logsumexp, bool causal) { #if defined(USE_FLASH_ATTENTION) if (!grad_out_.defined()) { return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } - // ndim + // ndim TORCH_CHECK(query.dim() == grad_out_.dim()); TORCH_CHECK(query.dim() == key.dim()); TORCH_CHECK(query.dim() == value.dim()); @@ -128,6 +129,7 @@ std::tuple _efficient_attention_backward( // initialized bool grad_kv_needs_init = causal && N > M; at::Tensor grad_q, grad_k, grad_v; + int8_t gQKV_strideM_multiplier = 1; if (!grad_kv_needs_init && query.size(1) == key.size(1) && query.size(3) == value.size(3) && query.storage().is_alias_of(key.storage()) && @@ -141,10 +143,13 @@ std::tuple _efficient_attention_backward( grad_q = chunk.select(2, 0); grad_k = chunk.select(2, 1); grad_v = chunk.select(2, 2); + gQKV_strideM_multiplier=3; } else { - grad_q = at::empty_like(query); - grad_k = grad_kv_needs_init ? at::zeros_like(key) : at::empty_like(key); - grad_v = grad_kv_needs_init ? at::zeros_like(value) : at::empty_like(value); + grad_q = at::empty(query.sizes(), query.options()); + grad_k = grad_kv_needs_init ? at::zeros(key.sizes(), key.options()) + : at::empty(key.sizes(), key.options()); + grad_v = grad_kv_needs_init ? at::zeros(value.sizes(), value.options()) + : at::empty(value.sizes(), value.options()); } auto launchKernel = [&](auto _k, int computeCapability) { @@ -198,7 +203,7 @@ std::tuple _efficient_attention_backward( ASSIGN_CHECK_OVERFLOW(p.gQ_strideH, grad_q.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gK_strideH, grad_k.stride(2)); ASSIGN_CHECK_OVERFLOW(p.gV_strideH, grad_v.stride(2)); - p.gQKV_strideM_multiplier = grad_q.is_contiguous() ? 1 : 3; + p.gQKV_strideM_multiplier = gQKV_strideM_multiplier; TORCH_INTERNAL_ASSERT(p.gQ_strideM() == grad_q.stride(1)); TORCH_INTERNAL_ASSERT(p.gK_strideM() == grad_k.stride(1)); TORCH_INTERNAL_ASSERT(p.gV_strideM() == grad_v.stride(1)); @@ -257,5 +262,28 @@ std::tuple _efficient_attention_backward( return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); } + +std::tuple _scaled_dot_product_efficient_attention_backward_cuda( + const at::Tensor& grad_out_, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& out, + const at::Tensor& logsumexp, + bool causal){ + if (!grad_out_.defined()) { + return std::make_tuple(Tensor{}, Tensor{}, Tensor{}); + } + auto grad_out = grad_out_.transpose(1, 2); + auto out_t = out.transpose(1, 2); + auto q_t = query.transpose(1, 2); + auto k_t = key.transpose(1, 2); + auto v_t = value.transpose(1, 2); + + Tensor grad_q, grad_k, grad_v; + std::tie(grad_q, grad_k, grad_v) = at::_efficient_attention_backward(grad_out, q_t, k_t, v_t, out_t, logsumexp, causal); + return std::make_tuple(grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2)); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp index aaf7d833fe83..7cc0c250664e 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp @@ -26,6 +26,7 @@ * ******************************************************************************/ +#include #ifdef USE_FLASH_ATTENTION #include #include @@ -115,7 +116,7 @@ void set_params_fprop(FMHA_fprop_params ¶ms, params.is_causal = is_causal; } -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i @@ -241,9 +242,7 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q run_fmha_fprop(launch_params, /*configure=*/false); - std::vector result = {o, softmax_lse}; - if (return_softmax) {result.push_back(s);} - return result; + return std::make_tuple(o, softmax_lse, s); } } // namespace fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h index 226d4ddd2b55..b0555463be04 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h @@ -7,7 +7,7 @@ namespace fmha { TORCH_API -std::vector +std::tuple mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 5d62a6cbd0dc..55e9aeb184a2 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -91,6 +91,31 @@ inline bool check_for_seq_len_1_nested_tensor(sdp_params params, bool debug) { return true; } +inline bool check_for_nested_inputs(sdp_params params, bool debug){ + if (params.query.is_nested() || params.key.is_nested() || params.value.is_nested()) { + TORCH_CHECK(!debug, "We are not enabling nested Tensors for Flash Attention because of cuda memory errors."); + return false; + } + return true; +} + +inline bool check_requires_grad(sdp_params params, bool debug) { + if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { + TORCH_CHECK(!debug, "Flash Attention does not currently support training."); + return false; + } + return true; +} + +inline bool check_requires_grad_and_nested(sdp_params params, bool debug) { + // If we fail both checks then we return false + if (!check_for_nested_inputs(params, false) && !check_requires_grad(params,false)){ + TORCH_CHECK(!debug, "Memory efficient attention currently doesn't support training with NT inputs."); + return false; + } + return true; +} + inline bool check_for_attn_mask(sdp_params params, bool debug) { if (params.has_attn_mask) { TORCH_CHECK(!debug, "Flash Attention does not support attention mask."); @@ -198,13 +223,15 @@ inline bool use_flash_attention(sdp_params params, bool debug) { return false; #endif // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints {{ + constexpr std::array constraints {{ check_runtime_disabled_flash, + check_requires_grad, check_tensor_shapes, check_for_attn_weights, check_for_attn_mask, check_head_dim_size, check_gpu_sm75_or_greater, + check_for_nested_inputs, check_for_seq_len_1_nested_tensor}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { @@ -232,14 +259,15 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { at::kHalf, at::kFloat, at::kBFloat16}; // Define gate functions that determine if a flash kernel can be ran - std::vector> constraints{ + constexpr std::array constraints{{ check_gpu_sm50_or_greater, check_runtime_disabled_mem_efficient, + check_requires_grad_and_nested, check_for_attn_weights, check_tensor_shapes, check_for_attn_mask, check_for_seq_len_1_nested_tensor, - check_for_non_zero_dropout}; + check_for_non_zero_dropout}}; for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/benchmarks/transformer/sdp_backwards.py b/benchmarks/transformer/sdp_backwards.py new file mode 100644 index 000000000000..2f745e157b28 --- /dev/null +++ b/benchmarks/transformer/sdp_backwards.py @@ -0,0 +1,189 @@ +import torch +import numpy as np +import random +import torch.utils.benchmark as benchmark +from torch.profiler import profile, record_function, ProfilerActivity + + +class CompositeMHA(torch.nn.Module): + def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): + super().__init__() + self.in_proj_weight = in_proj_weight + self.in_proj_bias = in_proj_bias + self.out_proj = out_proj + self.num_heads = num_heads + + def forward(self, query, key, value, mask): + if not (query is key and key is value): + raise NotImplementedError( + "query, key and value must be the same Tensor for now." + ) + if mask is not None: + raise NotImplementedError("mask is currently not supported.") + + query_projected = torch.nn.functional.linear( + query, self.in_proj_weight, self.in_proj_bias + ) + + batch_size = query_projected.size(0) + embed_dim = query_projected.size(2) + head_dim = embed_dim // (self.num_heads * 3) + + query, key, value = query_projected.chunk(3, -1) + + query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + attn, _ = torch.nn.functional._scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + need_attn_weights=False, + is_causal=False, + ) + + attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim) + # Match return signature of nn.MHA + return self.out_proj(attn) + + +def build_composite_mha_from_nn_mha(pt): + assert pt._qkv_same_embed_dim + in_proj_weight = pt.in_proj_weight + assert in_proj_weight is not None + assert pt.batch_first + return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj) + + +def forw_back(model, input, upward): + output = model(*input) + output.backward(upward) + + +# Context manger not working in timer + + +def forw_back_fused(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + output = model(*input) + output.backward(upward) + + +def forw_back_eager(model, input, upward): + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + output = model(*input) + output.backward(upward) + + +def run_timing( + min_run_time, batch_size, embed_dimension, num_heads, max_sequence_len, dtype +): + dropout_p = 0.0 + mask = None + + pt = torch.nn.MultiheadAttention( + embed_dim=embed_dimension, + num_heads=num_heads, + batch_first=True, + dropout=dropout_p, + ) + npt = pt.cuda().to(dtype) + cpt = build_composite_mha_from_nn_mha(npt) + x = torch.randn( + batch_size, + max_sequence_len, + embed_dimension, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + rand_fused_upward = cpt(x, x, x, mask).clone().detach() + + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + rand_eager_upward = cpt(x, x, x, mask).clone().detach() + + t0 = benchmark.Timer( + stmt="forw_back_fused(cpt, (x,x,x,mask), rand_fused_upward)", + globals={ + "forw_back_fused": forw_back_fused, + "cpt": cpt, + "x": x, + "rand_fused_upward": rand_fused_upward, + "mask": mask, + }, + label=f"Fused SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + t1 = benchmark.Timer( + stmt="forw_back_eager(cpt, (x,x,x,mask), rand_eager_upward)", + globals={ + "forw_back_eager": forw_back_eager, + "cpt": cpt, + "x": x, + "rand_eager_upward": rand_eager_upward, + "mask": mask, + }, + label=f"Eager SDP forward and backward batch_size={batch_size} max_sequence_len={max_sequence_len} " + f"num_heads={num_heads} embed_dimension={embed_dimension} dtype={dtype}", + num_threads=torch.get_num_threads(), + ) + + m0 = t0.blocked_autorange(min_run_time=min_run_time) + m1 = t1.blocked_autorange(min_run_time=min_run_time) + + print(m0) + print(m1) + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print("Profile for Fused".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_mem_efficient=True): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_fused_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + print("Profile for eager".center(200, "-")) + with torch.backends.cuda.sdp_kernel(enable_math=True, enable_mem_efficient=False): + with profile( + activities=activities, record_shapes=False, with_stack=True + ) as prof: + with record_function("Fused SDP forward and backward"): + for _ in range(20): + forw_back(cpt, (x, x, x, mask), rand_eager_upward) + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20)) + + +def main(): + seed = 123 + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + min_run_time = 10 + batch_size = 64 + num_heads = 32 + max_seq_len = 256 + embed_dim = 1024 + dtype = torch.bfloat16 + + print( + f"Running timing for batch_size={batch_size} max_sequence_len={max_seq_len} " + f"num_heads={num_heads} embed_dimension={embed_dim} dtype={dtype}" + ) + run_timing(min_run_time, batch_size, embed_dim, num_heads, max_seq_len, dtype) + + +if __name__ == "__main__": + main() diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 90080ab0934f..853f5206969b 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -317,6 +317,9 @@ ("aten::_upsample_nearest_exact1d_backward", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d", datetime.date(2022, 12, 15)), ("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)), + ("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 12, 15)), + ("aten::_scaled_dot_product_attention_forward", datetime.date(2022, 12, 15)), + ("aten::_efficient_attention_backward", datetime.date(2022, 12, 15)), ("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)), ] diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 5e3aa1ff898f..e9451b596b4a 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -401,6 +401,7 @@ def wrapped_fn(*args, **kwargs): skip('nn.functional.max_unpool2d'), # fails everywhere except on windows skip('nn.functional.max_unpool3d'), # fails everywhere except on mac xfail("native_batch_norm"), + xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented @@ -555,6 +556,7 @@ def f(inp, *args, **kwargs): xfail('nn.functional.ctc_loss'), # Not Implemented xfail('native_layer_norm', ''), # Expected a proper Tensor but got None for argument #1 'other' xfail('sparse.sampled_addmm', ''), # sparse tensors have no strides + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), # AssertionError: Tensor-likes are not close! # Mismatched elements: 1 / 15 (6.7%) # Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed) @@ -649,7 +651,7 @@ def fn(inp, *args, **kwargs): skip("nn.functional.feature_alpha_dropout", "with_train"), # calls random op skip("nn.functional.fractional_max_pool2d"), # calls random op skip("nn.functional.fractional_max_pool3d"), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness # It looks like you're either (1) calling .item() on a Tensor or # (2) attempting to use a Tensor in some data-dependent control flow or # (3) encountering this error in PyTorch internals. @@ -1126,6 +1128,7 @@ def test(): skip('nn.functional.rrelu'), # randomness skip('nn.functional.feature_alpha_dropout', 'with_train'), # randomness skip('nn.functional.feature_alpha_dropout', 'without_train'), # randomness + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), skip('nn.functional.alpha_dropout'), # randomness skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format skip('to_sparse', ''), # non-dense output @@ -1249,6 +1252,7 @@ def get_vjp(cotangents, *primals): xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward + skip('nn.functional._scaled_dot_product_attention', device_type='cuda'), xfail('nn.functional.multi_margin_loss', ''), # NYI: forward AD with multi_margin_loss skip('linalg.householder_product', '', device_type='cuda'), # flaky, I'm not sure why xfail('sparse.sampled_addmm', ''), # Sparse tensors have no strides @@ -1369,7 +1373,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail('nn.functional.dropout2d'), # calls random op xfail('nn.functional.dropout3d'), # calls random op xfail('nn.functional.dropout'), # calls random op - skip('nn.functional._scaled_dot_product_attention'), # randomness + xfail('nn.functional._scaled_dot_product_attention'), # randomness xfail('nn.functional.embedding_bag'), # Forward AD not implemented and no decomposition xfail('nn.functional.alpha_dropout'), # calls randomn op xfail('nn.functional.feature_alpha_dropout', 'with_train'), # calls random op diff --git a/test/test_meta.py b/test/test_meta.py index 6d21d5c7bd75..0e3cfb6ef140 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -294,7 +294,6 @@ def test_tensor_outlives_converter(self): aten._fft_c2r.default, aten._fft_r2c.default, aten._linalg_svd.default, - aten._scaled_dot_product_attention_forward.default, aten.binary_cross_entropy.default, aten.complex.default, aten.copysign.Tensor, diff --git a/test/test_transformers.py b/test/test_transformers.py index abb4c71ec19a..0260c822498d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1059,6 +1059,11 @@ def rand_tensor(shape): if fused_kernel == "flash": with sdp_kernel(enable_mem_efficient=False, enable_math=False): + # TODO Flash for the nested path is currently not working due to cuda memory issues + if type == "nested": + self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False)) + return actual = torch.nn.functional._scaled_dot_product_attention( query_lp, key_lp, value_lp, attn_mask=None, dropout_p=0.0, need_attn_weights=False, is_causal=False) elif fused_kernel == "mem_efficient": @@ -1097,28 +1102,73 @@ def rand_tensor(shape): @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") @parametrize("contiguous_inputs", [True, False]) - def test_efficient_attention_gradcheck(self, contiguous_inputs: bool): + def test_sdp_math_gradcheck(self, contiguous_inputs: bool): - batch_size, seq_len, num_heads, head_dim = 8, 8, 4, 64 - rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float16, requires_grad=True, packed=True) + batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) query, key, value = qkv.chunk(3, dim=-1) - query = query.view(batch_size, -1, num_heads, head_dim) - key = key.view(batch_size, -1, num_heads, head_dim) - value = value.view(batch_size, -1, num_heads, head_dim) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + if contiguous_inputs: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + assert gradcheck(lambda *args, **kwargs: + wrapper_set_seed(torch.nn.functional._scaled_dot_product_attention, *args, **kwargs), + (query, key, value, None, 0.0, False, False) + ) + + @unittest.skipIf(not TEST_CUDA or TEST_WITH_ROCM or IS_WINDOWS, "Flash Attention was not built for this system") + @parametrize("contiguous_inputs", [True, False]) + def test_sdp_fused_grad_against_math(self, contiguous_inputs: bool): + batch_size, seq_len, num_heads, head_dim = 4, 4, 2, 16 + rand_tensor = partial(self.rand_tensor, device="cuda", dtype=torch.float64, requires_grad=True, packed=True) + + qkv = rand_tensor((batch_size, seq_len, num_heads, head_dim)) + qkv_lp = qkv.detach().clone().to(torch.float32).requires_grad_() + + query, key, value = qkv.chunk(3, dim=-1) + query_lp, key_lp, value_lp = qkv_lp.chunk(3, dim=-1) + + query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + + query_lp = query_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + key_lp = key_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + value_lp = value_lp.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) if contiguous_inputs: query = query.contiguous() key = key.contiguous() value = value.contiguous() - # Normally we would transpose the inputs but the fused kernels expect - # (batch, seq_len, num_heads, head_dim) bump the tolerance since we can only run kernel - # in fp32 - assert gradcheck(lambda *args, **kwargs: - wrapper_set_seed(torch.ops.aten._efficient_attention_forward, *args, **kwargs), - (query, key, value, None, None, None, True, False), fast_mode=True, atol=8e-5, rtol=1e-3) + query_lp = query_lp.contiguous() + key_lp = key_lp.contiguous() + value_lp = value_lp.contiguous() + + with sdp_kernel(enable_math=True, enable_mem_efficient=False, enable_flash=False): + out, atten = torch.nn.functional._scaled_dot_product_attention(query, key, value, None, 0.0, False, False) + + with sdp_kernel(enable_math=False, enable_mem_efficient=True, enable_flash=False): + out_lp, atten_lp = torch.nn.functional._scaled_dot_product_attention( + query_lp, key_lp, value_lp, None, 0.0, False, False) + + rand_upward = torch.rand_like(out) + rand_upward_lp = rand_upward.to(torch.float32) + + out.backward(rand_upward) + out_lp.backward(rand_upward_lp) + + # Cast up and compare + self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5) @parametrize("type", ["dense", "nested"]) def test_fused_sdp_choice(self, type: str): @@ -1144,7 +1194,7 @@ def test_fused_sdp_choice(self, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) - if SM80OrLater: + if SM80OrLater and not type == "nested": assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION else: assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index a0892b32a835..52c0f76bf070 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2613,9 +2613,13 @@ nested_strides: non_differentiable # Transformers +- name: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False) -> (Tensor, Tensor) + output_differentiability: [True, False] + query, key, value: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, result0, result1, is_causal) + - name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, bool compute_log_sumexp=False, bool causal=False) -> (Tensor, Tensor) output_differentiability: [True, False] - query, key, value: _efficient_attention_backward(grad, query, key, value, result1, result0, causal) + query, key, value: _efficient_attention_backward(grad, query, key, value, result0, result1, causal) # fft - name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 0f845f765829..998f1cde65f7 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -12009,16 +12009,21 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # This is only failing on Linux Bionic 3.10 Cuda 11.6 DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', device_type='cuda', active_if=_get_torch_cuda_version() >= (11, 6)), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples', + device_type='cuda', dtypes=(torch.float32,)), # AssertionError: JIT Test does not execute any logic DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'), # Doesn't support autocasting DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensorNonErroring', 'test_fake_autocast', device_type='cpu'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_autocast'), + # Forward works for dtype=float64 which is the math path + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), # No meta function DecorateInfo(unittest.skip("Skipped!"), 'TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive'), DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'), DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake'), + DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', device_type='cuda'), DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),), ), UnaryUfuncInfo( From 3d247a8bcd6f07ffae8c7144ac08ba9fdeeb2025 Mon Sep 17 00:00:00 2001 From: Keval Morabia Date: Mon, 21 Nov 2022 20:40:04 +0000 Subject: [PATCH 403/453] Fix unconvertible_ops as per #89261 (#89299) Fixes #89261 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89299 Approved by: https://github.com/justinchuby, https://github.com/BowenBao --- test/onnx/test_utility_funs.py | 12 ++++++++++++ torch/onnx/utils.py | 4 +++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 51adaef317af..5d1cdc5e8ea5 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -124,6 +124,18 @@ def test_it_returns_empty_list_when_all_ops_convertible( _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=12) self.assertEqual(unconvertible_ops, []) + def test_it_returns_empty_list_when_model_contains_supported_inplace_ops(self): + class SkipConnectionModule(torch.nn.Module): + def forward(self, x): + out = x + out += x + out = torch.nn.functional.relu(out, inplace=True) + + module = SkipConnectionModule() + x = torch.randn(4, 4) + _, unconvertible_ops = utils.unconvertible_ops(module, (x,), opset_version=13) + self.assertEqual(unconvertible_ops, []) + @parameterized.parameterized_class( [ diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 67dd719bae9f..36d7fdb75762 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -1333,7 +1333,9 @@ def unconvertible_ops( # eliminated in the conversion passes. Users may still see errors caused # by prim ops even though they don't show up in the list. continue - if not registration.registry.is_registered_op(domain_op, opset_version): + if not registration.registry.is_registered_op( + domain_op.rstrip("_"), opset_version + ): # We consider all registered ops supported, even though some of them are # only partially supported, because there is not yet a good way to check # if an op is fully supported. From 1267dcf2971b181d7379928f3452ce622add91e9 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 20 Nov 2022 23:19:24 +0000 Subject: [PATCH 404/453] [inductor] Fix nan handling for aten.sign (#88937) ATen gives `sign(nan) == 0` but inductor's cuda codegen would give `sign(nan) == 1`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88937 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 8 ++++++++ torch/_inductor/codegen/common.py | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 399032890ca8..ec024c67b81c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13,6 +13,8 @@ from typing import Any, Callable from unittest.mock import patch +import numpy as np + import torch import torch._dynamo @@ -668,6 +670,12 @@ def fn(a): self.common(fn, [torch.linspace(-10, 10, 41)]) + def test_sgn_extremal(self): + def fn(a): + return (torch.sgn(a),) + + self.common(fn, [torch.tensor([np.nan, np.inf, -np.inf, 0])]) + def test_max_min(self): def fn(a, b): return (torch.maximum(a, b), torch.minimum(a, b)) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index cf98833964ca..da64f3e63584 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -90,7 +90,9 @@ def square(x): @staticmethod def sign(x): - return ops.where(f"{x} == 0", "0", ops.where(f"{x} < 0", "-1", "1")) + left = ops.where(ops.lt("0", x), "1", "0") + right = ops.where(ops.lt(x, "0"), "1", "0") + return ops.sub(left, right) @staticmethod def bitwise_not(x): From c068fa900f1352240a04123a74d4d1f83b295222 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Sun, 20 Nov 2022 23:36:41 +0000 Subject: [PATCH 405/453] [inductor] Misc division lowering fixes (#88603) 1. `aten.div.Tensor_mode` should allow broadcasting 2. `div` can use `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` 3. `prims.div` on integers should be truncating division 4. Add lowering for `true_divide` which is aliased to `div` 5. register lowering for inplace version of `div_mode` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88603 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 59 +++++++++++++++++++++++++++++ torch/_inductor/lowering.py | 46 ++++++++++------------ 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ec024c67b81c..2196f4f8a026 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -23,6 +23,7 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.passes.shape_prop import ShapeProp from torch.nn import functional as F +from torch.testing import make_tensor from torch.testing._internal.common_utils import ( TEST_WITH_ASAN, TEST_WITH_ROCM, @@ -1166,6 +1167,45 @@ def fn(a, b): self.common(fn, (1024, 100)) + def test_div_zero_dim(self): + def fn(a, b): + return ( + aten.div(a, b, rounding_mode=None), + aten.div(a, b, rounding_mode="floor"), + aten.div(a, b, rounding_mode="trunc"), + a / b, + a // b, + ) + + for dtype in (torch.float32, torch.int64): + self.common( + fn, + ( + make_tensor(10, device="cpu", dtype=dtype), + make_tensor((), device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + self.common( + fn, + ( + make_tensor((), device="cpu", dtype=dtype), + make_tensor(10, device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + + def test_div_prim(self): + def fn(a, b): + return (torch.ops.prims.div(a, b),) + + for dtype in (torch.float32, torch.int64): + self.common( + fn, + ( + make_tensor(100, device="cpu", dtype=dtype), + make_tensor(100, device="cpu", dtype=dtype, exclude_zero=True), + ), + ) + def test_both_scalars(self): def fn(a, b): return ( @@ -2589,6 +2629,25 @@ def fn(a, b): shape = [1, 2, 6, 6] self.common(fn, (torch.randn(shape), torch.randn(shape))) + def test_fmod_zero_dim(self): + def fn(a, b): + return (torch.fmod(a, b),) + + self.common( + fn, + ( + make_tensor(10, device="cpu", dtype=torch.float32), + make_tensor((), device="cpu", dtype=torch.float32), + ), + ) + self.common( + fn, + ( + make_tensor((), device="cpu", dtype=torch.float32), + make_tensor(10, device="cpu", dtype=torch.float32), + ), + ) + def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a76a9baea953..0bd92007c986 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3354,7 +3354,7 @@ def truncdiv(a, b): return ops.truncdiv(a, b) -@register_lowering(aten.div.Tensor_mode) +@register_lowering(aten.div, broadcast=True) def div_mode(a, b, rounding_mode=None): both_integer = is_integer_type(a) and is_integer_type(b) both_boolean = is_boolean_type(a) and is_boolean_type(b) @@ -3370,23 +3370,6 @@ def div_mode(a, b, rounding_mode=None): return div(a, b) -@register_lowering([aten.div], broadcast=True) -def div(a, b): - def fn(*args): - return ops.div(*args) - - dtype = get_promoted_dtype( - a, b, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT - ) - # truediv produces a float tensor even if both operands are integer types - if is_integer_type(a) and is_integer_type(b): - dtype = torch.get_default_dtype() - return make_pointwise(fn, override_return_dtype=dtype)( - a if isinstance(a, Number) else to_dtype(a, dtype), - b if isinstance(b, Number) else to_dtype(b, dtype), - ) - - @register_lowering([aten.mul], broadcast=True) def mul(a, b): both_bool = is_boolean_type(a) and is_boolean_type(b) @@ -3397,21 +3380,29 @@ def mul(a, b): return make_pointwise(fn)(a, b) -# TODO(lezcano) I believe the casting behaviour of prims.div is wrong -# https://github.com/pytorch/pytorch/issues/84412 -# div prim performs truncation division on integer inputs -# and true division for floating and complex inputs +# NOTE: prims.div maps to a / b in C, so performs truncation division on +# integer inputs and true division for floating and complex inputs. @register_lowering([prims.div], broadcast=True) def div_prim(a, b): is_integral = is_boolean_type(a) or is_integer_type(a) if is_integral: - return div_mode(a, b, rounding_mode="floor") - else: - return div(a, b) + return truncdiv(a, b) + + def fn(*args): + return ops.div(*args) + + return make_pointwise(fn)(a, b) + + +div = register_lowering( + [aten.true_divide, aten.div.Tensor], + broadcast=True, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +)(div_prim) -@register_lowering([aten.fmod, prims.fmod]) +@register_lowering([aten.fmod, prims.fmod], broadcast=True) def fmod(a, b): is_integral = is_boolean_type(a) or is_integer_type(a) @@ -3564,7 +3555,8 @@ def fn(*args, **kwargs): register_inplace(aten.add_, add) register_inplace(aten.mul_, mul) -register_inplace(aten.div_, div) +register_inplace(aten.div_.Tensor, div) +register_inplace(aten.div_.Tensor_mode, div_mode) register_inplace(aten.sub_, sub) register_inplace(aten.relu_, relu) register_inplace(aten.sigmoid_, sigmoid) From 047e542a1a71448083d812783380b855e023eb14 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 21 Nov 2022 21:08:13 +0000 Subject: [PATCH 406/453] [tools] expose selective build library (#89351) Change the base module and visibility of `tools:gen_oplist_lib` so that it can be reused. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89351 Approved by: https://github.com/cccclai --- tools/BUCK.bzl | 5 +++-- tools/code_analyzer/gen_oplist.py | 4 ++-- tools/test/gen_oplist_test.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tools/BUCK.bzl b/tools/BUCK.bzl index 6d16e8fe3ff8..58a49fded0ee 100644 --- a/tools/BUCK.bzl +++ b/tools/BUCK.bzl @@ -62,10 +62,11 @@ def define_tools_targets( ("code_analyzer", "gen_oplist.py"), ("code_analyzer", "gen_op_registration_allowlist.py"), ]), - base_module = "", + base_module = "tools.code_analyzer", tests = [ ":gen_oplist_test", ], + visibility = ["PUBLIC"], deps = [ ":gen_selected_mobile_ops_header", torchgen_deps, @@ -75,7 +76,7 @@ def define_tools_targets( python_binary( name = "gen_oplist", - main_module = "gen_oplist", + main_module = "tools.code_analyzer.gen_oplist", visibility = ["PUBLIC"], deps = [ ":gen_oplist_lib", diff --git a/tools/code_analyzer/gen_oplist.py b/tools/code_analyzer/gen_oplist.py index 1e5d1277afcd..18104ab30cb6 100644 --- a/tools/code_analyzer/gen_oplist.py +++ b/tools/code_analyzer/gen_oplist.py @@ -127,7 +127,7 @@ def main(argv: List[Any]) -> None: default=False, required=False, ) - options = parser.parse_args() + options = parser.parse_args(argv) if os.path.isfile(options.model_file_list_path): print("Processing model file: ", options.model_file_list_path) @@ -186,4 +186,4 @@ def main(argv: List[Any]) -> None: if __name__ == "__main__": - main(sys.argv) + main(sys.argv[1:]) diff --git a/tools/test/gen_oplist_test.py b/tools/test/gen_oplist_test.py index d58e2ccc9067..33f9fb293edc 100644 --- a/tools/test/gen_oplist_test.py +++ b/tools/test/gen_oplist_test.py @@ -4,7 +4,7 @@ import unittest from unittest.mock import MagicMock -from gen_oplist import throw_if_any_op_includes_overloads +from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads class GenOplistTest(unittest.TestCase): From deae450899eb048754f046999a18fbda8c9b2d68 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 21 Nov 2022 19:19:29 +0000 Subject: [PATCH 407/453] [1/n] Thread PG: add test for allgather (#89439) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89439 Approved by: https://github.com/XilunWu, https://github.com/yhcharles, https://github.com/fduwjj --- test/distributed/test_multi_threaded_pg.py | 12 ++++++++++-- .../_internal/distributed/multi_threaded_pg.py | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 6a0fe33cd8ad..dc4713b50439 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: distributed"] import sys +import torch import torch.distributed as dist if not dist.is_available(): @@ -16,7 +17,7 @@ DEFAULT_WORLD_SIZE = 4 -class TestObjectCollectivesWithWrapper(TestCase): +class TestCollectivesWithWrapper(TestCase): @spawn_threads_and_init_comms(world_size=4) def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None @@ -25,11 +26,18 @@ def test_broadcast_object_list(self): dist.broadcast_object_list(object_list=object_list) self.assertEqual(99, object_list[0]) -class TestObjectCollectivesWithBaseClass(MultiThreadedTestCase): +class TestCollectivesWithBaseClass(MultiThreadedTestCase): @property def world_size(self): return 4 + def test_allgather(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() + output_tensors = [torch.empty_like(input_tensor) for _ in range(self.world_size)] + dist.all_gather(output_tensors, input_tensor) + for rank, out_tensor in enumerate(output_tensors): + self.assertEqual(out_tensor, torch.ones(3, 3) * rank) + def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index 7e18f870f2e7..7ad4bfa4cddb 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -7,7 +7,11 @@ import torch import torch.distributed as dist -from torch._C._distributed_c10d import _create_work_from_future, Store +from torch._C._distributed_c10d import ( + _create_work_from_future, + AllgatherOptions, + Store, +) from torch.futures import Future from torch.utils._pytree import tree_flatten @@ -135,7 +139,7 @@ def _end_coll(cls, collective): if cls._cur_coll == collective: cls._cur_coll = None - def allgather(self, output_tensors, input_tensor, options): + def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()): coll = ProcessLocalGroup._start_coll(self._world, AllGather()) res = coll.join(self._rank, (output_tensors, input_tensor)) ProcessLocalGroup._end_coll(coll) From 3876f94c3d0eb329686d0699da2bab00849099b6 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 21 Nov 2022 19:19:29 +0000 Subject: [PATCH 408/453] [2/n] Thread PG: add test for broadcast (#89440) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89440 Approved by: https://github.com/XilunWu, https://github.com/yhcharles, https://github.com/fduwjj --- test/distributed/test_multi_threaded_pg.py | 7 +++++++ torch/testing/_internal/distributed/multi_threaded_pg.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index dc4713b50439..3e1e765eef51 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -38,6 +38,13 @@ def test_allgather(self): for rank, out_tensor in enumerate(output_tensors): self.assertEqual(out_tensor, torch.ones(3, 3) * rank) + def test_broadcast(self): + input_tensor = torch.ones(3, 3) * dist.get_rank() + for rank in range(self.world_size): + cloned_input = input_tensor.clone() + dist.broadcast(cloned_input, src=rank) + self.assertEqual(cloned_input, torch.ones(3, 3) * rank) + def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index 7ad4bfa4cddb..ae465b95641b 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -10,6 +10,7 @@ from torch._C._distributed_c10d import ( _create_work_from_future, AllgatherOptions, + BroadcastOptions, Store, ) from torch.futures import Future @@ -145,7 +146,7 @@ def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()): ProcessLocalGroup._end_coll(coll) return res - def broadcast(self, tensor_list, opts): + def broadcast(self, tensor_list, opts=BroadcastOptions()): coll = ProcessLocalGroup._start_coll(self._world, Broadcast(opts.rootRank)) res = coll.join(self._rank, tensor_list) ProcessLocalGroup._end_coll(coll) From 3e99d4db7671430901bb6292073f368ce1443e05 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 21 Nov 2022 19:19:29 +0000 Subject: [PATCH 409/453] [3/n] Thread PG: add scatter to threaded pg (#89441) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89441 Approved by: https://github.com/XilunWu, https://github.com/yhcharles, https://github.com/fduwjj --- test/distributed/test_multi_threaded_pg.py | 12 ++++++++-- .../distributed/multi_threaded_pg.py | 24 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 3e1e765eef51..1e16f5d03a8c 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -45,6 +45,16 @@ def test_broadcast(self): dist.broadcast(cloned_input, src=rank) self.assertEqual(cloned_input, torch.ones(3, 3) * rank) + def test_scatter(self): + if dist.get_rank() == 0: + scatter_list = [torch.ones(3, 3) * rank for rank in range(self.world_size)] + else: + scatter_list = None + output_tensor = torch.empty(3, 3) + + dist.scatter(output_tensor, scatter_list) + self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank()) + def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() @@ -53,8 +63,6 @@ def test_broadcast_object_list(self): dist.broadcast_object_list(object_list=object_list) self.assertEqual(99, object_list[0]) - def test_something_else(self): - pass if __name__ == "__main__": run_tests() diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index ae465b95641b..321c61d993cf 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -11,6 +11,7 @@ _create_work_from_future, AllgatherOptions, BroadcastOptions, + ScatterOptions, Store, ) from torch.futures import Future @@ -50,6 +51,23 @@ def work(self, data): with torch.no_grad(): dest_tensor.copy_(src_tensor) +class Scatter: + def __init__(self, src): + self.src = src + + def work(self, data): + src_in_tensor_list = data[self.src][1] + # Can't handle scatter with multiple input tensor list + assert len(src_in_tensor_list) == 1 + src_in_tensors = src_in_tensor_list[0] + + for rank, each_rank_data in enumerate(data): + out_tensor_list = each_rank_data[0] + # Can't handle scatter with multiple output tensor + assert len(out_tensor_list) == 1 + dest_tensor = out_tensor_list[0] + with torch.no_grad(): + dest_tensor.copy_(src_in_tensors[rank]) class Broadcast: def __init__(self, src): @@ -152,6 +170,12 @@ def broadcast(self, tensor_list, opts=BroadcastOptions()): ProcessLocalGroup._end_coll(coll) return res + def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()): + coll = ProcessLocalGroup._start_coll(self._world, Scatter(opts.rootRank)) + res = coll.join(self._rank, (output_tensors, input_tensors)) + ProcessLocalGroup._end_coll(coll) + return res + def __init__(self, rank, world): super(ProcessLocalGroup, self).__init__(rank, world) self._rank = rank From 821ba6b51beb1844f264fd57e1eccecb446e4870 Mon Sep 17 00:00:00 2001 From: Wanchao Liang Date: Mon, 21 Nov 2022 19:19:29 +0000 Subject: [PATCH 410/453] [4/n] Thread PG: add reduce_scatter to threaded pg (#89442) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89442 Approved by: https://github.com/yhcharles, https://github.com/fduwjj --- test/distributed/test_multi_threaded_pg.py | 8 +++++ .../distributed/multi_threaded_pg.py | 31 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/test/distributed/test_multi_threaded_pg.py b/test/distributed/test_multi_threaded_pg.py index 1e16f5d03a8c..f520698258ed 100644 --- a/test/distributed/test_multi_threaded_pg.py +++ b/test/distributed/test_multi_threaded_pg.py @@ -55,6 +55,14 @@ def test_scatter(self): dist.scatter(output_tensor, scatter_list) self.assertEqual(output_tensor, torch.ones(3, 3) * dist.get_rank()) + def test_reduce_scatter(self): + to_reduce_scatter = [torch.ones(3, 3) * rank for rank in range(self.world_size)] + output_tensor = torch.empty(3, 3) + + dist.reduce_scatter(output_tensor, to_reduce_scatter) + expected_tensor = torch.ones(3, 3) * dist.get_rank() * self.world_size + self.assertEqual(output_tensor, expected_tensor) + def test_broadcast_object_list(self): val = 99 if dist.get_rank() == 0 else None object_list = [val] * dist.get_world_size() diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py index 321c61d993cf..df45748ee6c6 100644 --- a/torch/testing/_internal/distributed/multi_threaded_pg.py +++ b/torch/testing/_internal/distributed/multi_threaded_pg.py @@ -11,6 +11,7 @@ _create_work_from_future, AllgatherOptions, BroadcastOptions, + ReduceScatterOptions, ScatterOptions, Store, ) @@ -69,6 +70,30 @@ def work(self, data): with torch.no_grad(): dest_tensor.copy_(src_in_tensors[rank]) +class ReduceScatter: + def __init__(self, op): + if op != dist.ReduceOp.SUM: + raise NotImplementedError("ReduceScatter only supports SUM on threaded pg for now.") + self.op = op + + def work(self, data): + start_reduction = [False for _ in range(len(data))] + for each_rank_data in data: + # Can't handle reduce_scatter with multiple scatter list + assert len(each_rank_data[1]) == 1 + to_scatter = each_rank_data[1][0] + for i in range(len(to_scatter)): + dest_tensor_on_rank_i = data[i][0] + # Can't handle reduce_scatter with multiple output tensor + assert len(dest_tensor_on_rank_i) == 1 + if not start_reduction[i]: + with torch.no_grad(): + dest_tensor_on_rank_i[0].copy_(to_scatter[i]) + start_reduction[i] = True + else: + with torch.no_grad(): + dest_tensor_on_rank_i[0].add_(to_scatter[i]) + class Broadcast: def __init__(self, src): self.src = src @@ -176,6 +201,12 @@ def scatter(self, output_tensors, input_tensors, opts=ScatterOptions()): ProcessLocalGroup._end_coll(coll) return res + def reduce_scatter(self, output_tensor, scatter_list, opts=ReduceScatterOptions()): + coll = ProcessLocalGroup._start_coll(self._world, ReduceScatter(opts.reduceOp)) + res = coll.join(self._rank, (output_tensor, scatter_list)) + ProcessLocalGroup._end_coll(coll) + return res + def __init__(self, rank, world): super(ProcessLocalGroup, self).__init__(rank, world) self._rank = rank From 186192bb26a71ec9b0131a6c49fdf19e76d208d7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 21 Nov 2022 22:43:58 +0000 Subject: [PATCH 411/453] [Dynamo] Fix bugs when calling tensor.data and tensor.layout (#89257) Fix bugs in [7k github models](https://github.com/pytorch/torchdynamo/issues/1884). * Legacy code still use ```tensor.data```, I think we can use ```tensor.detach``` to rewrite, not sure if there is anything I didn't anticipate. * Support ```tensor.layout```. The root cause of these issues are: dynamo wraps unimplemented ```tensor.x``` call into ```GetAttrVariable(TensorVariable, x)```, but this op was not inserted into FX graph. Hence, during the fake tensor propagation, it throws ```KeyError: 'example_value` ```. For these two popular attributes, Dynamo should support them anyway. However, if dynamo should support ___all___ ```tensor.x``` call and not fallback to ```GetAttrVariable```, I think it's debatable. If I turn off fake tensor propagation, it works well even not including this fix. So I'm curious if we should improve the fake propagation to cover similar cases. cc @mlazos @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @jansel @eellison ``` Traceback (most recent call last): File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile out_code = transform_code_object(code, transform) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object transformations(instructions, code_options) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform tracer.run() File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1523, in run super().run() File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 389, in run and self.step() File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 359, in step getattr(self, inst.opname)(inst) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 193, in wrapper return inner_fn(self, inst) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 865, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 301, in call_function self.push(fn.call_function(self, args, kwargs)) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/torch.py", line 407, in call_function tensor_variable = wrap_fx_proxy( File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/builder.py", line 636, in wrap_fx_proxy return wrap_fx_proxy_cls( File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/builder.py", line 676, in wrap_fx_proxy_cls example_value = get_fake_value(proxy.node, tx) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1024, in get_fake_value args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 613, in map_arg return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 621, in map_aggregate t = tuple(map_aggregate(elem, fn) for elem in a) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 621, in t = tuple(map_aggregate(elem, fn) for elem in a) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 627, in map_aggregate return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 627, in return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items()) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 631, in map_aggregate return fn(a) File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 613, in return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1022, in visit return n.meta["example_value"] KeyError: 'example_value\n\nfrom user code:\n File "./generated/test_BayesWatch_pytorch_prunes.py", line 108, in forward\n return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], dtype=x.dtype, layout=x.layout, device=x.device)\n\nSet torch._dynamo.config.verbose=True for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n torch._dynamo.config.suppress_errors = True\n' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89257 Approved by: https://github.com/jansel --- test/dynamo/test_misc.py | 26 ++++++++++++++++++++++++++ torch/_dynamo/variables/tensor.py | 16 ++++++++++++---- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index e3274738fc21..1a04f25e7404 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1301,6 +1301,32 @@ def fn(x): self.assertTrue(same(ref0, res0)) self.assertTrue(same(ref1, res1)) + def test_tensor_data(self): + def fn(x, y): + return x[y.data] + + x = torch.rand(8) + y = torch.ones(8).to(torch.int) + ref = fn(x, y) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + + def test_tensor_layout(self): + def fn(x): + return torch.zeros( + [x.size()[0], x.size()[1]], + dtype=x.dtype, + layout=x.layout, + device=x.device, + ) + + x = torch.rand(2, 3) + ref = fn(x) + opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + def test_version_ci(self): # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index ab94aaf537d2..84de57c0f295 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -30,6 +30,7 @@ class TensorVariable(VariableTracker): "proxy", "dtype", "device", + "layout", "ndim", "size", "stride", @@ -52,6 +53,7 @@ def __init__( proxy: torch.fx.Proxy, dtype=None, device=None, + layout=None, ndim=None, size=None, stride=None, @@ -67,6 +69,7 @@ def __init__( self.proxy = proxy self.dtype = dtype self.device = device + self.layout = layout self.ndim = ndim self.size = size self.stride = stride @@ -101,6 +104,7 @@ def specialize(value: torch.Tensor): props = { "dtype": value.dtype, "device": value.device, + "layout": value.layout, "ndim": int(value.ndim), "requires_grad": value.requires_grad, "is_quantized": value.is_quantized, @@ -130,6 +134,8 @@ def var_getattr(self, tx, name): result = TorchVariable(self.dtype, **options) elif name == "device" and self.device is not None: result = TorchVariable(self.device, **options) + elif name == "layout" and self.layout is not None: + result = TorchVariable(self.layout, **options) elif name == "is_cuda" and self.device is not None: result = ConstantVariable(self.device.type == "cuda", **options) elif name == "shape" and self.size is not None: @@ -145,6 +151,8 @@ def var_getattr(self, tx, name): result = self.call_method(tx, "size", [], {}) elif name == "ndim" and self.ndim is None: result = self.call_method(tx, "dim", [], {}) + elif name == "data": + result = self.call_method(tx, "detach", [], {}) elif name == "T": args = [variables.ConstantVariable(i) for i in range(self.ndim - 1, -1, -1)] result = self.call_method(tx, "permute", args, {}) @@ -198,7 +206,7 @@ def call_method( tx.output.create_proxy( "call_method", name, - *proxy_args_kwargs([self] + args, kwargs), + *proxy_args_kwargs([self] + list(args), kwargs), current_tx=tx, ), **options, @@ -277,7 +285,7 @@ def call_method( tx.output.create_proxy( "call_function", operator.setitem, - *proxy_args_kwargs([self] + args, kwargs), + *proxy_args_kwargs([self] + list(args), kwargs), current_tx=tx, ) return ConstantVariable(None, **options) @@ -309,7 +317,7 @@ def call_method( tx.output.create_proxy( "call_method", name, - *proxy_args_kwargs([self] + args, kwargs), + *proxy_args_kwargs([self] + list(args), kwargs), current_tx=tx, ), **options, @@ -329,7 +337,7 @@ def call_method( tx.output.create_proxy( "call_method", name, - *proxy_args_kwargs([self] + args, kwargs), + *proxy_args_kwargs([self] + list(args), kwargs), current_tx=tx, ), **options, From fa4980cd5e7581b5195ed4d02d63bf73497549d0 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 21 Nov 2022 22:56:13 +0000 Subject: [PATCH 412/453] Add commit hash to dynamo dashboard (#89462) Title - also fix a small bug with dashboard outputs. Sample: https://github.com/pytorch/torchdynamo/issues/1831#issuecomment-1322732698 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89462 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 843dbd12909a..963dcf493705 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -336,7 +336,7 @@ def generate_dropdown_comment(title, body): return str_io.getvalue() -def build_summary(): +def build_summary(args): import git out_io = io.StringIO() @@ -352,31 +352,36 @@ def print_commit_hash(path, name): def env_var(name): out_io.write(f"{name} = {os.environ[name]}\n") - out_io.write("## Commit hashes ##\n") - print_commit_hash(".", "torch._dynamo") + out_io.write("\n") + out_io.write("### Run name ###\n") + out_io.write(get_archive_name(args, args.dtypes[0])) + out_io.write("\n") + + out_io.write("\n") + out_io.write("### Commit hashes ###\n") print_commit_hash("../pytorch", "pytorch") print_commit_hash("../functorch", "functorch") print_commit_hash("../torchbenchmark", "torchbench") out_io.write("\n") - out_io.write("## TorchDynamo config flags ##\n") + out_io.write("### TorchDynamo config flags ###\n") for key in dir(torch._dynamo.config): val = getattr(torch._dynamo.config, key) if not key.startswith("__") and isinstance(val, bool): out_io.write(f"torch._dynamo.config.{key} = {val}\n") out_io.write("\n") - out_io.write("## Torch version ##\n") + out_io.write("### Torch version ###\n") out_io.write(f"torch: {torch.__version__}\n") out_io.write("\n") - out_io.write("## Environment variables ##\n") + out_io.write("### Environment variables ###\n") env_var("TORCH_CUDA_ARCH_LIST") env_var("CUDA_HOME") env_var("USE_LLVM") out_io.write("\n") - out_io.write("## GPU details ##\n") + out_io.write("### GPU details ###\n") out_io.write(f"CUDNN VERSION: {torch.backends.cudnn.version()}\n") out_io.write(f"Number CUDA Devices: {torch.cuda.device_count()}\n") out_io.write(f"Device Name: {torch.cuda.get_device_name(0)}\n") @@ -415,6 +420,12 @@ def default_archive_name(dtype): return f"{prefix}_performance_{dtype}_{randint(100, 999)}" +def get_archive_name(args, dtype): + return ( + default_archive_name(dtype) if args.archive_name is None else args.archive_name + ) + + def archive(src_dir, dest_dir_prefix, archive_name, dtype): if archive_name is None: archive_name = default_archive_name(dtype) @@ -810,7 +821,7 @@ def gen_summary_files(self): def parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir): mode = get_mode(args) - build_summary() + build_summary(args) parser_class = ParsePerformanceLogs parser = parser_class( @@ -965,13 +976,13 @@ def generate_comment(self): f"suite: {suite}): {path}\n\n" ) + regressions_present = False for metric in [ "accuracy", "speedup", "compilation_latency", "compression_ratio", ]: - regressions_present = False dfs = [] for compiler in self.args.flag_compilers: if last2[compiler] is None: @@ -1148,11 +1159,7 @@ def __init__(self, args): def update_lookup_file(self): dtype = self.args.dtypes[0] day, _ = archive_data(self.args.archive_name) - target_dir = ( - default_archive_name(dtype) - if self.args.archive_name is None - else self.args.archive_name - ) + target_dir = get_archive_name(self.args, dtype) # Update lookup csv the folder to arhived logs subprocess.check_call( f'echo "{day},performance,{dtype},{target_dir}" >> {self.lookup_file}', @@ -1198,6 +1205,7 @@ def gen_comment(self): "gh_metric_regression.txt", "gh_training.txt", "gh_graphs.txt", + "gh_build_summary.txt", ] all_lines = [] for f in files: From ea50549ce62aeeccfe27035a0a975e83b9c2c987 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 21 Nov 2022 18:12:21 -0500 Subject: [PATCH 413/453] Suppress guards when creating fake tensors (#89349) When we create fake tensors, we may call operators that introduce guards, to accurately reconstruct views. But these guards are spurious: if a user is able to present a tensor that "looks the same", they have implicitly fulfilled the contract that the view is creatable. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89349 Approved by: https://github.com/voznesenskym --- torch/_subclasses/fake_tensor.py | 6 ++++- torch/fx/experimental/symbolic_shapes.py | 29 ++++++++++++++++++------ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 9a0ac050e6b9..758f4431f688 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -230,7 +230,11 @@ def mk_fake_tensor(make_meta_t): constant=t if make_constant else None, ) - out = self.meta_converter(t, shape_env=shape_env, callback=mk_fake_tensor) + ctx = contextlib.nullcontext() + if shape_env is not None: + ctx = shape_env.suppress_guards() + with ctx: + out = self.meta_converter(t, shape_env=shape_env, callback=mk_fake_tensor) if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") if make_constant: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index bd52760502c6..f25302a88397 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -6,6 +6,8 @@ import builtins import math import functools +import threading +from contextlib import contextmanager from functools import lru_cache, partial import traceback import collections @@ -439,6 +441,18 @@ def __init__(self): # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable self.val_to_var: Dict[int, "sympy.Expr"] = {0: sympy.Integer(0), 1: sympy.Integer(1)} + self.tls = threading.local() + + def _suppress_guards_tls(self): + return getattr(self.tls, "suppress_guards", False) + + @contextmanager + def suppress_guards(self): + self.tls.suppress_guards = True + try: + yield + finally: + self.tls.suppress_guards = False def _get_key(self): """ @@ -673,11 +687,12 @@ def evaluate_expr(self, expr: "sympy.Expr"): # TODO: optimize this; avoid formatting traces until we need them # NB: drop two frames; evaluate_expr and the Sym* function that # actually called us - stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2])) - if concrete_val is sympy.true: - self.guards.append((expr, stack)) - elif concrete_val is sympy.false: - self.guards.append((sympy.Not(expr), stack)) - else: - self.guards.append((sympy.Eq(expr, concrete_val), stack)) + if not self._suppress_guards_tls(): + stack = ''.join(traceback.format_list(traceback.extract_stack()[:-2])) + if concrete_val is sympy.true: + self.guards.append((expr, stack)) + elif concrete_val is sympy.false: + self.guards.append((sympy.Not(expr), stack)) + else: + self.guards.append((sympy.Eq(expr, concrete_val), stack)) return concrete_val From dbc354b262f7f5e49aa781785cfce6299fdc2aa8 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 22 Nov 2022 00:13:38 +0000 Subject: [PATCH 414/453] Mitigate flaky test_ops_fwd_gradients on macOS (#89410) This has been flaky on macOS for a while ([hud](https://hud.pytorch.org/failure/RuntimeError%3A%20test_ops_fwd_gradients%20failed)) and I can reproduce this locally. The issue was raised by https://github.com/pytorch/pytorch/issues/66033 and it seems to point to macos itself https://github.com/graphia-app/graphia/issues/33. So switching to single thread when running `test_ops_fwd_gradients` on macOS as a mitigation for the flaky tests. ### Testing `pytest test_ops_fwd_gradients.py -k test_fn_fwgrad_bwgrad -vv --flake-finder` to run all `test_fn_fwgrad_bwgrad` tests 50 times to make sure they all pass (no flaky anymore) https://hud.pytorch.org/tests shows that `test_ops_fwd_gradients` on macOS takes about 15m to finish or 8 minute if using 2 shards like in the test. There is no obvious difference in the test duration: ``` 2022-11-21T21:34:18.6078080Z Running test_ops_fwd_gradients ... [2022-11-21 21:34:18.600663] 2022-11-21T21:34:21.6805770Z Executing ['/Users/runner/work/_temp/conda_environment_3517515737/bin/python', '-bb', 'test_ops_fwd_gradients.py', '-v', '--use-pytest', '-vv', '-rfEX', '-x', '--reruns=2', '--shard-id=0', '--num-shards=2', '-k=not _linalg_cholesky_', '--import-slow-tests', '--import-disabled-tests'] ... [2022-11-21 21:34:21.680156] 2022-11-21T21:34:21.6806380Z Ignoring disabled issues: [] 2022-11-21T21:34:21.6815250Z Executing ['/Users/runner/work/_temp/conda_environment_3517515737/bin/python', '-bb', 'test_ops_fwd_gradients.py', '-v', '--use-pytest', '-vv', '-rfEX', '-x', '--reruns=2', '--shard-id=1', '--num-shards=2', '-k=not _linalg_cholesky_', '--import-slow-tests', '--import-disabled-tests'] ... [2022-11-21 21:34:21.681174] 2022-11-21T21:34:21.6815830Z Ignoring disabled issues: [] ..... 2022-11-21T21:40:42.2422700Z =============================== warnings summary =============================== ..... 2022-11-21T21:40:42.2424670Z - generated xml file: /Users/runner/work/pytorch/pytorch/test/test-reports/python-pytest/test_ops_fwd_gradients/test_ops_fwd_gradients-47b619449ea7db1f.xml - 2022-11-21T21:40:42.2424850Z = 831 passed, 596 skipped, 5 deselected, 17 xfailed, 1 warning in 374.54s (0:06:14) = ..... 2022-11-21T21:42:00.1923310Z =============================== warnings summary =============================== ..... 2022-11-21T21:42:00.1925370Z - generated xml file: /Users/runner/work/pytorch/pytorch/test/test-reports/python-pytest/test_ops_fwd_gradients/test_ops_fwd_gradients-d24ee6419a602a6e.xml - 2022-11-21T21:42:00.1925540Z = 828 passed, 603 skipped, 7 deselected, 20 xfailed, 1 warning in 452.94s (0:07:32) = .... 2022-11-21T21:42:09.9035670Z FINISHED PRINTING LOG FILE of test_ops_fwd_gradients (/Users/runner/work/pytorch/pytorch/test/test-reports/test_ops_fwd_gradients_ha_3rfhb) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89410 Approved by: https://github.com/soulitzer --- test/test_ops_fwd_gradients.py | 8 +++++++- torch/testing/_internal/common_methods_invocations.py | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/test/test_ops_fwd_gradients.py b/test/test_ops_fwd_gradients.py index c3fca7235461..4b7b1c785d5f 100644 --- a/test/test_ops_fwd_gradients.py +++ b/test/test_ops_fwd_gradients.py @@ -4,7 +4,7 @@ import torch from torch.testing._internal.common_utils import ( - TestGradients, run_tests, skipIfTorchInductor) + TestGradients, run_tests, skipIfTorchInductor, IS_MACOS) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, OpDTypes) @@ -12,6 +12,12 @@ # TODO: fixme https://github.com/pytorch/pytorch/issues/68972 torch.set_default_dtype(torch.float32) +# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033 +# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The +# issue seems to point to macOS itself https://github.com/graphia-app/graphia/issues/33 +if IS_MACOS: + torch.set_num_threads(1) + # gradcheck requires double precision _gradcheck_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 998f1cde65f7..c0c53efa503e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -11711,7 +11711,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # (see sample_inputs_max_unpool_grad to find out more). DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), - DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad', device_type='cpu'), )), @@ -11743,7 +11744,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # and if there are several indices pointing to the same memory, # gradcheck is oblivious about that and cannot perturb them all at once # (see sample_inputs_max_unpool_grad to find out more). - DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), DecorateInfo(unittest.skip("Skipped!"), 'TestCompositeCompliance', 'test_forward_ad'), @@ -11780,7 +11782,8 @@ def reference_flatten(input, start_dim=0, end_dim=-1): # and if there are several indices pointing to the same memory, # gradcheck is oblivious about that and cannot perturb them all at once # (see sample_inputs_max_unpool_grad to find out more). - DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'), + DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD', + active_if=(not IS_MACOS)), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_grad'), DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_forward_ad'), From b189a7444da8b17c535e7d04c9ab705289ec53e1 Mon Sep 17 00:00:00 2001 From: Khushi Date: Tue, 22 Nov 2022 00:15:30 +0000 Subject: [PATCH 415/453] [fix] tril & tril : out of bound check (#89384) Fixes #83326 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89384 Approved by: https://github.com/ngimel --- aten/src/ATen/native/TriangularOps.cpp | 2 ++ test/functorch/test_vmap.py | 4 ++++ torch/testing/_internal/common_methods_invocations.py | 8 ++++++++ 3 files changed, 14 insertions(+) diff --git a/aten/src/ATen/native/TriangularOps.cpp b/aten/src/ATen/native/TriangularOps.cpp index fbdd204f6430..59d2b8a0d224 100644 --- a/aten/src/ATen/native/TriangularOps.cpp +++ b/aten/src/ATen/native/TriangularOps.cpp @@ -23,10 +23,12 @@ namespace at { namespace meta { TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) { + TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) { + TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 0c38c5101cf8..4c2c680ca637 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3295,6 +3295,8 @@ def test(): @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({ xfail('native_batch_norm'), + xfail('tril'), # Exception not raised on error input + xfail('triu'), # Exception not raised on error input # The error inputs are vectors, that pass when batched as they are treated as a matrix xfail('trace'), })) @@ -3342,6 +3344,8 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('tensor_split'), xfail('to_sparse'), xfail('vdot'), + xfail('tril'), # Exception not raised on error input + xfail('triu'), # Exception not raised on error input xfail('__getitem__', ''), xfail('all'), xfail('any'), diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index c0c53efa503e..3d3c13bb7208 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6087,6 +6087,12 @@ def sample_inputs_tril_triu(op_info, device, dtype, requires_grad, **kwargs): for shape, args in cases: yield SampleInput(make_arg(shape), args=args) +def error_inputs_tril_triu(opinfo, device, **kwargs): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # error inputs for input.ndim <= 2 + yield ErrorInput(SampleInput(make_arg((4,))), error_regex="input tensor must have at least 2 dimensions") + def sample_inputs_trilu_indices(op_info, device, dtype, requires_grad, **kwargs): # (row, col, offset) args_list = ((0, 0), @@ -15371,12 +15377,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half), supports_forward_ad=True, supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, sample_inputs_func=sample_inputs_tril_triu), OpInfo('triu', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.chalf, torch.bool, torch.half), supports_forward_ad=True, supports_fwgrad_bwgrad=True, + error_inputs_func=error_inputs_tril_triu, sample_inputs_func=sample_inputs_tril_triu), OpInfo('triu_indices', dtypes=_dispatch_dtypes((torch.int32, torch.int64)), From 57ed94804e8195f227c7a75899a319cc0a3b833a Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 21 Nov 2022 16:04:46 -0500 Subject: [PATCH 416/453] Bind DispatchKey.Functionalonalize in pybind11 (#89452) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/89452 Approved by: https://github.com/albanD, https://github.com/bdhirsh --- torch/csrc/utils/python_dispatch.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 381e82e1fcdb..e4ce9ccf5217 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -495,6 +495,7 @@ void initDispatchBindings(PyObject* module) { DEF_ONE(FuncTorchDynamicLayerFrontMode) DEF_ONE(FuncTorchDynamicLayerBackMode) DEF_ONE(PythonDispatcher) + DEF_ONE(Functionalize) // clang-format on #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) From 7174572b1ef4cff545e4ca8fc77c135e58fcbefb Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 21 Nov 2022 21:37:32 +0000 Subject: [PATCH 417/453] Add torchvis support to dist bench (#89324) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89324 Approved by: https://github.com/davidberard98, https://github.com/albanD --- benchmarks/dynamo/distributed.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py index 360fd846dbe8..dee44210e93c 100644 --- a/benchmarks/dynamo/distributed.py +++ b/benchmarks/dynamo/distributed.py @@ -18,6 +18,17 @@ from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup +def torchviz_model(args, model, inputs, rank): + from torchviz import make_dot + + outputs = model(*inputs) + loss = reduce_to_scalar_loss(outputs) + parameter_names = dict(model.named_parameters()) + dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True) + if rank == 0: + dot.render("torchviz.dot") + + def profile_model(args, model, inputs, rank): with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: for i in range(args.repeat): @@ -87,7 +98,8 @@ def print_compile(gm, ex): t_total = timed( model, model_iter_fn, inputs, times=args.repeat, return_result=False ) - + if args.torchviz: + torchviz_model(args, model, inputs, rank) if args.profile: profile_model(args, model, inputs, rank) @@ -105,6 +117,9 @@ def print_compile(gm, ex): ) parser.add_argument("--verbose", action="store_true") parser.add_argument("--batch_size", default=None) + parser.add_argument( + "--torchviz", action="store_true", help="Dump autograd graph with torchviz" + ) parser.add_argument("--profile", action="store_true", help="Run the profiler") parser.add_argument("--trace_file", default="profile.json", help="Run the profiler") parser.add_argument("--repeat", default=10, help="Repeats for timing run") From 58a74f34f981de2c24b8f57c37687421c87a782a Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Nov 2022 11:05:38 -0800 Subject: [PATCH 418/453] [17/N] Add _reduce_scatter_base custom op with CPU/CUDA implementation (#88903) Differential Revision: [D41415325](https://our.internmc.facebook.com/intern/diff/D41415325) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88903 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 15 ++++++++++ torch/csrc/distributed/c10d/Ops.cpp | 38 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 6 ++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 34 ++++++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 8 +++++- 5 files changed, 100 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index fb28e744b5ed..85ebb6b75bc5 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2974,6 +2974,21 @@ def test_allgather_base(self): dist.all_gather_into_tensor(output_tensor, tensor) self.assertEqual(output_tensor, tensor) + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_reduce_scatter_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.reduce_scatter_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) if __name__ == "__main__": assert ( diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index f825afca2a1d..5d343c344ec8 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -111,6 +111,19 @@ std::tuple, c10::intrusive_ptr> reduce_scatter_( output_tensors, work); } +c10::intrusive_ptr _reduce_scatter_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr gather_( const std::vector>& output_tensors, const std::vector& input_tensors, @@ -210,6 +223,10 @@ TORCH_LIBRARY(c10d, m) { m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); + m.def( + "_reduce_scatter_base_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, _reduce_scatter_base_)); m.def( "reduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_)); @@ -350,6 +367,27 @@ c10::intrusive_ptr reduce_scatter( opts.timeout.count())); } +c10::intrusive_ptr _reduce_scatter_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ReduceScatterOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_reduce_scatter_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + return op.call( + output_tensor, + input_tensor, + process_group, + c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr reduce( const c10::intrusive_ptr& process_group, at::TensorList tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index 72f09e341d7d..f6425e0ea350 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -44,6 +44,12 @@ TORCH_API c10::intrusive_ptr reduce_scatter( const std::vector>& input_tensors, const ReduceScatterOptions& opts = {}); +TORCH_API c10::intrusive_ptr _reduce_scatter_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ReduceScatterOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce( const c10::intrusive_ptr& process_group, at::TensorList tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 78e26c9656d8..c3db5c438124 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -259,6 +259,32 @@ reduce_scatter_cuda_( output_tensors, work); } +c10::intrusive_ptr _reduce_scatter_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + +c10::intrusive_ptr _reduce_scatter_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + return process_group->_reduce_scatter_base( + output_tensor, + input_tensor, + ReduceScatterOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr gather_cpu_( const std::vector>& output_tensors, const std::vector& input_tensors, @@ -439,6 +465,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("reduce_scatter_", reduce_scatter_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_reduce_scatter_base_", _reduce_scatter_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_reduce_scatter_base_", _reduce_scatter_base_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("gather_", gather_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index d39fc322d326..ae98000112fc 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1411,7 +1411,13 @@ that adds a prefix to each key inserted to the store. .def( "_reduce_scatter_base", - &::c10d::ProcessGroup::_reduce_scatter_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::ReduceScatterOptions& opts) { + return ::c10d::ops::_reduce_scatter_base( + self, output_tensor, input_tensor, opts); + }, py::arg("outputTensor"), py::arg("inputTensor"), py::arg("opts") = ::c10d::ReduceScatterOptions(), From 06dffb3319a38bf909939f64320e0fde88679b94 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 21 Nov 2022 17:54:25 -0500 Subject: [PATCH 419/453] dont clone symints, dont clobber symint proxies (#88230) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88230 Approved by: https://github.com/albanD --- test/functorch/test_aotdispatch.py | 7 +------ torch/fx/experimental/proxy_tensor.py | 13 ++++++++++++- torch/fx/experimental/symbolic_shapes.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index e03fe1e15385..84b1ba893cce 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -694,7 +694,7 @@ def f(a, b, c, d): # # TODO(whc)- are the saved-tensors/saved-symints correct here? # i just made the test pass based on what default partition did - [False, True, True, False, False] + [False] * 5 + [True] * 3, + [False, True, True, False, False] + [False] * 4 + [True] * 4, [is_sym_node(n) for n in fw_graph_out_nodes] ) @@ -996,7 +996,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('addr', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('amax', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('amin', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('as_strided', ''), # Tensor-likes are not close! xfail('baddbmm', ''), # aten.baddbmm.default - couldn't find symbolic meta function/decomposition xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('cartesian_prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides @@ -1102,10 +1101,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('min', 'reduction_with_dim'), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mode', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('mv', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - - # Deleting this in a followup - xfail('nn.functional.poisson_nll_loss', ''), - xfail('nn.functional._scaled_dot_product_attention', ''), # Cannot call sizes() on tensor with symbolic ... xfail('nn.functional.adaptive_avg_pool3d', ''), # aten._adaptive_avg_pool3d_backward.default - couldn't ... xfail('nn.functional.adaptive_max_pool1d', ''), # Cannot call sizes() on tensor with symbolic sizes/strides diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index daa17f94b7bb..012984ebe6f0 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -65,7 +65,18 @@ def set_proxy_slot(obj, tracer, proxy): assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) d = obj.__dict__.setdefault(proxy_slot, weakref.WeakKeyDictionary()) # type: ignore[call-overload] assert isinstance(d, weakref.WeakKeyDictionary) - d[tracer] = proxy + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + # + # However, we DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + # This doesn't really apply to SymInts/SymFloats though, which are immutable. + if tracer not in d or isinstance(obj, torch.Tensor): + d[tracer] = proxy def has_proxy_slot(obj, tracer): assert isinstance(obj, (torch.Tensor, SymNode)), type(obj) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index f25302a88397..41121808e24e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -177,7 +177,7 @@ def wrap_float(self, num): return SymNode(sympy.Float(num), self.shape_env, float, constant=num) def clone(self): - return SymNode(self.expr, self.shape_env, self.pytype, constant=self.constant) + return self def str(self): return f"{self.expr}" From 120d200620159597f416f9142f1d5708182ca047 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 22 Nov 2022 02:20:45 +0000 Subject: [PATCH 420/453] Revert "Added conv constraint that infers layouts (#89031)" (#89451) This reverts commit 716f70f19a4b63268da2a753afdbe9b385a831ab. Fixes performance regression and compilation latency increase. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89451 Approved by: https://github.com/soumith, https://github.com/jansel --- test/inductor/test_torchinductor.py | 4 +- torch/_inductor/graph.py | 29 +------ torch/_inductor/ir.py | 3 - torch/_inductor/lowering.py | 118 ++++++++++++++++++---------- 4 files changed, 81 insertions(+), 73 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2196f4f8a026..0d28f156ecc0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -68,6 +68,7 @@ from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten + requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow @@ -5309,8 +5310,6 @@ def get_kernels(self, fn, args) -> typing.List[CachingAutotuner]: return kernels def test_divisibile_by_16_covers_numel_args(self): - torch._dynamo.reset() - def fn(a: torch.Tensor) -> torch.Tensor: return torch.sum(a) @@ -5330,7 +5329,6 @@ def fn(a: torch.Tensor) -> torch.Tensor: kernels[1].meta["configs"][0].divisible_by_16 ) self.assertEqual(arguments_that_are_divisible_by_16_in_kernel1, (0, 1)) - torch._dynamo.reset() if __name__ == "__main__": diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a47d9c1a02e1..7a5791de8a38 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -20,12 +20,7 @@ MissingOperatorWithoutDecomp, ) from .ir import Constant, FixedLayout, InputBuffer, Pointwise, Reduction, TensorBox -from .lowering import ( - layout_constraints, - lowerings, - make_fallback, - needs_realized_inputs, -) +from .lowering import lowerings, make_fallback, needs_realized_inputs from .sizevars import SizeVarAllocator from .utils import dynamo_utils, gather_origins, get_dtype_size, sympy_product from .virtualized import V @@ -306,12 +301,7 @@ def finalize(self): def run_node(self, n: torch.fx.Node): with ir.IRNode.current_origins({n}): - if n.op == "call_function" and n.target in layout_constraints: - args, kwargs = self.fetch_args_kwargs_from_env(n) - args, kwargs = layout_constraints[n.target](n, *args, **kwargs) - result = self.call_function(n.target, args, kwargs) - else: - result = super().run_node(n) + result = super().run_node(n) # Realize if (1) any user need inputs realized, or (2) there is # already too many reads and rematerializing can be bad. @@ -320,20 +310,7 @@ def run_node(self, n: torch.fx.Node): for user in n.users: if user.target in needs_realized_inputs: result.realize_hint() - # This inclusion is somewhat controversial (from - # discussion between Horace, Natalia, and Elias). - # Currently, it's not very clear why this is helpful. - # The general idea here is that even though a node may - # have FlexibleLayout, we still often *treat* it as if - # it was contiguous. This appears to sometime result in - # suboptimal behavior. - # - # When we do a better job selecting layout, we should - # revisit this. - result = ir.ExternKernel.require_stride_order( - result, ir.get_stride_order(n.meta["val"].stride()) - ) - if user.op == "output": + elif user.op == "output": if isinstance(result.data.data, (Pointwise, Reduction)): result.realize() diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index d54724671768..8327fe0d7b52 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -2478,9 +2478,6 @@ def require_stride1(cls, x): @classmethod def require_stride_order(cls, x, order): - if x.get_numel() == 0: # Layout doesn't matter - return x - # require x to have the layout as strided_ordered as order if is_storage_and_layout(x): if isinstance( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0bd92007c986..80743a563e73 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -23,6 +23,7 @@ from .decomposition import decompositions, get_decompositions from .ir import ( ExpandView, + get_stride_order, IndexingConstant, IndexingDiv, PermuteView, @@ -37,7 +38,6 @@ log = logging.getLogger(__name__) lowerings = {} -layout_constraints = {} fallbacks = set() aten = torch.ops.aten prims = torch.ops.prims @@ -53,14 +53,6 @@ def add_needs_realized_inputs(fn): needs_realized_inputs.add(getattr(fn, overload)) -def add_layout_constraint(fn, constraint): - if isinstance(fn, torch._ops.OpOverloadPacket): - for overload in fn.overloads(): - layout_constraints[getattr(fn, overload)] = constraint - else: - layout_constraints[fn] = constraint - - add_needs_realized_inputs( [ aten.as_strided, @@ -1021,10 +1013,12 @@ def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr): register_onednn_fusion_ops() -def fallback_handler(kernel): +def fallback_handler(kernel, inps_hook=None): fallbacks.add(kernel) def handler(*args, **kwargs): + if inps_hook is not None: + args, kwargs = inps_hook(*args, **kwargs) return pytree.tree_map( TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs) ) @@ -1032,7 +1026,7 @@ def handler(*args, **kwargs): return handler -def make_fallback(kernel, layout_constraint=None): +def make_fallback(kernel, inps_hook=None): assert ( kernel not in decompositions ), f"both a fallback and a decomp for same kernel: {kernel}" @@ -1042,9 +1036,9 @@ def make_fallback(kernel, layout_constraint=None): ) add_needs_realized_inputs(kernel) - if layout_constraint is not None: - add_layout_constraint(kernel, layout_constraint) - return register_lowering(kernel, type_promotion_kind=None)(fallback_handler(kernel)) + return register_lowering(kernel, type_promotion_kind=None)( + fallback_handler(kernel, inps_hook) + ) @register_lowering(aten.native_dropout, type_promotion_kind=None) @@ -1195,14 +1189,72 @@ def inner_fn(index): ) -def require_dense(_, *args, **kwargs): +def conv_backward(*args, **kwargs): + # output striding complex and has a lot of build dependent options, + # take the output strides to determine what to set the inputs + with torch._subclasses.FakeTensorMode(): + args_fake, kwargs_fake = pytree.tree_map_only( + ir.IRNode, + lambda t: ir.ir_node_to_tensor(t, guard_shape=False), + (args, kwargs), + ) + output = aten.convolution_backward(*args_fake, **kwargs_fake) + + def constraints( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ): + out = ( + output[0] + if output[0] is not None + else output[1] + if output[1] is not None + else output[2] + ) + if out is not None: + stride_order = get_stride_order(out.stride()) + grad_output = ir.ExternKernel.require_stride_order( + grad_output, stride_order + ) + weight = ir.ExternKernel.require_stride_order(weight, stride_order) + # Only make input contiguous when it is necessary for the backwards computation + if output_mask[1]: + input = ir.ExternKernel.require_stride_order(input, stride_order) + + return ( + grad_output, + input, + weight, + bias_sizes, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + output_mask, + ), {} + + return constraints(*args, **kwargs) + + +def require_dense(*args, **kwargs): args, kwargs = pytree.tree_map_only( ir.IRNode, lambda t: ir.ExternKernel.require_stride1(t), (args, kwargs) ) return args, kwargs -def require_contiguous(_, *args, **kwargs): +def require_contiguous(*args, **kwargs): args, kwargs = pytree.tree_map_only( ir.IRNode, lambda t: ir.ExternKernel.require_contiguous(t), (args, kwargs) ) @@ -1212,42 +1264,26 @@ def require_contiguous(_, *args, **kwargs): if has_torchvision_roi_align(): make_fallback(torch.ops.torchvision.roi_align) - -def constrain_to_fx_strides(fx_node, *args, **kwargs): - def apply_constraint(arg, fx_arg): - if isinstance(arg, ir.IRNode): - stride_order = ir.get_stride_order(fx_arg.meta["val"].stride()) - return ir.ExternKernel.require_stride_order(arg, stride_order) - return arg - - args = [apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)] - kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} - return args, kwargs - - # TODO(jansel): we should implement decomps or lowerings for these # https://github.com/pytorch/torchdynamo/issues/327 make_fallback(aten._adaptive_avg_pool2d_backward, require_dense) -make_fallback(aten.convolution_backward, constrain_to_fx_strides) +make_fallback(aten.convolution_backward, inps_hook=conv_backward) make_fallback(aten._cudnn_rnn, require_dense) -make_fallback(aten._cudnn_rnn_backward, require_contiguous) -make_fallback(aten.cumsum, require_dense) -make_fallback(aten._embedding_bag, require_contiguous) -make_fallback(aten._embedding_bag_forward_only, require_contiguous) +make_fallback(aten._cudnn_rnn_backward, inps_hook=require_contiguous) +make_fallback(aten.cumsum, inps_hook=require_dense) +make_fallback(aten._embedding_bag, inps_hook=require_contiguous) +make_fallback(aten._embedding_bag_forward_only, inps_hook=require_contiguous) make_fallback(aten._fused_moving_avg_obs_fq_helper) make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) -make_fallback(aten.grid_sampler_2d_backward, require_dense) +make_fallback(aten.grid_sampler_2d_backward, inps_hook=require_dense) make_fallback(aten.randperm) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) -make_fallback(aten._thnn_fused_lstm_cell, require_dense) +make_fallback(aten._thnn_fused_lstm_cell, inps_hook=require_dense) make_fallback(aten.topk) -make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) -make_fallback(aten.upsample_bilinear2d_backward, require_dense) - - -add_layout_constraint(aten.convolution, constrain_to_fx_strides) +make_fallback(aten.upsample_bicubic2d_backward, inps_hook=require_contiguous) +make_fallback(aten.upsample_bilinear2d_backward, inps_hook=require_dense) @register_lowering(aten.convolution) From 496c8ae760bf646d7a45aad0c2e0320a67b66fd2 Mon Sep 17 00:00:00 2001 From: maxren Date: Mon, 21 Nov 2022 10:58:05 -0800 Subject: [PATCH 421/453] [xnnpack][lite-int] Handle Constant Data (#89445) Handling constant data for xnnpack delegation. This allows us to handle new modules like such: ``` class Module(torch.nn.Module): def __init__(self): super().__init__() self._constant = torch.ones(4, 4, 4) def forward(self, x): return x + self._constant ``` this is the precursor work to handling convolution, as we need to serialize constant data(weights) Differential Revision: [D41050349](https://our.internmc.facebook.com/intern/diff/D41050349/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89445 Approved by: https://github.com/digantdesai --- test/jit/xnnpack/test_xnnpack_delegate.py | 28 +++++++++++++++ .../xnnpack/compiler/xnn_compiler.cpp | 13 ++++--- .../xnnpack/serialization/serializer.cpp | 35 +++++++++++-------- .../xnnpack/serialization/serializer.h | 17 ++++++--- .../xnnpack/xnnpack_graph_builder.cpp | 26 +++++++++++--- 5 files changed, 89 insertions(+), 30 deletions(-) diff --git a/test/jit/xnnpack/test_xnnpack_delegate.py b/test/jit/xnnpack/test_xnnpack_delegate.py index 997cc757e629..c54d9ba1b088 100644 --- a/test/jit/xnnpack/test_xnnpack_delegate.py +++ b/test/jit/xnnpack/test_xnnpack_delegate.py @@ -8,6 +8,34 @@ torch.ops.load_library("//caffe2:xnnpack_backend") class TestXNNPackBackend(unittest.TestCase): + def test_xnnpack_constant_data(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self._constant = torch.ones(4, 4, 4) + + def forward(self, x): + return x + self._constant + + scripted_module = torch.jit.script(Module()) + + lowered_module = torch._C._jit_to_backend( + "xnnpack", + scripted_module, + { + "forward": { + "inputs" : [torch.randn(4, 4, 4)], + "outputs": [torch.randn(4, 4, 4)] + } + } + ) + + for i in range(0, 20): + sample_input = torch.randn(4, 4, 4) + actual_output = scripted_module(sample_input) + expected_output = lowered_module(sample_input) + self.assertTrue(torch.allclose(actual_output, expected_output, atol=1e-03, rtol=1e-03)) + def test_xnnpack_lowering(self): class Module(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp index 0f654dff0ac0..a64bf35431fd 100644 --- a/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp +++ b/torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp @@ -42,24 +42,23 @@ void XNNCompiler::compileModel( case fb_xnnpack::XValueUnion::XNNTensorValue: { auto tensor_value = value->xvalue_as_XNNTensorValue(); - const void* data_ptr = nullptr; - auto buffer_idx = tensor_value->constant_buffer_idx(); - if (buffer_idx != 0) { - // TODO: @maxren implement data handling - TORCH_CHECK(false, "Constant data handling not yet implemented") - } std::vector dims_data; for (auto dim : *tensor_value->dims()) { dims_data.push_back(static_cast(dim)); } uint32_t id = XNN_INVALID_VALUE_ID; + const auto& constant_buffer = *flatbuffer_graph->constant_buffer(); + auto buffer_idx = tensor_value->constant_buffer_idx(); + const auto buffer_ptr = buffer_idx == 0 + ? nullptr + : constant_buffer[buffer_idx]->storage()->data(); status = xnn_define_tensor_value( /*subgraph=*/subgraph_ptr, /*datatype=*/xnn_datatype_fp32, /*num_dims=*/tensor_value->num_dims(), /*dims=*/dims_data.data(), - /*data=*/data_ptr, + /*data=*/buffer_ptr, /*external_id=*/tensor_value->external_id(), /*flags=*/tensor_value->flags(), /*id_out=*/&id); diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp index 63cb62c5698e..637f7cdf4c52 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp @@ -24,26 +24,33 @@ void XNNSerializer::serializeAddNode( _nodes.push_back(flatbufferNode); } +size_t XNNSerializer::serializeData(const uint8_t* data_ptr, size_t num_bytes) { + size_t constant_buffer_idx = 0; + // Handling the tensor _values with data + if (data_ptr != nullptr) { + // steps: + // 1. creating flatbuffer byte-vector for tensor data + auto storage = _builder.CreateVector(data_ptr, num_bytes); + + // 2. put it in the common buffer + constant_buffer_idx = _constantBuffer.size(); + _constantBuffer.emplace_back(CreateBuffer(_builder, storage)); + + // 3. record size into bufferSizes + _bufferSizes.push_back(num_bytes); + assert(_bufferSizes.size() == _constantBuffer.size()); + } + return constant_buffer_idx; +} + void XNNSerializer::serializeTensorValue( uint32_t xnn_datatype, size_t num_dims, std::vector dims, - void* data, + size_t data_buffer_idx, uint32_t external_id, uint32_t flags, uint32_t id_out) { - // we will reserve buffers without data to index 0 - int constant_buffer_idx = 0; - // Handling the tensor _values with data - // TODO @maxren fill out when handling tensors with data - if (data != nullptr) { - assert(false); // not supported yet - // steps: - // 1. creating buffer to store the 16 bit aligned data - // 2. increment buffer_idx, to reflect no buffer being added - // 3. record size into bufferSizes - } - std::vector serialized_dims; serialized_dims.reserve(dims.size()); for (auto dim : dims) { @@ -55,7 +62,7 @@ void XNNSerializer::serializeTensorValue( XNNDatatype(xnn_datatype), num_dims, &serialized_dims, - constant_buffer_idx, + data_buffer_idx, external_id, flags, id_out); diff --git a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h index 08a3875d3267..5a683c3dc323 100644 --- a/torch/csrc/jit/backends/xnnpack/serialization/serializer.h +++ b/torch/csrc/jit/backends/xnnpack/serialization/serializer.h @@ -17,15 +17,18 @@ class XNNSerializer { public: // Constructors // initial buffersize of 1024 which will grow - // automatically + // automatically, constant buffer and buffer sizes initialized with dummy + // values as 0 index is reserved for non-constant tensors XNNSerializer() : XNNSerializer(1024) {} explicit XNNSerializer(size_t bufferSize) : _builder(bufferSize), _nodes(), _values(), - _constantBuffer(), - _bufferSizes() {} + _constantBuffer({CreateBuffer( + _builder, + {})}), // index 0 is reserved for non-const data + _bufferSizes({0}) {} // Serializing Nodes @@ -43,7 +46,7 @@ class XNNSerializer { uint32_t xnn_datatype, size_t num_dims, std::vector dims, - void* data, + size_t buffer_data_idx, uint32_t external_id, uint32_t flags, uint32_t id_out); @@ -54,6 +57,12 @@ class XNNSerializer { std::vector output_ids, size_t num_extern_ids); + // decoupled data serialization with tensor values. This way constant tensor + // data can be referenced by multiple intermediate tensors. This call + // serializes the num_bytes of the data_ptr and returns the index it was + // placed in. + size_t serializeData(const uint8_t* data_ptr, size_t num_bytes); + private: // xnnpack version we are serializing const char* _version_sha1 = "ae108ef49aa5623b896fc93d4298c49d1750d9ba"; diff --git a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp index 45a4bd2fa795..7c7bb2d02e4c 100644 --- a/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp +++ b/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp @@ -225,6 +225,22 @@ void XNNGraph::defineAllTensorValues() { // update flag for if tensor is either graph input/output uint32_t flags = 0; + // Check if value was produced by prim::Constant + void* value_data = nullptr; + size_t buffer_idx = 0; + size_t num_bytes = 0; + if (val->node()->kind() == prim::Constant) { + c10::optional constant = val->node()->t(attr::value); + auto const_val = constant->toIValue().toTensor(); + // Need tensor data to be contiguous for serialization + auto cont_const_val = const_val.contiguous(); + value_data = cont_const_val.data_ptr(); + + num_bytes = const_val.storage().nbytes(); + buffer_idx = _serializer.serializeData( + static_cast(value_data), num_bytes); + } + if (isGraphInput(val) || isGraphOutput(val)) { if (isGraphInput(val)) { flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT; @@ -239,21 +255,21 @@ void XNNGraph::defineAllTensorValues() { /*datatype=*/xnn_datatype_fp32, /*num_dims=*/num_dims, /*dims=*/tensor_shape.data(), - /*data=*/nullptr, // currently no constant data + /*data=*/value_data, /*external_id=*/ext_id, /*flags=*/flags, /*id_out=*/&id); + TORCH_CHECK( + status == xnn_status_success, + "failed to define xnn_tensor_id for: " + val->debugName()); _serializer.serializeTensorValue( xnn_datatype_fp32, num_dims, tensor_shape, - nullptr, + buffer_idx, ext_id, flags, id); - TORCH_CHECK( - status == xnn_status_success, - "failed to define xnn_tensor_id for: " + val->debugName()); _val_to_ids.insert({val, id}); } } From 82713a1cc4589f084ecbcb591d1f9b12570cac43 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 22 Nov 2022 02:23:21 +0000 Subject: [PATCH 422/453] [inductor][compilation time] Fallback when kernel size for avg/max pool is large (#89448) This fixes compilation time for yolov3 from 400 seconds to 48 seconds. yolov3 has a 13x13 max_pool2d kernel, which was creating really large Triton code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89448 Approved by: https://github.com/ngimel --- test/inductor/test_torchinductor.py | 93 +++++++++++++++++++++++++++++ torch/_dynamo/utils.py | 4 +- torch/_inductor/codecache.py | 14 ++++- torch/_inductor/lowering.py | 64 ++++++++++++++++++++ 4 files changed, 173 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0d28f156ecc0..0aaf74886c7c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1920,6 +1920,7 @@ def fn(x): self.common( fn, (torch.randn(2, 4, 16, 16),), + check_lowp=False, ) # lowering to avg_pool2d case @@ -1934,6 +1935,19 @@ def fn(x): (torch.randn(2, 4, 6, 6),), ) + def test_adaptive_avg_pool2d2(self): + # Big kernel size, use fallback + def fn(x): + return aten._adaptive_avg_pool2d(x, (4, 4)) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (torch.randn(2, 4, 21, 21),), + check_lowp=False, + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_max_pool2d1(self): def fn(x): return aten.max_pool2d_with_indices(x, [3, 3], [2, 2]) @@ -1981,6 +1995,18 @@ def fn(x): (torch.randn([16, 64, 55, 55]),), ) + def test_max_pool2d6(self): + # Too big kernel size, use fallback + def fn(x): + return aten.max_pool2d_with_indices(x, [13, 13], []) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (torch.randn([16, 64, 55, 55]),), + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_avg_pool2d1(self): def fn(x): return aten.avg_pool2d(x, [3, 3], [2, 2]) @@ -2035,6 +2061,18 @@ def fn(x): (-torch.arange(1 * 8 * 8, dtype=torch.float32).view(1, 1, 8, 8),), ) + def test_avg_pool2d7(self): + # Large kernel size, use fallback + def fn(x): + return aten.avg_pool2d(x, [13, 13], [1, 1], [0, 0]) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + (-torch.arange(1 * 24 * 24, dtype=torch.float32).view(1, 1, 24, 24),), + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) + def test_alexnet_prefix(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( @@ -3936,6 +3974,7 @@ def fn(a, b, c): a, b, [5, 5], [1, 1], [2, 2], [1, 1], False, c ) + torch._inductor.metrics.generated_kernel_count = 0 x = torch.randn([2, 64, 3, 4]) result, indices = aten.max_pool2d_with_indices( x, @@ -3953,6 +3992,34 @@ def fn(a, b, c): indices, ], ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_max_pool2d_with_indices_backward5(self): + # Window size is too big. Should fallback + def fn(a, b, c): + return aten.max_pool2d_with_indices_backward( + a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c + ) + + torch._inductor.metrics.generated_kernel_count = 0 + x = torch.randn([2, 64, 20, 20]) + result, indices = aten.max_pool2d_with_indices( + x, + [13, 13], + [1, 1], + 2, + 1, + False, + ) + self.common( + fn, + [ + torch.randn_like(result), + x, + indices, + ], + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) def test_avg_pool2d_backward(self): def fn(a, b): @@ -4009,6 +4076,7 @@ def fn(a, b): None, ) + torch._inductor.metrics.generated_kernel_count = 0 self.common( fn, [ @@ -4016,6 +4084,31 @@ def fn(a, b): torch.randn([1, 2016, 21, 21]), ], ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + def test_avg_pool2d_backward4(self): + def fn(a, b): + return aten.avg_pool2d_backward( + a, + b, + [13, 13], + [1, 1], + [0, 0], + True, + False, + None, + ) + + torch._inductor.metrics.generated_kernel_count = 0 + self.common( + fn, + [ + torch.randn([1, 16, 12, 12]), + torch.randn([1, 16, 24, 24]), + ], + check_lowp=False, + ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 0) def test_mm_views(self): def fn(a, b): diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 889bb5683b6b..cbf5a0b46148 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -88,7 +88,9 @@ def time_wrapper(*args, **kwargs): compilation_metrics[key] = [] t0 = time.time() r = func(*args, **kwargs) - compilation_metrics[key].append(time.time() - t0) + latency = time.time() - t0 + # print(f"Dynamo timer: key={key}, latency={latency:.2f} sec") + compilation_metrics[key].append(latency) return r return time_wrapper diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 232a611b06c6..c020ff52f3af 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -22,7 +22,6 @@ import torch from torch.utils import cpp_extension - from . import config, cuda_properties, exc LOCK_TIMEOUT = 600 @@ -449,17 +448,30 @@ def _load_kernel(source_code): return kernel +def _load_kernel_name(source_code): + return TritonCodeCache.get_name(PyCodeCache.load(source_code)) + + class TritonFuture: def __init__(self, source_code, future): self.source_code = source_code self.future = future + # @dynamo_utils.dynamo_timed def result(self): + t0 = time() if hasattr(self, "kernel"): return self.kernel # If the worker failed this will throw an exception. self.future.result() kernel = self.kernel = _load_kernel(self.source_code) + latency = time() - t0 + if latency > 50: + name = _load_kernel_name(self.source_code) + log.warning( + f"Detected long compilation time of {latency} seconds for kernel name {name}" + ) + log.warning(self.source_code) del self.source_code, self.future return kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 80743a563e73..221f064e2e73 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2580,6 +2580,9 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): return x_out, ceil_mode +fallback_max_pool2d_with_indices = fallback_handler(aten.max_pool2d_with_indices) + + @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) def max_pool2d_with_indices( x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False @@ -2608,6 +2611,13 @@ def max_pool2d_with_indices( x_loader = x.make_loader() new_size = list(batch) + [h_out, w_out] + window_size = kernel_size[0] * kernel_size[1] + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode + ) def fn(idx, return_index): *prefix, bh, bw = idx @@ -2645,6 +2655,11 @@ def fn(idx, return_index): return r1, r2 +fallback_max_pool2d_with_indices_backward = fallback_handler( + aten.max_pool2d_with_indices_backward +) + + @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None) def max_pool2d_with_indices_backward( grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices @@ -2685,6 +2700,14 @@ def max_pool2d_with_indices_backward( ] ) + window_size = h_window_size * w_window_size + + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_max_pool2d_with_indices_backward( + grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices + ) + def fn(idx): *prefix, h, w = idx index_test = ops.index_expr(h * width + w, torch.int32) @@ -2807,6 +2830,9 @@ def fn_sum(idx, loader): return fn_sum +fallback_adaptive_avg_pool2d = fallback_handler(aten._adaptive_avg_pool2d) + + @register_lowering(aten._adaptive_avg_pool2d) def _adaptive_avg_pool2d(x, output_size): assert isinstance(x, TensorBox) @@ -2846,6 +2872,11 @@ def end_index(index, out_dim, inp_dim): w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in) w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in) + window_size = h_kernel_max * w_kernel_max + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_adaptive_avg_pool2d(x, output_size) + fn_sum = _adaptive_pooling_idx_sum( [h_kernel_max, w_kernel_max], [h_start_index, w_start_index], @@ -2916,6 +2947,9 @@ def fn(idx): return rv +fallback_avg_pool2d = fallback_handler(aten.avg_pool2d) + + @register_lowering(aten.avg_pool2d, type_promotion_kind=None) def avg_pool2d( x, @@ -2953,6 +2987,19 @@ def avg_pool2d( new_size = list(batch) + [h_out, w_out] dtype = x.get_dtype() + window_size = kernel_size[0] * kernel_size[1] + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d( + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + def fn_sum(idx, loader): *prefix, bh, bw = idx total = None @@ -2992,6 +3039,9 @@ def fn(idx): return rv +fallback_avg_pool2d_backward = fallback_handler(aten.avg_pool2d_backward) + + @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None) def avg_pool2d_backward( grad_output, @@ -3045,6 +3095,20 @@ def avg_pool2d_backward( ] ) + window_size = h_window_size * w_window_size + if window_size > 25: + # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback. + return fallback_avg_pool2d_backward( + grad_output, + x, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + def compute_pool_size_without_padding(ph, pw): """ This computes the scaling factor that we will divide an element From 00b9473ad68da319a1dc3f655cc1a97490ae9669 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 22 Nov 2022 03:05:50 +0000 Subject: [PATCH 423/453] [PT-D][Tensor Parallelism][2/N] Sync TP API change to PT prod (#89467) This is part of TP Beta Release efforts. ref: https://github.com/pytorch/tau/issues/576 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89467 Approved by: https://github.com/wanchaol --- .../_tensor/parallel/test_2d_parallel.py | 27 +- .../_tensor/parallel/test_parallelize_api.py | 4 +- .../_tensor/parallel/test_tp_examples.py | 26 +- .../_tensor/parallel/test_tp_style.py | 8 +- .../distributed/_tensor/parallel/__init__.py | 30 +- torch/distributed/_tensor/parallel/api.py | 326 ++++++++++++------ torch/distributed/_tensor/parallel/style.py | 8 +- torch/distributed/_tensor/parallel/utils.py | 12 +- 8 files changed, 280 insertions(+), 161 deletions(-) diff --git a/test/distributed/_tensor/parallel/test_2d_parallel.py b/test/distributed/_tensor/parallel/test_2d_parallel.py index 7a3779c296c3..ea41d5388660 100644 --- a/test/distributed/_tensor/parallel/test_2d_parallel.py +++ b/test/distributed/_tensor/parallel/test_2d_parallel.py @@ -16,6 +16,10 @@ Shard, Replicate, ) +from torch.distributed._tensor.parallel import ( + PairwiseParallel, + parallelize_module, +) import torch.distributed.distributed_c10d as distributed_c10d @@ -32,17 +36,6 @@ TP_DEGREE = 2 LR = 3e-5 -OPS_NOT_SHARD = [ - "net3.weight", - "net3.bias", -] - -SHARD_PARAMS = [ - "net1.weight", - "net1.bias", - "net2.weight", -] - class SimpleModel(torch.nn.Module): def __init__(self): @@ -108,10 +101,9 @@ def shard_module(m, pg): m.net2 = _aggregate_local_tensor(m.net2) -def _shard_wrap_module(module, module_shard, fsdp_wrap, tp_pg, fsdp_pg): +def _shard_wrap_module(module, module_shard, fsdp_wrap, mesh_2d, fsdp_pg): if module_shard: - # Fetch the module sharding planner. - shard_module(module, tp_pg) + parallelize_module(module, mesh_2d, PairwiseParallel(), tp_mesh_dim=1) if fsdp_wrap and module_shard: return FSDP(module, process_group=fsdp_pg) @@ -134,11 +126,10 @@ def init_model(model_parallel_size=TP_DEGREE): ) fsdp_pg = twod_mesh.get_dim_groups()[0] - tp_pg = twod_mesh.get_dim_groups()[1] # Create Input - model = _shard_wrap_module(model, True, True, tp_pg, fsdp_pg) - return model, tp_pg, fsdp_pg + model = _shard_wrap_module(model, True, True, twod_mesh, fsdp_pg) + return model, fsdp_pg def is_nested_tensor(val: Any) -> bool: @@ -200,7 +191,7 @@ def test_2d_fsdp_integration_correctness(self) -> None: model = SimpleModel().cuda(self.rank) model = FSDP(model) torch.manual_seed(0) - model_2d, _, dp_pg = init_model() + model_2d, dp_pg = init_model() optim = torch.optim.Adam(model.parameters(), lr=0.0001) optim_2d = torch.optim.Adam(model_2d.parameters(), lr=0.0001) diff --git a/test/distributed/_tensor/parallel/test_parallelize_api.py b/test/distributed/_tensor/parallel/test_parallelize_api.py index fb3e8f4721c8..036f4ef79a49 100644 --- a/test/distributed/_tensor/parallel/test_parallelize_api.py +++ b/test/distributed/_tensor/parallel/test_parallelize_api.py @@ -86,7 +86,7 @@ def test_parallelize_mlp(self): device_mesh = DeviceMesh( self.device_type, torch.arange(self.world_size) ) - _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) + model_tp = _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) # Ensure the parameter is properly distributed. self.assertEqual( @@ -125,7 +125,7 @@ def __init__(self) -> None: _parallelize_mlp(model_tp, device_mesh, DummyParallel()) with self.assertRaisesRegex( - RuntimeError, "We only support even number of Linear for MLP." + RuntimeError, "More than one nn.Linear needed for a MLP." ): _parallelize_mlp( torch.nn.Linear(10, 5), device_mesh, PairwiseParallel() diff --git a/test/distributed/_tensor/parallel/test_tp_examples.py b/test/distributed/_tensor/parallel/test_tp_examples.py index 696171e4ca88..74cd44dfd57d 100644 --- a/test/distributed/_tensor/parallel/test_tp_examples.py +++ b/test/distributed/_tensor/parallel/test_tp_examples.py @@ -11,18 +11,14 @@ skip_unless_torch_gpu, ) from torch.distributed._tensor import ( - distribute_module, DeviceMesh, Replicate, ) from torch.distributed._tensor.parallel import ( + PairwiseParallel, TensorParallelMultiheadAttention, - tp_shard_self_attn, - replicate_input, - replicate_output, + parallelize_module, ) -from torch.distributed._tensor.parallel import PairwiseParallel -from torch.distributed._tensor.parallel.api import _parallelize_mlp class MLPModule(torch.nn.Module): @@ -70,7 +66,7 @@ def test_mlp_megatron_e2e(self): self.device_type, torch.arange(0, NUM_DEVICES), ) - _parallelize_mlp(model_tp, device_mesh, PairwiseParallel()) + model_tp = parallelize_module(model_tp, device_mesh, PairwiseParallel()) optim = torch.optim.SGD(model.parameters(), lr=LR) optim_tp = torch.optim.SGD(model_tp.parameters(), lr=LR) @@ -182,13 +178,7 @@ def test_self_attn_megatron_e2e(self): # Shard module and initialize optimizer. device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) - distribute_module( - model_tp, - device_mesh, - partition_fn=tp_shard_self_attn, - input_fn=replicate_input, - output_fn=replicate_output, - ) + parallelize_module(model_tp, device_mesh, PairwiseParallel()) device_mesh = model_tp.qkv.weight.device_mesh replicate = [Replicate()] * device_mesh.ndim @@ -339,13 +329,7 @@ def test_self_attn_replacement_megatron_e2e(self): # Shard module and initialize optimizer. device_mesh = DeviceMesh(self.device_type, list(range(NUM_DEVICES))) - distribute_module( - model_tp, - device_mesh, - partition_fn=tp_shard_self_attn, - input_fn=replicate_input, - output_fn=replicate_output, - ) + parallelize_module(model_tp, device_mesh, PairwiseParallel()) device_mesh = model_tp.attn.qkv.weight.device_mesh replicate = [Replicate()] * device_mesh.ndim diff --git a/test/distributed/_tensor/parallel/test_tp_style.py b/test/distributed/_tensor/parallel/test_tp_style.py index 314fe470955b..e52aef1a6f3f 100644 --- a/test/distributed/_tensor/parallel/test_tp_style.py +++ b/test/distributed/_tensor/parallel/test_tp_style.py @@ -64,7 +64,9 @@ def test_make_input_shard_1d(self): def _test_prepare_output( self, func, spec, dim=None, device_mesh_input_none=False ): - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size) + ) tensor = torch.rand(8, 16, device=self.device_type) dtensor = distribute_tensor(tensor, device_mesh, spec) device_mesh_input = None if device_mesh_input_none else device_mesh @@ -135,7 +137,9 @@ def test_make_output_tensor(self): # Common logic for testing prepare output funcs errors. def _test_prepare_output_error(self, func): tensor = torch.rand(8, 16, device=self.device_type) - device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) + device_mesh = DeviceMesh( + self.device_type, torch.arange(self.world_size) + ) dtensor = distribute_tensor(tensor, device_mesh, [Shard(0)]) output = [dtensor] with self.assertRaisesRegex( diff --git a/torch/distributed/_tensor/parallel/__init__.py b/torch/distributed/_tensor/parallel/__init__.py index 0ef0e8ff0b9e..3c72143f345f 100644 --- a/torch/distributed/_tensor/parallel/__init__.py +++ b/torch/distributed/_tensor/parallel/__init__.py @@ -3,20 +3,32 @@ TensorParallelMultiheadAttention, ) -from torch.distributed._tensor.parallel.api import ( - tp_shard_self_attn, - replicate_input, - replicate_output, -) - from torch.distributed._tensor.parallel.style import ( + ColwiseParallel, ParallelStyle, PairwiseParallel, RowwiseParallel, - ColwiseParallel, - make_input_shard_1d, make_input_replicate_1d, - make_output_shard_1d, + make_input_shard_1d, make_output_replicate_1d, + make_output_shard_1d, make_output_tensor, ) + +from torch.distributed._tensor.parallel.api import ( + parallelize_module, +) + +__all__ = [ + "ColwiseParallel", + "TensorParallelMultiheadAttention", + "ParallelStyle", + "PairwiseParallel", + "RowwiseParallel", + "make_input_replicate_1d", + "make_input_shard_1d", + "make_output_replicate_1d", + "make_output_tensor", + "make_output_shard_1d", + "parallelize_module", +] diff --git a/torch/distributed/_tensor/parallel/api.py b/torch/distributed/_tensor/parallel/api.py index 68d444882c4c..a7a896ebf859 100644 --- a/torch/distributed/_tensor/parallel/api.py +++ b/torch/distributed/_tensor/parallel/api.py @@ -1,106 +1,127 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import torch import torch.nn as nn -from typing import Sequence, Tuple +from typing import Union, Dict from torch.distributed._tensor import ( distribute_module, distribute_tensor, - DTensor, Shard, Replicate, DeviceMesh, - Placement, ) from torch.distributed._tensor.parallel import TensorParallelMultiheadAttention -from torch.distributed._tensor.parallel.style import ParallelStyle, PairwiseParallel +from torch.distributed._tensor.parallel.style import PairwiseParallel, ParallelStyle from torch.distributed._tensor.parallel.utils import _create_1d_device_mesh -def replicate_input( - inputs: Sequence[torch.Tensor], device_mesh: DeviceMesh -) -> Tuple[DTensor, ...]: - replicate = [Replicate()] * device_mesh.ndim - return tuple( - DTensor.from_local(tensor, device_mesh, replicate) for tensor in inputs - ) +__all__ = [ + "parallelize_module", +] -def replicate_output(output: DTensor, device_mesh: DeviceMesh) -> torch.Tensor: - if isinstance(output, DTensor): - replicate = [Replicate()] * output.device_mesh.ndim - # TODO: can the output be left incontiguous? - return ( - output.redistribute(output.device_mesh, replicate) - .to_local() - .contiguous() - ) +def parallelize_module( # type: ignore[return] + module: nn.Module, + device_mesh: DeviceMesh, + parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]], + tp_mesh_dim: int = 0, +) -> nn.Module: + """ + The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module + or sub_modules based on a parallelize_plan which contains the parallel_style + which indicates how user want the module or sub_module to be parallelized. + User can also specify different parallel_style per module fully qualifed name (FQN). + The API supports 2D parallelism natively by accepting an n-dimension device_mesh + and users just need to specify the dimension where we perform tensor parallelism on. + Args: + module (nn.Module): + :class:`nn.Module` object to be parallelized. + device_mesh (DeviceMesh): + :class:`DeviceMesh` object which describes the mesh topology + of devices for the DTensor. + parallelize_plan (Union[ParallelStyle, Dict[str, ParallelStyle]]): + The plan used to parallelize the module. It can be either a + :class:`ParallelStyle` object which contains how + we prepare input/output for Tensor Parallelism or it can be a + dict of module FQN and its corresponding :class:`ParallelStyle` object. + tp_mesh_dim (int): + the dimension of ``device_mesh`` where we perform + Tensor Parallelism on. -def tp_shard_self_attn( - name: str, module: nn.Module, device_mesh: DeviceMesh -) -> None: - col_wise_sharding: Sequence[Placement] = [Shard(0)] - row_wise_sharding: Sequence[Placement] = [Shard(1)] - replicate: Sequence[Placement] = [Replicate()] * device_mesh.ndim - - def _shard_self_attn_params(name: str, module: nn.Module) -> None: - if isinstance(module, nn.Linear): - if name == "qkv": - sharded_weight = nn.Parameter( - distribute_tensor( - module.weight, device_mesh, col_wise_sharding - ) + Return: + A :class:`nn.Module` object parallelized. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel + >>> + >>> # Define the module. + >>> m = Model(...) + >>> m = parallelize_module(m, PairwiseParallel()) + >>> + + .. warning:: + ``PairwiseParallel`` comes with constraints for now. If you need finer + granularity, you need to pass in a dict of module FQN and parallel style instead. + """ + + if device_mesh.ndim > 1: + device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) + + if isinstance(parallelize_plan, ParallelStyle): + if _is_mha_for_pairwise_parallel(module): + return _parallelize_multihead_attn(module, device_mesh) + elif _is_mlp_for_pairwise_parallel(module): + return _parallelize_mlp(module, device_mesh) + else: + for n, m in module.named_children(): + module.register_module( + n, parallelize_module(m, device_mesh, parallelize_plan) ) - module.register_parameter("weight", sharded_weight) - if module.bias is not None: - sharded_bias = nn.Parameter( - distribute_tensor( - module.bias, device_mesh, col_wise_sharding - ) - ) - module.register_parameter("bias", sharded_bias) - elif name == "proj": - sharded_weight = nn.Parameter( - distribute_tensor( - module.weight, device_mesh, row_wise_sharding - ) + return module + # TODO: Add parallelize linear logic when https://github.com/pytorch/tau/pull/624/ merged. + elif isinstance(parallelize_plan, dict): + for module_path, parallelize_style in parallelize_plan.items(): + sub_module = module.get_submodule(module_path) + module.register_module( # type: ignore[call-arg] # pyre-ignore[20] + parallelize_module( # type: ignore[arg-type] + module_path, sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6] ) - module.register_parameter("weight", sharded_weight) - if module.bias is not None: - replicated_bias = nn.Parameter( - distribute_tensor(module.bias, device_mesh, replicate) - ) - module.register_parameter("bias", replicated_bias) - - if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA - for n, m in module.named_children(): - _shard_self_attn_params(n, m) + ) + return module else: - for n, m in module.named_children(): # replace with TPMA - if isinstance(m, nn.MultiheadAttention): - tp_multi_head_attention = TensorParallelMultiheadAttention( - m.embed_dim, - m.num_heads, - device=torch.device(device_mesh.device_type), - tp_size=device_mesh.size(0), # group size on dim 0 - add_bias_kv=m.bias_k is not None, - ) - tp_multi_head_attention.copy(m) - module.register_module(n, tp_multi_head_attention) + raise RuntimeError( # pyre-ignore[7] + f"Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for parallelize_plan, {type(parallelize_plan)} found!" + ) -def _has_even_num_linears(module: nn.Module) -> bool: +def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool: """ - We traverse through all the children of the given module and count the - number of Linear module. If the number is even, we return True. + Check whether the mha module is the one can be handled for Pairwise parallel. + + Args: + module (nn.Module): + :class:``nn.Module`` object to be checked. + + Return: + A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not. + """ + return isinstance(module, TensorParallelMultiheadAttention) or isinstance( + module, nn.MultiheadAttention + ) + + +def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool: + """ + Traverse through all the immediate children of the given module and count the + number of Linear module. If the number is more than one, we return True. Args: module (nn.Module): :class:``nn.Module`` object to be traversed and counted. Return: - A boolean object which specifies whether the module contains - event-number of Linears in its children. + A boolean object which specifies whether the module is MLP or not. .. warning:: The traversal is not recursive for now. @@ -108,15 +129,66 @@ def _has_even_num_linears(module: nn.Module) -> bool: linear_submodules = list( filter(lambda x: isinstance(x, nn.Linear), module.children()) ) - return len(linear_submodules) > 0 and len(linear_submodules) % 2 == 0 + return len(linear_submodules) > 1 -def _parallelize_mlp( +def _rowwise_parallelize_linear_fn( + name: str, + module: nn.Module, + device_mesh: DeviceMesh, +) -> None: + """ + This function parallelizes the input :class:``nn.Linear`` module in :class:``RowwiseParallel`` style. + + Args: + name (str): name of the input module. + module (nn.Module): the :class:``nn.Linear`` object to be parallelized. + device_mesh (DeviceMesh): :class:``DeviceMesh`` object which describes the mesh topology + of devices for the DTensor. + + Return: + None + """ + for name, param in module.named_parameters(): + dist_spec = ( + [Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item] + ) + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, dist_spec) + ) + module.register_parameter(name, dist_param) + + +def _colwise_parallelize_linear_fn( + name: str, + module: nn.Module, + device_mesh: DeviceMesh, +) -> None: + """ + This function parallelizes the input :class:``nn.Linear`` module in :class:``ColwiseParallel`` style. + + Args: + name (str): name of the input module. + module (nn.Module): the :class:``nn.Linear`` object to be parallelized. + device_mesh (DeviceMesh): :class:``DeviceMesh`` object which describes the mesh topology + of devices for the DTensor. + + Return: + None + """ + for name, param in module.named_parameters(): + dist_param = torch.nn.Parameter( + distribute_tensor(param, device_mesh, [Shard(0)]) + ) + module.register_parameter(name, dist_param) + + +def _parallelize_multihead_attn( module: nn.Module, device_mesh: DeviceMesh, parallel_style: ParallelStyle = PairwiseParallel(), tp_mesh_dim: int = 0, -) -> None: +) -> nn.Module: """ This function assumes the input module is a sequence of nn.Linear and we parallelize the module based on the given parallel style. @@ -137,37 +209,90 @@ def _parallelize_mlp( Tensor Parallelism on. Return: - None + A :class:``nn.Module`` object parallelized. .. warning:: We only support ``PairwiseParallel`` right now. """ - # Define partition functions needed. - def _rowwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] - for name, param in module.named_parameters(): - dist_spec = ( - [Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item] - ) - dist_param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, dist_spec) - ) - module.register_parameter(name, dist_param) + if not isinstance(parallel_style, PairwiseParallel): + raise NotImplementedError( + "Only support PairwiseParallel for Multihead Attention parallelization." + ) + + if device_mesh.ndim > 1: + device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) + + if isinstance(module, nn.MultiheadAttention): + tp_multi_head_attention = TensorParallelMultiheadAttention( + module.embed_dim, + module.num_heads, + device=torch.device(device_mesh.device_type), + tp_size=device_mesh.size(tp_mesh_dim), + add_bias_kv=module.bias_k is not None, + ) + tp_multi_head_attention.copy(module) + module = tp_multi_head_attention + + if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA + for n, m in module.named_children(): + if n == "qkv": + # Col-wise Parallelize the qkv layer. + distribute_module( + m, + device_mesh, + _colwise_parallelize_linear_fn, + input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6] + ) + elif n == "proj": + # Row-wise Parallelize the proj layer + distribute_module( + m, + device_mesh, + _rowwise_parallelize_linear_fn, + output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6] + ) + return module - def _colwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] - for name, param in module.named_parameters(): - dist_param = torch.nn.Parameter( - distribute_tensor(param, device_mesh, [Shard(0)]) - ) - module.register_parameter(name, dist_param) +def _parallelize_mlp( + module: nn.Module, + device_mesh: DeviceMesh, + parallel_style: ParallelStyle = PairwiseParallel(), + tp_mesh_dim: int = 0, +) -> nn.Module: + """ + This function assumes the input module is a sequence of nn.Linear + and we parallelize the module based on the given parallel style. + We don't change the FQN of each sub-module and replace each parameter + in place. + + Args: + module (nn.Module): + :class:``nn.Module`` object to be parallelized. + device_mesh (DeviceMesh): + :class:``DeviceMesh`` object which describes the mesh topology + of devices for the DTensor. + parallel_style (ParallelStyle): + :class:``ParallelStyle`` object which contains how + we prepare input/output for Tensor Parallelism. + tp_mesh_dim (int): + the dimension of ``device_mesh`` where we perform + Tensor Parallelism on. + + Return: + A :class:``nn.Module`` object parallelized. + + .. warning:: + We only support ``PairwiseParallel`` right now. + """ if not isinstance(parallel_style, PairwiseParallel): raise NotImplementedError( "Only support PairwiseParallel for MLP parallelization." ) - if not _has_even_num_linears(module): - raise RuntimeError("We only support even number of Linear for MLP.") + if not _is_mlp_for_pairwise_parallel(module): + raise RuntimeError("More than one nn.Linear needed for a MLP.") if device_mesh.ndim > 1: device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim) @@ -175,13 +300,15 @@ def _colwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] linear_submodules = list( filter(lambda x: isinstance(x, nn.Linear), module.children()) ) - for i, m in enumerate(linear_submodules): + mlp_last_even_layer = (len(linear_submodules) // 2) * 2 + for i in range(mlp_last_even_layer): + m = linear_submodules[i] if i % 2 == 0: # Col-wise Parallelize the linear layer distribute_module( m, device_mesh, - _colwise_parallelize_fn, + _colwise_parallelize_linear_fn, input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6] if i == 0 else None, @@ -191,8 +318,9 @@ def _colwise_parallelize_fn(name, module, device_mesh): # pyre-ignore[2, 3] distribute_module( m, device_mesh, - _rowwise_parallelize_fn, + _rowwise_parallelize_linear_fn, output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6] - if i == (len(linear_submodules) - 1) + if i == (mlp_last_even_layer - 1) else None, ) + return module diff --git a/torch/distributed/_tensor/parallel/style.py b/torch/distributed/_tensor/parallel/style.py index 5ea96434118a..60b6a1c88dfd 100644 --- a/torch/distributed/_tensor/parallel/style.py +++ b/torch/distributed/_tensor/parallel/style.py @@ -5,8 +5,8 @@ from typing import Union, Optional from torch.distributed._tensor import DTensor, Shard, Replicate, DeviceMesh from torch.distributed._tensor.parallel.utils import ( - _Prepare_Input_Func_Type, - _Prepare_Output_Func_Type, + _PrepareInputType, + _PrepareOutputType, _prepare_input_validate, _prepare_output_validate, ) @@ -18,8 +18,8 @@ class ParallelStyle(ABC): Users can extend this class to build their own parallel style with customized input/output preparations. """ - _prepare_input: _Prepare_Input_Func_Type - _prepare_output: _Prepare_Output_Func_Type + _prepare_input: _PrepareInputType + _prepare_output: _PrepareOutputType @abstractmethod def __init__(self, _prepare_input, _prepare_output) -> None: diff --git a/torch/distributed/_tensor/parallel/utils.py b/torch/distributed/_tensor/parallel/utils.py index 2680ae41ffbe..c4cca5c88eda 100644 --- a/torch/distributed/_tensor/parallel/utils.py +++ b/torch/distributed/_tensor/parallel/utils.py @@ -4,18 +4,18 @@ from torch.distributed._tensor import DeviceMesh, DTensor from typing import Callable, Optional, Union -_Prepare_Input_Func_Type = Callable[ +_PrepareInputType = Callable[ [Union[torch.Tensor, DTensor], Optional[DeviceMesh], Optional[int]], DTensor ] -_Prepare_Output_Func_Type = Callable[ +_PrepareOutputType = Callable[ [DTensor, Optional[DeviceMesh], Optional[int]], Union[torch.Tensor, DTensor] ] def _prepare_input_validate( - _prepare_input_func: _Prepare_Input_Func_Type, -) -> _Prepare_Input_Func_Type: + _prepare_input_func: _PrepareInputType, +) -> _PrepareInputType: """ Inject common validation logics for `_prepare_input` funcs via this decorator, including verifying that input needs to be either @@ -66,8 +66,8 @@ def wrapper(*args, **kwargs): # pyre-ignore[2, 3] def _prepare_output_validate( - _prepare_output_func: _Prepare_Output_Func_Type, -) -> _Prepare_Output_Func_Type: + _prepare_output_func: _PrepareOutputType, +) -> _PrepareOutputType: """ Inject common validation logics for _prepare_output funcs via this decorator, including verifying that output needs to be a DTensor From 338f61904421bef1b46c9d614470b523c0696654 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 22 Nov 2022 03:38:53 +0000 Subject: [PATCH 424/453] [vision hash update] update the pinned vision hash (#89471) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/master/.github/workflows/_update-commit-hash.yml). Update the pinned vision hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89471 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/vision.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 80fe47b2cee2..30711c5bbfd9 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -5b4f79d9ba8cbeeb8d6f0fbba3ba5757b718888b +4a310f26049371959617921d0eb9b001f4d262c6 From ce342ed2d3a4a0dd8151abe80bfe0bb06a7b0ae9 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Tue, 22 Nov 2022 03:39:15 +0000 Subject: [PATCH 425/453] Fix retrying logic for successful unittest tests under --rerun-disabled-tests mode (#89454) When looking into Rockset data for disabled test unittest, for example `testAdd`, I see that it's re-run only 3 times instead of 50+ times as expected under rerun-disabled -test mode ``` [ { "name": "testAdd", "classname": "TestLazyReuseIr", "filename": "lazy/test_reuse_ir.py", "flaky": false, "num_green": 3, "num_red": 0 } ] ``` It turns out that I made a mistake mixing `RERUN_DISABLED_TESTS` and `report_only` into `(RERUN_DISABLED_TESTS or report_only) and num_retries_left < MAX_NUM_RETRIES` in https://github.com/pytorch/pytorch/pull/88646. The retrying logic for successful tests under rerun-disabled-tests mode is never executed because num_retries_left would be equal to MAX_NUM_RETRIES (not smaller) if the very first run successes. Thus, the sample test `testAdd` finishes right away (1 success count) * `report_only` and `RERUN_DISABLED_TESTS` are 2 different things and shouldn't be mixed together. RERUN_DISABLED_TESTS has the higher priority. * We also don't want to retry skipped tests under rerun-disabled-tests mode because they are only skipped due to `check_if_enable` check `Test is enabled but --rerun-disabled-tests verification mode is set, so only disabled tests are run` ### Testing * CI https://github.com/pytorch/pytorch/actions/runs/3518228784 generates https://gha-artifacts.s3.amazonaws.com/pytorch/pytorch/3518228784/1/artifact/test-reports-test-default-4-4-linux.4xlarge.nvidia.gpu_9627285587.zip in which `testAdd` is correctly called multiple times and `TestLazyReuseIr` is skipped correctly * Locally ``` # export CI=1 # export PYTORCH_RETRY_TEST_CASES=1 # export PYTORCH_OVERRIDE_FLAKY_SIGNAL=1 # export PYTORCH_TEST_RERUN_DISABLED_TESTS=1 $ python test/run_test.py --verbose -i lazy/test_reuse_ir Ignoring disabled issues: [] Selected tests: lazy/test_reuse_ir Prioritized test from test file changes. reordering tests for PR: prioritized: [] the rest: ['lazy/test_reuse_ir'] Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/slow-tests.json to /Users/huydo/Storage/mine/pytorch/test/.pytorch-slow-tests.json Downloading https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/disabled-tests-condensed.json to /Users/huydo/Storage/mine/pytorch/test/.pytorch-disabled-tests.json parallel (file granularity) tests: lazy/test_reuse_ir serial (file granularity) tests: Ignoring disabled issues: [] Ignoring disabled issues: [] Running lazy/test_reuse_ir ... [2022-11-21 13:21:07.165877] Executing ['/Users/huydo/miniconda3/envs/py3.9/bin/python', '-bb', 'lazy/test_reuse_ir.py', '-v', '--import-slow-tests', '--import-disabled-tests', '--rerun-disabled-tests'] ... [2022-11-21 13:21:07.166279] Expand the folded group to see the log file of lazy/test_reuse_ir ##[group]PRINTING LOG FILE of lazy/test_reuse_ir (/Users/huydo/Storage/mine/pytorch/test/test-reports/lazy-test_reuse_ir_6cf_dxa1) Running tests... ---------------------------------------------------------------------- Test results will be stored in test-reports/python-unittest/lazy.test_reuse_ir testAdd (__main__.TestLazyReuseIr) ... ok (1.215s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 50 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 49 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 48 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 47 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 46 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 45 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 44 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 43 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 42 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 41 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 40 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 39 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 38 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 37 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 36 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 35 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 34 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 33 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 32 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 31 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 30 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 29 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 28 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 27 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 26 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 25 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 24 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 23 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 22 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 21 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 20 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 19 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 18 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 17 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 16 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 15 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 14 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 13 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 12 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 11 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 10 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 9 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 8 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 7 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 6 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 5 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 4 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 3 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 2 ok (0.001s) testAdd (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 1 ok (0.001s) testAddSub (__main__.TestLazyReuseIr) ... testAdd succeeded - num_retries_left: 0 skip: Test is enabled but --rerun-disabled-tests verification mode is set, so only disabled tests are run (0.001s) testAddSubFallback (__main__.TestLazyReuseIr) ... skip: Test is enabled but --rerun-disabled-tests verification mode is set, so only disabled tests are run (0.001s) testBatchNorm (__main__.TestLazyReuseIr) ... skip: Test is enabled but --rerun-disabled-tests verification mode is set, so only disabled tests are run (0.001s) ---------------------------------------------------------------------- Ran 54 tests in 1.264s OK (skipped=3) ``` Here is the sample rockset query ``` WITH added_row_number AS ( SELECT *, ROW_NUMBER() OVER(PARTITION BY name, classname, filename ORDER BY _event_time DESC) AS row_number FROM commons.rerun_disabled_tests ) SELECT name, classname, filename, flaky, num_green, num_red FROM added_row_number WHERE row_number = 1 AND name = 'testAdd' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89454 Approved by: https://github.com/clee2000 --- torch/testing/_internal/common_utils.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index e53887a5fdbb..2c72296d1e30 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2165,13 +2165,22 @@ def _run_with_retry(self, result=None, num_runs_left=0, report_only=True, num_re result.addExpectedFailure(self, err) self._run_with_retry(result=result, num_runs_left=num_retries_left, report_only=report_only, num_red=num_red + 1, num_green=num_green) - elif (RERUN_DISABLED_TESTS or report_only) and num_retries_left < MAX_NUM_RETRIES: - # Always re-run up to MAX_NUM_RETRIES when running under report only or rerun disabled tests modes + elif RERUN_DISABLED_TESTS and num_retries_left <= MAX_NUM_RETRIES and not result.skipped: + # Always re-run up to MAX_NUM_RETRIES when running under rerun disabled tests modes if the test successes. + # The parameter num_retries_left can be equal to MAX_NUM_RETRIES here because num_runs_left is initially + # set to MAX_NUM_RETRIES + 1, i.e. the first run successes + # + # Also if the result is skipped, this is due to check_if_enable skipping non-disabled tests, thus we + # want to ignore them, not retrying and skipping multiple times print(f" {self._testMethodName} succeeded - num_retries_left: {num_retries_left}") - if RERUN_DISABLED_TESTS: - result.addSuccess(self) - else: - result.addUnexpectedSuccess(self) + result.addSuccess(self) + self._run_with_retry(result=result, num_runs_left=num_retries_left, report_only=report_only, + num_red=num_red, num_green=num_green + 1) + elif report_only and num_retries_left < MAX_NUM_RETRIES: + # The original logic here is that num_retries_left must be smaller than MAX_NUM_RETRIES indicating + # that at least one retry has been spent + print(f" {self._testMethodName} succeeded - num_retries_left: {num_retries_left}") + result.addUnexpectedSuccess(self) self._run_with_retry(result=result, num_runs_left=num_retries_left, report_only=report_only, num_red=num_red, num_green=num_green + 1) elif not report_only and num_retries_left < MAX_NUM_RETRIES: From 1dae59ba168fe3c4c11c102f935101c3e4f3b105 Mon Sep 17 00:00:00 2001 From: Iris Date: Tue, 22 Nov 2022 03:52:32 +0000 Subject: [PATCH 426/453] [Checkpoint][2D][1/N] Add dedup_tensors for distributed checkpoint to core distributed (#89399) This PR moves dedup_tensors and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint. This removes duplicated shards in list of SavePlan. It is used when saving DT with replicated placement. Docstring and comments will be added in the following PRs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89399 Approved by: https://github.com/wanchaol --- .../checkpoint/test_dedup_tensors.py | 45 +++++++++++++++++++ torch/distributed/checkpoint/dedup_tensors.py | 38 ++++++++++++++++ 2 files changed, 83 insertions(+) create mode 100644 test/distributed/checkpoint/test_dedup_tensors.py create mode 100644 torch/distributed/checkpoint/dedup_tensors.py diff --git a/test/distributed/checkpoint/test_dedup_tensors.py b/test/distributed/checkpoint/test_dedup_tensors.py new file mode 100644 index 000000000000..a0d72147efeb --- /dev/null +++ b/test/distributed/checkpoint/test_dedup_tensors.py @@ -0,0 +1,45 @@ +# Owner(s): ["oncall: distributed"] + +import dataclasses +import torch +from torch.distributed.checkpoint.dedup_tensors import dedup_tensors +from torch.distributed.checkpoint.planner import SavePlan, WriteItemType +from torch.distributed.checkpoint.planner_helpers import ( + _create_write_item_for_tensor, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +# TODO: add comments for create_plan +def create_plan(second_fqn) -> SavePlan: + # the first write item is for a duplicated shard (that covers the whole tensor) + write_item_1 = _create_write_item_for_tensor("tensor_0", torch.rand(4)) + write_item_1 = dataclasses.replace(write_item_1, type=WriteItemType.SHARD) + + # the second write item has different keys + write_item_2 = _create_write_item_for_tensor(second_fqn, torch.rand(10)) + + return SavePlan([write_item_1, write_item_2]) + + +# TODO: add comments for TestDedupTensor +class TestDedupTensor(TestCase): + def test_dedup_shards(self): + rank0 = create_plan("r0") + rank1 = create_plan("r1") + + dedup_plans = dedup_tensors([rank0, rank1]) + + self.assertEqual(2, len(dedup_plans[0].items)) + self.assertEqual(1, len(dedup_plans[1].items)) + + self.assertIn( + "tensor_0", (item.index.fqn for item in dedup_plans[0].items) + ) + self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items)) + + self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items)) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/checkpoint/dedup_tensors.py b/torch/distributed/checkpoint/dedup_tensors.py new file mode 100644 index 000000000000..4b60e49d3105 --- /dev/null +++ b/torch/distributed/checkpoint/dedup_tensors.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +from typing import Dict, List +import dataclasses + +from torch.distributed.checkpoint.metadata import MetadataIndex +from torch.distributed.checkpoint.planner import SavePlan + +__all__ = ["dedup_tensors"] + +# TODO add docstring for dedup_tensors +def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + all_plans = list(all_plans) + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for wi in plan.items: + key_to_plan.setdefault(wi.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove deplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [ + wi for wi in all_plans[plan_idx].items if wi.index not in key_set + ] + all_plans[plan_idx] = dataclasses.replace( + all_plans[plan_idx], items=new_items + ) + + return all_plans From e545caa50f3cd893ca0419543e57af08a7de85b5 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 22 Nov 2022 03:57:01 +0000 Subject: [PATCH 427/453] dynamo/torchxla integration: trace on xla rather than eager (#88904) In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because - we try 2 things together 1. let dynamo trace the model on xla rather than eager 2. enable training - It turns out neither of these two tasks are trivial enough. Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync. This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training. Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x. ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 1.38 | 1.008 | +-------------------------+--------------------+-------------------------+ | resnet50 | 1.227 | 0.998 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 1.544 | 1.008 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.085 | 1.045 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 2.028 | 1.013 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 1.516 | 0.995 | +-------------------------+--------------------+-------------------------+ | squeezenet1_1 | 0.868 | 1.01 | +-------------------------+--------------------+-------------------------+ | vgg16 | 1.099 | 1.008 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 3.26 | 1.027 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 2.182 | 1.015 | +-------------------------+--------------------+-------------------------+ | geomean | 1.50389 | 1.01261 | +-------------------------+--------------------+-------------------------+ ``` Example command ``` GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/88904 Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel --- .github/ci_commit_pins/xla.txt | 2 +- benchmarks/dynamo/common.py | 106 +++++++++--------- torch/_dynamo/optimizations/backends.py | 26 +---- .../optimizations/torchxla_integration.py | 62 +++++----- torch/_dynamo/utils.py | 20 +++- 5 files changed, 103 insertions(+), 113 deletions(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 6e29f8ee3c31..f680f0ddccb2 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -dd9b67ff0d6ba4da6a46ca1b22e35c98dbed0d77 +50855d7babfa7970cba18528c659989b91c83824 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 8731d545c456..3fad203c5d87 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -266,13 +266,35 @@ def print_summary(filename): pass +def tensor_is_on_xla(tensors): + if not isinstance(tensors, (tuple, list)): + tensors = [tensors] + return any(map(lambda x: x.device.type == "xla", tensors)) + + def timed(model, model_iter_fn, example_inputs, times=1, return_result=False): synchronize() + if tensor_is_on_xla(example_inputs): + import torch_xla.core.xla_model as xm + + xm.mark_step() + reset_rng_state() t0 = time.perf_counter() # Dont collect outputs to correctly measure timing for _ in range(times): result = model_iter_fn(model, example_inputs, collect_outputs=False) + if tensor_is_on_xla(result): + # If the model is on XLA device, it's possible that after running + # the model, the computation is accumulated but not performed yet. + # Flush all the accumulated computations to make the time measurement + # accurate. + import torch_xla + + result_list = result + if not isinstance(result, (tuple, list)): + result_list = [result] + torch_xla._XLAC._xla_sync_multi(result_list, []) synchronize() t1 = time.perf_counter() return (t1 - t0, result) if return_result else t1 - t0 @@ -384,6 +406,13 @@ def randomize_input(inputs): ) +def maybe_mark_step(args): + if args.trace_on_xla: + import torch_xla.core.xla_model as xm + + xm.mark_step() + + def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): """ Measure speedups over eager. @@ -398,9 +427,6 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): should_check_result = should_randomize_input = args.randomize_input is_correct = True - baseline_model_iter_fn = get_baseline_model_iter_fn(args, model_iter_fn) - baseline_model = get_baseline_model(args, model) - import contextlib @contextlib.contextmanager @@ -419,16 +445,25 @@ def maybe_profile(*args, **kwargs): if should_randomize_input else example_inputs ) + # need call mark_step to perform the computation + # on randomize_input. Otherwise the first call using the + # inputs will incur high penalty then the next one. + maybe_mark_step(args) # interleave the runs to handle frequency scaling and load changes timings[rep, 0], expected_output = timed( - baseline_model, baseline_model_iter_fn, inputs, return_result=True + model, model_iter_fn, inputs, return_result=True ) + + # call mark_step between the 2 calls to make the comparison fair. + maybe_mark_step(args) + timings[rep, 1], actual_output = timed( model, frozen_model_iter_fn, inputs, return_result=True ) if should_check_result: is_correct = is_correct and same(expected_output, actual_output) + if args.export_profiler_trace: name = args.profiler_trace_name + "_" + model.name + ".json" name = os.path.join(torch._dynamo.config.base_dir, name) @@ -843,56 +878,6 @@ def maybe_init_distributed(should_init_distributed, port="6789", rank=0, world_s torch.distributed.destroy_process_group() -def xla_wrapper(model_iter_fn): - """ - Wrap the model_iter_fn to run the model on XLA devices. - """ - - def wrapper(xla_mod, inputs, collect_outputs=True): - import torch_xla.core.xla_model as xm - - # Make sure the model is already moved to the xla device. Moving - # the model to xla device can be very expensive since model parameters - # need to be copied. We should not do that inside the wrapper since - # the wrapper will be calles for each set of inputs. - assert ( - next(xla_mod.parameters()).device.type == "xla" - ), "The model should be already on xla device" - - xla_dev = xm.xla_device() - eager_dev = inputs[0].device - xla_inputs = tree_map(lambda x: x.to(device=xla_dev), inputs) - xla_out = model_iter_fn(xla_mod, xla_inputs, collect_outputs) - if isinstance(xla_out, torch.Tensor): - return xla_out.to(device=eager_dev) - elif hasattr(xla_out, "__dict__"): - for k in xla_out.__dict__.keys(): - if xla_out.__dict__[k] is None: - continue - xla_out.__dict__[k] = tree_map( - lambda x: x.to(device=eager_dev), xla_out.__dict__[k] - ) - return xla_out - else: - raise RuntimeError(f"Can not handle type {type(xla_out)}") - - return wrapper - - -def get_baseline_model_iter_fn(args, model_iter_fn): - return xla_wrapper(model_iter_fn) if args.use_xla_baseline else model_iter_fn - - -def get_baseline_model(args, model): - if args.use_xla_baseline: - import torch_xla.core.xla_model as xm - - xla_dev = xm.xla_device() - return copy.deepcopy(model).to(device=xla_dev) - else: - return model - - class BenchmarkRunner: def __init__(self): self.model_iter_fn = None @@ -1544,9 +1529,9 @@ def get_example_inputs(self): help="Disables cudagraphs for Inductor", ) parser.add_argument( - "--use-xla-baseline", + "--trace-on-xla", action="store_true", - help="Whether to run baseline on XLA devices or eager devices", + help="Whether to trace the model on XLA or on eager device", ) group_fuser = parser.add_mutually_exclusive_group() @@ -1995,6 +1980,15 @@ def run(runner, args, original_dir=None): logging.warn(f"{args.only} failed to load") continue # bad benchmark implementation + if args.trace_on_xla: + import torch_xla.core.xla_model as xm + + xla_dev = xm.xla_device() + model = model.to(device=xla_dev) + example_inputs = tree_map( + lambda x: x.to(device=xla_dev), example_inputs + ) + current_name = name current_device = device current_batch_size = batch_size diff --git a/torch/_dynamo/optimizations/backends.py b/torch/_dynamo/optimizations/backends.py index 55974c69d76e..e97940b7311f 100644 --- a/torch/_dynamo/optimizations/backends.py +++ b/torch/_dynamo/optimizations/backends.py @@ -785,33 +785,9 @@ def ltc_model(*inputs): return ltc_model -@functools.lru_cache(None) -def _init_torchxla(): - global xm - try: - import torch_xla.core.xla_model as xm - except ModuleNotFoundError as e: - print(f"torchxla backend fails. Can not import {e.name}") - raise - - @create_backend def torchxla_trivial(subgraph): - _init_torchxla() - - xla_dev = xm.xla_device() - - xla_model = copy.deepcopy(subgraph.model).to(device=xla_dev) - - def xla_model_wrapper(*inputs): - orig_device = inputs[0].device if len(inputs) > 0 else "cpu" - xla_inputs = tuple(inp.to(device=xla_dev) for inp in inputs) - - xla_out = xla_model(*xla_inputs) - result = tuple(out.to(device=orig_device) for out in xla_out) - return result - - return xla_model_wrapper + return subgraph.model @create_backend diff --git a/torch/_dynamo/optimizations/torchxla_integration.py b/torch/_dynamo/optimizations/torchxla_integration.py index d3cac23e7c4b..f93e4d385ad8 100644 --- a/torch/_dynamo/optimizations/torchxla_integration.py +++ b/torch/_dynamo/optimizations/torchxla_integration.py @@ -1,7 +1,7 @@ -import copy import dataclasses import functools +import itertools import os import time from typing import Any, Dict, List @@ -19,7 +19,7 @@ class GraphInputMatcher: arguments for the current call. tensor_id_to_arg_idx maps the tensor id to the parameter index. - graph_input_tensor_ids, graph_input_ivalues list the tensor_id and ivalue for each of the + graph_input_tensor_ids, graph_input_xla_values list the tensor_id and ivalue for each of the TS/XLA graph inputs. """ @@ -30,17 +30,17 @@ class GraphInputMatcher: # most likely const tensors and we can get its content from graph_input_tensors # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get # the tensor from method arguments - graph_input_ivalues: List[Any] + graph_input_xla_values: List[Any] # get the real graph input tensors def __call__(self, args): real_input = [] - for tensor_id, traced_ivalue in zip( - self.graph_input_tensor_ids, self.graph_input_ivalues + for tensor_id, traced_xla_value in zip( + self.graph_input_tensor_ids, self.graph_input_xla_values ): arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) if arg_idx is None: - inp = traced_ivalue + inp = traced_xla_value else: inp = args[arg_idx] real_input.append(inp) @@ -73,12 +73,25 @@ def import_torchxla(): import torch_xla.debug.metrics as metrics -def extract_compiled_graph(model: torch.fx.GraphModule, example_inputs): +def is_xla_tensor(tensor: torch.Tensor) -> bool: + return tensor.device.type == "xla" + + +def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): import_torchxla() - orig_device = example_inputs[0].device - xla_dev = xm.xla_device() - xla_model = copy.deepcopy(model).to(device=xla_dev) - xla_args = [arg.to(device=xla_dev) for arg in example_inputs] + + assert all( + map( + is_xla_tensor, + filter( + lambda x: isinstance(x, torch.Tensor), + itertools.chain(xla_model.parameters(), xla_args), + ), + ) + ), "All tensors should be on xla" + + # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids + xm.mark_step() args_tensor_ids = [ torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in xla_args ] @@ -88,6 +101,7 @@ def extract_compiled_graph(model: torch.fx.GraphModule, example_inputs): tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} xla_out = xla_model(*xla_args) + fallback_ops = get_fallback_ops() if len(fallback_ops) > 0: raise RuntimeError( @@ -121,28 +135,29 @@ def extract_compiled_graph(model: torch.fx.GraphModule, example_inputs): ( graph_input_tensor_ids, - graph_input_ivalues, + graph_input_xla_values, ) = torch_xla._XLAC._get_tensors_xla_device_data_node(args_and_out) if debug: print(f"graph_input_tensor_ids {graph_input_tensor_ids}") assert len(graph_input_tensor_ids) == len( - graph_input_ivalues - ), f"{len(graph_input_tensor_ids)} v.s. {len(graph_input_ivalues)}" + graph_input_xla_values + ), f"{len(graph_input_tensor_ids)} v.s. {len(graph_input_xla_values)}" graph_input_matcher = GraphInputMatcher( - tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_ivalues + tensor_id_to_arg_idx, graph_input_tensor_ids, graph_input_xla_values ) # compiles+runs graph rooted at tensors in 'args_and_out' torch_xla._XLAC._xla_sync_multi(args_and_out, []) + torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) # input all cpu tensors def optimized_mod(*args): + torch_xla._XLAC._xla_sync_multi(args, []) enter_ts = time.time() if len(args_and_out) == 0: return () assert len(args) > 0 # can not handle no args case for now - eager_device = args[0].device graph_input = graph_input_matcher(args) start_ts = time.time() res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input) @@ -151,9 +166,7 @@ def optimized_mod(*args): f"torchxla reuse compiled graph run_cached_graph takes {time.time() - start_ts} seconds" ) - prepare_output_ts = time.time() - - copy_args_ts = time.time() + args_inplace_update_ts = time.time() assert len(res) == len(args_and_out) ncopy = 0 @@ -161,17 +174,16 @@ def optimized_mod(*args): args[arg_index].copy_(res[res_index]) if debug: - print(f"Copy {ncopy} args takes {time.time() - copy_args_ts} seconds") + print( + f"Copy {ncopy} args takes {time.time() - args_inplace_update_ts} seconds" + ) - # need to convert xla tensor back to eager tensor - copy_res_ts = time.time() # First few elements might be xla_args that needs to be in place updated - result = [x.to(device=eager_device) for x in res[len(xla_args_need_update) :]] + result = res[len(xla_args_need_update) :] if debug: - print(f"Copy results takes {time.time() - copy_res_ts} seconds") - print(f"prepare output takes {time.time() - prepare_output_ts} seconds") print(f"optimized_mod takes {time.time() - enter_ts} seconds overall") + xm.mark_step() return result return optimized_mod diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index cbf5a0b46148..481794707efd 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -399,7 +399,20 @@ def clone_tensor(x): def clone_input(x): """copy while preserving strides""" + + def torch_clone(x): + y = torch.clone(x) + if x.is_leaf: + y.requires_grad_(x.requires_grad) + if x.is_leaf and x.grad is not None: + y.grad = clone_input(x.grad) + return y + with torch.no_grad(): + if x.device.type == "xla": + # Access data_ptr() for a xla tensor will cause crash + return torch_clone(x) + needed_size = sum( (shape - 1) * stride for shape, stride in zip(x.size(), x.stride()) ) @@ -421,12 +434,7 @@ def clone_input(x): # RuntimeError: unsupported operation: more than one element of the written-to # tensor refers to a single memory location. Please clone() the tensor before # performing the operation. - y = torch.clone(x) - if x.is_leaf: - y.requires_grad_(x.requires_grad) - if x.is_leaf and x.grad is not None: - y.grad = clone_input(x.grad) - return y + return torch_clone(x) return result From 40cf214f2d18b3b8af5354ddc5dad8156ea32520 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Mon, 21 Nov 2022 03:31:51 +0000 Subject: [PATCH 428/453] Support masked_fill to address the GPT2 performance issue (#89274) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89274 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 23 +++++++++++++ torch/_inductor/codegen/cpp.py | 51 ++++++++++++++++++++++++---- torch/_inductor/codegen/cpp_prefix.h | 12 +++++++ 3 files changed, 79 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 0aaf74886c7c..ed68c2844236 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4833,6 +4833,29 @@ def test_auto_simd(self): isa = codecache.pick_vec_isa() self.assertTrue(isa == vec_avx2) + @unittest.skipIf( + not codecache.valid_vec_isa_list(), "Does not support vectorization" + ) + @patch("torch.cuda.is_available", lambda: False) + def test_masked_fill_softmax(self): + def fn(value, mask): + mask = mask.to(torch.bool) + x = torch.masked_fill(value, mask, -33.0) + return torch.softmax(x, -1) + + value = torch.randn((2, 17)) + mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) + with patch.object(config.cpp, "simdlen", None): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value, mask) + + real_out = fn(value, mask) + compiled_out = opt_fn(value, mask) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count >= 1 + @unittest.skipIf( not codecache.valid_vec_isa_list(), "Does not support vectorization" ) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index f82591ddff36..3568cfdc08ef 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -311,6 +311,10 @@ def maximum(a, b): def square(a): return f"{a}.pow(2)" + @staticmethod + def where(a, b, c): + return f"decltype({b})::blendv({c}, {b}, {a})" + @staticmethod def sign(x): code = BracesBuffer() @@ -330,6 +334,11 @@ def sign(x): V.kernel.compute.splice(code) return result + @staticmethod + def to_dtype(x, dtype): + assert dtype in [torch.bool], f"{__name__} does not support {dtype}" + return f"({x})" + class CppOverrides(OpOverrides): """Map element-wise ops to C++""" @@ -745,7 +754,16 @@ def load(self, name: str, index: sympy.Expr): if expanded_index == new_index: line = f"at::vec::Vectorized({var}[{cexpr(index)}])" else: - line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" + if V.graph.get_dtype(name) in [torch.bool, torch.uint8]: + g_tmp_buf = f"g_tmp_buffer_{var}" + nelements = codecache.pick_vec_isa().nelements() + self.loads.writeline(f"float {g_tmp_buf}[{nelements}] = {{0}};") + self.loads.writeline( + f"flag_to_float({var} + {cexpr(new_index)}, {g_tmp_buf}, {nelements});" + ) + line = f"at::vec::Vectorized::loadu({g_tmp_buf})" + else: + line = f"at::vec::Vectorized::loadu({var} + {cexpr(new_index)})" return self.cse.generate(self.loads, line) @@ -842,9 +860,6 @@ def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr): return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index) def could_vec(self, name: str, index: sympy.Expr): - if V.graph.get_dtype(name) is not torch.float: - return False - assert self.itervars is not None # Not a loop if len(self.itervars) == 0: @@ -854,12 +869,24 @@ def could_vec(self, name: str, index: sympy.Expr): return self.is_legal_data_access(most_inner_var, index) def load(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) + if not V.graph.get_dtype(name) in [ + torch.float, + torch.float32, + torch.bool, + torch.uint8, + ]: + self.simd_vec = False + return self.simd_vec + index = self.rename_indexing(index) self.simd_vec = self.simd_vec and self.could_vec(name, index) return self.simd_vec def store(self, name, index, value, mode=None): + if not V.graph.get_dtype(name) in [torch.float, torch.float32]: + self.simd_vec = False + return self.simd_vec + assert "buf" in name index = self.rename_indexing(index) @@ -932,15 +959,24 @@ def constant(val, dtype): @staticmethod def index_expr(expr, dtype): self.simd_vec = False - return self.cse.newvar() + tmp_var = self.cse.newvar() + return tmp_var @staticmethod def indirect_indexing(index_var): + self.simd_vec = False return sympy.Symbol(str(index_var)) @staticmethod def masked(mask, body, other): - return V.kernel.cse.newvar() + tmp_var = self.cse.newvar() + return tmp_var + + @staticmethod + def to_dtype(x, dtype): + if dtype != torch.bool: + self.simd_vec = False + return x self.exit_stack.enter_context(V.set_ops_handler(VecCheckerProxy())) self.exit_stack.enter_context(V.set_kernel_handler(self)) @@ -1088,6 +1124,7 @@ def codegen_loops(self, code, worksharing): if reduction_par_depth > 0 and reduction_par_depth != len( loops_nest_reduce.loops ): + metrics.generated_cpp_vec_kernel_count -= 1 return self.simd_omp_kernel.codegen_loops(code, worksharing) with contextlib.ExitStack() as stack: diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 1905aefcda5c..c1c9c3bae112 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -57,3 +57,15 @@ template void atomic_add(volatile T *addr, T offset) { } while (!atomic_addr->compare_exchange_weak(expected, desired, std::memory_order_relaxed)); } + +// This function is used to convert bool or uint8 to float mask for +// vectorization. The caller needs to make sure the src represents TRUE/FALSE +// correctly. +template +void flag_to_float(const T* src, float* dst, int64_t n) { +#pragma unroll + for (int64_t i = 0; i < n; i++) { + uint32_t* dst_u32 = (uint32_t*)dst; + dst_u32[i] = *(src + i) ? 0xFFFFFFFF : 0; + } +} From f2cf1b0f5e98094cf7a97439ebdf3679ceee04b0 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 22 Nov 2022 05:48:43 +0000 Subject: [PATCH 429/453] Revert submodule updates introduced by #89157 (#89449) Reverts updates that were introduced by https://github.com/pytorch/pytorch/pull/89157 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89449 Approved by: https://github.com/kit1980, https://github.com/huydhn, https://github.com/clee2000 --- third_party/gloo | 2 +- third_party/pybind11 | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/gloo b/third_party/gloo index 5b1435132631..4a5e339b7642 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 5b143513263133af2b95547e97c07cebeb72bf72 +Subproject commit 4a5e339b764261d20fc409071dc7a8b8989aa195 diff --git a/third_party/pybind11 b/third_party/pybind11 index aa304c9c7d72..80dc998efced 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit aa304c9c7d725ffb9d10af08a3b34cb372307020 +Subproject commit 80dc998efced8ceb2be59756668a7e90e8bef917 From 7b0650d5cf4897089f32c011504d2b2d185cc60a Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Tue, 22 Nov 2022 06:26:10 +0000 Subject: [PATCH 430/453] Back out "[static-runtime] change the backend for permute_copy" (#89463) Summary: This permute copy change seems to be causing huge regressions on machines without AVX512. Revert to mitigate. This shouldn't be problematic since the improvement from changing it was super small anyways. Differential Revision: D41450088 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89463 Approved by: https://github.com/hlu1 --- .../static_runtime/test_static_runtime.cc | 5 -- torch/csrc/jit/runtime/static/ops.cpp | 53 ------------------- 2 files changed, 58 deletions(-) diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index dc4ce01df72c..ef3bc75f921b 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -2164,12 +2164,7 @@ TEST(StaticRuntime, Permute) { c10::List dims_b{0, 2, 1}; std::vector args_b{b, dims_b}; - auto c = at::randn({3, 3, 3}); - c10::List dims_c{0, -1, 1}; - std::vector args_c{c, dims_c}; - testStaticRuntime(permute_script, args_a); - testStaticRuntime(permute_script, args_c); testStaticRuntime(permute_script, args_a, args_b); permute_script = R"JIT( diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 834a71b08161..e2a154ad069e 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1675,36 +1675,6 @@ REGISTER_OPERATOR_FUNCTOR( }; }); -namespace { - -std::vector permute_output_sizes( - c10::IntArrayRef self_sizes, - c10::IntArrayRef dims) { - const auto nDim = dims.size(); - TORCH_CHECK( - self_sizes.size() == nDim, - "permute input and output tensors must have the same rank, got input rank=", - self_sizes.size(), - "; output rank=", - nDim); - std::vector dims_seen(nDim, false); - std::vector output_sizes; - output_sizes.reserve(nDim); - for (size_t i = 0; i < nDim; ++i) { - auto dim = c10::maybe_wrap_dim(dims[i], nDim); - TORCH_CHECK( - !dims_seen[dim], - "permute dims must be unique, found duplicate dim=", - dim); - - output_sizes.push_back(self_sizes[dim]); - dims_seen[dim] = true; - } - return output_sizes; -} - -} // namespace - // Out variants for view ops are registered to a separate registry because // their outputs (views) can't participate in memory reuse. REGISTER_OPERATOR_FUNCTOR( @@ -1729,29 +1699,6 @@ REGISTER_OPERATOR_FUNCTOR( }; }); -REGISTER_OPERATOR_FUNCTOR( - static_runtime::permute_copy, - sr_permute_copy, - [](Node* n) -> SROperator { - if (!n->matches(torch::schema( - "static_runtime::permute_copy(Tensor self, int[] dims) -> Tensor"))) { - LogAndDumpSchema(n); - return nullptr; - } - return [](ProcessedNode* p_node) { - const auto& self = p_node->Input(0).toTensor(); - const auto dims = p_node->Input(1).toDimVector(); - - if (p_node->Output(0).isNone()) { - p_node->Output(0) = create_empty_from(self); - } - auto& output = p_node->Output(0).toTensor(); - at::native::resize_( - output, permute_output_sizes(self.sizes(), dims), c10::nullopt); - at::native::permute_copy_out(self, dims, output); - }; - }); - REGISTER_OPERATOR_FUNCTOR( static_runtime::flatten_copy, aten_flatten, From 6b085d5cadffb10591c450623f93a21dd3dd786d Mon Sep 17 00:00:00 2001 From: Iris Date: Tue, 22 Nov 2022 07:49:06 +0000 Subject: [PATCH 431/453] [Checkpoint][2D][2/N] Add traverse for distributed checkpoint to core distributed (#89398) This PR moves traverse and its test to torch.distributed.checkpoint. This is a pre-req for enabling 2D checkpoint. This is used when flatten nested dict and flatten sharded tensors. Docstring and comments will be added in the following PRs. Test: ``` python3 test/distributed/_tensor/parallel/test_2d_parallel.py ``` and CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/89398 Approved by: https://github.com/wanchaol --- test/distributed/checkpoint/test_traverse.py | 176 +++++++++++++++++++ torch/distributed/checkpoint/traverse.py | 170 ++++++++++++++++++ 2 files changed, 346 insertions(+) create mode 100644 test/distributed/checkpoint/test_traverse.py create mode 100644 torch/distributed/checkpoint/traverse.py diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py new file mode 100644 index 000000000000..a73cb89befba --- /dev/null +++ b/test/distributed/checkpoint/test_traverse.py @@ -0,0 +1,176 @@ +# Owner(s): ["oncall: distributed"] + +from collections import OrderedDict +import torch + +import torch.distributed.checkpoint.traverse as traverse +from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from torch.testing._internal.common_utils import run_tests, TestCase + + +# TODO: add comments for TestTraverse +class TestTraverse(TestCase): + def test_traverse_shallow(self) -> None: + state_dict = { + "key0": 1, + "key1": [1, 2], + "key2": {1: 2, 2: 3}, + "key3": torch.tensor([1]), + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0",), data) + self.assertEqual(data[("key0",)], 1) + + self.assertIn(("key1",), data) + self.assertEqual(data[("key1",)], [1, 2]) + + self.assertIn(("key2",), data) + self.assertEqual(data[("key2",)], {1: 2, 2: 3}) + + self.assertIn(("key3",), data) + self.assertEqual(data[("key3",)], torch.tensor([1])) + + def test_traverse_nested_list(self) -> None: + state_dict = { + "key1": [ + torch.tensor([1]), + [33, torch.tensor([2]), [44, 55]], + [66, 77], + ], + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + traverse.traverse_state_dict(state_dict, collect_data) + + self.assertNotIn(("key1"), data) + + self.assertIn(("key1", 0), data) + self.assertEqual(data[("key1", 0)], torch.tensor([1])) + + self.assertIn(("key1", 1, 0), data) + self.assertEqual(data[("key1", 1, 0)], 33) + + self.assertIn(("key1", 1, 1), data) + self.assertEqual(data[("key1", 1, 1)], torch.tensor([2])) + + self.assertIn(("key1", 1, 2), data) + self.assertEqual(data[("key1", 1, 2)], [44, 55]) + self.assertNotIn(("key1", 1, 2, 0), data) + + self.assertIn(("key1", 2), data) + self.assertEqual(data[("key1", 2)], [66, 77]) + + def test_traverse_nested_dict(self) -> None: + state_dict = { + "key0": {"key1": 99, "key2": torch.tensor([1])}, + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + traverse.traverse_state_dict(state_dict, collect_data) + + self.assertNotIn(("key0",), data) + + self.assertIn(("key0", "key1"), data) + self.assertEqual(data[("key0", "key1")], 99) + + self.assertIn(("key0", "key2"), data) + self.assertEqual(data[("key0", "key2")], torch.tensor([1])) + + def test_traverse_doesnt_ignore_intermediate_collections(self) -> None: + state_dict: STATE_DICT_TYPE = { + "key0": [{"key1": {"key2": torch.tensor([1])}}] + } + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0", 0, "key1", "key2"), data) + self.assertEqual( + data[("key0", 0, "key1", "key2")], + torch.tensor([1]), + ) + + def test_traverse_with_ordered_dict(self) -> None: + state_dict = OrderedDict( + { + "key0": [ + 99, + torch.tensor([3]), + ] + } + ) + + data = {} + + def collect_data(path, value): + nonlocal data + data[path] = value + + traverse.traverse_state_dict(state_dict, collect_data) + + self.assertIn(("key0", 0), data) + self.assertEqual(data[("key0", 0)], 99) + + self.assertIn(("key0", 1), data) + self.assertEqual(data[("key0", 1)], torch.tensor([3])) + + def test_set_element(self) -> None: + state_dict: STATE_DICT_TYPE = {} + + traverse.set_element(state_dict, ("k",), 10) + self.assertEqual(state_dict["k"], 10) + + traverse.set_element(state_dict, ("k1", 2), 1) + self.assertEqual(state_dict["k1"], [None, None, 1]) + + traverse.set_element(state_dict, ("k1", 1), 99) + self.assertEqual(state_dict["k1"], [None, 99, 1]) + + traverse.set_element(state_dict, ("k1", 3), 88) + self.assertEqual(state_dict["k1"], [None, 99, 1, 88]) + + traverse.set_element(state_dict, ("k2", "k3"), 3) + self.assertEqual(state_dict["k2"], {"k3": 3}) + + traverse.set_element(state_dict, ("k2", "k4", 0, 0), 99) + self.assertEqual(state_dict["k2"]["k4"][0], [99]) + + def test_get_element(self) -> None: + state_dict = {"a": [0, 1], "b": [2, {"c": "d"}]} + self.assertEqual(traverse.get_element(state_dict, ("a",)), [0, 1]) + self.assertEqual(traverse.get_element(state_dict, ("b", 0)), 2) + self.assertEqual(traverse.get_element(state_dict, ("b", 1, "c")), "d") + + self.assertIsNone(traverse.get_element(state_dict, ("c",))) + self.assertIsNone(traverse.get_element(state_dict, ("a", 33))) + self.assertIsNone(traverse.get_element(state_dict, ("b", 88))) + self.assertIsNone(traverse.get_element(state_dict, ("b", 0, 2))) + self.assertIsNone(traverse.get_element(state_dict, ("b", 1, 2))) + self.assertIsNone(traverse.get_element(state_dict, ("b", 1, "d"))) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/distributed/checkpoint/traverse.py b/torch/distributed/checkpoint/traverse.py new file mode 100644 index 000000000000..75dc42453348 --- /dev/null +++ b/torch/distributed/checkpoint/traverse.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch + +from typing import ( + Callable, + Collection, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + TypeVar, + Union, + cast, +) +from torch.distributed.checkpoint.metadata import ( + STATE_DICT_TYPE, +) +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._tensor import DTensor + +PATH_ITEM = Union[str, int] +OBJ_PATH = Tuple[PATH_ITEM, ...] +T = TypeVar("T") + +STATE_DICT_ITEM = object +CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM] + +__all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] + + +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: + return isinstance(value, torch.Tensor) + + +# TODO: update docstring for traverse.py +def traverse_state_dict( + state_dict: STATE_DICT_TYPE, + visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None], + keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors, +) -> None: + """ + Invoke ``visitor`` for each value recursively in ``state_dict``. + Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates + to false for all elements. + By default, all collections with at least one ``torch.Tensor`` element are traversed. + Visitor takes a path argument that is a tuple of the keys used to reach it. + """ + # a value is terminal if it has no other containers values inside it + def _is_terminal(value: STATE_DICT_ITEM) -> bool: + values: Collection[STATE_DICT_ITEM] + if isinstance(value, Mapping): + values = value.values() + elif isinstance(value, list): + values = value + else: + return True + + for entry in values: + if isinstance(entry, (Mapping, list)) and not _is_terminal(entry): + return False + if keep_traversing is not None and keep_traversing(entry): + return False + return True + + def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None: + if _is_terminal(value): + visitor(path, value) + elif isinstance(value, Mapping): + for k, v in value.items(): + _traverse_obj(path + (str(k),), v) + elif isinstance(value, list): + for i, v in enumerate(value): + _traverse_obj(path + (i,), v) + + for key, value in state_dict.items(): + _traverse_obj((str(key),), value) + + +def set_element( + root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM +) -> None: + """ + Set ``value`` in ``root_dict`` along the ``path`` object path. + """ + cur_container = cast(CONTAINER_TYPE, root_dict) + + def extend_list(lst: List[STATE_DICT_ITEM], idx: int) -> None: + while len(lst) <= idx: + lst.append(None) + + for i in range(1, len(path)): + prev_key = path[i - 1] + key = path[i] + def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else []) + + if isinstance(cur_container, Mapping): + cur_container = cast( + CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) + ) + else: + extend_list(cur_container, prev_key) + if cur_container[prev_key] is None: + cur_container[prev_key] = def_val + cur_container = cur_container[prev_key] + + key = path[-1] + if type(key) == int: + extend_list(cast(List[STATE_DICT_ITEM], cur_container), key) + + cur_container[key] = value + + +def get_element( + root_dict: STATE_DICT_TYPE, + path: OBJ_PATH, + default_value: Optional[T] = None, +) -> Optional[T]: + """ + Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found. + """ + cur_value = cast(CONTAINER_TYPE, root_dict) + for part in path: + if type(part) is int: + if not isinstance(cur_value, list) or len(cur_value) < part: + return default_value + elif not isinstance(cur_value, Mapping) or part not in cur_value: + return default_value + + cur_value = cast(CONTAINER_TYPE, cur_value[part]) + return cast(Optional[T], cur_value) + + +def _print_nested( + value: STATE_DICT_ITEM, + prefix: str = "", + print_fun: Callable[[str], None] = print, +) -> None: + if type(value) is ShardedTensor: + print_fun(f"{prefix} ShardedTensor size: {value.size()}") + for shard in value.local_shards(): + _print_nested( + shard.tensor, + f"{shard.metadata.shard_offsets} ", + print_fun=print_fun, + ) + elif type(value) is (DTensor): + print_fun(f"{prefix} DistributedTensor size: {value.size()}") + # TODO: add local offset for _local_tensor in print_nested. + _print_nested( + value._local_tensor, + print_fun=print_fun, + ) + elif isinstance(value, torch.Tensor): + print_fun(f"{prefix} Tensor size: {value.size()}") + else: + print_fun(f"{prefix} Type: {type(value)}") + + +def print_tensor( + path: OBJ_PATH, + value: STATE_DICT_ITEM, + print_fun: Callable[[str], None] = print, +) -> None: + """ + Callback that can be used with travese_state_dict to print its content. + By default the content is printed using the builtin ``print`` but this can + be change by passing a different ``print_fun` callable. + """ + _print_nested(value, prefix=str(path), print_fun=print_fun) From 1d6a188d08829b1aee28eb1e6255d5bf43a77f16 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sat, 19 Nov 2022 01:00:03 +0000 Subject: [PATCH 432/453] Reland Dispatch torch.norm to linalg.vector_norm and linalg.matrix_norm (#81761) (#84624) Reland https://github.com/pytorch/pytorch/pull/81761 Differential Revision: [D39332292](https://our.internmc.facebook.com/intern/diff/D39332292) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84624 Approved by: https://github.com/kit1980 --- aten/src/ATen/autocast_mode.cpp | 5 +- .../functorch/BatchRulesDecompositions.cpp | 1 - .../ATen/functorch/BatchRulesReduceOps.cpp | 2 +- aten/src/ATen/native/LinearAlgebra.cpp | 4 +- test/functorch/test_vmap.py | 5 ++ test/onnx/test_operators.py | 6 +++ test/onnx/test_pytorch_onnx_onnxruntime.py | 8 +++ test/onnx/test_utility_funs.py | 2 - test/test_decomp.py | 6 +-- test/test_linalg.py | 15 +++--- test/test_reductions.py | 8 ++- torch/functional.py | 49 +++++++++++++++--- .../_internal/common_methods_invocations.py | 51 +++++++------------ 13 files changed, 100 insertions(+), 62 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index ca75c38258ff..ee8b4b30b152 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -450,6 +450,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL2(cumprod, dimname, fp32_set_opt_dtype) KERNEL(cumsum, fp32_set_opt_dtype) KERNEL2(cumsum, dimname, fp32_set_opt_dtype) + KERNEL(linalg_vector_norm, fp32_set_opt_dtype) + KERNEL(linalg_matrix_norm, fp32_set_opt_dtype) + KERNEL2(linalg_matrix_norm, str_ord, fp32_set_opt_dtype) // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even // when autocasting. // KERNEL2(norm, ScalarOpt_dtype, fp32_set_opt_dtype) @@ -576,8 +579,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(fft_irfftn, fp32) KERNEL_CPU(fft_hfft, fp32) KERNEL_CPU(fft_ihfft, fp32) - KERNEL_CPU(linalg_matrix_norm, fp32) - KERNEL_CPU2(linalg_matrix_norm, str_ord, fp32) KERNEL_CPU(linalg_cond, fp32) KERNEL_CPU2(linalg_cond, p_str, fp32) KERNEL_CPU(linalg_matrix_rank, fp32) diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 05ee8d07a410..d5a38e9804dd 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -253,7 +253,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE2(conv2d, padding); OP_DECOMPOSE2(conv3d, padding); OP_DECOMPOSE(_convolution_mode); - OP_DECOMPOSE(frobenius_norm); OP_DECOMPOSE(type_as); OP_DECOMPOSE(linalg_diagonal); OP_DECOMPOSE(diagonal_copy); diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index 8654b7845501..9126507e73be 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -168,7 +168,7 @@ void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack #define REDUCTION_BOXED_ARGS(op, dim_pos) \ m.impl(#op, torch::CppFunction::makeFromBoxedFunction>()); -// Skipping frobenius/nuclear/all/any since they don't have opinfo tests right now :P +// Skipping all/any since they don't have opinfo tests right now :P Tensor dist_decomp(const Tensor& self, const Tensor& other, const Scalar& p) { return at::norm((self - other), p); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c21bc4b47531..7e47170cd72e 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -2770,7 +2770,7 @@ Tensor& linalg_norm_out(const Tensor& X, c10::string_view ord, OptionalIntArrayR //////////////////////////////////////////////////////////////////////////////// // Frobenius Norm // -// Just used in linalg.norm. It should not be removed. // +// Just used in torch..norm. It should not be removed. // //////////////////////////////////////////////////////////////////////////////// Tensor frobenius_norm(const Tensor& self) { @@ -2816,7 +2816,7 @@ Tensor &frobenius_norm_out(const Tensor& self, //////////////////////////////////////////////////////////////////////////////// // Nuclear Norm // -// Just used in linalg.norm. It should not be removed. // +// Just used in torch.norm. It should not be removed. // //////////////////////////////////////////////////////////////////////////////// Tensor nuclear_norm(const Tensor& self, bool keepdim) { diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 4c2c680ca637..9b3293a7db75 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3878,6 +3878,11 @@ def f(e_): skip('linalg.multi_dot'), # accepts list of tensor inputs, has its own special test xfail('linalg.vander'), xfail('linalg.vecdot'), + # throws in vmap on CUDA + # IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2) + # https://github.com/pytorch/pytorch/runs/8110653462?check_suite_focus=true + # but it passes locally + skip('linalg.matrix_norm', ''), skip('linalg.ldl_solve', ''), }) def test_vmap_linalg_failure_1D_input(self, device, dtype, op): diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index cfb36732af4d..7375cf3fe4d7 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -654,10 +654,14 @@ def test_repeat_dim_overflow(self): x = torch.randn(1, 2, requires_grad=True) self.assertONNX(lambda x: x.repeat(1, 2, 3, 4), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_norm_p1(self): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: x.norm(p=1, dim=2), (x)) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_norm_p2(self): x = torch.randn(1, 2, 3, 4, requires_grad=True) self.assertONNX(lambda x: x.norm(p=2, dim=2), (x)) @@ -957,6 +961,8 @@ def test_pixel_shuffle(self): lambda x: torch.pixel_shuffle(x, upscale_factor=2), x, opset_version=11 ) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm(self): x = torch.randn(2, 3, 4).float() self.assertONNX(lambda x: torch.norm(x, p="fro", dim=(0, 1), keepdim=True), x) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 16839dded0c4..184cc5f4ae67 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -6701,6 +6701,8 @@ def forward(self, x, y): y = torch.tensor(2) self.run_test(FullLikeModel(), (x, y)) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_l1_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6709,6 +6711,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_l2_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6717,6 +6721,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm(self): class NormModel(torch.nn.Module): def forward(self, x): @@ -6725,6 +6731,8 @@ def forward(self, x): x = torch.randn(4, 2, 3, requires_grad=True) self.run_test(NormModel(), x) + @unittest.skip("It started failing after #81761") + # TODO(#83661): Fix and enable the test def test_frobenius_norm_keepdim(self): class NormModel(torch.nn.Module): def forward(self, x): diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 5d1cdc5e8ea5..7e23b06e5541 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -240,7 +240,6 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL2") - self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_reduceL1(self): class NormModule(torch.nn.Module): @@ -258,7 +257,6 @@ def forward(self, x): for node in graph.nodes(): self.assertNotEqual(node.kind(), "onnx::ReduceL1") - self.assertEqual(len(list(graph.nodes())), 2) def test_constant_fold_slice(self): class NarrowModule(torch.nn.Module): diff --git a/test/test_decomp.py b/test/test_decomp.py index ad8cf27ae0f2..d69d72753e47 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -159,8 +159,8 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2, (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, - (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-6, - (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-6, + (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-5, + (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-5, (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2, (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1, } @@ -303,9 +303,9 @@ def normalize_op_input_output(f, sample, requires_grad=True): (None, None, "meshgrid"), # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit) (None, None, "diag"), - # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32 ("cpu", torch.bfloat16, "_softmax_backward_data"), + (None, None, "norm"), } CROSS_REF_BACKWARD_EXCLUDE_SET = { diff --git a/test/test_linalg.py b/test/test_linalg.py index 273c74d4e614..41c3e8a2d9ba 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1357,17 +1357,16 @@ def run_test_case(input, ord, dim, keepdim): def test_norm_fused_type_promotion(self, device, dtype): x = torch.randn(10, device=device, dtype=dtype) - def profile_and_check(fn, x, kwargs, fn_name): + def profile_and_check(fn, x, kwargs): with torch.profiler.profile(activities=(torch.profiler.ProfilerActivity.CPU,)) as p: fn(x, **kwargs, dtype=torch.float) # smoke check that profiler returned some events - self.assertTrue(fn_name in map(lambda e: e.name, p.events())) + self.assertTrue("aten::linalg_vector_norm" in (e.name for e in p.events())) # test that there was no explicit copy - self.assertFalse("aten::to" in map(lambda e: e.name, p.events())) + self.assertFalse("aten::to" in (e.name for e in p.events())) - for f, kwargs, fn_name in zip((torch.norm, torch.linalg.vector_norm), ({"p" : 2}, {}), - ("aten::norm", "aten::linalg_vector_norm")): - profile_and_check(f, x, kwargs, fn_name) + for f, kwargs, in zip((torch.linalg.vector_norm, torch.norm), ({}, {"p" : 2})): + profile_and_check(f, x, kwargs) @skipMeta # https://github.com/pytorch/pytorch/issues/53739 @skipCPUIfNoLapack @@ -2310,10 +2309,10 @@ def test_nuclear_norm_exceptions_old(self, device): x = torch.tensor(lst, dtype=torch.double, device=device) for axes in (), (0,): self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) - self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) + self.assertRaises(RuntimeError, torch.norm, x, "nuc", (0, 1)) x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) + self.assertRaisesRegex(RuntimeError, "must be different", torch.norm, x, "nuc", (0, 0)) self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) @skipCUDAIfNoCusolver diff --git a/test/test_reductions.py b/test/test_reductions.py index 8d91f56545f0..7a360888e659 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -464,9 +464,9 @@ def test_dim_reduction_less_than_64(self, device): torch.norm] for op in ops: with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, 64) + op(x, dim=64) with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): - op(x, -1) + op(x, dim=-1) @onlyCPU @dtypes(torch.float, torch.bfloat16) @@ -1793,11 +1793,9 @@ def test_repeated_dim(self, device): x = torch.randn(3, 3, 3, 3, device=device) error_msg = r'appears multiple times in the list of dims' - norm_error_msg = r'Expected dims to be different, got' for op in ops: for dim in [(0, 0), (0, -4)]: - e_msg = norm_error_msg if op == torch.norm else error_msg - with self.assertRaisesRegex(RuntimeError, e_msg): + with self.assertRaisesRegex(RuntimeError, error_msg): op(x, dim=dim) # TODO: update this test to comapre against NumPy diff --git a/torch/functional.py b/torch/functional.py index 7e96d42fde30..ee04cb250c2c 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1393,10 +1393,11 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa Its documentation and behavior may be incorrect, and it is no longer actively maintained. - Use :func:`torch.linalg.norm`, instead, or :func:`torch.linalg.vector_norm` - when computing vector norms and :func:`torch.linalg.matrix_norm` when - computing matrix norms. Note, however, the signature for these functions - is slightly different than the signature for torch.norm. + Use :func:`torch.linalg.vector_norm` when computing vector norms and + :func:`torch.linalg.matrix_norm` when computing matrix norms. + For a function with a similar behavior as this one see :func:`torch.linalg.norm`. + Note, however, the signature for these functions is slightly different than the + signature for ``torch.norm``. Args: input (Tensor): The input tensor. Its data type must be either a floating @@ -1446,8 +1447,8 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa .. note:: Even though ``p='fro'`` supports any number of dimensions, the true mathematical definition of Frobenius norm only applies to tensors with - exactly two dimensions. :func:`torch.linalg.norm` with ``ord='fro'`` aligns - with the mathematical definition, since it can only be applied across + exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'`` + aligns with the mathematical definition, since it can only be applied across exactly two dimensions. Example:: @@ -1481,6 +1482,42 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa return handle_torch_function( norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) + # NB. All the repeated code and weird python is to please TorchScript. + # For a more compact implementation see the relevant function in `_refs/__init__.py` + + # We don't do this for MPS or sparse tensors + if input.layout == torch.strided and input.device.type in ("cpu", "cuda", "meta"): + if dim is not None: + if isinstance(dim, int): + _dim = [dim] + else: + _dim = dim + else: + _dim = None # type: ignore[assignment] + + if isinstance(p, str): + if p == "fro" and (dim is None or isinstance(dim, int) or len(dim) <= 2): + if out is None: + return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out) + + # Here we either call the nuclear norm, or we call matrix_norm with some arguments + # that will throw an error + if _dim is None: + _dim = list(range(input.ndim)) + if out is None: + return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype) + else: + return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out) + else: + # NB. p should be Union[str, number], not Optional! + _p = 2.0 if p is None else p + if out is None: + return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype) + else: + return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out) + ndim = input.dim() # catch default case diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 3d3c13bb7208..4b2d0ebabc46 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -15878,24 +15878,20 @@ def reference_flatten(input, start_dim=0, end_dim=-1): "norm", sample_inputs_func=sample_inputs_norm, dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), + # TODO Benchmark again with the new implementation # Runs very slowly on slow gradcheck - alternatively reduce input sizes gradcheck_fast_mode=True, + check_batched_forward_grad=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result - # of dtype torch.float32 into an out= with dtype torch.long - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_out", - device_type="meta", - ), - ), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)),) ), OpInfo('norm', variant_test_name='nuc', - aten_name='nuclear_norm', sample_inputs_func=sample_inputs_norm_nuc, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack], check_batched_gradgrad=False, @@ -15907,19 +15903,14 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypes=floating_and_complex_types(), dtypesIfCUDA=floating_and_complex_types(), skips=( - # RuntimeError not raised : - # Expected RuntimeError when calling with input.device=cpu and out.device=cuda - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # RuntimeError: - # Arguments for call are not valid. - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950 - ) + # Dispatches in Python to matrix_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) ), OpInfo('norm', variant_test_name='fro', - aten_name='frobenius_norm', sample_inputs_func=sample_inputs_norm_fro, - dtypes=floating_and_complex_types_and(torch.bfloat16), + dtypes=floating_and_complex_types_and(torch.bfloat16, torch.float16), dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_forward_ad=True, # torch.autograd.gradcheck.GradcheckError: While computing batched gradients @@ -15933,33 +15924,29 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', dtypes=(torch.complex64, torch.complex128)), - # Expected RuntimeError when calling with input.device=cpu and out.device=cuda - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # Arguments for call are not valid. - DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.complex64, torch.float32,)), # noqa: B950 - )), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.complex64, torch.float32,)),) + ), OpInfo( "norm", variant_test_name="inf", sample_inputs_func=sample_inputs_norm_inf, dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), supports_forward_ad=True, + check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, # fast gradcheck produces NaNs gradcheck_fast_mode=False, skips=( - # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast from a result - # of dtype torch.float32 into an out= with dtype torch.long - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_out", - device_type="meta", - ), DecorateInfo( toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', ), + # Dispatches in Python to vector_norm. Not sure how to make this test happy + # Happens to pass on complex64. Also a mystery + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', + dtypes=(torch.float32,)) ), ), OpInfo('t', From 0f7dca17332152fdd28270eb95398efbe8212ca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Mon, 21 Nov 2022 04:22:00 +0000 Subject: [PATCH 433/453] Vectorized CPU code implementing right shift operator. (#88990) Pull Request resolved: https://github.com/pytorch/pytorch/pull/88990 Approved by: https://github.com/lezcano, https://github.com/peterbell10 --- aten/src/ATen/cpu/vec/vec256/vec256_int.h | 44 ++++++++++++++++++++ aten/src/ATen/cpu/vec/vec512/vec512_int.h | 24 +++++++++++ aten/src/ATen/cpu/vec/vec_base.h | 14 +++++++ aten/src/ATen/native/cpu/BinaryOpsKernel.cpp | 11 +++-- 4 files changed, 89 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 7737f4a0037c..f17cdc5bc156 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -1181,6 +1181,8 @@ Vectorized inline shift_256_16(const Vectorized& a, const Vect __m256i c0; if (left_shift) c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); c0 = _mm256_shuffle_epi8(c0, ctl_1_0); // Peform shifting the same way for input array elements with @@ -1190,6 +1192,8 @@ Vectorized inline shift_256_16(const Vectorized& a, const Vect __m256i c1; if (left_shift) c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); c1 = _mm256_and_si256(c1, keep_1); // Merge partial results into the final result. @@ -1271,6 +1275,8 @@ Vectorized inline shift_256_8(const Vectorized& a, const Vectori __m256i c0; if (left_shift) c0 = _mm256_sllv_epi32(a0, b0); + else + c0 = _mm256_srav_epi32(a0, b0); c0 = _mm256_shuffle_epi8(c0, ctl_3_0); // Peform shifting the same way for input array elements with @@ -1280,6 +1286,8 @@ Vectorized inline shift_256_8(const Vectorized& a, const Vectori __m256i c1; if (left_shift) c1 = _mm256_sllv_epi32(a1, b1); + else + c1 = _mm256_srav_epi32(a1, b1); c1 = _mm256_shuffle_epi8(c1, ctl_3_1); // Peform shifting the same way for input array elements with @@ -1289,6 +1297,8 @@ Vectorized inline shift_256_8(const Vectorized& a, const Vectori __m256i c2; if (left_shift) c2 = _mm256_sllv_epi32(a2, b2); + else + c2 = _mm256_srav_epi32(a2, b2); c2 = _mm256_shuffle_epi8(c2, ctl_3_2); // Peform shifting the same way for input array elements with @@ -1298,6 +1308,8 @@ Vectorized inline shift_256_8(const Vectorized& a, const Vectori __m256i c3; if (left_shift) c3 = _mm256_sllv_epi32(a3, b3); + else + c3 = _mm256_srav_epi32(a3, b3); c3 = _mm256_and_si256(c3, keep_3); // Merge partial results into the final result. @@ -1328,6 +1340,38 @@ Vectorized inline operator<<(const Vectorized& a, const Vectoriz return shift_256_8(a, b); } +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + // No vector instruction for right shifting int64_t, so emulating it + // instead. + + // Shift the number logically to the right, thus filling the most + // significant bits with 0s. Then, replace these bits with the sign + // bit. + __m256i sign_bits = _mm256_cmpgt_epi64(_mm256_set1_epi64x(0), a); + __m256i b_inv_mod_64 = _mm256_sub_epi64(_mm256_set1_epi64x(64), b); + __m256i sign_ext = _mm256_sllv_epi64(sign_bits, b_inv_mod_64); + __m256i c = _mm256_srlv_epi64(a, b); + c = _mm256_or_si256(c, sign_ext); + + return c; +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm256_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index 590c3254e379..bf03f8e290b6 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -1219,6 +1219,8 @@ Vectorized inline shift_512_8(const Vectorized& a, const Vectori __m512i c0; if (left_shift) c0 = _mm512_sllv_epi16(a0, b0); + else + c0 = _mm512_srav_epi16(a0, b0); c0 = _mm512_shuffle_epi8(c0, ctl_1_0); // Peform shifting the same way for input array elements with @@ -1228,6 +1230,8 @@ Vectorized inline shift_512_8(const Vectorized& a, const Vectori __m512i c1; if (left_shift) c1 = _mm512_sllv_epi16(a1, b1); + else + c1 = _mm512_srav_epi16(a1, b1); c1 = _mm512_and_si512(c1, keep_1); // Merge partial results into the final result. @@ -1256,6 +1260,26 @@ Vectorized inline operator<<(const Vectorized& a, const Vectoriz return shift_512_8(a, b); } +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi64(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi32(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return _mm512_srav_epi16(a, b); +} + +template <> +Vectorized inline operator>>(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index e9e87fa605f7..abf106e8d5b3 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -807,6 +807,14 @@ template Vectorized inline operator<<(const Vectorized &a, const return c; } +template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] >> b[i]; + } + return c; +} + template inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { a = a + b; @@ -839,6 +847,12 @@ inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { return a; } +template +inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { + a = a >> b; + return a; +} + template inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index c2497a6949f1..9b5f442ef02c 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -383,10 +383,13 @@ void logical_xor_kernel(TensorIterator& iter) { void rshift_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return a >> b; - }); + cpu_kernel_vec(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return a >> b; + }, + [](Vectorized a, Vectorized b) { + return a >> b; + }); }); } From 2d94fd3b198a31f28df10b7d9b3fcd526a82f24a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 22 Nov 2022 11:05:58 +0000 Subject: [PATCH 434/453] [Vulkan][TCC] Fix quantized shaders (#89456) Summary: Fix rounding issue in quantized shaders Test Plan: On Mac ``` cd ~/fbsource buck1 run -c pt.vulkan_full_precision=1 //xplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 ``` On Android ``` cd ~/fbsource buck1 build -c ndk.custom_libcxx=false -c pt.enable_qpl=0 -c pt.vulkan_full_precision=1 //xplat/caffe2:pt_vulkan_quantized_api_test_binAndroid\#android-arm64 --show-output adb push buck-out/gen/xplat/caffe2/pt_vulkan_quantized_api_test_binAndroid\#android-arm64 /data/local/tmp/vulkan_quantized_api_test adb shell "/data/local/tmp/vulkan_quantized_api_test" ``` Reviewed By: salilsdesai Differential Revision: D41047095 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89456 Approved by: https://github.com/kirklandsign, https://github.com/digantdesai --- aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl | 8 +++++--- aten/src/ATen/native/vulkan/glsl/quantized_add.glsl | 4 ++-- aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl | 2 +- aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl | 2 +- .../ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl | 2 +- aten/src/ATen/native/vulkan/glsl/quantized_div.glsl | 4 ++-- aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl | 4 ++-- aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl | 4 ++-- .../native/vulkan/glsl/quantized_upsample_nearest2d.glsl | 3 +-- 9 files changed, 17 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl index 910603aa29f2..f67954ad48c1 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantize_per_tensor.glsl @@ -19,11 +19,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); if (all(lessThan(pos, uBlock.size.xyz))) { - vec4 ret = texelFetch(uInput, pos, 0) / uBlock.scale.x + uBlock.zero_point.x; - uvec4 texel = uvec4(int(ret.x), int(ret.y), int(ret.z), int(ret.w)); + vec4 q_res = roundEven(texelFetch(uInput, pos, 0) / uBlock.scale.x) + uBlock.zero_point.x; + + uvec4 ret = uvec4(q_res); + imageStore( uOutput, pos, - texel); + ret); } } diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl index 8f6e51397d1c..a526dc2121bf 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_add.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 + deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl index bb139d914f07..63bf055761cc 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl @@ -64,7 +64,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } /* diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl index c2ccee79d56a..0d823620a517 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl @@ -65,7 +65,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } void main() { diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl index c8a2a98f9ef0..2ef6d3d60f32 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl @@ -60,7 +60,7 @@ vec4 dequantize(vec4 tex, float scale, int zero_point) { * Quantizes a float texel based on a scale and zero point. */ uvec4 quantize(vec4 tex, float scale, int zero_point) { - return uvec4(tex / scale + zero_point); + return uvec4(roundEven(tex / scale) + zero_point); } /* diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl index aa961eb34993..1998c5abbca3 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_div.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 / deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl index 459f56915d77..c1ce18dbb38c 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_mul.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 * deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl index 6bd00f33a89c..767181f080fd 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_sub.glsl @@ -34,9 +34,9 @@ void main() { vec4 deq_in_1 = uBlock.in_scale.y * (texel1 - uBlock.in_zero_point.y); vec4 res = deq_in_0 - deq_in_1; - vec4 q_res = res / uBlock.out_scale.x + uBlock.out_zero_point.x; + vec4 q_res = roundEven(res / uBlock.out_scale.x) + uBlock.out_zero_point.x; - uvec4 ret = uvec4(int(q_res.x), int(q_res.y), int(q_res.z), int(q_res.w)); + uvec4 ret = uvec4(q_res); imageStore( uOutput, diff --git a/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl index 28c167515405..46abbb1a8d76 100644 --- a/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl +++ b/aten/src/ATen/native/vulkan/glsl/quantized_upsample_nearest2d.glsl @@ -25,8 +25,7 @@ void main() { ivec2(0), uBlock.isize); - vec4 texel = texelFetch(uInput, ivec3(ipos, pos.z), 0); - uvec4 ret = uvec4(int(texel.r), int(texel.g), int(texel.b), int(texel.a)); + uvec4 ret = texelFetch(uInput, ivec3(ipos, pos.z), 0); imageStore( uOutput, From d9cbe7764e1af938d7edc23ffa873703d960df6c Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 22 Nov 2022 05:02:45 -0800 Subject: [PATCH 435/453] Make aten.copy preserve strides (hf_Longformer) (#89464) Fixes https://github.com/pytorch/torchdynamo/issues/1888 Signed-off-by: Edward Z. Yang Differential Revision: [D41460986](https://our.internmc.facebook.com/intern/diff/D41460986) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89464 Approved by: https://github.com/bdhirsh --- aten/src/ATen/native/Copy.cpp | 46 +++++++++------ aten/src/ATen/native/native_functions.yaml | 2 + test/test_functionalization.py | 65 +++++++++++++--------- test/test_fx_reinplace_pass.py | 3 +- torch/_inductor/decomposition.py | 11 ++++ 5 files changed, 81 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index dc30db8e1100..0c99943eb0cb 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -278,27 +278,39 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return self; } +// NB: cribbed from https://github.com/pytorch/pytorch/pull/88198 +at::Tensor clone_preserve_strides(const at::Tensor& self) { + TORCH_INTERNAL_ASSERT(self.has_storage()); + // In cases where the input tensor has internal memory overlap, we cannot actually + // preserve the strides/storage_offset of the input tensor, because + // *_scatter ops will try to copy_() into the cloned tensor. + // However, this should **never** show up in functionalized user code; + // most aten ops that try to mutate a tensor with internal memory overlap would error anyway. + // + // The one place that this does come up is in autograd - if there's a select_scatter + // in the forward, then autograd will generate one for the backward. + // If the input to the select_scatter is grad_output, then this could be an expanded tensor + // with internal overlap. + //if (at::has_internal_overlap(self) == at::MemOverlap::Yes) { + // return self.clone(); + //} + auto dtype_size = self.dtype().itemsize(); + auto nbytes = self.storage().sym_nbytes(); + TORCH_INTERNAL_ASSERT(nbytes % dtype_size == 0); + auto numel = nbytes / dtype_size; + auto self_full_size = self.as_strided_symint({numel}, {1}, 0); + auto clone = self_full_size.clone(); + auto out = clone.as_strided_symint(self.sym_sizes(), self.sym_strides(), self.sym_storage_offset()); + return out; +} + Tensor copy(const Tensor& self, const Tensor& src, bool non_blocking) { // copy() is the "functional" form of copy_(). It exists so we can properly functionalize copy_(), but: // (1) It isn't exposed to the frontend (no python bindings) // (2) It isn't exposed to the backend (it's a composite, that decomposes into to() and expand_as() calls. - // Note: This implementation doesn't currently preserve the strides of `self`. - // That might be fine for functorch (which already doesn't preserve strides in vmap), - // but it's worth looking into whether or not this implementation will be problematic for LazyTensor/XLA. - auto intermediate = src.to(self, non_blocking); - // We can't use expand() here. Why? - // The contract for copy_() is that the output tensor has the same amount of storage as the original tensor. - // e.g. This should work: - // a = torch.ones(4, 4) - // b = torch.ones(1, 4) - // c = torch.ones(4, 4) - // torch.ops.aten.copy(a, b).add_(c) - // We don't want to emit an extra copy every time though, so we only do it if the shapes are different. - if (self.sym_sizes() != intermediate.sym_sizes()) { - return at::expand_copy_symint(intermediate, self.sym_sizes()); - } else { - return intermediate; - } + auto r = clone_preserve_strides(self); + r.copy_(src, non_blocking); + return r; } Tensor& copy_(Tensor& self, const Tensor& src, bool non_blocking) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8c759cd09c48..730032528661 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1535,6 +1535,8 @@ - func: copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor variants: function + dispatch: + CompositeExplicitAutogradNonFunctional: copy - func: copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) variants: method diff --git a/test/test_functionalization.py b/test/test_functionalization.py index c5330664d1e8..0ab552d0d04a 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -114,6 +114,15 @@ def f(x): _functionalize(f, reapply_views=True)(torch.ones(3, 3)) + def test_copy_stride_mismatch(self): + def f(x): + y = torch.empty_strided((2, 2), (5, 1)) + y.copy_(x) + return y + + r = _functionalize(f, reapply_views=True)(torch.ones(2, 2)) + self.assertEqual(r.stride(), (5, 1)) + def test_view_clone_view_inplace(self): def f(input): shape = [1, 1024, 128, 128] @@ -149,13 +158,15 @@ def forward(self, a_1): expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) - view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) - view_copy_5 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) - clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format); view_copy_5 = None + copy = torch.ops.aten.copy.default(new_empty_strided, view_copy_3); new_empty_strided = view_copy_3 = None + view_copy_4 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) + view_copy_5 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]) + clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format) threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None - view_copy_6 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]); view_copy_3 = None + copy_1 = torch.ops.aten.copy.default(view_copy_5, threshold_backward); view_copy_5 = threshold_backward = None + view_copy_6 = torch.ops.aten.view_copy.default(copy, [16, 64, 128, 128]); copy = None detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None - view_copy_7 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None + view_copy_7 = torch.ops.aten.view_copy.default(copy_1, [1, 1024, 128, 128]); copy_1 = None view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None return detach_copy_1 @@ -829,8 +840,8 @@ def f(x): _z = torch._from_functional_tensor(z) self.assertTrue(are_aliased(_y, _z)) - # copy_() gets its own test, because it is special cased in functionalization. - # self.copy_(src) decomposes into src.to(self).expand_as(self). + # copy_() gets its own test, because it used to be special cased in functionalization. + # However, now it works pretty similar to other functional ops def test_copy_(self): def f(x): tmp = torch.zeros(2, 2) @@ -850,7 +861,8 @@ def f(x): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None return add """) @@ -862,8 +874,9 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None - return add + copy = torch.ops.aten.copy_.default(diagonal, a_1) + add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None + return diagonal """) # Test 2: copy_() with same dtype, different shape @@ -876,8 +889,8 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - expand_copy = torch.ops.aten.expand_copy.default(a_1, [2]) - add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None return add """) @@ -889,9 +902,9 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - expand_copy = torch.ops.aten.expand_copy.default(a_1, [2]) - add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None - return expand_copy + copy = torch.ops.aten.copy_.default(diagonal, a_1) + add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None + return diagonal """) # Test 3: copy_() with different dtype, same shape @@ -904,8 +917,8 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add.Tensor(_to_copy, a_1); _to_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None return add """) # noqa: B950 @@ -917,9 +930,9 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - add = torch.ops.aten.add_.Tensor(_to_copy, a_1); a_1 = None - return _to_copy + copy = torch.ops.aten.copy_.default(diagonal, a_1) + add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None + return diagonal """) # noqa: B950 # Test 4: copy_() with different dtype, different shape @@ -932,9 +945,8 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal_copy = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None - add = torch.ops.aten.add.Tensor(expand_copy, a_1); expand_copy = a_1 = None + copy = torch.ops.aten.copy.default(diagonal_copy, a_1); diagonal_copy = None + add = torch.ops.aten.add.Tensor(copy, a_1); copy = a_1 = None return add """) # noqa: B950 @@ -946,10 +958,9 @@ def forward(self, a_1): def forward(self, a_1): zeros = torch.ops.aten.zeros.default([2, 2], device = device(type='cpu'), pin_memory = False) diagonal = torch.ops.aten.diagonal.default(zeros); zeros = None - _to_copy = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - expand_copy = torch.ops.aten.expand_copy.default(_to_copy, [2]); _to_copy = None - add = torch.ops.aten.add_.Tensor(expand_copy, a_1); a_1 = None - return expand_copy + copy = torch.ops.aten.copy_.default(diagonal, a_1) + add = torch.ops.aten.add_.Tensor(diagonal, a_1); a_1 = None + return diagonal """) # noqa: B950 def test_expand_symint(self): diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index abb9696225c4..dc512cadea69 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -345,9 +345,8 @@ def forward(self): ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False) slice_1 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) slice_2 = torch.ops.aten.slice.Tensor(slice_1, 1, 2, 9223372036854775807); slice_1 = None + copy = torch.ops.aten.copy_.default(slice_2, ones); slice_2 = ones = None slice_3 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807) - slice_tensor = torch.ops.aten.slice.Tensor(slice_3, 1, 2, 9223372036854775807); slice_3 = None - copy__default = torch.ops.aten.copy_.default(slice_tensor, ones); slice_tensor = ones = None return zeros """) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 188072b3d489..6cddc0f489c5 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -416,6 +416,17 @@ def all_dim(input, dim, keeepdim=False): return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim)) +# NB: this decomposition is not stride accurate, do not put it in the main +# library +@register_decomposition(aten.copy) +def copy(self, src, non_blocking=False): + intermediate = src.to(self, non_blocking) + if self.size() != intermediate.size(): + return aten.expand_copy.default(intermediate, self.size()) + else: + return intermediate + + @register_decomposition(aten.hardswish_) def hardswish_(x): return x.copy_(aten.hardswish(x)) From be22b5d39f37aa501d07fa3ff3b9448826d48eca Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Nov 2022 11:05:38 -0800 Subject: [PATCH 436/453] [18/N] Add allgather_coalesced custom op with CPU/CUDA implementations (#89317) Differential Revision: [D41415321](https://our.internmc.facebook.com/intern/diff/D41415321) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89317 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_gloo.py | 14 +++++++++++++ torch/csrc/distributed/c10d/Ops.cpp | 28 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 6 ++++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 26 +++++++++++++++++++++++ torch/csrc/distributed/c10d/init.cpp | 8 ++++++- 5 files changed, 81 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index c0a25fff9d82..545f125527af 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2417,6 +2417,20 @@ def test_collectives(self): def test_allreduce_coalesced(self): self._test_allreduce_coalesced(backend="gloo") + @requires_gloo() + def test_allgather_coalesced(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "gloo", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + input_tensor = torch.ones(10, 10, dtype=torch.float32) + output_tensor_list = [torch.zeros_like(input_tensor)] + dist.all_gather_coalesced([output_tensor_list], [input_tensor]) + self.assertEqual(output_tensor_list, [input_tensor]) + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 5d343c344ec8..4edb70c413bf 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -95,6 +95,15 @@ c10::intrusive_ptr _allgather_base_( return process_group->_allgather_base(output_tensor, input_tensor); } +c10::intrusive_ptr allgather_coalesced_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_( const std::vector& output_tensors, const std::vector>& input_tensors, @@ -220,6 +229,10 @@ TORCH_LIBRARY(c10d, m) { m.def( "_allgather_base_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_)); + m.def( + "allgather_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allgather_coalesced_)); m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); @@ -345,6 +358,21 @@ c10::intrusive_ptr _allgather_base( return op.call(output_tensor, input_tensor, process_group); } +c10::intrusive_ptr allgather_coalesced( + const c10::intrusive_ptr& process_group, + const std::vector>& output_lists, + const std::vector& input_list, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_coalesced_", "") + .typed( + const std::vector>&, + const std::vector&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_lists, input_list, process_group); +} + c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index f6425e0ea350..ad6e2d3573ee 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -38,6 +38,12 @@ TORCH_API c10::intrusive_ptr _allgather_base( at::Tensor& inputTensor, const AllgatherOptions& opts = {}); +TORCH_API c10::intrusive_ptr allgather_coalesced( + const c10::intrusive_ptr& process_group, + const std::vector>& output_lists, + const std::vector& input_list, + const AllgatherOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index c3db5c438124..66269db1eae8 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -225,6 +225,24 @@ c10::intrusive_ptr _allgather_base_cuda_( return process_group->_allgather_base(output_tensor, input_tensor); } +c10::intrusive_ptr allgather_coalesced_cpu_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + +c10::intrusive_ptr allgather_coalesced_cuda_( + const std::vector>& output_lists, + const std::vector& input_list, + const c10::intrusive_ptr& process_group) { + return process_group->allgather_coalesced( + const_cast>&>(output_lists), + const_cast&>(input_list)); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_cpu_( const std::vector& output_tensors, @@ -457,6 +475,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("_allgather_base_", _allgather_base_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allgather_coalesced_", allgather_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allgather_coalesced_", allgather_coalesced_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_scatter_", reduce_scatter_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index ae98000112fc..f65354d97f97 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1310,7 +1310,13 @@ that adds a prefix to each key inserted to the store. .def( "allgather_coalesced", - &::c10d::ProcessGroup::allgather_coalesced, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector>& output_lists, + const std::vector& input_list, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::allgather_coalesced( + self, output_lists, input_list, opts); + }, py::arg("output_lists"), py::arg("input_list"), py::arg("opts") = ::c10d::AllgatherOptions(), From 5797f74924d1f19cbb10e689a0f8112665fc07d9 Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Mon, 21 Nov 2022 11:05:39 -0800 Subject: [PATCH 437/453] [19/N] Add monitored_barrier custom op with CPU implementation (#89318) Differential Revision: [D41415324](https://our.internmc.facebook.com/intern/diff/D41415324) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89318 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_gloo.py | 11 ++++++++ torch/csrc/distributed/c10d/Ops.cpp | 37 +++++++++++++++++++++++++ torch/csrc/distributed/c10d/Ops.hpp | 5 ++++ torch/csrc/distributed/c10d/OpsImpl.cpp | 15 ++++++++++ torch/csrc/distributed/c10d/init.cpp | 2 +- 5 files changed, 69 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 545f125527af..bee76e788d19 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2431,6 +2431,17 @@ def test_allgather_coalesced(self): dist.all_gather_coalesced([output_tensor_list], [input_tensor]) self.assertEqual(output_tensor_list, [input_tensor]) + @requires_gloo() + def test_monitored_barrier(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "gloo", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + dist.monitored_barrier() + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 4edb70c413bf..6b4717a8e1d1 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -181,6 +181,17 @@ c10::intrusive_ptr barrier( BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); } +void monitored_barrier_( + at::Tensor /* unused */, + const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, + const std::vector& device_ids, + int64_t timeout, + bool wait_all_ranks) { + process_group->monitoredBarrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, + wait_all_ranks); +} + c10::intrusive_ptr send( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -255,6 +266,10 @@ TORCH_LIBRARY(c10d, m) { m.def( "barrier", dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier)); + m.def( + "monitored_barrier_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, monitored_barrier_)); m.def("send", dispatch(c10::DispatchKey::CompositeExplicitAutograd, send)); m.def("recv_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_)); } @@ -497,6 +512,28 @@ c10::intrusive_ptr alltoall( output_tensors, input_tensors, process_group, opts.timeout.count()); } +void monitored_barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts, + bool wait_all_ranks) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::monitored_barrier_", "") + .typed&, + const std::vector&, + int64_t, + bool)>(); + // Default to using cpu implementation, monitored barrier is only for GLOO + at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); + op.call( + tensor, + process_group, + opts.device_ids, + opts.timeout.count(), + wait_all_ranks); +} + c10::intrusive_ptr barrier( const c10::intrusive_ptr& process_group, const BarrierOptions& opts) { diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index ad6e2d3573ee..b5426039f01e 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -83,6 +83,11 @@ TORCH_API c10::intrusive_ptr barrier( const c10::intrusive_ptr& process_group, const BarrierOptions& opts = {}); +TORCH_API void monitored_barrier( + const c10::intrusive_ptr& process_group, + const BarrierOptions& opts, + bool waitAllRanks); + TORCH_API c10::intrusive_ptr send( const c10::intrusive_ptr& process_group, at::TensorList tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 66269db1eae8..31386695a132 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -399,6 +399,17 @@ c10::intrusive_ptr barrier_cuda( BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); } +void monitored_barrier_cpu_( + at::Tensor /* unused */, + const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, + const std::vector& device_ids, + int64_t timeout, + bool wait_all_ranks) { + process_group->monitoredBarrier( + BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}, + wait_all_ranks); +} + // register functions to dispatcher namespace { TORCH_LIBRARY_IMPL(c10d, CPU, m) { @@ -531,6 +542,10 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("barrier", barrier_cuda); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("monitored_barrier_", monitored_barrier_cpu_); +} + } // namespace } // namespace ops diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index f65354d97f97..9a9699c5e12f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1539,7 +1539,7 @@ that adds a prefix to each key inserted to the store. bool waitAllRanks) { ::c10d::BarrierOptions opts; opts.timeout = timeout; - return self->monitoredBarrier(opts, waitAllRanks); + return ::c10d::ops::monitored_barrier(self, opts, waitAllRanks); }, py::arg("timeout") = ::c10d::kUnsetTimeout, py::arg("wait_all_ranks") = false, From 2823fc5e4c73a36ae1859889d34f4cc0d4145ae5 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 22 Nov 2022 00:30:12 +0000 Subject: [PATCH 438/453] [inductor] generate nan in the cpp backend (#89289) Summary: Fixes https://github.com/pytorch/torchdynamo/issues/1797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89289 Approved by: https://github.com/ngimel, https://github.com/jansel, https://github.com/jgong5 --- test/inductor/test_torchinductor.py | 11 +++++++++++ test/inductor/test_torchinductor_opinfo.py | 1 - torch/_inductor/codegen/cpp.py | 5 +++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed68c2844236..4f672afff80a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3367,6 +3367,17 @@ def fn(x): ], ) + def test_isinf2(self): + def fn(x): + y = torch.tensor( + [1, float("inf"), 2, float("-inf"), float("nan")], device=self.device + ) + return x == y + + self.common( + fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),) + ) + def test_any(self): def fn(x): return ( diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 188fcd8b67dc..8d2ac24afb7e 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -195,7 +195,6 @@ def process(device_type): "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, "linalg.pinv.singular": {f32, f64}, - "logdet": {f32, f64}, "masked.norm": {f16}, "masked.normalize": {f16}, "masked_fill": {f16}, diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 3568cfdc08ef..c7e40899c86f 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1,6 +1,7 @@ import contextlib import dataclasses import functools +import math from copy import deepcopy from pathlib import Path from typing import Dict, List @@ -268,6 +269,8 @@ def constant(val, dtype): quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif val == float("-inf"): quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif math.isnan(val): + quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" elif val is True or val is False: quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})" else: @@ -459,6 +462,8 @@ def constant(val, dtype): return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" elif val == float("-inf"): return f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()" + elif math.isnan(val): + return f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()" elif val is True or val is False: return ops.to_dtype(str(val).lower(), dtype) return ops.to_dtype(repr(val), dtype) From c4e08387c1542eca67dc6e40661a50006bc879ff Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 21 Nov 2022 14:19:02 -0800 Subject: [PATCH 439/453] [quant][fx] Support producing reference quantized patterns for dynamic quantization (#89248) Summary: split the is_decomposed logic for `_replace_observer_with_quantize_dequantize_node` in a separate function and added support for dynamic quantization in the decomposed version of this function. In case of dynamic quantization, we'll produce the following reference quantized pattern in decomposed mode: ``` x -> choose_qparams -> quantize_per_tensor -> dequantize_per_tensor -> linear ``` Test Plan: python test/test_quantization.py -k test__convert_to_reference_decomposed_fx_dynamic_quant Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/89248 Approved by: https://github.com/vkuzo --- test/quantization/fx/test_quantize_fx.py | 35 +++ torch/ao/quantization/fx/convert.py | 338 ++++++++++++++++++----- 2 files changed, 310 insertions(+), 63 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index bab4467894e2..d31641ec2ae3 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -159,6 +159,7 @@ LinearReluModel, QuantizationTestCase, skipIfNoFBGEMM, + skipIfNoQNNPACK, skip_if_no_torchvision, train_one_epoch, run_ddp, @@ -5342,6 +5343,40 @@ def forward(self, x): res = m(*example_inputs) self.assertEqual(res, res_ref) + @skipIfNoQNNPACK + def test__convert_to_reference_decomposed_fx_dynamic_quant(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(5, 10) + + def forward(self, x): + return self.linear(x) + + # to avoid reduce_range + with override_quantized_engine("qnnpack"): + m = M().eval() + qconfig_mapping = get_default_qconfig_mapping("fbgemm") \ + .set_object_type(torch.nn.Linear, default_dynamic_qconfig) + example_inputs = (torch.randn(1, 5),) + m = prepare_fx(m, qconfig_mapping, example_inputs) + m(*example_inputs) + m_ref = copy.deepcopy(m) + m_ref = convert_to_reference_fx(m_ref) + m = _convert_to_reference_decomposed_fx(m) + expected_occurrence = { + ns.call_function(torch.ops.quantized_decomposed.choose_qparams.tensor): 1, + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.tensor): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor): 1, + } + self.checkGraphModuleNodes( + m, + expected_node_occurrence=expected_occurrence) + # make sure it runs + res_ref = m_ref(*example_inputs) + res = m(*example_inputs) + self.assertEqual(res, res_ref) + def test_change_backend_config_for_fixed_qparam_ops(self): """ Making sure we can skip validation of qconfigs for fixedqparam ops based on BackendConfig diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index faa267c492c6..e7e0b482356a 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -74,7 +74,7 @@ from .lower_to_fbgemm import lower_to_fbgemm # importing the lib so that the quantized_decomposed ops are registered from ._decomposed import quantized_decomposed_lib # noqa: F401 - +import operator # TODO: revisit this list. Many helper methods shouldn't be public __all__ = [ @@ -91,27 +91,29 @@ "run_weight_observers", ] -def _replace_observer_with_quantize_dequantize_node( +def _replace_observer_with_quantize_dequantize_node_decomposed( model: torch.nn.Module, graph: Graph, node: Node, modules: Dict[str, torch.nn.Module], node_name_to_scope: Dict[str, Tuple[str, type]], - node_name_to_qconfig: Dict[str, QConfigAny], - is_decomposed: bool) -> None: + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: """ Replace activation_post_process module call node with quantize and - dequantize node + dequantize node working with decomposed Tensor Before: ... -> observer_0(x) -> ... After: - ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + ... -> torch.ops.quantized_decomposed.quantize_per_tensor(x, ...) -> + torch.ops.quantized_decomposed.dequantize_per_tensor() -> ... + + or quantize_per_channel and dequantize_per_channel """ assert modules is not None assert isinstance(node.target, str) module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) activation_post_process = modules[node.target] - # Skip replacing observers to quant/dequant nodes if the qconfigs of all + # skip replacing observers to quant/dequant nodes if the qconfigs of all # consumers and producers of this observer are None skip_replacement = all([ has_none_qconfig(n, node_name_to_qconfig) for n in @@ -124,89 +126,294 @@ def _replace_observer_with_quantize_dequantize_node( graph.erase_node(node) return - # otherwise, we can convert the observer module call to quantize/dequantize node + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + # 1. extract the information from activation_post_process module for generating # the quantize and dequantize operator dtype = activation_post_process.dtype # type: ignore[attr-defined] compute_dtype = None if hasattr(activation_post_process, "compute_dtype"): compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] - quantize_op : Optional[Union[Callable, str]] = None if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ not hasattr(activation_post_process, 'compute_dtype'): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract information for inserting q/dq node from activation_post_process node_type = "call_function" + quantize_op : Optional[Callable] = None + scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] + if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] + raise NotImplementedError("decomposed quantize_per_channel op not implemented yet") + else: + scale = float(scale) + zero_point = int(zero_point) + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + dtype_ = to_underlying_dtype(dtype) + qparams = { + "_scale_": scale, + "_zero_point_": zero_point, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: + # TODO(future PR): switch compute_dtype to is_dynamic + + # uint8/int8/fp16 dynamic quantization + + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.tensor + # we only use choose_qparams for is_decomposed now, + # but we should probably align the non-decomposed path with this as well, + # and that can be done after we remove reduce_range flag + # 1. extract qparams from activation_post_process module + dtype_ = to_underlying_dtype(dtype) + assert dtype_ in [torch.uint8, torch.int8], \ + "only uint8 and int8 are supported in reference flow for " \ + "dynamic quantization right now" + quant_min = activation_post_process.quant_min # type: ignore[attr-defined] + quant_max = activation_post_process.quant_max # type: ignore[attr-defined] + # note: scale and zero_point are missing for quantize_per_tensor op + # we'll need to get this from choose_qparams op, which we'll add after + # this step + qparams = { + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype_ + } + + # 2. insert choose_qparams op and update the qparams list + with graph.inserting_before(node): + input_node = node.args[0] + choose_qparams_op_inputs = [node.args[0]] + for key, value in qparams.items(): + # we have quant_min, quant_max and dtype, all should be stored + # as literals + choose_qparams_op_inputs.append(value) + choose_qparams_node = graph.create_node( + "call_function", + torch.ops.quantized_decomposed.choose_qparams.tensor, + tuple(choose_qparams_op_inputs), + {} + ) + # choose_qparms returns (scale, zero_point) + scale_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 0), + {} + ) + zero_point_node = graph.create_node( + "call_function", + operator.getitem, + (choose_qparams_node, 1), + {} + ) + quant_min = qparams["_quant_min_"] + quant_max = qparams["_quant_max_"] + dtype = qparams["_dtype_"] + qparams = { + "_scale_": scale_node, + "_zero_point_": zero_point_node, + "_quant_min_": quant_min, + "_quant_max_": quant_max, + "_dtype_": dtype + } + + # 3. replace activation_post_process node to quantize and dequantize node + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # in this case we have a node in the graph since it's dynamically + # computed from the input, with choose_qparams op + qparam_node = value_or_node + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we + # store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + # need to use the tensor variant of this op, since scale and zero_point + # from choose_qparam are Tensors, instead of float/int, this is to + # prevent these nodes being traced away by downstream systems + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.tensor + dequantized_node = graph.call_function( + dequantize_op, + tuple(dq_inputs), + {} + ) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + elif dtype == torch.float16: + raise NotImplementedError("decomposed to float16 op not implemented yet") + + # should not reach since we have checks in the begining to make sure the + # activation_post_process is supported + +def _replace_observer_with_quantize_dequantize_node( + model: torch.nn.Module, + graph: Graph, + node: Node, + modules: Dict[str, torch.nn.Module], + node_name_to_scope: Dict[str, Tuple[str, type]], + node_name_to_qconfig: Dict[str, QConfigAny]) -> None: + """ Replace activation_post_process module call node with quantize and + dequantize node + + Before: + ... -> observer_0(x) -> ... + After: + ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... + """ + assert modules is not None + assert isinstance(node.target, str) + module_path, prefix = get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig) + activation_post_process = modules[node.target] + # skip replacing observers to quant/dequant nodes if the qconfigs of all + # consumers and producers of this observer are None + skip_replacement = all([ + has_none_qconfig(n, node_name_to_qconfig) for n in + list(node.args) + list(node.users.keys())]) + if skip_replacement or not _is_conversion_supported(activation_post_process): + # didn't find correponding quantize op and info for the activation_post_process + # so we just remove the observer + with graph.inserting_before(node): + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + return + + # otherwise, we can convert the activation_post_process module call to quantize/dequantize node + dtype = activation_post_process.dtype # type: ignore[attr-defined] + compute_dtype = None + if hasattr(activation_post_process, "compute_dtype"): + compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] + + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ + not hasattr(activation_post_process, "compute_dtype"): + # TODO: probably should cleanup this condition check, it's hard + # to reason about this if and the following elif + + # uint8/int8/int32 static quantization branch + + # 1. extract the information from activation_post_process module for generating + # the quantize and dequantize operator + node_type = "call_function" + quantize_op : Optional[Callable] = None scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] if is_per_channel(activation_post_process.qscheme): # type: ignore[attr-defined] ch_axis = int(activation_post_process.ch_axis) # type: ignore[attr-defined] qparams = {"_scale_": scale, "_zero_point_": zero_point, "_axis_": ch_axis, "_dtype_": dtype} - if is_decomposed: - raise NotImplementedError("decomposed quantize_per_channel op not implemented yet") - else: - quantize_op = torch.quantize_per_channel + quantize_op = torch.quantize_per_channel else: scale = float(scale) zero_point = int(zero_point) - if is_decomposed: - quant_min = activation_post_process.quant_min # type: ignore[attr-defined] - quant_max = activation_post_process.quant_max # type: ignore[attr-defined] - dtype = to_underlying_dtype(dtype) - qparams = { - "_scale_": scale, - "_zero_point_": zero_point, - "_quant_min": quant_min, - "_quant_max": quant_max, - "_dtype_": dtype - } - quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor - else: - qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} - quantize_op = torch.quantize_per_tensor + qparams = {"_scale_": scale, "_zero_point_": zero_point, "_dtype_": dtype} + quantize_op = torch.quantize_per_tensor + + # 2. replace activation_post_process node with quantize and dequantize + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ['_scale_', '_zero_point_']: + # For scale and zero_point values we register them as buffers in the root module. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value( + model, graph, module_path + prefix + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + quantize_op_inputs.append(value_or_node) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) elif compute_dtype in [torch.quint8, torch.qint8, torch.float16]: # TODO(future PR): switch compute_dtype to is_dynamic - # dynamic quantization + + # uint8/int8/fp16 dynamic quantization branch + node_type = "call_function" - if is_decomposed: - raise NotImplementedError("decomposed quantize_per_tensor_dynamic op not implemented yet") - else: - quantize_op = torch.quantize_per_tensor_dynamic + quantize_op = torch.quantize_per_tensor_dynamic # TODO: get reduce range from observer # reduce_range = activation_post_process.reduce_range reduce_range = torch.backends.quantized.engine in ("fbgemm", "x86") qparams = {"_dtype_": compute_dtype, "_reduce_range_": reduce_range} + + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + quantize_op_inputs.append(value) + + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) elif dtype == torch.float16: node_type = "call_method" quantize_op = "to" qparams = {"_dtype_": dtype} - - # 2. replace observer node with quant - dequant node - with graph.inserting_before(node): - input_node = node.args[0] - quantize_op_inputs = [input_node] - for key, value in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ['_scale_', '_zero_point_']: - # For scale and zero_point values we register them as buffers in the root module. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, module_path + prefix + key, value) - quantize_op_inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store them as literals in the graph. + with graph.inserting_before(node): + input_node = node.args[0] + quantize_op_inputs = [input_node] + for key, value in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself quantize_op_inputs.append(value) - quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) - if is_decomposed: - # use the same qparams from quantize op - dq_inputs = [quantized_node] + quantize_op_inputs[1:] - dequantized_node = graph.call_function( - torch.ops.quantized_decomposed.dequantize_per_tensor, - tuple(dq_inputs), - {} - ) - else: + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) dequantized_node = graph.call_method("dequantize", args=(quantized_node,)) - node.replace_all_uses_with(dequantized_node) - graph.erase_node(node) + node.replace_all_uses_with(dequantized_node) + graph.erase_node(node) + + # should not reach since we have checks in the begining to make sure the + # activation_post_process is supported # this is a temporary hack for custom module, we may want to implement # this properly after the custom module class design is finalized @@ -792,9 +999,14 @@ def convert( if observed_node in statically_quantized_custom_module_nodes: _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) else: - _replace_observer_with_quantize_dequantize_node( - model, model.graph, node, modules, node_name_to_scope, - node_name_to_qconfig, is_decomposed) + if is_decomposed: + _replace_observer_with_quantize_dequantize_node_decomposed( + model, model.graph, node, modules, node_name_to_scope, + node_name_to_qconfig) + else: + _replace_observer_with_quantize_dequantize_node( + model, model.graph, node, modules, node_name_to_scope, + node_name_to_qconfig) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph) elif is_observed_standalone_module(mod): From 9c0bf9387c1e39efda268a1fb300e8f87b7ef0e6 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Tue, 22 Nov 2022 13:33:55 +0000 Subject: [PATCH 440/453] Meta impl for linalg_cholesky and linalg_cholesky_ex (#89430) Pull Request resolved: https://github.com/pytorch/pytorch/pull/89430 Approved by: https://github.com/ezyang --- test/functorch/test_aotdispatch.py | 2 +- test/test_proxy_tensor.py | 2 -- torch/_meta_registrations.py | 48 +++++++++++++++++++++++++++++- torch/_prims_common/__init__.py | 7 +++-- 4 files changed, 53 insertions(+), 6 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 84b1ba893cce..648dc04dc522 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -1041,7 +1041,7 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('inner', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kron', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta functio... + xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default xfail('linalg.cond', ''), # Cannot call numel() on tensor with symbolic sizes/strides xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fa04c57d9426..21142f56e729 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1201,8 +1201,6 @@ def f(a, b, c, d, e): xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition - xfail('linalg.cholesky', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition - xfail('linalg.cholesky_ex', ''), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9849df0a58af..6232462ede21 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -15,6 +15,7 @@ ELEMENTWISE_TYPE_PROMOTION_KIND, FloatLike, IntLike, + make_contiguous_strides_for, ) from torch._prims_common.wrappers import out_wrapper @@ -178,7 +179,8 @@ def meta_angle_out(self, out): return out.copy_(torch.angle(self)) -def squareCheckInputs(self, f_name): +# From aten/src/ATen/native/LinearAlgebraUtils.h +def squareCheckInputs(self: Tensor, f_name: str): assert ( self.dim() >= 2 ), f"{f_name}: The input tensor must have at least 2 dimensions." @@ -187,6 +189,22 @@ def squareCheckInputs(self, f_name): ), f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices" +# From aten/src/ATen/native/LinearAlgebraUtils.h +def checkFloatingOrComplex( + t: Tensor, f_name: str, allow_low_precision_dtypes: bool = True +): + dtype = t.dtype + check( + t.is_floating_point() or t.is_complex(), + lambda: f"{f_name}, : Expected a floating point or complex tensor as input. Got , {dtype}", + ) + if allow_low_precision_dtypes: + check( + dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble), + lambda: f"{f_name} : Low precision dtypes not supported. Got {dtype}", + ) + + def checkUplo(uplo: str): uplo_uppercase = uplo.upper() assert ( @@ -206,6 +224,34 @@ def meta_linalg_eigh(self, uplo="L"): return (values, vectors) +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_cholesky_ex.default) +def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False): + squareCheckInputs(A, "linalg.cholesky") + checkFloatingOrComplex(A, "linalg.cholesky") + + A_shape = A.shape + ndim = len(A_shape) + + # L + L_strides = make_contiguous_strides_for(A_shape, False) + L = A.new_empty(A_shape) + L.as_strided_(A_shape, L_strides) + + # infos + infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32) + return L, infos + + +# From aten/src/ATen/native/BatchLinearAlgebra.cpp +@register_meta(aten.linalg_cholesky.default) +def meta_linalg_cholesky(A: Tensor, upper=False): + # All the checks done on info in the corresponding C++ function + # are data dependent, so we skip info computation + L, infos = linalg_cholesky_ex(A, upper, False) + return L, infos + + # From aten/src/ATen/native/ReflectionPad.cpp @register_meta( [aten.reflection_pad2d_backward.default, aten.replication_pad2d_backward.default] diff --git a/torch/_prims_common/__init__.py b/torch/_prims_common/__init__.py index 6df72f6c158d..a17dad4f2a92 100644 --- a/torch/_prims_common/__init__.py +++ b/torch/_prims_common/__init__.py @@ -1341,15 +1341,17 @@ def reduction_dtypes( result_dtype = torch.bool return computation_dtype, result_dtype - +# This function's logic is borrowed from the following functions defined in C++: +# batched_matrix_contiguous_strides and contiguous_strides def make_contiguous_strides_for( shape: ShapeType, row_major: bool = True ) -> Tuple[int, ...]: """ - Returns the strides of a contriguous tensor if row_major + Returns the strides of a contiguous tensor if row_major If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices This is often used when calling external libraries like BLAS/LAPACK/cuSolver... """ + # contiguous_strides from c10/util/strides.h validate_shape(shape) if not shape: return () @@ -1363,6 +1365,7 @@ def make_contiguous_strides_for( result = tuple(reversed(strides)) + # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h if row_major: return result else: From f4898daaeeea130a009330674726b5492982af85 Mon Sep 17 00:00:00 2001 From: PratsBhatt Date: Tue, 22 Nov 2022 18:00:01 +0000 Subject: [PATCH 441/453] Add cached conda env file for Buck CI workflow (#89422) Fixes - T137631262 Caching conda dependencies for build workflows. Conda dependencies have been gathered from the workflow https://github.com/pytorch/pytorch/blob/master/.github/workflows/_buck-build-test.yml The pull request updates the action from `conda-incubator/setup-miniconda@v2` to `pytorch/test-infra/.github/actions/setup-miniconda@main` as it supports caching. Test Plan: Running the `ciflow/periodic` which runs the ci builds `buck-build-test` workflow. Expected output is to have all the conda dependencies cached. Screenshot 2022-11-22 at 15 44 20 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89422 Approved by: https://github.com/huydhn --- .github/requirements/README.md | 2 ++ .github/requirements/conda-env-Linux-X64 | 10 ++++++++++ .github/workflows/_buck-build-test.yml | 23 ++--------------------- 3 files changed, 14 insertions(+), 21 deletions(-) create mode 100644 .github/requirements/conda-env-Linux-X64 diff --git a/.github/requirements/README.md b/.github/requirements/README.md index a4f3cb75d9a7..7300eee14562 100644 --- a/.github/requirements/README.md +++ b/.github/requirements/README.md @@ -17,6 +17,8 @@ The list of support files are as follows: test jobs to setup the conda environment * conda-env-macOS-X64. This is use by MacOS (x86-64) build and test jobs to setup the conda environment + * conda-env-Linux-X64. This is used by Linux buck build and test jobs + to setup the conda environment * Pip: * pip-requirements-macOS.txt. This is used by MacOS build and test jobs to setup the pip environment diff --git a/.github/requirements/conda-env-Linux-X64 b/.github/requirements/conda-env-Linux-X64 new file mode 100644 index 000000000000..f2b3811263e5 --- /dev/null +++ b/.github/requirements/conda-env-Linux-X64 @@ -0,0 +1,10 @@ +cffi=1.15.1 +cmake=3.22.1 +mkl=2022.1.0 +mkl-include=2022.1.0 +ninja=1.10.2 +numpy=1.23.3 +pyyaml=6.0 +requests=2.28.1 +setuptools=65.5.0 +typing_extensions=4.3.0 diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index f52bb6017c58..07f41299c711 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -21,29 +21,10 @@ jobs: distribution: 'temurin' - name: Setup miniconda - uses: conda-incubator/setup-miniconda@v2 + uses: pytorch/test-infra/.github/actions/setup-miniconda@main with: - auto-update-conda: true python-version: 3.8 - activate-environment: build - - - name: Install dependencies - uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 - with: - timeout_minutes: 10 - max_attempts: 5 - command: | - conda install -y \ - cffi=1.15.1 \ - cmake=3.22.1 \ - mkl=2022.1.0 \ - mkl-include=2022.1.0 \ - ninja=1.10.2 \ - numpy=1.23.3 \ - pyyaml=6.0 \ - requests=2.28.1 \ - setuptools=65.5.0 \ - typing_extensions=4.3.0 + environment-file: .github/requirements/conda-env-${{ runner.os }}-${{ runner.arch }} - name: Install Buck uses: nick-fields/retry@3e91a01664abd3c5cd539100d10d33b9c5b68482 From 7c0bb61291d62c449b78ce4930c27cbbd8ffac92 Mon Sep 17 00:00:00 2001 From: mantaionut Date: Tue, 22 Nov 2022 18:37:14 +0000 Subject: [PATCH 442/453] Force numpy prod to use 64 bit integers on Windows in some tests (#88089) This fixes some prod and masked.prod tests on Windows. np.prod uses int32 on Windows so it overflows. On Linux it uses by default int64. Fixes #77305 Fixes #77320 Fixes #77334 Fixes #77335 Fixes #77336 Fixes #77337 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88089 Approved by: https://github.com/mruberry --- .../_internal/common_methods_invocations.py | 3 ++- .../_internal/opinfo/definitions/_masked.py | 5 ++--- torch/testing/_internal/opinfo/utils.py | 20 +++++++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4b2d0ebabc46..4116d967fd8a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -105,6 +105,7 @@ from torch.testing._internal.opinfo.utils import ( np_unary_ufunc_integer_promotion_wrapper, reference_reduction_numpy, + prod_numpy ) from torch.testing._internal import opinfo from torch.testing._internal.opinfo.definitions.linalg import ( @@ -16468,7 +16469,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), sample_inputs_func=sample_inputs_prod, - ref=reference_reduction_numpy(np.prod), + ref=prod_numpy, skips=( # FIXME: prod does not support passing keepdim without passing dim DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'), diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index f4b590fe2520..5a5ce8bc7e16 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -26,8 +26,7 @@ sample_inputs_reduction, SampleInput, ) -from torch.testing._internal.opinfo.utils import reference_reduction_numpy - +from torch.testing._internal.opinfo.utils import prod_numpy, reference_reduction_numpy # Used for log_softmax, softmax, softmin def sample_inputs_softmax_variant( @@ -434,7 +433,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar ), ReductionOpInfo( "masked.prod", - ref=reference_reduction_numpy(np.prod), + ref=prod_numpy, method_variant=None, identity=1, nan_policy="propagate", diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py index a19da98cd1d0..0bbba7c769d8 100644 --- a/torch/testing/_internal/opinfo/utils.py +++ b/torch/testing/_internal/opinfo/utils.py @@ -258,3 +258,23 @@ def wrapper(x: np.ndarray, *args, **kwargs): return result return wrapper + + +def prod_numpy(a, *args, **kwargs): + """ + The function will call np.prod with type as np.int64 if the input type + is int or uint64 if is uint. This is necessary because windows np.prod uses by default + int32 while on linux it uses int64. + This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320 + + Returns: + np.prod of input + """ + if "dtype" not in kwargs: + if np.issubdtype(a.dtype, np.signedinteger): + a = a.astype(np.int64) + elif np.issubdtype(a.dtype, np.unsignedinteger): + a = a.astype(np.uint64) + + fn = reference_reduction_numpy(np.prod) + return fn(a, *args, **kwargs) From f281f435a8c60cf5781688bee3e4ff258c52344f Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 22 Nov 2022 18:42:13 +0000 Subject: [PATCH 443/453] Fix benchmarks - xla tensor test (#89509) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/89509 Approved by: https://github.com/ngimel, https://github.com/shunting314 --- benchmarks/dynamo/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 3fad203c5d87..a167ab75b53f 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -269,6 +269,7 @@ def print_summary(filename): def tensor_is_on_xla(tensors): if not isinstance(tensors, (tuple, list)): tensors = [tensors] + tensors = [x for x in tensors if isinstance(x, torch.Tensor)] return any(map(lambda x: x.device.type == "xla", tensors)) From ef8b91fec73884f3043da8f541176ab7b4c57364 Mon Sep 17 00:00:00 2001 From: Fuzzkatt Date: Tue, 22 Nov 2022 19:05:56 +0000 Subject: [PATCH 444/453] enable previously failing UCC distributed_test.py tests (#89023) Enables previously failing UCC distributed_test.py tests that are now fixed due to either ProcessGroupUCC barrier blocking fix (https://github.com/pytorch/pytorch/pull/86961) or UCC-side timeout error handling fix: (https://github.com/openucx/ucc/pull/679/files). Bump upstream UCC version to build UCC with timeout error handling fix merged in. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89023 Approved by: https://github.com/kwen2501, https://github.com/malfet --- .circleci/docker/build.sh | 2 +- .../_internal/distributed/distributed_test.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 61d9b73d73df..b41d5fe2c8c1 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -81,7 +81,7 @@ fi TRAVIS_DL_URL_PREFIX="https://s3.amazonaws.com/travis-python-archives/binaries/ubuntu/14.04/x86_64" _UCX_COMMIT=31e74cac7bee0ef66bef2af72e7d86d9c282e5ab -_UCC_COMMIT=12944da33f911daf505d9bbc51411233d0ed85e1 +_UCC_COMMIT=1c7a7127186e7836f73aafbd7697bbc274a77eee # It's annoying to rename jobs every time you want to rewrite a # configuration, so we hardcode everything here rather than do it diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index c67dfc7c40a3..814dd3d5ad5f 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -73,6 +73,7 @@ FILE_SCHEMA, IS_FBCODE, NO_MULTIPROCESSING_SPAWN, + IS_SANDCASTLE, parametrize, sandcastle_skip, sandcastle_skip_if, @@ -3748,7 +3749,7 @@ def _test_barrier_helper( @skip_if_no_gpu @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't supports GPU barrier") - @sandcastle_skip_if(BACKEND == "ucc", "flaky on PyTorch CI with timeout") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") def test_barrier_cuda(self): group, group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) @@ -3842,7 +3843,7 @@ def _test_all_reduce_multigpu_helper( @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") @sandcastle_skip_if(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") - @sandcastle_skip_if(BACKEND == "ucc", "UCC all_reduce multigpu skipped") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_no_gpu def test_all_reduce_multigpu(self): group, group_id, rank = self._init_global_test() @@ -3860,7 +3861,7 @@ def test_all_reduce_multigpu(self): @sandcastle_skip_if(BACKEND == "mpi", "MPI doesn't support broadcast multigpu") @sandcastle_skip_if(BACKEND == "nccl", "CUDA all_reduce multigpu skipped for NCCL") - @sandcastle_skip_if(BACKEND == "ucc", "UCC all_reduce multigpu skipped") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_no_gpu def test_all_reduce_multigpu_complex(self): group, group_id, rank = self._init_global_test() @@ -7717,14 +7718,14 @@ def _test_verify_model_across_rank(self, use_logger): @require_backend(DistTestCases.backend_feature["gpu"]) @require_backends_available(DistTestCases.backend_feature["gpu"]) - @sandcastle_skip_if(BACKEND == "ucc", "test timing out locally with ucc") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_lt_x_gpu(2) def test_verify_model_across_rank_with_logger(self): self._test_verify_model_across_rank(use_logger=True) @require_backend(DistTestCases.backend_feature["gpu"]) @require_backends_available(DistTestCases.backend_feature["gpu"]) - @sandcastle_skip_if(BACKEND == "ucc", "test timing out locally with ucc") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_lt_x_gpu(2) def test_verify_model_across_rank_without_logger(self): self._test_verify_model_across_rank(use_logger=False) @@ -7748,7 +7749,7 @@ def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo): @require_backend(DistTestCases.backend_feature["gpu"]) @require_backends_available(DistTestCases.backend_feature["gpu"]) - @sandcastle_skip_if(BACKEND == "ucc", "test failing locally with UCC") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_lt_x_gpu(2) def test_ddp_model_diff_shape_across_ranks(self): group_gloo = dist.new_group( @@ -7771,7 +7772,7 @@ def test_ddp_model_diff_shape_across_ranks(self): @require_backend(DistTestCases.backend_feature["gpu"]) @require_backends_available(DistTestCases.backend_feature["gpu"]) - @sandcastle_skip_if(BACKEND == "ucc", "test failing locally with UCC") + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") @skip_if_lt_x_gpu(2) def test_ddp_model_diff_num_params_across_ranks(self): group_gloo = dist.new_group( @@ -9185,11 +9186,8 @@ def _test_hook_pickling(self, hook, hook_state): BACKEND not in DistTestCases.backend_feature["cuda"], f"The {BACKEND} backend does not support DDP communication hook on CUDA devices" ) - @sandcastle_skip_if( - BACKEND == "ucc", - "flaky on PyTorch CI: No such file or directory: '/tmp/checkpoint.pt'" - ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @sandcastle_skip_if(BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally") def test_ddp_hook_pickling_powerSGD(self): hook = powerSGD.powerSGD_hook From c2ce79f06eb4a8cec2f9cfbdf3a1a4021a0a4cfa Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 22 Nov 2022 19:33:21 +0000 Subject: [PATCH 445/453] Fix dev-discuss link in the maintainer docs (#89493) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/89493 Approved by: https://github.com/H-Huang --- docs/source/community/persons_of_interest.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index d011250d490d..02224696c61b 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -7,7 +7,7 @@ Responsibilities * Triage and fix high priority issues assigned to the module or library * Triage, review, and land high priority pull requests assigned to the module or library * Answer module or library questions on `discuss.pytorch.org `__ - and `dev-discuss.pytorch.org `__ + and `dev-discuss.pytorch.org `__ * Maintain public user and development documentation * Run meetings and share minutes plus roadmap on a half or quarterly basis From d053d513432bea75ae783529bf9f639f977a47d2 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 22 Nov 2022 20:25:38 +0000 Subject: [PATCH 446/453] (Further) limit world size in test_fsdp_pure_fp16 (#86280) Test still fails when run on 5 A100 GPUs, although it works with 5 V100s. Using 4 GPUs seems to be fine. Followup to #85957 Pull Request resolved: https://github.com/pytorch/pytorch/pull/86280 Approved by: https://github.com/awgu, https://github.com/kit1980 --- test/distributed/fsdp/test_fsdp_pure_fp16.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/distributed/fsdp/test_fsdp_pure_fp16.py b/test/distributed/fsdp/test_fsdp_pure_fp16.py index 1c663f826335..e0033ef3d4b7 100644 --- a/test/distributed/fsdp/test_fsdp_pure_fp16.py +++ b/test/distributed/fsdp/test_fsdp_pure_fp16.py @@ -33,8 +33,8 @@ class TestPureFP16(FSDPTest): @property def world_size(self): - # Test fails due to inaccuracies when using more than 5 GPUs - return min(5, super().world_size) + # Test fails due to inaccuracies when using more than 4 GPUs + return min(4, super().world_size) @skip_if_lt_x_gpu(2) @parametrize( From ac3004757ef64b1ed1ff884a39d2a34cdfb5f772 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 22 Nov 2022 20:27:27 +0000 Subject: [PATCH 447/453] Relax tolerance for test_out_addbmm_cpu_float32 (#86365) The test may fail due to slightly different values caused by different order of matrizes in SGEMM: > Mismatched elements: 1 / 50 (2.0%) > Greatest absolute difference: 1.430511474609375e-05 at index (4, 5) (up to 1e-05 allowed) > Greatest relative difference: 4.65393206065873e-06 at index (4, 5) (up to 1.3e-06 allowed) Observed on POWER (ppc64le) Pull Request resolved: https://github.com/pytorch/pytorch/pull/86365 Approved by: https://github.com/mruberry, https://github.com/kit1980 --- torch/testing/_internal/common_methods_invocations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4116d967fd8a..177dc669469e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8539,6 +8539,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): 'TestConsistency', 'test_output_match', ), + DecorateInfo( + toleranceOverride({torch.float32: tol(atol=1.5e-05, rtol=1e-05)}), + 'TestCommon', 'test_out'), ], skips=( # NVIDIA only assures that bfloat16 is supported by bmm if SM >= 5.3 From 177baf366ad16b868ab19a8776ae0e636f9d1951 Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 22 Nov 2022 20:29:07 +0000 Subject: [PATCH 448/453] Fix vectorized trigonometric functions for VSX (#86453) Replace the remaining hand-written code in vec256_float_vsx.h by calls to Sleef functions similar to what was done in #59382 & #82646 after #41541 This fixes wrong results for e.g. `sin(1e20)`. Fixes #85978 To fix #85978 I only needed to do the sin/cos functions to make the test pass but to not encounter the same issue again and again (see the previous PRs and issues) I checked the whole file for similar functions where a Sleef function could be used and changed those too. In the diff I've noticed the faulty whitespace so to make this complete I fixed that too, so it should now be done. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86453 Approved by: https://github.com/malfet --- .../cpu/vec/vec256/vsx/vec256_float_vsx.h | 224 ++---------------- 1 file changed, 21 insertions(+), 203 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 77cf3695ab91..8fe6cc25f0ee 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -256,29 +256,29 @@ class Vectorized { } Vectorized C10_ALWAYS_INLINE acos() const { - return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; + return {Sleef_acosf4_u10vsx(_vec0), Sleef_acosf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE asin() const { - return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; + return {Sleef_asinf4_u10vsx(_vec0), Sleef_asinf4_u10vsx(_vec1)}; } Vectorized atan() const { - return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; + return {Sleef_atanf4_u10vsx(_vec0), Sleef_atanf4_u10vsx(_vec1)}; } Vectorized atan2(const Vectorized& b) const { - return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; + return {Sleef_atan2f4_u10vsx(_vec0, b._vec0), Sleef_atan2f4_u10vsx(_vec1, b._vec1)}; } Vectorized copysign(const Vectorized &sign) const { return {Sleef_copysignf4_vsx(_vec0, sign._vec0), Sleef_copysignf4_vsx(_vec1, sign._vec1)}; } Vectorized lgamma() const { - return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; + return {Sleef_lgammaf4_u10vsx(_vec0), Sleef_lgammaf4_u10vsx(_vec1)}; } Vectorized erf() const { - return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; + return {Sleef_erff4_u10vsx(_vec0), Sleef_erff4_u10vsx(_vec1)}; } Vectorized erfc() const { - return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; + return {Sleef_erfcf4_u15vsx(_vec0), Sleef_erfcf4_u15vsx(_vec1)}; } Vectorized erfinv() const { @@ -301,133 +301,32 @@ class Vectorized { } Vectorized C10_ALWAYS_INLINE exp() const { - // implementation logic from avx_mathfun with some modifications from sleef - // Express e**x = e**g 2**n - /// = e**g e**( n loge(2) ) - /// = e**( g + n loge(2) ) - // - auto tmp_x = *this; - auto fx = (tmp_x * log2e_inv).round(); - - auto x = fx.madd(negln2f_hi, tmp_x); - x = fx.madd(negln2f_lo, x); - auto z = x * x; - auto y = x.madd(exp_p0, exp_p1); - y = y.madd(x, exp_p2); - y = y.madd(x, exp_p3); - y = y.madd(x, exp_p4); - y = y.madd(x, exp_p5); - y = y.madd(z, x) + one; - - // vm_pow2n 2^n - vint32 imm0 = vec_signed(fx._vec0); - vint32 imm1 = vec_signed(fx._vec1); - // this pow2n logic is from Sleef code - vint32 imm00 = imm0 >> 1; //>>1 - vint32 imm01 = imm1 >> 1; - vint32 imm10 = imm0 - imm00; - vint32 imm11 = imm1 - imm01; - imm00 = (imm00 + v0x7f) << vu_23; - imm01 = (imm01 + v0x7f) << vu_23; - imm10 = (imm10 + v0x7f) << vu_23; - imm11 = (imm11 + v0x7f) << vu_23; - // treat imm as float vector without conversion - - y._vec0 = (y._vec0 * (vfloat32)imm00) * (vfloat32)imm10; - y._vec1 = (y._vec1 * (vfloat32)imm01) * (vfloat32)imm11; - // boundary check - auto tmp = blendv(y, v_inf, (Vectorized(exp_hi) <= tmp_x)); - y = blendv(tmp, zero, (tmp_x < Vectorized(exp_lo))); - - return y; + return {Sleef_expf4_u10vsx(_vec0), Sleef_expf4_u10vsx(_vec1)}; } Vectorized expm1() const { - return exp() - one; + return {Sleef_expm1f4_u10vsx(_vec0), Sleef_expm1f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log() const { return {Sleef_logf4_u10vsx(_vec0), Sleef_logf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log10() const { - return {Sleef_log10f4_u10vsx(_vec0), Sleef_log10f4_u10vsx(_vec1)}; + return {Sleef_log10f4_u10vsx(_vec0), Sleef_log10f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log1p() const { - return {Sleef_log1pf4_u10vsx(_vec0), Sleef_log1pf4_u10vsx(_vec1)}; + return {Sleef_log1pf4_u10vsx(_vec0), Sleef_log1pf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE log2() const { - return {Sleef_log2f4_u10vsx(_vec0), Sleef_log2f4_u10vsx(_vec1)}; + return {Sleef_log2f4_u10vsx(_vec0), Sleef_log2f4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE ceil() const { return {vec_ceil(_vec0), vec_ceil(_vec1)}; } Vectorized C10_ALWAYS_INLINE cos() const { - // take the absolute value - auto x = abs(); - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - // scale by 4/Pi - auto y = x * _4div_pi; - // store the integer part of y in mm0 - // j=(j+1) & (~1) (see the cephes sources) - vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; - vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; - y._vec0 = vec_float(imm0); - y._vec1 = vec_float(imm1); - - imm0 = imm0 - vi_2; - imm1 = imm1 - vi_2; - Vectorized poly_mask; - // get the swap sign flag - vint32 tmp0 = vec_and(vec_nand(imm0, imm0), vi_4); - vint32 tmp1 = vec_and(vec_nand(imm1, imm1), vi_4); - sign_bit._vecb0 = (vbool32)vec_sl(tmp0, vu_29); - sign_bit._vecb1 = (vbool32)vec_sl(tmp1, vu_29); - // get the polynom selection mask - // there is one polynom for 0 <= x <= Pi / 4 - // and another one for Pi / 4 < x <= Pi / 2 - // Both branches will be computed. - - poly_mask._vecb0 = (vbool32)vec_cmpeq((imm0 & vi_2), vi_0); - poly_mask._vecb1 = (vbool32)vec_cmpeq((imm1 & vi_2), vi_0); - - // The magic pass: "Extended precision modular arithmetic" - // x = ((x - y * DP1) - y * DP2) - y * DP3; - x = y.madd(minus_cephes_dp1, x); - x = y.madd(minus_cephes_dp2, x); - x = y.madd(minus_cephes_dp3, x); - - // Evaluate the first polynom (0 <= x <= Pi/4) - auto z = x * x; - y = z.madd(coscof_p0, coscof_p1); - y = y.madd(z, coscof_p2); - y = y * z * z; - y = y - z * half + one; - - // Evaluate the second polynom (Pi/4 <= x <= 0) - auto y_2 = z.madd(sincof_p0, sincof_p1); - y_2 = y_2.madd(z, sincof_p2); - y_2 = y_2 * z; - y_2 = y_2.madd(x, x); - - // select the correct result from the two polynoms - y = blendv(y, y_2, poly_mask); - // update the sign - y = y ^ sign_bit; - - return y; + return {Sleef_cosf4_u10vsx(_vec0), Sleef_cosf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE cosh() const { - // cosh = 1/2 * (e^x + e^-x) - auto x = abs(); - auto e_x = x.exp(); - auto ret = (e_x + Vectorized(one) / e_x) * half; - // inf and nan checks -#if 0 - ret = blendv(ret, v_inf, x >= vf_89); - ret = blendv(ret, v_inf, ret.isnan()); - ret = blendv(ret, v_nan, this->isnan()); -#endif - return ret; + return {Sleef_coshf4_u10vsx(_vec0), Sleef_coshf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE floor() const { return {vec_floor(_vec0), vec_floor(_vec1)}; @@ -440,97 +339,16 @@ class Vectorized { return {vec_round(_vec0), vec_round(_vec1)}; } Vectorized C10_ALWAYS_INLINE sin() const { - // take the absolute value and xtract sign - auto x = abs(); - auto sign_bit = (*this) & sign_mask; - - // scale by 4/Pi - auto y = x * _4div_pi; - // store the integer part of y in mm0 - - // j=(j+1) & (~1) (see the cephes sources) - vint32 imm0 = (vec_signed(y._vec0) + vi_1) & vi_inv1; - vint32 imm1 = (vec_signed(y._vec1) + vi_1) & vi_inv1; - y._vec0 = vec_float(imm0); - y._vec1 = vec_float(imm1); - // get the swap sign flag - Vectorized swap_sign_bit, poly_mask; - swap_sign_bit._vecb0 = (vbool32)vec_sl(imm0 & vi_4, vu_29); - swap_sign_bit._vecb1 = (vbool32)vec_sl(imm1 & vi_4, vu_29); - // get the polynom selection mask - // there is one polynom for 0 <= x <= Pi/4 - // and another one for Pi/4 C10_ALWAYS_INLINE sinh() const { - auto temp_abs = abs(); - // get exponent - auto ret = temp_abs.exp(); - auto recp = Vectorized(half) / ret; - auto v = ret * half - recp; - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - auto z = temp_abs * temp_abs; - auto y = z.madd(p0, p1); - y = y.madd(z, p2); - y = (y * z).madd(temp_abs, temp_abs); - // check and select - auto result = blendv(y, v, temp_abs > one); - return result | sign_bit; + return {Sleef_sinhf4_u10vsx(_vec0), Sleef_sinhf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE tan() const { - return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; + return {Sleef_tanf4_u10vsx(_vec0), Sleef_tanf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE tanh() const { - auto x = *this; - auto vabs = abs(); - // get exponent - auto exp2x = (vabs + vabs).exp(); - auto vv = Vectorized(one) - Vectorized(two) / (exp2x + one); - // extract the sign bit (upper one) - auto sign_bit = (*this) & sign_mask; - auto z = vabs * vabs; - auto y = z.madd(tanh_p0, tanh_p1); - auto tmp = y.madd(z, tanh_p2); - y = z.madd(tmp, tanh_p3); - tmp = y.madd(z, tanh_p4); - y = tmp * z; - tmp = y.madd(x, x); - // add sign - vv = vv | sign_bit; - // check and select - auto sel_mask = vabs >= tanh_0p625; - auto max_mask = vabs > tanh_half_max; - auto max_ret = sign_bit ^ one; - return blendv(blendv(tmp, vv, sel_mask), max_ret, max_mask); + return {Sleef_tanhf4_u10vsx(_vec0), Sleef_tanhf4_u10vsx(_vec1)}; } Vectorized C10_ALWAYS_INLINE trunc() const { return {vec_trunc(_vec0), vec_trunc(_vec1)}; @@ -555,15 +373,15 @@ class Vectorized { } Vectorized fmod(const Vectorized& b) const { - return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; + return {Sleef_fmodf4_vsx(_vec0, b._vec0),Sleef_fmodf4_vsx(_vec1, b._vec1)}; } Vectorized hypot(const Vectorized& b) const { - return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; + return {Sleef_hypotf4_u05vsx(_vec0, b._vec0), Sleef_hypotf4_u05vsx(_vec1, b._vec1)}; } Vectorized nextafter(const Vectorized& b) const { - return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; + return {Sleef_nextafterf4_vsx(_vec0, b._vec0), Sleef_nextafterf4_vsx(_vec1, b._vec1)}; } Vectorized igamma(const Vectorized& x) const { From 77d7f2c65945438e0292b270998cea07c0d9d3d8 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 22 Nov 2022 21:17:36 +0000 Subject: [PATCH 449/453] [dashboard] Add commit date & fix date related issues (#89517) Add commit date to build summary of dashboard. Make the date of the run reflective of when the run started, not when the run ended. Use PST (UTC -8) to determine day, rather than GMT (UTC +0). Test comment: https://github.com/pytorch/torchdynamo/issues/1831#issuecomment-1324176119 Pull Request resolved: https://github.com/pytorch/pytorch/pull/89517 Approved by: https://github.com/anijain2305 --- benchmarks/dynamo/runner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/dynamo/runner.py b/benchmarks/dynamo/runner.py index 963dcf493705..f39d8dbab05f 100755 --- a/benchmarks/dynamo/runner.py +++ b/benchmarks/dynamo/runner.py @@ -39,7 +39,7 @@ import sys import tempfile from collections import defaultdict -from datetime import datetime +from datetime import datetime, timedelta, timezone from os.path import abspath, exists from random import randint @@ -345,7 +345,9 @@ def print_commit_hash(path, name): if exists(path): repo = git.Repo(path, search_parent_directories=True) sha = repo.head.object.hexsha + date = repo.head.object.committed_datetime out_io.write(f"{name} commit: {sha}\n") + out_io.write(f"{name} commit date: {date}\n") else: out_io.write(f"{name} Absent\n") @@ -409,8 +411,9 @@ def archive_data(archive_name): else: day = "000" else: - day = datetime.today().strftime("%j") - prefix = datetime.today().strftime(f"day_{day}_%d_%m_%y") + now = datetime.now(tz=timezone(timedelta(hours=-8))) + day = now.strftime("%j") + prefix = now.strftime(f"day_{day}_%d_%m_%y") return day, prefix @@ -1297,6 +1300,9 @@ def extract(key): parse_logs(args, dtypes, suites, devices, compilers, flag_compilers, output_dir) elif args.run: generate_commands(args, dtypes, suites, devices, compilers, output_dir) + # generate memoized archive name now so that the date is reflective + # of when the run started + get_archive_name(args, dtypes[0]) # TODO - Do we need to worry about segfaults try: os.system("bash run.sh") From 00b7d8ef237f4f0fc3d247e016d504095b415d1f Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 22 Nov 2022 21:52:50 +0000 Subject: [PATCH 450/453] Shard windows periodic job more (#89455) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/89455 Approved by: https://github.com/huydhn --- .github/workflows/periodic.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index b5512b20eaae..80ad04c9be32 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -167,8 +167,9 @@ jobs: cuda-version: "11.7" test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 2, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 3, runner: "windows.8xlarge.nvidia.gpu" }, { config: "force_on_cpu", shard: 1, num_shards: 1, runner: "windows.4xlarge" }, ]} From 74e62a1fefb7100689169dc12fd70095de54079d Mon Sep 17 00:00:00 2001 From: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> Date: Tue, 22 Nov 2022 22:15:38 +0000 Subject: [PATCH 451/453] [ROCm] Optimize layer norm backward kernel for ROCm (#87635) We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR #68238](https://github.com/pytorch/pytorch/pull/68238#issue-1051621716) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - https://github.com/pytorch/pytorch/pull/26201 - https://github.com/pytorch/pytorch/pull/27634 - https://github.com/pytorch/pytorch/pull/68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 --------- **At this PR** M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 --------- **Performance Improvement (%)**
M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491
--------- **Benchmark script of this PR** ``` # Ref: # 1. https://github.com/pytorch/pytorch/pull/26201 # 2. https://github.com/pytorch/pytorch/pull/68238 from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") # CVT config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] # https://github.com/pytorch/pytorch/pull/68238#issue-1051621716 config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] # https://github.com/pytorch/pytorch/pull/27634 config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: https://github.com/pytorch/pytorch/pull/87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 234 +++++++++++++++++- 1 file changed, 233 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index fa70f075d4fa..693524818fb4 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -25,6 +25,8 @@ #endif #include +#include + namespace at { namespace native { @@ -832,6 +834,201 @@ void LayerNormKernelImpl( }); } +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + T_ACC* warp_buf1, + T_ACC* warp_buf2, + const T* input, + const T* dout, + const int i1_end, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + T curr_mean = mean[i1]; + T curr_rstd = rstd[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*N+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + T curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_rstd; + } else { + warp_buf1[write_idx] = T(0); + warp_buf2[write_idx] = T(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = T(0); + warp_buf2[write_idx] = T(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + T_ACC* warp_buf1, + T_ACC* warp_buf2, + const T* input, + const T* dout, + const int i1_end, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + T_ACC curr_mean = mean[i1]; + T_ACC curr_rstd = rstd[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*N+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + T_ACC curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_rstd; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const T* __restrict__ dout, + const T* __restrict__ input, + const int64_t M, + const int64_t N, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T_ACC* part_grad_gamma, + T_ACC* part_grad_beta) +{ + const int numsegs_M = (M+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_M + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < M ? i1_beg_plus_one : M; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + alignas(sizeof(double)) extern __shared__ char shared[]; + T_ACC * buf = reinterpret_cast(&shared); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + T_ACC* warp_buf1 = (T_ACC*)buf; + T_ACC* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,N,mean,rstd); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + T_ACC acc1 = T_ACC(0); + T_ACC acc2 = T_ACC(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < N) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*N+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*N+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const T_ACC* part_grad_gamma, + const T_ACC* part_grad_beta, + const int part_size, + const int64_t M, + const int64_t N, + T* grad_gamma, + T* grad_beta) +{ + // sum partial gradients for gamma and beta + alignas(sizeof(double)) extern __shared__ char shared[]; + T_ACC * buf = reinterpret_cast(&shared); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < N) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + T_ACC sum_gamma = T_ACC(0); + T_ACC sum_beta = T_ACC(0); + const T_ACC* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * N + i2; + const T_ACC* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * N + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*N]; + sum_beta += part_grad_beta_ptr[warp_offset*N]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + template void LayerNormBackwardKernelImplInternal( const Tensor& dY, @@ -860,8 +1057,8 @@ void LayerNormBackwardKernelImplInternal( gamma.defined() ? gamma.template data_ptr() : nullptr; T* dX_data = dX->defined() ? dX->template data_ptr() : nullptr; cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); + const int warp_size = at::cuda::warp_size(); if (dX_data != nullptr) { - const int warp_size = at::cuda::warp_size(); const dim3 blocks(M); int nshared = (num_threads()/warp_size) * sizeof(T_ACC); layer_norm_grad_input_kernel<<>>(dY_data, @@ -889,6 +1086,40 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { +#if defined(USE_ROCM) + // For small batch size, do colwise reduce directly. + const int part_size = warp_size; + const dim3 threads2(warp_size, 4, 1); + const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1); + const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + + const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true); + Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); + Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); + cuComputePartGradGammaBeta<<>>( + dY_data, + X_data, + M,N, + mean_data, + rstd_data, + part_grad_gamma.template data_ptr(), + part_grad_beta.template data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm + const dim3 blocks3((N + threads2.x - 1) / threads2.x, 1, 1); + const int nshared3 = threads3.x * threads3.y * sizeof(T); + cuComputeGradGammaBeta<<>>( + part_grad_gamma.template data_ptr(), + part_grad_beta.template data_ptr(), + part_size, + M,N, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +#else if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { // This implementation relies on warp primitives and requires that M and N divide // exactly to warp size. @@ -925,6 +1156,7 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } +#endif } } } From d17ddf06367a871ec551a87c0c81d5f50cb661d5 Mon Sep 17 00:00:00 2001 From: Guobing Chen Date: Mon, 18 Jul 2022 14:46:23 +0800 Subject: [PATCH 452/453] Enable maxpool_2d in NNC With both quantization/non-quantization supported. --- test/cpp/tensorexpr/test_quantization.cpp | 74 ++++++++++++++++ torch/csrc/jit/passes/tensorexpr_fuser.cpp | 1 + torch/csrc/jit/tensorexpr/codegen.cpp | 1 + .../jit/tensorexpr/external_functions.cpp | 84 +++++++++++++++++++ torch/csrc/jit/tensorexpr/lowerings.cpp | 4 + .../jit/tensorexpr/operators/reduction.cpp | 72 ++++++++++++++++ .../csrc/jit/tensorexpr/operators/reduction.h | 6 ++ 7 files changed, 242 insertions(+) diff --git a/test/cpp/tensorexpr/test_quantization.cpp b/test/cpp/tensorexpr/test_quantization.cpp index a34a6e7bd7bd..dd9c32977e8c 100644 --- a/test/cpp/tensorexpr/test_quantization.cpp +++ b/test/cpp/tensorexpr/test_quantization.cpp @@ -449,5 +449,79 @@ TEST_F(Quantization, QuantCatDequantUInt8) { TORCH_CHECK_EQ(check, 1); } +TEST_F(Quantization, QuantMaxPool2dDequantUInt8) { + const auto graph_string = R"IR( + graph(%x : Float(4, 2, 9, 9, strides=[162, 81, 9, 1], device=cpu)): + %1 : int = prim::Constant[value=13]() + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[2, 2]]() + %5 : bool = prim::Constant[value=0]() + %qz : int = prim::Constant[value=52]() + %qs : float = prim::Constant[value=0.0549519]() + %q : QUInt8(4, 2, 9, 9, strides=[162, 81, 9, 1], requires_grad=0, device=cpu) = aten::quantize_per_tensor(%x, %qs, %qz, %1) + %qu : QUInt8(4, 2, 4, 4, strides=[32, 16, 4, 1], requires_grad=0, device=cpu) = aten::max_pool2d(%q, %4, %4, %3, %2, %5) + %6 : Float(4, 2, 4, 4, strides=[32, 16, 4, 1], requires_grad=0, device=cpu) = aten::dequantize(%qu) + return (%6))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({4, 2, 9, 9}, TensorOptions(kCPU).dtype(at::kFloat)); + auto q = at::quantize_per_tensor(x, 0.0549519f, 52, at::kQUInt8); + auto qu = at::max_pool2d(q, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false); + auto y_expected = at::dequantize(qu); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "q:\n" << q << std::endl; + std::cout << "qu:\n" << qu << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + +TEST_F(Quantization, MaxPool2dFloat) { + const auto graph_string = R"IR( + graph(%x : Float(4, 2, 9, 9, strides=[162, 1, 18, 2], device=cpu)): + %2 : int[] = prim::Constant[value=[1, 1]]() + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[2, 2]]() + %5 : bool = prim::Constant[value=0]() + %qu : Float(4, 2, 4, 4, strides=[32, 1, 8, 2], device=cpu) = aten::max_pool2d(%x, %4, %4, %3, %2, %5) + return (%qu))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph); + + auto x = at::rand({4, 2, 9, 9}, TensorOptions(kCPU).dtype(at::kFloat)); + x = x.contiguous(c10::MemoryFormat::ChannelsLast); + auto y_expected = at::max_pool2d(x, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false); + + TensorExprKernel k(graph); + std::vector inputs = {x}; + StmtPtr s = k.getCodeGenStmt(); + + std::vector stack = fmap(inputs); + k.run(stack); + auto y = stack[0].toTensor(); + + bool check = at::allclose(y_expected, y); + if (!check) { + std::cout << "x:\n" << x << std::endl; + std::cout << "y_expected:\n" << y_expected << std::endl; + std::cout << "y:\n" << y << std::endl; + } + TORCH_CHECK_EQ(check, 1); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 00335c92fe80..b50789d0503e 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -98,6 +98,7 @@ bool isSupported(Node* node) { "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", + "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", }; static const OperatorSet supported_misc_set{ "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index e1e7a01875ec..b5df7de91cce 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -245,6 +245,7 @@ std::unordered_map ExtCallMemoryReuse:: {"nnc_aten_quantized_mul_scalar", "nnc_aten_quantized_mul_scalar_out"}, {"nnc_aten_max_red", "nnc_aten_max_red_out"}, {"nnc_aten_conv1d", "nnc_aten_conv1d_out"}, + {"nnc_aten_max_pool2d", "nnc_aten_max_pool2d_out"}, }; } diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 3b87c9458a55..10ece9864508 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1313,6 +1313,84 @@ void nnc_aten_mean( } } +void nnc_aten_max_pool2d( + int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t, + int64_t* extra_args) { + const int64_t x_qdtype = extra_args[2]; + c10::optional>> qdata; + if (x_qdtype != -1) { + qdata = { + {1u, + {((double*)extra_args)[0], + extra_args[1], + at::toQIntType(static_cast(x_qdtype))}}}; + } + auto tensors = constructTensors( + bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes, qdata); + + auto r = at::max_pool2d( + /*input=*/tensors[1], + /*kernel_size=*/at::IntArrayRef({extra_args[3], extra_args[4]}), + /*stride=*/at::IntArrayRef({extra_args[5], extra_args[6]}), + /*padding=*/at::IntArrayRef({extra_args[7], extra_args[8]}), + /*dilation=*/at::IntArrayRef({extra_args[9], extra_args[10]}), + /*ceil_mode=*/(bool)extra_args[11]); + + memcpy(buf_data[0], r.data_ptr(), r.element_size() * r.numel()); +} + +void nnc_aten_max_pool2d_out( + int64_t bufs_in_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t, + int64_t* extra_args) { + const int64_t x_qdtype = extra_args[2]; + c10::optional>> qdata; + if (x_qdtype != -1) { + qdata = { + {1u, + {((double*)extra_args)[0], + extra_args[1], + at::toQIntType(static_cast(x_qdtype))}}}; + } + const size_t bufs_out_num = 1u; + auto tensors = constructTensors2( + bufs_in_num, + buf_data, + buf_ranks, + buf_dims, + buf_strides, + buf_dtypes, + qdata, + bufs_out_num); + + at::Tensor r; + try { + r = at::max_pool2d( + tensors[1], + at::IntArrayRef({extra_args[3], extra_args[4]}), + at::IntArrayRef({extra_args[5], extra_args[6]}), + at::IntArrayRef({extra_args[7], extra_args[8]}), + at::IntArrayRef({extra_args[9], extra_args[10]}), + (bool)extra_args[11]); + } catch (...) { + } + + buf_data[0] = r.data_ptr(); + c10::raw::intrusive_ptr::incref(r.getIntrusivePtr().get()); + buf_data[bufs_in_num + bufs_out_num] = r.getIntrusivePtr().get(); +} + void nnc_aten_max_red( int64_t bufs_num, void** buf_data, @@ -1613,6 +1691,12 @@ const static RegisterNNCExternalFunction nnc_triangular_solve( const static RegisterNNCExternalFunction nnc_embedding( "nnc_aten_embedding", nnc_aten_embedding); +const static RegisterNNCExternalFunction nnc_max_pool2d( + "nnc_aten_max_pool2d", + nnc_aten_max_pool2d); +const static RegisterNNCExternalFunction nnc_max_pool2d_out( + "nnc_aten_max_pool2d_out", + nnc_aten_max_pool2d_out); #if AT_MKLDNN_ENABLED() const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_conv_run( diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index de2d6e7cbdd8..224755f8f202 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -1888,6 +1888,10 @@ int nnc_lowerings_lazy_registration() { {"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)"}, computeAdaptiveAvgPool2d); + RegisterNNCLoweringsFunction aten_max_pool2d( + {"aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor)"}, + computeMaxPool2d); + RegisterNNCLoweringsFunction aten_add( {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)", "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor)"}, diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.cpp b/torch/csrc/jit/tensorexpr/operators/reduction.cpp index 6fa04899c99f..9dc4e7ca82c5 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/operators/reduction.cpp @@ -1,3 +1,4 @@ +#include #include using namespace torch::jit::tensorexpr; @@ -183,6 +184,77 @@ Tensor computeAdaptiveAvgPool2d( c10::fmap(out_size_param))); } +Tensor computeMaxPool2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device) { + auto x = c10::get(inputs[0]); + auto kernel_size = c10::get(inputs[1]); + auto stride = c10::get(inputs[2]); + auto padding = c10::get(inputs[3]); + auto dilation = c10::get(inputs[4]); + auto ceil_mode = c10::get(inputs[5]); + + // Expand the dims as needed, to facilitate external call params processing + if (kernel_size.size() == 1) { + kernel_size.push_back(kernel_size[0]); + } + if (padding.size() == 1) { + padding.push_back(padding[0]); + } + if (dilation.size() == 1) { + dilation.push_back(dilation[0]); + } + if (stride.empty()) { + stride.push_back(kernel_size[0]); + stride.push_back(kernel_size[1]); + }; + + Dtype dtype = (outputType == c10::nullopt) ? kFloat : Dtype(*outputType); + + ExprHandle qx_qscale = DoubleImm::make(0.0f); + ExprHandle qx_qzero = LongImm::make(1l); + int64_t qx_qdtype = -1l; + if (isQuantized(x)) { + qx_qscale = ExprHandle(x.node()->qscale()); + qx_qzero = ExprHandle(x.node()->qzero()); + qx_qdtype = (int64_t)immQDType(x); + } + + auto strides = x.is_contiguous(c10::MemoryFormat::ChannelsLast) + ? make_channels_last_strides(outputShape) + : make_contiguous_strides(outputShape); + + BufHandle ResultBuf = Buf::make( + "max_pool2d", + outputShape, + Dtype(*outputType), + c10::nullopt, // initializer + ExprVectorToExprHandleVector(strides), + qx_qscale, + qx_qzero); + + StmtPtr s = ExternalCall::make( + ResultBuf, + "nnc_aten_max_pool2d", + {x}, + {qx_qscale, + qx_qzero, + qx_qdtype, + kernel_size[0], + kernel_size[1], + stride[0], + stride[1], + padding[0], + padding[1], + dilation[0], + dilation[1], + (int64_t)ceil_mode}); + return Tensor(ResultBuf.node(), s); +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.h b/torch/csrc/jit/tensorexpr/operators/reduction.h index 6265c4d26585..7fe52192a0c3 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.h +++ b/torch/csrc/jit/tensorexpr/operators/reduction.h @@ -24,6 +24,12 @@ TORCH_API Tensor computeAdaptiveAvgPool2d( const std::vector& outputStrides, const c10::optional& outputType, at::Device device); +TORCH_API Tensor computeMaxPool2d( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device); Tensor computeMax( const std::vector& inputs, const std::vector& outputShape, From aadf9745b639c3ed7b4361302ddd9befafde91b7 Mon Sep 17 00:00:00 2001 From: Guobing Chen Date: Mon, 29 Aug 2022 13:19:43 +0800 Subject: [PATCH 453/453] fix maxpool2d output buf dtype --- torch/csrc/jit/tensorexpr/operators/reduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/tensorexpr/operators/reduction.cpp b/torch/csrc/jit/tensorexpr/operators/reduction.cpp index 9dc4e7ca82c5..c05e5935b679 100644 --- a/torch/csrc/jit/tensorexpr/operators/reduction.cpp +++ b/torch/csrc/jit/tensorexpr/operators/reduction.cpp @@ -230,7 +230,7 @@ Tensor computeMaxPool2d( BufHandle ResultBuf = Buf::make( "max_pool2d", outputShape, - Dtype(*outputType), + dtype, c10::nullopt, // initializer ExprVectorToExprHandleVector(strides), qx_qscale,