From 51f0403f46d40354098406e26bf967e6de74429e Mon Sep 17 00:00:00 2001 From: LifengWang Date: Mon, 31 Mar 2025 09:45:40 +0000 Subject: [PATCH 001/332] Update the baseline for max_autotune ci workflow (#149107) Since the issue https://github.com/pytorch/pytorch/issues/148535 is fixed in PR https://github.com/pytorch/pytorch/pull/148923, update the baseline for max_autotune ci workflow. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149107 Approved by: https://github.com/chuanqi129, https://github.com/leslie-fang-intel, https://github.com/desertfire --- .github/workflows/inductor-nightly.yml | 3 +++ ...max_autotune_inductor_amp_freezing_torchbench_inference.csv | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 076e67b08ed2..55a37f031fdf 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -4,6 +4,9 @@ on: pull_request: paths: - .github/workflows/inductor-nightly.yml + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_huggingface_inference.csv + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_timm_inference.csv + - benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv workflow_dispatch: schedule: # Run every day at 7:00 AM UTC diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv index 00fc3c9e0949..96e54bf6f0df 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_max_autotune_inductor_amp_freezing_torchbench_inference.csv @@ -290,7 +290,7 @@ soft_actor_critic,pass,0 -speech_transformer,fail_to_run,5 +speech_transformer,pass,10 From c158eac0de2afe38d68952ca401888ed5777f6b0 Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Mon, 31 Mar 2025 09:49:40 +0000 Subject: [PATCH 002/332] [ROCm] use correct workspace for hipblaslt, silence warning (#150227) Follow up to #145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150227 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- aten/src/ATen/cuda/CUDABlas.cpp | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d39fe4be31c9..a374ee3c8b7c 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -222,19 +222,35 @@ static size_t _getWorkspaceSize() { return workspace_size; } +static at::DataPtr _getNewWorkspace() { + return c10::cuda::CUDACachingAllocator::get()->allocate(_getWorkspaceSize()); +} + +// See Note [hipblaslt handles]. +// ROCm's hipblas and hipblaslt do not share handles, unlike with CUDA. +// Using getCurrentCUDABlasLtHandle is on purpose. For CUDA it's the same as +// getCurrentCUDABlasHandle, but for ROCm it's a unique handle. void* _getWorkspaceWithoutHandle() { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle(); auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); +#ifdef USE_ROCM + // The first call to _getWorkspaceWithoutHandle could be empty, so allocate and store. + if (workspace_it == at::cuda::cublas_handle_stream_to_workspace().end()) { + workspace_it = at::cuda::cublas_handle_stream_to_workspace().insert(workspace_it, {key, _getNewWorkspace()}); + } +#else TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); +#endif return workspace_it->second.mutable_get(); } void* _getWorkspace(size_t& workspaceSize) { -// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2)) workspaceSize = _getWorkspaceSize(); +#ifndef USE_ROCM + // See Note [hipblaslt handles]. auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); if (cublasWorkspaceSize < workspaceSize) { TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize, @@ -245,9 +261,7 @@ void* _getWorkspace(size_t& workspaceSize) { " size will be limited to the CUBLAS workspace size."); workspaceSize = cublasWorkspaceSize; } -// #else -// workspaceSize = at::cuda::getChosenWorkspaceSize(); -// #endif +#endif auto workspace_ptr = _getWorkspaceWithoutHandle(); return workspace_ptr; } From bbb9b2476bed5750194bfb87aacd71c0fcfd60dd Mon Sep 17 00:00:00 2001 From: Yichen Yan Date: Mon, 31 Mar 2025 12:23:26 +0000 Subject: [PATCH 003/332] Unify use of `enableCollectiveHashDebug_` and trivial updates (#142865) Use `enableCollectiveHashDebug_` instead of checking env ad-hoc when `TORCH_DISTRIBUTED_DEBUG = DETAIL` Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/142865 Approved by: https://github.com/fegin, https://github.com/kwen2501 --- .../distributed/c10d/ProcessGroupNCCL.cpp | 20 +++---------------- .../distributed/c10d/ProcessGroupNCCL.hpp | 2 +- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 863bc1c4491c..ce521d594fa8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -200,20 +200,6 @@ inline std::string getKeyFromDevice(at::Device& device) { return std::to_string(device.index()); } -inline at::DeviceIndex getIndexFromDeviceKey(const std::string& deviceKey) { - // initialize the device index to -1, which is an invalid value. - int index = -1; - try { - index = std::stoi(deviceKey); - } catch (const std::invalid_argument& e) { - LOG(ERROR) << c10::str( - "Invalid deviceKey: ", deviceKey, ",", e.what(), "."); - } catch (const std::out_of_range& e) { - LOG(ERROR) << "Out of range: " << e.what(); - } - return static_cast(index); -} - std::string getKeySendRecv(int myRank, int peer) { int lowRank = myRank < peer ? myRank : peer; int highRank = myRank < peer ? peer : myRank; @@ -781,7 +767,7 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { // upgrade. Once a NCCL version is qualified, this code should not be needed // at runtime. #ifdef PGNCCL_ENABLE_HASH - if (distDebugLevel_ >= DebugLevel::Detail) { + if (enableCollectiveHashDebug_.load()) { auto numel = getTensorsNumel(*outputs_); auto hashValue = hashTensors(*outputs_); PRINT_COLLECTIVE_HASH_SIGNATURE( @@ -921,7 +907,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 60 * 1000 /*60 Sec*/); coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000); - enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); + enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); // store_ usually is wrapped with PrefixStore and the prefix is different // across different ProcessGroupNCCL(PG) instances. We need to get the // underlying non-PrefixStore for sharing global information shared across @@ -3548,7 +3534,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // upgrade. Once a NCCL version is qualified, this code should not be needed at // runtime. #ifdef PGNCCL_ENABLE_HASH - if (enableCollecticeHashDebug_.load()) { + if (enableCollectiveHashDebug_.load()) { auto numel = getTensorsNumel(inputs); auto hashValue = hashTensors(inputs); PRINT_COLLECTIVE_HASH_SIGNATURE( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index f65d5955c8dd..ca870f702013 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1277,7 +1277,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Flag to enable the print of hash value of input/output of collectives for // verification. - std::atomic enableCollecticeHashDebug_{}; + std::atomic enableCollectiveHashDebug_{}; // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set bool avoidRecordStreams_ = false; From f74d5d576aedf053b7574f3eb06d12417d80625a Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Mon, 31 Mar 2025 13:36:11 +0000 Subject: [PATCH 004/332] Update torch-xpu-ops commit pin to 3ee2bd2 (#150300) Update the torch-xpu-ops commit to [3ee2bd2f13e1ed17a685986ff667a58bed5f2aa5](https://github.com/intel/torch-xpu-ops/commit/3ee2bd2f13e1ed17a685986ff667a58bed5f2aa5) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150300 Approved by: https://github.com/EikanWang --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 239a4b8aeb93..53b3ef7e4560 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -026b2c8c7c92a7b2cec5d26334006e3423251cc6 +3ee2bd2f13e1ed17a685986ff667a58bed5f2aa5 From e57fa18b40e37be7dc41ca0d5789acdd21ca8f9e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 31 Mar 2025 15:37:54 +0000 Subject: [PATCH 005/332] Revert "Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129)" This reverts commit 8a872261dcb3797557d1965af6832677a77efec1. Reverted https://github.com/pytorch/pytorch/pull/150129 on behalf of https://github.com/clee2000 due to breaking internal builds D72080428 ([comment](https://github.com/pytorch/pytorch/pull/150129#issuecomment-2766619006)) --- test/distributed/test_symmetric_memory.py | 37 +++------ .../c10d/CUDASymmetricMemoryOps.cu | 79 +++---------------- .../csrc/distributed/c10d/SymmetricMemory.cpp | 14 ---- 3 files changed, 23 insertions(+), 107 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index b5e961276f87..a5b77410b190 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -1,6 +1,5 @@ # Owner(s): ["module: c10d"] -import itertools import os from unittest import skipIf @@ -861,32 +860,22 @@ def test_multimem_one_shot_all_reduce( @skipIfRocm @skip_if_lt_x_gpu(4) - def test_one_shot_all_reduce(self) -> None: + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_one_shot_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: self._init_process() group_name = dist.group.WORLD.group_name - for dtype, size_bytes, align_bytes, copy, offset in itertools.product( - [torch.float, torch.bfloat16], - [4, 8192, 8196], - [4, 8, 16], - [True, False], - [0, 16], - ): - inp = symm_mem.empty( - size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device - ) - symm_mem.rendezvous(inp, group=group_name) - if not copy: - inp.normal_() - res = torch.ops.symm_mem.one_shot_all_reduce( - inp[offset:], "sum", group_name - ) - if copy: - local_inp = torch.randn_like(inp[offset:]) - res = torch.ops.symm_mem.one_shot_all_reduce_copy( - inp[offset:], local_inp, "sum", group_name - ) - self._verify_all_reduce_result(local_inp if copy else inp[offset:], res) + inp = symm_mem.empty( + size_bytes // dtype.itemsize, dtype=dtype, device=self.device + ).normal_() + symm_mem.rendezvous(inp, group=group_name) + + res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name) + self._verify_all_reduce_result(inp, res) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index 0b2044a870eb..e416e07aea20 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -397,7 +397,7 @@ at::Tensor multimem_all_gather_out( // One-shot all-reduce is register-intensive because it stages values loaded // from peers in registers before performing reduction. Setting the thread // count to 512 to prevent/alleviate register spill. -constexpr size_t one_shot_all_reduce_max_num_blocks = 24; +constexpr size_t one_shot_all_reduce_max_num_blocks = 8; constexpr size_t one_shot_all_reduce_max_num_threads = 512; template @@ -405,7 +405,6 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ void one_shot_all_reduce_kernel( T** input_ptrs, T* output_ptr, - T* input_ptr, size_t input_offset, size_t numel, uint32_t** signal_pads, @@ -413,18 +412,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ size_t world_size) { static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - // copy input to shared ptr + + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; auto stride = blockDim.x * gridDim.x * numel_per_thread; - if (input_ptr) { - for (size_t i = offset; i < numel; i += stride) { - Vec vec_st = ld_vec(input_ptr + i); - st_vec(input_ptrs[rank] + input_offset + i, vec_st); - } - } - // TODO make it sync with one block for no-copy case - sync_remote_blocks(signal_pads, rank, world_size); - __syncthreads(); for (size_t i = offset; i < numel; i += stride) { auto vec = load_and_reduce( @@ -433,12 +426,11 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } -at::Tensor one_shot_all_reduce_out_impl( +at::Tensor one_shot_all_reduce_out( const at::Tensor& input, - const c10::optional& local_input, std::string reduce_op, std::string group_name, at::Tensor out) { @@ -448,21 +440,11 @@ at::Tensor one_shot_all_reduce_out_impl( out.is_contiguous(), "one_shot_all_reduce: output must be contiguous."); TORCH_CHECK( out.sizes() == input.sizes(), - "one_shot_all_reduce: input/output size mismatch, input.sizes(): ", - input.sizes(), - ", output.sizes(): ", - out.sizes()); + "one_shot_all_reduce: input/output size mismatch."); TORCH_CHECK( reduce_op == "sum", "one_shot_all_reduce: only sum is supported for now."); - if (local_input.has_value()) { - TORCH_CHECK( - local_input->is_contiguous(), - "one_shot_all_reduce: local input must be contiguous."); - TORCH_CHECK( - local_input->numel() <= input.numel(), - "one_shot_all_reduce: local input size must be smaller than symm buffer size."); - } + auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); TORCH_CHECK( symm_mem != nullptr, @@ -470,13 +452,6 @@ at::Tensor one_shot_all_reduce_out_impl( const size_t alignment = get_and_verify_alignment(input, "one_shot_all_reduce"); - if (local_input.has_value()) { - const size_t local_alignment = - get_and_verify_alignment(*local_input, "one_shot_all_reduce"); - TORCH_CHECK( - alignment == local_alignment, - "one_shot_all_reduce: local input and symm buffer must have the same alignment."); - } int num_blocks = 0, num_threads = 0; init_elementwise_launch_config( @@ -501,8 +476,6 @@ at::Tensor one_shot_all_reduce_out_impl( reinterpret_cast( symm_mem->get_buffer_ptrs_dev()), out.data_ptr(), - local_input.has_value() ? local_input->data_ptr() - : nullptr, input.storage_offset(), input.numel(), reinterpret_cast( @@ -516,42 +489,12 @@ at::Tensor one_shot_all_reduce_out_impl( return out; } -at::Tensor one_shot_all_reduce_out( - const at::Tensor& input, - std::string reduce_op, - std::string group_name, - at::Tensor out) { - return one_shot_all_reduce_out_impl( - input, c10::nullopt, reduce_op, group_name, out); -} - -at::Tensor one_shot_all_reduce_copy_out( - const at::Tensor& input, - const at::Tensor& local_input, - std::string reduce_op, - std::string group_name, - at::Tensor out) { - return one_shot_all_reduce_out_impl( - input, local_input, reduce_op, group_name, out); -} - at::Tensor one_shot_all_reduce( const at::Tensor& input, std::string reduce_op, std::string group_name) { auto out = at::empty_like(input); - return one_shot_all_reduce_out_impl( - input, c10::nullopt, reduce_op, group_name, out); -} - -at::Tensor one_shot_all_reduce_copy( - const at::Tensor& input, - const at::Tensor& local_input, - std::string reduce_op, - std::string group_name) { - auto out = at::empty_like(local_input); - return one_shot_all_reduce_out_impl( - input, local_input, reduce_op, group_name, out); + return one_shot_all_reduce_out(input, reduce_op, group_name, out); } constexpr size_t two_shot_all_reduce_max_num_blocks = 24; @@ -895,8 +838,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("multimem_all_gather_out", ::multimem_all_gather_out); m.impl("one_shot_all_reduce", ::one_shot_all_reduce); m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out); - m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy); - m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 76eb7205a398..9d400395e073 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -217,14 +217,6 @@ at::Tensor one_shot_all_reduce_meta( return at::empty_like(input); } -at::Tensor one_shot_all_reduce_copy_meta( - const at::Tensor& symm_buffer, - const at::Tensor& local_input, - std::string reduce_op, - std::string group_name) { - return at::empty_like(local_input); -} - TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); @@ -238,11 +230,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor"); m.def( "one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)"); - m.def( - "one_shot_all_reduce_copy(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name) -> Tensor"); - m.def( - "one_shot_all_reduce_copy_out(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)"); - m.def( "two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); @@ -269,7 +256,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { m.impl("one_shot_all_reduce", one_shot_all_reduce_meta); - m.impl("one_shot_all_reduce_copy", one_shot_all_reduce_copy_meta); } } // namespace From 57fa99c5c387458f871e61a357c36c87bf4478ab Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 31 Mar 2025 15:43:24 +0000 Subject: [PATCH 006/332] Revert "enable out variant of 2-shot reduction (#150153)" This reverts commit cdeb32d2d1c31b60c65133e83510977c5c180005. Reverted https://github.com/pytorch/pytorch/pull/150153 on behalf of https://github.com/clee2000 due to failing internal builds D72083877 ([comment](https://github.com/pytorch/pytorch/pull/150153#issuecomment-2766633712)) --- test/distributed/test_symmetric_memory.py | 52 +++--- .../c10d/CUDASymmetricMemoryOps.cu | 170 +++--------------- .../csrc/distributed/c10d/SymmetricMemory.cpp | 4 - 3 files changed, 45 insertions(+), 181 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index a5b77410b190..34b8ed5a7b10 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -881,38 +881,34 @@ def test_one_shot_all_reduce( @skipIfRocm @skip_if_lt_x_gpu(4) - def test_two_shot_all_reduce(self) -> None: + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_two_shot_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: self._init_process() group_name = dist.group.WORLD.group_name - for dtype, size_bytes, align_bytes, inplace in itertools.product( - [torch.float, torch.bfloat16], - [4, 8192, 8196], - [4, 8, 16], - [True, False], - ): - t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) - symm_mem.rendezvous(t, group=group_name) - - self.assertTrue(t.data_ptr() % 16 == 0) - self.assertTrue(align_bytes % t.element_size() == 0) - self.assertTrue(size_bytes % t.element_size() == 0) - - shift = align_bytes // t.element_size() - numel = size_bytes // t.element_size() - res = t[shift : shift + numel] - res.normal_().fill_(1) - inp = res.clone() - if not inplace: - out = torch.empty_like(inp) - torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out) - else: - torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) + t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) - # Head and tail should not be written - self.assertTrue(t[:shift].eq(0).all().item()) - self.assertTrue(t[shift + numel :].eq(0).all().item()) - self._verify_all_reduce_result(inp, res if inplace else out) + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + res = t[shift : shift + numel] + res.normal_() + inp = res.clone() + + torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) + + # Head and tail should not be written + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_all_reduce_result(inp, res) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index e416e07aea20..438624f4bc07 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -39,16 +39,6 @@ } \ } -#define DISPATCH_WORLD_SIZES_NO_DEFAULT(world_size, ...) \ - switch (world_size) { \ - INT_SWITCH_CASE(k_world_size, 8, __VA_ARGS__); \ - INT_SWITCH_CASE(k_world_size, 4, __VA_ARGS__); \ - INT_SWITCH_CASE(k_world_size, 2, __VA_ARGS__); \ - default: { \ - TORCH_CHECK(false, "Not implemented for world_size=", world_size); \ - } \ - } - #define DISPATCH_ALIGNMENTS_16_8_4(alignment, ...) \ switch (alignment) { \ INT_SWITCH_CASE(k_alignment, 16, __VA_ARGS__); \ @@ -503,70 +493,6 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 512; template static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ void two_shot_all_reduce_kernel( - T** input_ptrs, - T* output_ptr, - size_t input_offset, - size_t numel, - uint32_t** signal_pads, - size_t rank, - size_t world_size) { - static_assert(alignment % sizeof(T) == 0); - constexpr size_t numel_per_thread = alignment / sizeof(T); - - sync_remote_blocks(signal_pads, rank, world_size); - __syncthreads(); - - const size_t numel_per_rank = - at::round_up(numel, alignment * world_size) / world_size; - const size_t start = numel_per_rank * rank; - - auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; - auto stride = blockDim.x * gridDim.x * numel_per_thread; - for (size_t i = offset; i < numel_per_rank; i += stride) { - if (start + i >= numel) { - continue; - } - auto vec = load_and_reduce( - input_ptrs, rank, world_size, input_offset + start + i); - // store to local buffer - st_vec(input_ptrs[rank] + input_offset + start + i, vec); - } - - __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); - __syncthreads(); - for (size_t i = offset; i < numel_per_rank; i += stride) { - Vec tmp[k_world_size]; -#pragma unroll k_world_size - for (size_t step = 0; step < k_world_size; ++step) { - size_t remote_rank = (rank + step) % k_world_size; - size_t remote_start = numel_per_rank * remote_rank; - if (remote_start + i >= numel) { - continue; - } - tmp[step] = ld_vec( - input_ptrs[remote_rank] + input_offset + remote_start + i); - } -#pragma unroll k_world_size - for (size_t step = 0; step < k_world_size; ++step) { - size_t remote_rank = (rank + step) % k_world_size; - size_t remote_start = numel_per_rank * remote_rank; - if (remote_start + i >= numel) { - continue; - } - st_vec( - output_ptr + remote_start + i, tmp[step]); - } - } - // need to make sure all blocks exit simultaneously so that the data - // is not corrupted by the subsequent kernels - __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); -} - -template -static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ - void two_shot_all_reduce_kernel_inplace( T** input_ptrs, size_t input_offset, size_t numel, @@ -602,9 +528,8 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ sync_remote_blocks(signal_pads, rank, world_size); } -at::Tensor two_shot_all_reduce_impl( +at::Tensor two_shot_all_reduce_( at::Tensor input, - c10::optional output, std::string reduce_op, std::string group_name) { TORCH_CHECK( @@ -621,14 +546,6 @@ at::Tensor two_shot_all_reduce_impl( const size_t alignment = get_and_verify_alignment(input, "two_shot_all_reduce"); - if (output.has_value()) { - const size_t output_alignment = - get_and_verify_alignment(*output, "two_shot_all_reduce"); - TORCH_CHECK( - alignment <= output_alignment, - "two_shot_all_reduce: output alignment must be equal to or larger than input."); - } - int num_blocks = 0, num_threads = 0; init_elementwise_launch_config( input.numel(), @@ -640,73 +557,30 @@ at::Tensor two_shot_all_reduce_impl( num_blocks, num_threads); - if (!output.has_value()) { - AT_DISPATCH_FLOAT_AND_BFLOAT16( - input.scalar_type(), "two_shot_all_reduce", [&]() { - DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { - DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { - two_shot_all_reduce_kernel_inplace< - scalar_t, - k_alignment, - k_world_size> - <<>>( - reinterpret_cast( - symm_mem->get_buffer_ptrs_dev()), - input.storage_offset(), - input.numel(), - reinterpret_cast( - symm_mem->get_signal_pad_ptrs_dev()), - symm_mem->get_rank(), - symm_mem->get_world_size()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - return input; - } else { - AT_DISPATCH_FLOAT_AND_BFLOAT16( - input.scalar_type(), "two_shot_all_reduce", [&]() { - DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { - DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { - two_shot_all_reduce_kernel - <<>>( - reinterpret_cast( - symm_mem->get_buffer_ptrs_dev()), - output->data_ptr(), - input.storage_offset(), - input.numel(), - reinterpret_cast( - symm_mem->get_signal_pad_ptrs_dev()), - symm_mem->get_rank(), - symm_mem->get_world_size()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); - return *output; - } -} - -at::Tensor two_shot_all_reduce_( - at::Tensor input, - std::string reduce_op, - std::string group_name) { - return two_shot_all_reduce_impl(input, c10::nullopt, reduce_op, group_name); + }); + return input; } -at::Tensor two_shot_all_reduce_out( - at::Tensor input, - std::string reduce_op, - std::string group_name, - at::Tensor output) { - return two_shot_all_reduce_impl(input, output, reduce_op, group_name); -} } // namespace #endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 @@ -839,8 +713,6 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("one_shot_all_reduce", ::one_shot_all_reduce); m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); - m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); - m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm); #endif m.impl("stream_write_value32_", ::stream_write_value32_); diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 9d400395e073..0308f2f5c4b2 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -233,10 +233,6 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); - // note this implementation also modified the input tensor - m.def( - "two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)"); - // An mm that supports consuming asynchronous input. It guarantees the // following rasterization order, and that the corresponding signal arrives // before an input chunk is consumed. From 7c858066aed352c9e1089db40619e92bc10fddb1 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 31 Mar 2025 15:58:34 +0000 Subject: [PATCH 007/332] Revert "Enable TMA persistent GEMM Template by default (#149427)" This reverts commit b8ef642f04874e13a9f2771902ddb7514f294015. Reverted https://github.com/pytorch/pytorch/pull/149427 on behalf of https://github.com/clee2000 due to failing tests internally D72116141 ([comment](https://github.com/pytorch/pytorch/pull/149427#issuecomment-2766672200)) --- torch/_inductor/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 10904cd53991..c210af25c16d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1165,7 +1165,7 @@ class triton: # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 # with a verison of triton new enough to support TMA enable_persistent_tma_matmul = ( - os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "1") == "1" + os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" ) # Skip L1 cache for buffers that are used only once. Disabled by default skip_l1_cache = os.environ.get("TORCHINDUCTOR_SKIP_L1", "0") == "1" From 47cdad299521bf1acc9890a0330bba3ce640e325 Mon Sep 17 00:00:00 2001 From: Prachi Gupta Date: Mon, 31 Mar 2025 16:15:57 +0000 Subject: [PATCH 008/332] [ROCm] Enable several fsdp related UTs (#149369) Enabling 26 UTs for ROCm in the following files: - distributed._shard.sharded_optim.test_sharded_optim - 2 UTs - distributed._shard.sharded_tensor.ops.test_binary_cmp - 4 UTs - distributed._shard.sharded_tensor.ops.test_init - 3 UTs - distributed._shard.sharded_tensor.ops.test_embedding - 2 UTs - distributed._shard.sharded_tensor.ops.test_embedding_bag - 2 UTs - distributed._composable.test_replicate_with_compiler - 4 UTs - distributed._composable.fsdp.test_fully_shard_grad_scaler - 1 UTs - distributed.tensor.test_attention - 4 UTs - distributed.tensor.test_matrix_ops - 1 UTs - distributed.tensor.test_tensor_ops - 1 UTs - distributed.fsdp.test_fsdp_grad_acc - 2 UTs Pull Request resolved: https://github.com/pytorch/pytorch/pull/149369 Approved by: https://github.com/jeffdaily --- .../fsdp/test_fully_shard_compile.py | 2 -- .../fsdp/test_fully_shard_grad_scaler.py | 3 +-- .../_composable/test_replicate_with_compiler.py | 14 +++++--------- test/distributed/fsdp/test_fsdp_grad_acc.py | 2 -- test/distributed/tensor/test_matrix_ops.py | 16 ++++++++++------ test/run_test.py | 9 --------- test/test_matmul_cuda.py | 17 +++++------------ torch/testing/_internal/common_device_type.py | 10 ++++++++++ 8 files changed, 31 insertions(+), 42 deletions(-) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 6351a74459bd..db460818dad5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -131,7 +131,6 @@ def skipTestForOldSm(self): if not sm_is_or_higher_than(device, 8, 0): self.skipTest("bf16 requires sm >= 8.0") - @skipIfRocm def test_dynamo_trace_use_training_state(self): torch._dynamo.reset() # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager. @@ -169,7 +168,6 @@ def f(x): self.assertEqual(cnt.op_count, 1) self.assertEqual(len(cnt.graphs), 1) - @skipIfRocm def test_trace_fsdp_copy_(self): @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: diff --git a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py index 7b7beb30af9d..bb4f28f43a41 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_grad_scaler.py @@ -13,12 +13,11 @@ ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest, MLP -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_utils import run_tests class TestFullyShardGradientScaler(FSDPTest): @skip_if_lt_x_gpu(4) - @skipIfRocm def test_gradient_scaler(self): self.run_subtests( {"has_inf": [True, False], "test_2d": [True, False]}, diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 839bbcd6920d..3b92dfcb0a9f 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -28,11 +28,10 @@ from torch.testing._internal.common_distributed import ( DistributedTestBase, skip_if_lt_x_gpu, - skip_if_rocm_multiprocess, sm_is_or_higher_than, ) from torch.testing._internal.common_fsdp import get_devtype -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed.fake_pg import FakeStore from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils.checkpoint import checkpoint @@ -194,7 +193,6 @@ def test_compile_cpu_no_sync(self): self._test_compile(no_sync=True, device="cpu") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch( reorder_for_locality=False, reorder_for_peak_memory=False @@ -203,7 +201,6 @@ def test_compile_gpu(self): self._test_compile(no_sync=False, checkpoint=False, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) @torch._inductor.config.patch( reorder_for_locality=False, reorder_for_peak_memory=False @@ -212,11 +209,13 @@ def test_compile_gpu_ac(self): self._test_compile(no_sync=False, checkpoint=True, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): # Check device capability wrt bf16 - if not sm_is_or_higher_than(torch.device(device_type), 8, 0): + if ( + not sm_is_or_higher_than(torch.device(device_type), 8, 0) + and torch.version.hip is None + ): self.skipTest("bf16 requires sm >= 8.0") def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -230,7 +229,6 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: self._test_compile(no_sync=False, setup_func=setup, device=device_type) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_fp16(self): def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: @@ -247,7 +245,6 @@ def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_backward_only(self): self._test_compile(no_sync=False, no_compile_forward=True, device=device_type) @@ -387,7 +384,6 @@ def tearDown(self): "Temporarily disabled due to SymInt error: `unhashable type: non-nested SymInt`" ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") - @skipIfRocm def test_ddp_tp(self): ref_model = Net() compiled_replicate_model = deepcopy(ref_model) diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index fc371979ca3c..1e51938a033f 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -24,7 +24,6 @@ instantiate_parametrized_tests, parametrize, run_tests, - skipIfRocm, TEST_WITH_DEV_DBG_ASAN, ) @@ -275,7 +274,6 @@ def test_grad_acc( ) @skip_if_lt_x_gpu(2) - @skipIfRocm @parametrize("use_orig_params", [False, True]) def test_grad_acc_cpu_offload( self, diff --git a/test/distributed/tensor/test_matrix_ops.py b/test/distributed/tensor/test_matrix_ops.py index 5c7d7fd43ae2..cd26a31abf7f 100644 --- a/test/distributed/tensor/test_matrix_ops.py +++ b/test/distributed/tensor/test_matrix_ops.py @@ -18,7 +18,8 @@ ) from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 -from torch.testing._internal.common_utils import run_tests, skipIfRocm +from torch.testing._internal.common_device_type import E4M3_MAX_POS, e4m3_type +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, skip_unless_torch_gpu, @@ -33,8 +34,10 @@ def scale_for_fp8( t = t.unsqueeze(0).unsqueeze(-2) else: t = t.unflatten(0, (scale_shape[0], -1)).unflatten(-1, (scale_shape[1], -1)) - scale = t.abs().amax(dim=[1, -1]).float() / torch.finfo(torch.float8_e4m3fn).max - t_fp8 = (t / scale[:, None, :, None]).to(torch.float8_e4m3fn) + + scale = t.abs().amax(dim=[1, -1]).float() / E4M3_MAX_POS + t_fp8 = (t / scale[:, None, :, None]).to(e4m3_type) + return t_fp8.flatten(end_dim=1).flatten(start_dim=-2), scale.view(scale_shape) @@ -205,7 +208,7 @@ def test_scaled_mm(self): full_dist_res = dist_res.full_tensor() # Fp8 matmuls are quite inaccurate, we need high tolerances - self.assertEqual(full_dist_res, full_ref_res, atol=1, rtol=7e-2) + self.assertEqual(full_dist_res, full_ref_res, atol=1.5, rtol=7e-2) self.assertEqual(comm_mode.get_total_counts(), 0) @@ -448,7 +451,6 @@ def test_scaled_dot_product_attention(self): self.assertTrue(dist_value.grad.placements[0].is_shard(dim=1)) self.assertEqual(dist_value.grad.full_tensor(), value.grad) - @skipIfRocm @skip_unless_torch_gpu @with_comms() def test_dtensor_mm(self): @@ -472,7 +474,9 @@ def test_dtensor_mm(self): lhs_dtensor = distribute_tensor(lhs, mesh, [Shard(dim=0), Replicate()]) rhs_dtensor = distribute_tensor(rhs, mesh, [Replicate(), Shard(dim=1)]) dtensor_result = lhs_dtensor @ rhs_dtensor - self.assertEqual(dtensor_result.full_tensor(), mm_result) + self.assertEqual( + dtensor_result.full_tensor(), mm_result, atol=1.5e-5, rtol=1e-6 + ) @with_comms @skip_unless_torch_gpu diff --git a/test/run_test.py b/test/run_test.py index efa7e46554cb..d341a182e29b 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -171,19 +171,10 @@ def __contains__(self, item): "distributed/rpc/test_tensorpipe_agent", "distributed/rpc/test_share_memory", "distributed/rpc/cuda/test_tensorpipe_agent", - "distributed/_shard/checkpoint/test_checkpoint" - "distributed/_shard/checkpoint/test_file_system_checkpoint" - "distributed/_shard/sharding_spec/test_sharding_spec", - "distributed/_shard/sharded_tensor/ops/test_embedding", - "distributed/_shard/sharded_tensor/ops/test_embedding_bag", - "distributed/_shard/sharded_tensor/ops/test_binary_cmp", - "distributed/_shard/sharded_tensor/ops/test_init", - "distributed/_shard/sharded_optim/test_sharded_optim", "test_determination", "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", - "distributed/tensor/test_attention", ] S390X_BLOCKLIST = [ diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 611a2f943f67..64f9ee7ad2df 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -24,7 +24,7 @@ SM90OrLater, _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, - PLATFORM_SUPPORTS_MX_GEMM + PLATFORM_SUPPORTS_MX_GEMM, ) from torch.testing._internal.common_device_type import ( dtypes, @@ -32,6 +32,10 @@ onlyCUDA, tol as xtol, toleranceOverride, + e4m3_type, + e5m2_type, + E4M3_MAX_POS, + E5M2_MAX_POS, ) from torch.testing._internal.common_utils import ( @@ -258,17 +262,6 @@ def _expand_to_batch(t: torch.Tensor): f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" -if torch.version.hip and 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName: - e4m3_type = torch.float8_e4m3fnuz - e5m2_type = torch.float8_e5m2fnuz - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max -else: - e4m3_type = torch.float8_e4m3fn - e5m2_type = torch.float8_e5m2 - E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max - E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max - # avoid division by zero when calculating scale EPS = 1e-12 diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 834e0ed10071..9cd0661cac15 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1982,3 +1982,13 @@ def get_all_device_types() -> list[str]: and torch.cuda.get_device_capability() >= (8, 0), "Requires CUDA and Triton", ) +if torch.version.hip and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName: + e4m3_type = torch.float8_e4m3fnuz + e5m2_type = torch.float8_e5m2fnuz + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max +else: + e4m3_type = torch.float8_e4m3fn + e5m2_type = torch.float8_e5m2 + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max From 284b7668980f31d6ff788048f13a702936960756 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 31 Mar 2025 17:04:25 +0000 Subject: [PATCH 009/332] [dynamic shapes] C++ bindings for guard_or_false/true (#150148) C++ version. Would like to add it in one place to prove it works, but couldn't find one that doesn't expose a chain of data-dependent changes... so just gonna put up the base implementation Pull Request resolved: https://github.com/pytorch/pytorch/pull/150148 Approved by: https://github.com/laithsakka, https://github.com/jingsh --- c10/core/SymBool.cpp | 16 ++++++++++++++ c10/core/SymBool.h | 35 +++++++++++++++++++++++++++++++ c10/core/SymNodeImpl.h | 10 +++++++++ torch/csrc/utils/python_symnode.h | 10 +++++++++ torch/fx/experimental/sym_node.py | 12 +++++++++++ 5 files changed, 83 insertions(+) diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index 1b5269c9da13..63fcf064e01b 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -72,6 +72,22 @@ bool SymBool::guard_size_oblivious(const char* file, int64_t line) const { return a->guard_size_oblivious(file, line); } +bool SymBool::guard_or_false(const char* file, int64_t line) const { + if (auto ma = maybe_as_bool()) { + return *ma; + } + SymNode a = toSymNodeImpl(); + return a->guard_or_false(file, line); +} + +bool SymBool::guard_or_true(const char* file, int64_t line) const { + if (auto ma = maybe_as_bool()) { + return *ma; + } + SymNode a = toSymNodeImpl(); + return a->guard_or_true(file, line); +} + bool SymBool::expect_true(const char* file, int64_t line) const { if (auto ma = maybe_as_bool()) { return *ma; diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index c7b1fe5ff316..875377b2eb37 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -62,6 +62,8 @@ class C10_API SymBool { bool guard_bool(const char* file, int64_t line) const; bool expect_true(const char* file, int64_t line) const; bool guard_size_oblivious(const char* file, int64_t line) const; + bool guard_or_false(const char* file, int64_t line) const; + bool guard_or_true(const char* file, int64_t line) const; bool has_hint() const; @@ -113,7 +115,40 @@ inline bool guard_size_oblivious( return b.guard_size_oblivious(file, line); } +inline bool guard_or_false( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_false( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_false(file, line); +} + +inline bool guard_or_true( + bool b, + const char* file [[maybe_unused]], + int64_t line [[maybe_unused]]) { + return b; +} + +inline bool guard_or_true( + const c10::SymBool& b, + const char* file, + int64_t line) { + return b.guard_or_true(file, line); +} + #define TORCH_GUARD_SIZE_OBLIVIOUS(cond) \ c10::guard_size_oblivious((cond), __FILE__, __LINE__) +#define TORCH_GUARD_OR_FALSE(cond) \ + c10::guard_or_false((cond), __FILE__, __LINE__) + +#define TORCH_GUARD_OR_TRUE(cond) c10::guard_or_true((cond), __FILE__, __LINE__) + } // namespace c10 diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 36652e1800ac..6589a1e0b780 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -186,6 +186,16 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { // with a better implementation! return guard_bool(file, line); } + virtual bool guard_or_false(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } + virtual bool guard_or_true(const char* file, int64_t line) { + // No improvement for unbacked SymBools by default, replace this + // with a better implementation! + return guard_bool(file, line); + } virtual bool expect_true(const char* file, int64_t line) { // No improvement for unbacked SymBools by default, replace this // with a better implementation! diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index 43ef85ad8fce..9c73f9ca2b9e 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -135,6 +135,16 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("guard_size_oblivious")(file, line).cast(); } + bool guard_or_false(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_false")(file, line).cast(); + } + + bool guard_or_true(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_or_true")(file, line).cast(); + } + int64_t int_() override { py::gil_scoped_acquire acquire; return getPyObj().attr("int_")().cast(); diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index fa4443b1b5d5..66f80e8dbc49 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -592,6 +592,18 @@ def guard_size_oblivious(self, file, line): log.warning("Failed to convert to bool: %s", r) raise + def guard_or_false(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + assert self.is_bool() + return guard_or_false(SymBool(self)) + + def guard_or_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + assert self.is_bool() + return guard_or_true(SymBool(self)) + def bool_(self): return self.guard_bool("", 0) From 5e34758cef85b497a6ba313cc09c4cac283583d5 Mon Sep 17 00:00:00 2001 From: angelayi Date: Fri, 28 Mar 2025 11:31:16 -0700 Subject: [PATCH 010/332] [invoke_subgraph] Support unbacked (#149298) Differential Revision: [D71420641](https://our.internmc.facebook.com/intern/diff/D71420641) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149298 Approved by: https://github.com/zou3519 --- test/higher_order_ops/test_invoke_subgraph.py | 22 ++++++++++++++++++- torch/_dynamo/variables/higher_order_ops.py | 1 + torch/_higher_order_ops/invoke_subgraph.py | 13 +++++++++++ torch/_higher_order_ops/utils.py | 5 +++++ torch/_subclasses/fake_tensor.py | 2 ++ 5 files changed, 42 insertions(+), 1 deletion(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index da071c4d20ed..287957d9f7a6 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -853,6 +853,27 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_unbacked(self): + @mark_compile_region + def gn(x, y): + b = x.item() + torch._check_is_size(b) + torch._check(b < y.shape[0]) + return y[:b].clone() + + def fn(x, y): + return gn(x, y) + + x = torch.tensor(4) + y = torch.randn(8) + ref = fn(x, y) + torch._dynamo.config.capture_scalar_outputs = True + opt_fn = torch.compile( + fn, backend="eager", fullgraph=True + ) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default + res = opt_fn(x, y) + self.assertEqual(ref, res) + def test_bwd_partitioning(self): @mark_compile_region def gn(x, y): @@ -981,7 +1002,6 @@ def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): """, ) - @unittest.expectedFailure def test_unbacked(self): @mark_compile_region def gn(x, y): diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 6e971301687e..8eeaacccb38c 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -3129,6 +3129,7 @@ def install_subgraph_in_output_graph( # inputs have already been seen before. If yes, the subgraph is already # installed in the output graph and we can just access the subgraph # using the saved attr name. + fake_inputs = [ node.meta["example_value"] for node in body_gmod.graph.nodes diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 4ac832cc6221..16819b44c6f6 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -31,6 +31,7 @@ track_tensor_tree, ) from torch.fx.graph_module import GraphModule +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts invoke_subgraph_counter = 0 @@ -335,6 +336,18 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands): if graph is None: graph = reenter_make_fx(subgraph)(*operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(operands) + insert_deferred_runtime_asserts( + graph, + fake_mode.shape_env, + "invoke_subgraph_proxy_torch_dispatch_mode", + export=True, + ) + graph.recompile() + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") proxy_mode.tracer.root.register_module(qualname, graph) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 4fb0ca60098b..ca9884687f3c 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -15,6 +15,7 @@ disable_proxy_modes_tracing, make_fx, ) +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from torch.multiprocessing.reductions import StorageWeakRef @@ -293,6 +294,10 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): pre_dispatch=pre_dispatch, _error_on_data_dependent_ops=False, )(*inputs) + if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: + insert_deferred_runtime_asserts( + gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True + ) return gm diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index b1c52f7e1bdf..000949475bc4 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1523,6 +1523,8 @@ def _validate_cache_key( and not op.__name__.startswith("i") ): continue + if op in (torch._check, torch._check_is_size): + continue try: self._validate_cache_key(op, [], {}) except _BypassDispatchCache as e: From ab342d3793472c65aaa0b007ca13a98fc9206dc5 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov Date: Mon, 31 Mar 2025 18:10:02 +0000 Subject: [PATCH 011/332] Make PyTorch buildable by CMake-4.x on s390x (#150294) This is a continuation of https://github.com/pytorch/pytorch/pull/150203 that fixes nightly build on s390x. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150294 Approved by: https://github.com/malfet --- cmake/Dependencies.cmake | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index bd8f7792214e..13b4671e0bfb 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -816,9 +816,18 @@ if(NOT TARGET fp16 AND NOT USE_SYSTEM_FP16) set(FP16_BUILD_TESTS OFF CACHE BOOL "") set(FP16_BUILD_BENCHMARKS OFF CACHE BOOL "") - add_subdirectory( - "${FP16_SOURCE_DIR}" - "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "FP16 is only cmake-2.8 compatible") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + add_subdirectory( + "${FP16_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + unset(CMAKE_POLICY_VERSION_MINIMUM) + else() + add_subdirectory( + "${FP16_SOURCE_DIR}" + "${CONFU_DEPENDENCIES_BINARY_DIR}/FP16") + endif() elseif(NOT TARGET fp16 AND USE_SYSTEM_FP16) add_library(fp16 STATIC "/usr/include/fp16.h") set_target_properties(fp16 PROPERTIES LINKER_LANGUAGE C) From 80b7f6b70426ae329b1c99a7efb863835d1de0cb Mon Sep 17 00:00:00 2001 From: Matthew Haddock Date: Mon, 31 Mar 2025 18:24:12 +0000 Subject: [PATCH 012/332] Adjust TestInductorOpInfo to depend on backend, not device (#146911) As is the case with many inductor tests, this test adapts test criteria based on device type, where it should be adjusting for the backend registered for that device. In this particular case, using the upstream triton CPU backend would lead to failures, as reference_in_float would be true as this is required for the C++/OpenMP backend which does not have float16 support. However most triton backends do, and as such should be tested in float16. Similarly a triton backend with a device not described as a GPU would get skipped from testing entirely. A more generic solution would be ideal, but this would require a lot of work across many tests. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/146911 Approved by: https://github.com/masnesral --- test/inductor/test_torchinductor_opinfo.py | 75 +++++++++++----------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index b84640235739..765927047700 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -45,6 +45,7 @@ GPU_TYPE, HAS_CPU, HAS_CUDA, + has_triton, HAS_XPU, maybe_skip_size_asserts, ) @@ -1103,48 +1104,50 @@ def _get_tolerances(dtype): # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True, file=f) # print(f"RUNNING OP {op_name} on {device_type} with {dtype}", flush=True) rtol, atol = _get_tolerances(dtype) - if device_type == GPU_TYPE: - # opinfo test case have already place the input on the correct device - # so we don't need do additional copy by setting copy_to_gpu=False - - no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) - for context_fn, kwarg_overrides in get_contexts(has_rng_op): - with context_fn(): - adjusted_kwargs = { - "check_lowp": False, - "nopython": no_python, - "copy_to_gpu": False, - "reference_in_float": False, - "check_gradient": requires_grad, - "check_has_compiled": no_python, - "output_process_fn_grad": sample_input.output_process_fn_grad, - "atol": atol, - "rtol": rtol, - } - adjusted_kwargs.update(overridden_kwargs) - adjusted_kwargs.update(kwarg_overrides) + no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) + for context_fn, kwarg_overrides in get_contexts(has_rng_op): + with context_fn(): + # Base kwargs + adjusted_kwargs = { + "check_lowp": False, + "nopython": no_python, + "check_has_compiled": no_python, + "atol": atol, + "rtol": rtol, + } + + # Backend-specific adjustments + # Triton + if has_triton(): + adjusted_kwargs.update( + { + "copy_to_gpu": False, + "reference_in_float": False, + "check_gradient": requires_grad, + "output_process_fn_grad": sample_input.output_process_fn_grad, + } + ) + # C++ CPU backend + elif torch._inductor.config.cpu_backend == "cpp": + adjusted_kwargs.update( + { + "check_gradient": False, # Skip checking gradient on CPU for now + } + ) + + # Update with overridden kwargs and context-specific overrides + adjusted_kwargs.update(overridden_kwargs) + adjusted_kwargs.update(kwarg_overrides) + + # Call the appropriate check method based on device type + if device_type == GPU_TYPE: self.check_model_gpu( fn, args, kwargs, **adjusted_kwargs, ) - elif device_type == "cpu": - no_python, has_rng_op = do_nopython_and_has_rng(fn, args, kwargs) - for context_fn, kwarg_overrides in get_contexts(has_rng_op): - with context_fn(): - adjusted_kwargs = { - "check_lowp": False, - "nopython": no_python, - "check_has_compiled": no_python, - # skip checking gradient on CPU for now - "check_gradient": False, - "atol": atol, - "rtol": rtol, - } - adjusted_kwargs.update(overridden_kwargs) - adjusted_kwargs.update(kwarg_overrides) - + else: self.check_model( fn, args, From dfcd98e684123b0cb0a143d8718b0672c58ec268 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Mon, 31 Mar 2025 09:03:30 -0700 Subject: [PATCH 013/332] cd: Fix naming for windows arm64 libtorch builds (#150310) Apparently the magical incantation to name these correctly lies in the build_variant variable otherwise it silently does nothing. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/150310 Approved by: https://github.com/atalman --- .github/scripts/generate_ci_workflows.py | 1 + ...generated-windows-arm64-binary-libtorch-debug-nightly.yml} | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) rename .github/workflows/{generated-windows-arm64-binary-libtorch-nightly.yml => generated-windows-arm64-binary-libtorch-debug-nightly.yml} (98%) diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 306061787d58..4f29628373e4 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -294,6 +294,7 @@ class OperatingSystem: BinaryBuildWorkflow( os=OperatingSystem.WINDOWS_ARM64, package_type="libtorch", + build_variant=generate_binary_build_matrix.DEBUG, build_configs=generate_binary_build_matrix.generate_libtorch_matrix( OperatingSystem.WINDOWS_ARM64, generate_binary_build_matrix.DEBUG, diff --git a/.github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml similarity index 98% rename from .github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml rename to .github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml index a70e26c114cc..42e1e18d5dc7 100644 --- a/.github/workflows/generated-windows-arm64-binary-libtorch-nightly.yml +++ b/.github/workflows/generated-windows-arm64-binary-libtorch-debug-nightly.yml @@ -2,7 +2,7 @@ # Template is at: .github/templates/windows_arm64_binary_build_workflow.yml.j2 # Generation script: .github/scripts/generate_ci_workflows.py -name: windows-arm64-binary-libtorch +name: windows-arm64-binary-libtorch-debug on: push: @@ -17,7 +17,7 @@ on: workflow_dispatch: env: - BUILD_ENVIRONMENT: windows-arm64-binary-libtorch + BUILD_ENVIRONMENT: windows-arm64-binary-libtorch-debug GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} From 925fd4aa2e4bf702789b1371531ed4204b47cb1c Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 31 Mar 2025 21:32:20 +0000 Subject: [PATCH 014/332] [export] min/max ranges for dim hints (#149590) Differential Revision: D71522032 Adds min/max ranges to Dim.AUTO/DYNAMIC/STATIC, so users can do `Dim.AUTO(min=2, max=2048)`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149590 Approved by: https://github.com/tugsbayasgalan --- test/export/test_export.py | 91 +++++++++++++++++++++++++++++++ torch/_export/non_strict_utils.py | 61 +++++++++++++++++---- torch/export/dynamic_shapes.py | 11 ++++ 3 files changed, 153 insertions(+), 10 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 104133d379b3..5e7d9a436e3d 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2377,6 +2377,97 @@ def forward(self, x, y, z): ): export(Foo(), inputs, dynamic_shapes=shapes) + def test_dim_hint_ranges(self): + class Foo(torch.nn.Module): + def forward(self, x, y): + return x + y + + inputs = ( + torch.randn(6, 4), + torch.randn(6, 4), + ) + shapes = { + "x": (Dim.AUTO(min=4), Dim.AUTO), + "y": (Dim.DYNAMIC(max=16), Dim.AUTO(max=32)), + } + ep = export(Foo(), inputs, dynamic_shapes=shapes) + ep.module()(torch.randn(8, 5), torch.randn(8, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be >= 4, but got 3" + ): + ep.module()(torch.randn(3, 5), torch.randn(3, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 16, but got 17" + ): + ep.module()(torch.randn(17, 5), torch.randn(17, 5)) + with self.assertRaisesRegex( + RuntimeError, "Expected input at .* to be <= 32, but got 33" + ): + ep.module()(torch.randn(9, 33), torch.randn(9, 33)) + + def test_dim_hint_range_violations(self): + class Foo(torch.nn.Module): + def forward(self, xs): + x, y = xs["data"][0] + assert y.shape[0] <= 32 + return x[6:], y + 2 + + x, y = torch.randn(8), torch.randn(8) + + # conflict with lower bound + shapes = torch.export.ShapesCollection() + shapes[x] = [Dim.DYNAMIC(max=5)] + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[None, 5\], conflicting with the inferred .*" + r"\[6, int_oo\],.* for inputs\['xs'\]\['data'\]\[0\]\[0\]\.shape\[0\]", + ): + export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) + + # conflict with upper bound + shapes = torch.export.ShapesCollection() + shapes[y] = [Dim.AUTO(min=48, max=62)] + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[48, 62\], conflicting with the inferred .*" + r"\[2, 32\],.* for inputs\['xs'\]\['data'\]\[0\]\[1\]\.shape\[0\]", + ): + export(Foo(), ({"data": [[x, y]]},), dynamic_shapes=shapes) + + class Bar(torch.nn.Module): + def forward(self, x): + return x + 2 + + # conflict with static range + shapes = {"x": [Dim.STATIC(min=6, max=8)]} + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[6, 8\], conflicting with the inferred .*" + r"\[4, 4\],.* for inputs\['x'\].shape\[0\]", + ): + export(Bar(), (torch.randn(4),), dynamic_shapes=shapes) + + # multiple conflicts + class Moo(torch.nn.Module): + def forward(self, x, y): + assert x.shape[0] <= 32 + assert y.shape[0] >= 128 + return x + 2, y + 2 + + inps = (torch.randn(16), torch.randn(256)) + shapes = { + "x": (Dim.DYNAMIC(min=33),), + "y": (Dim.DYNAMIC(max=127),), + } + with self.assertRaisesRegex( + ValueError, + r"Received user-specified .* \[33, None\], conflicting with the inferred .*" + r"\[2, 32\],.* for inputs\['x'\].shape\[0\](.*\n)*.*" + r"Received user-specified .* \[None, 127\], conflicting with the inferred .*" + r"\[128, int_oo\],.* for inputs\['y'\].shape\[0\]", + ): + export(Moo(), inps, dynamic_shapes=shapes) + def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 445465bffb64..63c0cf5d30ec 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -52,6 +52,7 @@ SequenceKey, tree_map_with_path, ) +from torch.utils._sympy.numbers import int_oo if TYPE_CHECKING: @@ -362,17 +363,21 @@ def make_constraints( (used only to enumerate the user-input nodes) """ + def is_int(x: object) -> bool: + return isinstance(x, int) or ( + isinstance(x, torch.SymInt) and x.node.expr.is_number + ) + shape_env = fake_mode.shape_env assert shape_env is not None inline_constraints = gm.meta.get("inline_constraints", []) - range_constraints = { - symbol: inline_constraints[symbol] for symbol in inline_constraints - } + range_constraints = defaultdict(lambda: ValueRanges(0, int_oo)) | inline_constraints if not dynamic_shapes: - return range_constraints + return dict(range_constraints) # clean up dynamic markers from tensors - for arg in pytree.tree_flatten(combined_args)[0]: + flat_paths, flat_args = zip(*pytree.tree_flatten_with_path(combined_args)[0]) + for arg in flat_args: if isinstance(arg, torch.Tensor): _clean_dynamic_markers(arg) @@ -388,6 +393,7 @@ def make_constraints( input_dims = defaultdict(list) free_symbols = set() + range_violations = [] for input_index, node in enumerate(gm.graph.nodes): if input_index < num_lifted_inputs or node.op != "placeholder": continue @@ -397,19 +403,54 @@ def make_constraints( continue shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] for i, d in enumerate(node.meta["val"].shape): - if isinstance(d, torch.SymInt) and not d.node.expr.is_number: + dim = None + if isinstance(shape_spec, (list, tuple)): + dim = shape_spec[i] + elif isinstance(shape_spec, dict): + dim = shape_spec.get(i) + if not is_int(d): # Compute the range constraint for the symbolic expression corresponding # to this shape dimension and store it. - dim = shape_spec[i] if shape_spec else None if dim is None or isinstance(dim, _DimHint): - range_constraints[d.node.expr] = shape_env.bound_sympy(d.node.expr) + range_constraints[d.node.expr] &= shape_env.bound_sympy(d.node.expr) else: - range_constraints[d.node.expr] = ValueRanges( + range_constraints[d.node.expr] &= ValueRanges( lower=dim.min, upper=dim.max ) + input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) free_symbols.update(d.node.expr.free_symbols) + # check user-specified min/max range for DimHints; + # we might want to do this even if model tracing inferred a static dimension. + if isinstance(dim, _DimHint): + trace_vr = ( + range_constraints[d.node.expr] + if not is_int(d) + else ValueRanges(int(d), int(d)) + ) + try: + user_vr = ValueRanges( + lower=0 if dim.min is None else dim.min, + upper=int_oo if dim.max is None else dim.max, + ) + if is_int(d): + trace_vr & user_vr + else: + range_constraints[d.node.expr] &= user_vr + shape_env.var_to_range[d.node._expr] &= user_vr + except torch.utils._sympy.value_ranges.ValueRangeError: + msg = ( + f"- Received user-specified min/max range of [{dim.min}, {dim.max}], " + f"conflicting with the inferred min/max range of [{trace_vr.lower}, {trace_vr.upper}], " + f"for inputs{pytree.keystr(flat_paths[input_index])}.shape[{i}]." + ) + range_violations.append(msg) + + if range_violations: + prefix = "Found the following conflicts between user-specified ranges and inferred ranges from model tracing:\n" + raise ValueError(prefix + "\n".join(range_violations)) + for symbol in free_symbols: if symbol not in range_constraints: # Placeholders can have symbolic shapes that are derived expressions. @@ -418,7 +459,7 @@ def make_constraints( # we want to record range constraints for their root symbols. range_constraints[symbol] = shape_env.var_to_range[symbol] - return range_constraints + return dict(range_constraints) def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 285b0555034b..3b0ce63d134c 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -56,6 +56,9 @@ class _DimHintType(Enum): @dataclasses.dataclass class _DimHint: type: _DimHintType + min: Optional[int] = None + max: Optional[int] = None + _factory: Optional[bool] = True @staticmethod def AUTO(): @@ -69,6 +72,14 @@ def DYNAMIC(): def STATIC(): return _DimHint(_DimHintType.STATIC) + def __call__(self, min=None, max=None) -> "_DimHint": + if not self._factory: + raise TypeError(f"'{type(self)}' object is not callable") + assert min is None or min >= 0, "min must be non-negative" + assert max is None or max >= 0, "max must be non-negative" + assert min is None or max is None or min <= max, "min must be <= max" + return _DimHint(self.type, min=min, max=max, _factory=False) + class Dim: """ From 4e2997db73fd80530505d687260fbcb1bedae369 Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Mon, 31 Mar 2025 21:46:09 +0000 Subject: [PATCH 015/332] [ROCm][CI] Increase wheel build timeout from 210 to 240 (#150221) Fixes #150046. Increasing the timeout from 210 to 240. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150221 Approved by: https://github.com/jeffdaily --- .github/workflows/_binary-build-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index 57a66798468f..507d5419a042 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -23,7 +23,7 @@ on: description: Hardware to run this "build" job on, linux.12xlarge or linux.arm64.2xlarge. timeout-minutes: required: false - default: 210 + default: 240 type: number description: timeout for the job use_split_build: From 423e4a4568958845da52808e50d1cdd2ba7fa48d Mon Sep 17 00:00:00 2001 From: Faa Diallo Date: Mon, 31 Mar 2025 21:55:53 +0000 Subject: [PATCH 016/332] [ROCm] cmake 4 workaround for hiprtc (#150324) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150324 Approved by: https://github.com/jeffdaily, https://github.com/atalman, https://github.com/malfet --- cmake/public/LoadHIP.cmake | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 5741cf7d0952..28d15a5ea1b7 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -154,7 +154,15 @@ if(HIP_FOUND) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(hipsolver REQUIRED) - find_package_and_print_version(hiprtc REQUIRED) + # workaround cmake 4 build issue + if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") + message(WARNING "Work around hiprtc cmake failure for cmake >= 4") + set(CMAKE_POLICY_VERSION_MINIMUM 3.5) + find_package_and_print_version(hiprtc REQUIRED) + unset(CMAKE_POLICY_VERSION_MINIMUM) + else() + find_package_and_print_version(hiprtc REQUIRED) + endif() find_package_and_print_version(hipblaslt REQUIRED) if(UNIX) From 1526ff955e61f431a055286cd3f7c854109d659b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 31 Mar 2025 22:19:08 +0000 Subject: [PATCH 017/332] Revert "Add a warning when a tensor with requires_grad=True is converted to a scalar (#143261)" This reverts commit 515b45e5693dbf9dd58d8472806cbe5f49e43074. Reverted https://github.com/pytorch/pytorch/pull/143261 on behalf of https://github.com/clee2000 due to failing internal tests D72135661 ([comment](https://github.com/pytorch/pytorch/pull/143261#issuecomment-2767531682)) --- aten/src/ATen/native/Scalar.cpp | 6 ------ test/test_torch.py | 17 ----------------- 2 files changed, 23 deletions(-) diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index d790a79de83e..0053b86c3373 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -11,17 +11,11 @@ #include #endif -#include - namespace at::native { Scalar item(const Tensor& self) { auto numel = self.sym_numel(); TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar"); - if (torch::autograd::GradMode::is_enabled() && self.requires_grad()) { - TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n" - "Consider using tensor.detach() first."); - } if (self.is_sparse()) { if (self._nnz() == 0) return Scalar(0); if (self.is_coalesced()) return at::_local_scalar_dense(self._values()); diff --git a/test/test_torch.py b/test/test_torch.py index 32838e0ddf33..fa96b40774d0 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10829,23 +10829,6 @@ class MyTwoTensor4(TwoTensor): def test_bf16_supported_on_cpu(self): self.assertFalse(torch.cuda.is_bf16_supported()) - def test_tensor_with_grad_to_scalar_warning(self) -> None: - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - x = torch.tensor(2.0, requires_grad=True) - math.pow(x, 3) # calling this results in a warning - - self.assertEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, UserWarning)) - self.assertIn( - "Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.", - str(w[0].message) - ) - - _ = math.pow(x, 3) # calling it again does not result in a second warning - self.assertEqual(len(w), 1) # The following block extends TestTorch with negative dim wrapping tests # FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests From 91666eef6043d35c0bd09c6e9a0a05a9f909be43 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 31 Mar 2025 22:40:23 +0000 Subject: [PATCH 018/332] Update gloo submodule (#150320) That updates its CMake minimum version(via https://github.com/facebookincubator/gloo/pull/424 ) and removes cmake-4.0.0 workarounds for gloo Pull Request resolved: https://github.com/pytorch/pytorch/pull/150320 Approved by: https://github.com/atalman --- cmake/Dependencies.cmake | 10 +--------- third_party/gloo | 2 +- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 13b4671e0bfb..1df6b350b9b1 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1215,15 +1215,7 @@ if(USE_GLOO) set(NCCL_EXTERNAL ON) endif() set(GLOO_USE_CUDA_TOOLKIT ON CACHE BOOL "" FORCE) - if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") - # Remove me when https://github.com/facebookincubator/gloo/pull/424 is landed - message(WARNING "Downgrading cmake-policy-version for gloo build") - set(CMAKE_POLICY_VERSION_MINIMUM 3.5) - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) - unset(CMAKE_POLICY_VERSION_MINIMUM) - else() - add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) - endif() + add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/gloo) # Here is a little bit hacky. We have to put PROJECT_BINARY_DIR in front # of PROJECT_SOURCE_DIR with/without conda system. The reason is that # gloo generates a new config.h in the binary diretory. diff --git a/third_party/gloo b/third_party/gloo index 95ca2af4e4c7..e348db90d867 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit 95ca2af4e4c76433fac8911525d8a0142b7a5289 +Subproject commit e348db90d8677277e926c14c94ee2acfa77173d4 From 981048854da154eae8ff0bd439e72e1256ae00da Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Mon, 31 Mar 2025 06:44:37 -0700 Subject: [PATCH 019/332] Merge Triton ScaledMM as epilogue to MM template (#150045) Previously, scaled_mm's (FP8 matmul) Triton lowering for inductor was in a separate template. This PR consolidates that lowering into the mm template, with an added epilogue to deal with multiplying the scales. This paves the way for future scaled variants of BMM, Grouped GEMM in inductor. Currently, there is still a separate template for TMA+persistent version of scaled_mm. The current mm lowering has a separate template for TMA + Persistent version. Will hopefully consolidate the extra scaled_mm TMA+persistent template when the consolidation for the mm template is done. TODO: Consolidate TMA+Persistent logic into 1 template and remove separate scaled_mm TMA template Pull Request resolved: https://github.com/pytorch/pytorch/pull/150045 Approved by: https://github.com/drisspg --- torch/_inductor/kernel/mm.py | 381 ++++++++++++++++- torch/_inductor/kernel/mm_common.py | 70 ++++ torch/_inductor/kernel/mm_scaled.py | 608 ---------------------------- torch/_inductor/utils.py | 8 +- 4 files changed, 454 insertions(+), 613 deletions(-) delete mode 100644 torch/_inductor/kernel/mm_scaled.py diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ffa1531efd42..2a52d7fc0135 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Optional +from typing import Any, Optional import torch from torch._dynamo.utils import counters @@ -21,10 +21,16 @@ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.wrapper import PythonWrapperCodegen from ..ir import FlexibleLayout, is_triton -from ..lowering import register_lowering +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, + realize_inputs, TritonTemplate, ) from ..utils import ( @@ -46,6 +52,8 @@ mm_options, persistent_mm_grid, persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, should_fallback_to_aten, ) @@ -119,7 +127,11 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + + if USE_FAST_ACCUM: + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + else: + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -188,7 +200,10 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -305,6 +320,179 @@ """, ) +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + # prevent duplication registration of extern functions @functools.lru_cache(None) @@ -326,6 +514,10 @@ def lazy_register_extern_choice(fn): has_out_variant=False, ) +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) @@ -336,6 +528,16 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**13 +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -347,6 +549,32 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + aten_bias_addmm = ExternKernelChoice(bias_addmm, None) @@ -746,6 +974,151 @@ def tuned_sparse_semi_structured_mm( ) +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if k == 16 and config.kwargs["BLOCK_M"] >= 64: + continue # Triton crashes in this case + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and k < 32: + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + if should_fallback_to_aten(choices): + return aten_choice.output_node() + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + @functools.lru_cache(None) def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index d990536c4362..663e78dc199c 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -76,6 +76,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): GROUP_M=8, EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=acc_type(layout.dtype), num_stages=config.num_stages, num_warps=config.num_warps, @@ -92,6 +93,47 @@ def persistent_mm_options(mat1, mat2): ) +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + return mm_template_options + + def mm_args( mat1, mat2, @@ -154,6 +196,34 @@ def epilogue(acc, bias): return epilogue +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + def _is_static_problem(layout: Layout) -> tuple[bool, bool]: """ Check if input tensors and output layout have static shapes and non-zero sizes. diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py deleted file mode 100644 index aa917e120168..000000000000 --- a/torch/_inductor/kernel/mm_scaled.py +++ /dev/null @@ -1,608 +0,0 @@ -import functools -import logging -from collections.abc import Sequence -from typing import Any, Optional - -import sympy - -import torch -from torch._dynamo.utils import counters -from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate -from torch.utils._triton import has_triton_tma_device - -from ..config import triton as triton_config -from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox -from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering -from ..select_algorithm import ( - autotune_select_algorithm, - ExternKernelChoice, - realize_inputs, - TritonTemplate, -) -from ..utils import ( - get_num_sms, - get_tma_workspace_arg, - TMA_DESCRIPTOR_SIZE, - use_aten_gemm_kernels, - use_ck_gemm_template, - use_triton_template, -) -from ..virtualized import V -from .mm_common import ( - _is_static_problem, - mm_args, - mm_grid, - persistent_mm_grid, - should_fallback_to_aten, -) - - -log = logging.getLogger(__name__) -aten = torch.ops.aten - -load_scales = r""" -@triton.jit -def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): - if SCALING_ROWWISE: - # For row-wise scaling, we'll return the pointers - return a_scale_ptr, b_scale_ptr - else: - # For per-tensor scaling, we'll load the scalar values - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr) - return a_scale, b_scale -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALING_ROWWISE: - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -device_tma = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALING_ROWWISE: - stride_a_scale_m = 1 - stride_b_scale_n = 1 - else: - stride_a_scale_m = 0 - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - - -scaled_mm_device_tma_template = TritonTemplate( - name="scaled_mm_device_tma", - grid=persistent_mm_grid, - source=device_tma + load_scales + apply_scaling, -) - - -scaled_mm_template = TritonTemplate( - name="scaled_mm", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -# Inductor does not allow optional tensor input arguments currently (pass None as an -# input node to template choices), but since for _scaled_mm there is only one such arg -# (bias), work around by having a second template when bias is provided. -scaled_mm_bias_template = TritonTemplate( - name="scaled_mm_bias", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # bias - bias = tl.load(bias_ptr + rn, mask=rn < N) - acc += bias - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -aten__fp8_mm = ExternKernelChoice( - torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out -) - - -def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool: - # Same sized scales are compatable - if len(size_a) == len(size_b): - return True - - # Both need to be scalars or len(1) tensors - if len(size_a) <= 1 and len(size_b) <= 1: - return True - - return False - - -def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: - def is_row_major(stride: Sequence[_IntLike]) -> bool: - return stride[1] == 1 - - def is_col_major(stride: Sequence[_IntLike]) -> bool: - return stride[0] == 1 - - def has_zero_dim(size: Sequence[_IntLike]) -> bool: - return bool(size[0] == 0 or size[1] == 0) - - # Check mat_a (self) stride requirements - torch._check( - is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), - lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", - ) - - # Check mat_b stride requirements - torch._check( - is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), - lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", - ) - - -def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - TMA_SIZE=TMA_DESCRIPTOR_SIZE, - NUM_SMS=get_num_sms(), - **config.kwargs, - ) - - -def scaled_mm_options( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - **config.kwargs, - ) - - -add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) - - -def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool: - available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul - # _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes - # When K is 16, BLOCK_K = 16 and is not valid - min_k = k >= 32 - return available and min_k and not has_bias - - -@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] -def tuned_scaled_mm( - mat_a: TensorBox, - mat_b: TensorBox, - scale_a: TensorBox, - scale_b: TensorBox, - bias: Optional[TensorBox] = None, - scale_result: Optional[TensorBox] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - layout: Optional[Layout] = None, -) -> TensorBox: - m, n, k, layout, mat_a, mat_b = mm_args( - mat_a, mat_b, layout=layout, out_dtype=out_dtype - ) - - # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 - log.info( - "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", - m, - n, - k, - mat_a.get_dtype(), - mat_b.get_dtype(), - layout, - ) - - device_type = get_device_type(mat_a) - - check_supported_striding(mat_a, mat_b) - - scale_a, scale_b = realize_inputs(scale_a, scale_b) - - input_nodes: tuple[Any, ...] - # workaround for Inductor not supporting optional tensor input arguments - if bias is None: - input_nodes = (mat_a, mat_b, scale_a, scale_b) - triton_template = scaled_mm_template - else: - bias = realize_inputs(bias) - input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) - triton_template = scaled_mm_bias_template - - aten_choice = aten__fp8_mm.bind( - input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum - ) - - choices: list[ChoiceCaller] = [] - if use_aten_gemm_kernels(): - choices.append(aten_choice) - - _, is_nonzero = _is_static_problem(layout) - - scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) - scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( - device_type - ) - - if is_nonzero and use_triton_template(layout, enable_float8=True): - if use_persistent_tma(k, bias is not None): - for config in scaled_persistent_mm_configs(m, n, k): - kwargs = scaled_mm_options_device_tma( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - input_nodes = (mat_a, mat_b, scale_a, scale_b) - scaled_mm_device_tma_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - workspace_arg=get_tma_workspace_arg( - num_tma_descriptors=2, - device=mat_a.get_device(), - ), - **kwargs, - ) - else: - for config in scaled_mm_configs(m, n, k): - if k == 16 and config.kwargs["BLOCK_M"] >= 64: - continue # Triton crashes in this case - - # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid - # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape - if using_b200() and k < 32: - continue - - kwargs = scaled_mm_options( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - # possibly appends a TritonTemplateCaller to choices - triton_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - ) - - if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) - - if should_fallback_to_aten(choices): - return aten_choice.output_node() - - return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) - - -@functools.lru_cache -def using_b200() -> bool: - """Returns true if the device is a NVIDIA B200, otherwise returns false.""" - if not torch.cuda.is_available(): - return False - # compute capability 10.0 or 10.0a is NVIDIA B200 - device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) - return device_properties.major == 10 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e93ed88bcbda..bca3f024d134 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1377,7 +1377,7 @@ def _is_tma_compatible(x: IRNode) -> bool: return False dtype = x.get_dtype() - if dtype not in (torch.float16, torch.bfloat16): + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False layout = x.get_layout() @@ -1388,6 +1388,12 @@ def _is_tma_compatible(x: IRNode) -> bool: inner_dim = layout.size[1] if transposed: inner_dim = layout.size[0] + + if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + inner_dim, 32 + ): + return False + inner_bytes = inner_dim * dtype.itemsize return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) From 982a7f7db0a1068ed80384319a1515f512bc27f8 Mon Sep 17 00:00:00 2001 From: Shiyan Deng Date: Mon, 31 Mar 2025 23:23:03 +0000 Subject: [PATCH 020/332] [cachinghostallocator] remove the check on cudaHostRegister path (#150070) Summary: In the cudaHostAlloc path, the flag we used is `cudaHostAllocDefault` [0] which don't really have this strict enforcement (devicePtr retrieved from ` cudaHostGetDevicePointer(()` point to the same addr as the hostPtr) according to the guide [1]. This diff removes the check so that the host register path works for ROCm. [0]https://github.com/pytorch/pytorch/blob/6aca002d82e5131cbf48496a04e7b0213ace1c03/aten/src/ATen/cuda/CachingHostAllocator.cpp#L97 [1] https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1gb65da58f444e7230d3322b6126bb4902 Test Plan: test_pinned_memory_with_cudaregister tests Differential Revision: D71932562 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150070 Approved by: https://github.com/jeffdaily --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 8a039ea3bff9..8e084aec2a0c 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -185,21 +185,6 @@ struct CUDACachingHostAllocatorImpl } } - void registerPages(const void* ptr, size_t size) { - AT_CUDA_CHECK( - cudaHostRegister((void*)ptr, (size_t)size, cudaHostRegisterDefault)); - - // If host and device pointer don't match, give a warning and exit - void* devptr = nullptr; - AT_CUDA_CHECK(cudaHostGetDevicePointer(&devptr, (void*)ptr, 0)); - TORCH_CHECK( - (void*)devptr == (void*)ptr, - "Host and device pointer dont match with cudaHostRegister. " - "Please dont use this feature by setting " - "PYTORCH_CUDA_ALLOC_CONF=use_cuda_host_register:False (default)", - ""); - } - void allocWithCudaHostRegister(void** ptr, size_t roundSize) { // Here we do regular allocation, pre-fault/map the pages, and then do // cudaHostRegister with GPU mapping flags to lock the pages, so we @@ -249,7 +234,8 @@ struct CUDACachingHostAllocatorImpl } // Register the mapped pages using cudaHostRegister - registerPages(*ptr, roundSize); + AT_CUDA_CHECK( + cudaHostRegister(*ptr, roundSize, cudaHostRegisterDefault)); } }; From a2070e2fd5ce7f7bf83852bb4c26fdfb605558d9 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Mon, 31 Mar 2025 19:25:39 +0000 Subject: [PATCH 021/332] [AOTInductor] Free tensors in test (#150274) Summary: This PR frees tensor that were new-ed within the test itself to prevent memory leak. Test Plan: Fixing tests itself. Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/150274 Approved by: https://github.com/chenyang78 --- test/cpp/aoti_inference/test.cpp | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 8a9c36db683c..9861fd6bdead 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -230,6 +230,16 @@ void test_aoti_constants_update( actual_output_tensors = runner->run(input_tensors); ASSERT_FALSE( torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : missing_map) { + delete pair.second; + } + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } void test_aoti_extract_constants_map(const std::string& device) { @@ -395,6 +405,13 @@ void test_aoti_double_buffering( runner->swap_constant_buffer(); actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } #if defined(USE_CUDA) || defined(USE_ROCM) @@ -435,6 +452,10 @@ void test_aoti_double_buffering_with_tensor_constants() { runner->swap_constant_buffer(); actual_output_tensors = runner->run(input_tensors); ASSERT_TRUE(torch::allclose(ref_output_tensors[0], actual_output_tensors[0])); + + for (auto& pair : real_map) { + delete pair.second; + } } void test_aoti_free_buffer(bool use_runtime_constant_folding) { @@ -584,6 +605,13 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } ASSERT_EQ(initMemory + DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1); ASSERT_EQ(2 * FOLDEDDATASIZE, active1 - active2); + + for (auto& pair : rand_map) { + delete pair.second; + } + for (auto& pair : real_map) { + delete pair.second; + } } class ThreadPool { From b48505a8a10f0cc1a827a34ce2e8452d846db9c1 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Mon, 31 Mar 2025 23:30:19 +0000 Subject: [PATCH 022/332] [MPS] Add support for hermite_polynomial_h. (#150279) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150279 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: Aaron Gokaslan --- .../native/mps/kernels/BinaryKernel.metal | 10 +++++ .../native/mps/operations/BinaryKernel.mm | 7 ++++ aten/src/ATen/native/native_functions.yaml | 2 +- c10/metal/special_math.h | 37 +++++++++++++++++++ test/test_mps.py | 2 +- 5 files changed, 56 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 12b78f32e96e..eb2a038b16be 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -75,6 +75,13 @@ struct chebyshev_polynomial_w_functor { } }; +struct hermite_polynomial_h_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(c10::metal::hermite_polynomial_h_forward(a, b)); + } +}; + struct nextafter_functor { #if __METAL_VERSION__ < 310 template @@ -164,6 +171,8 @@ REGISTER_BINARY_OP(chebyshev_polynomial_v, float, float); REGISTER_BINARY_OP(chebyshev_polynomial_v, half, half); REGISTER_BINARY_OP(chebyshev_polynomial_w, float, float); REGISTER_BINARY_OP(chebyshev_polynomial_w, half, half); +REGISTER_BINARY_OP(hermite_polynomial_h, float, float); +REGISTER_BINARY_OP(hermite_polynomial_h, half, half); #if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(copysign, bfloat, bfloat); @@ -176,6 +185,7 @@ REGISTER_BINARY_OP(chebyshev_polynomial_t, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_u, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat); +REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat); #endif // Complex binary functions diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 92f3aa011ae0..35a3ec81ca07 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -110,6 +110,12 @@ static void chebyshev_polynomial_w_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "chebyshev_polynomial_w"); } +static void hermite_polynomial_h_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), + "hermite_polynomial_h_mps not implemented for non-floating types"); + lib.exec_binary_kernel(iter, "hermite_polynomial_h"); +} + static void polar_mps_kernel(TensorIterator& iter) { lib.exec_binary_kernel(iter, "polar"); } @@ -128,6 +134,7 @@ static void complex_mps_kernel(TensorIterator& iter) { REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_mps_kernel) +REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_mps_kernel) REGISTER_DISPATCH(polar_stub, &polar_mps_kernel); REGISTER_DISPATCH(complex_stub, &complex_mps_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 1336122e2fd6..e3a1cd175c86 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15262,7 +15262,7 @@ - func: special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck dispatch: - CPU, CUDA: special_hermite_polynomial_h_out + CPU, CUDA, MPS: special_hermite_polynomial_h_out python_module: special structured_inherits: TensorIteratorBase structured: True diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 26e1da619e35..1b60563b205d 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -1716,5 +1716,42 @@ float chebyshev_polynomial_w_forward(T x, int64_t n) { return r; } // chebyshev_polynomial_w_forward(T x, int64_t n) +template +// TODO: Add 512 if/when double will be supported in Metal +inline constexpr int getHermitianLimit() { + return 128; +} + +template +inline float hermite_polynomial_h_forward(T x, int64_t n) { + if (n < 0) { + return 0.0; + } + + if (n == 0) { + return 1.0; + } + + if (n == 1) { + return x + x; + } + + if (n > getHermitianLimit()) { + return NAN; + } + + float p = 1.0; + float q = x + x; + float r = 0.0; + + for (int64_t k = 2; k < n + n; k += 2) { + r = (x + x) * q - k * p; + p = q; + q = r; + } + + return r; +} // hermite_polynomial_h_forward(T x, int64_t n) + } // namespace metal } // namespace c10 diff --git a/test/test_mps.py b/test/test_mps.py index f5c9befc492a..c10381331d2e 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -657,7 +657,6 @@ def mps_ops_modifier(ops): 'sparse.mmreduce': None, 'special.airy_ai': None, 'special.erfcx': None, - 'special.hermite_polynomial_h': None, 'special.hermite_polynomial_he': None, 'special.laguerre_polynomial_l': None, 'special.log_ndtr': None, @@ -714,6 +713,7 @@ def mps_ops_modifier(ops): 'special.zeta': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 'special.chebyshev_polynomial_t': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], 'special.chebyshev_polynomial_u': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], + 'special.hermite_polynomial_h': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], # entr does not support boolean types 'special.entr': [torch.bool], From c75dac5f5c34f2f7d7eefefc1ce57de73a01f499 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 31 Mar 2025 23:58:37 +0000 Subject: [PATCH 023/332] Fix typo (#150363) Fixes https://github.com/pytorch/pytorch/issues/150339 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150363 Approved by: https://github.com/atalman, https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ce521d594fa8..7aa659565ae7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -4750,7 +4750,7 @@ c10::DeviceIndex ProcessGroupNCCL::guessDeviceId() const { devIdx, " as device used by this process is currently unknown. ", "This can potentially cause a hang if this rank to GPU mapping is incorrect. ", - "You can pecify device_id in init_process_group() to force use of a particular device."); + "You can specify device_id in init_process_group() to force use of a particular device."); return static_cast(devIdx); } From 49b7d0d84d42b5ab204214fc15067fba48a3ef93 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Tue, 1 Apr 2025 00:30:32 +0000 Subject: [PATCH 024/332] [ROCm] Enable more inductor UTs (#149513) Primarily enable inductor fp8 tests, also enable other inductor tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/149513 Approved by: https://github.com/jeffdaily --- test/dynamo/test_ctx_manager.py | 3 +- test/inductor/test_aot_inductor_package.py | 8 +--- test/inductor/test_cuda_repro.py | 10 +++-- test/inductor/test_fp8.py | 20 +++++----- test/inductor/test_max_autotune.py | 4 +- test/inductor/test_torchinductor.py | 10 ++--- ...st_torchinductor_codegen_dynamic_shapes.py | 10 ++--- test/inductor/test_triton_kernels.py | 37 +++++++------------ 8 files changed, 41 insertions(+), 61 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 44edc5305e14..4e4af9341e75 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -16,7 +16,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - TEST_WITH_ROCM, ) @@ -659,7 +658,7 @@ def fn(a_float32, b_float32): self.assertTrue(same(ref, res)) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION or TEST_WITH_ROCM, + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Can't run fused SDPA on this platform", ) def test_autocast_sdpa(self): diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 2d9d7cdb1b80..09398c2c59d1 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -19,12 +19,7 @@ from torch._inductor.test_case import TestCase from torch._inductor.utils import fresh_inductor_cache from torch.export import Dim -from torch.testing._internal.common_utils import ( - IS_FBCODE, - skipIfRocm, - skipIfXpu, - TEST_CUDA, -) +from torch.testing._internal.common_utils import IS_FBCODE, skipIfXpu, TEST_CUDA from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -183,7 +178,6 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) @unittest.skipIf(IS_FBCODE, "cmake won't work in fbcode") - @skipIfRocm # build system may be different @skipIfXpu # build system may be different def test_compile_after_package(self): if not self.package_cpp_only: diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index e46dfff708ab..2f28257731af 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -37,12 +37,17 @@ DeterministicGuard, freeze_rng_state, IS_FBCODE, - skipIfRocm, TEST_WITH_ASAN, + TEST_WITH_ROCM, xfailIfPy312Plus, ) +if TEST_WITH_ROCM: + config.force_layout_optimization = 1 + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + + DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -187,7 +192,6 @@ def f(q, k, v, mask): self.assertEqual(out, f(*inputs)) - @skipIfRocm def test_input_channels_last(self): m = torch.nn.Sequential( torch.nn.Conv2d(3, 3, 1, 1), @@ -1403,7 +1407,6 @@ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): fn(*args) torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address - @skipIfRocm def test_non_commutative_scan_op(self): from torch._higher_order_ops.associative_scan import associative_scan @@ -1450,7 +1453,6 @@ def outer_reduce(x): self.assertEqual(outer_reduce(a), out) self.assertTrue("for roffset" not in code) - @skipIfRocm def test_scaled_dot_product_efficient_attention_backward(self): from torch import nn, Tensor diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 64086e5071c6..e208565081a1 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -12,7 +12,6 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.utils._triton import has_triton_tma_device @@ -118,7 +117,6 @@ def _fix_fp8_dtype_for_rocm( @instantiate_parametrized_tests class TestFP8Types(TestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): """ @@ -129,6 +127,7 @@ def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): We should not pick a XBLOCK larger than xnumel """ + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") def f(x): return x.to(dtype=float8_dtype) @@ -139,7 +138,6 @@ def f(x): torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("dtype", (torch.float16, torch.bfloat16)) def test_eager_fallback(self, dtype: torch.dtype): weight_shape = (32, 16) @@ -247,7 +245,6 @@ def fp8_saturated(x, dtype): torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) @@ -303,7 +300,6 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @@ -413,7 +409,6 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): @instantiate_parametrized_tests class TestFP8Lowering(TestCase): - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.bfloat16, torch.float32)) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @@ -435,6 +430,7 @@ def test_tensorwise_scaling( device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) shape = [int(dim) for dim in shape.split(",")] M, K, N = shape # Matmul Y = X [M, K] x W [N, K] @@ -491,7 +487,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # setting a small absolute tolerance in these tests torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @@ -506,6 +501,7 @@ def test_rowwise_scaling( dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) shape = [int(dim) for dim in shape.split(",")] M, K, N = shape # Matmul Y = X [M, K] x W [N, K] @@ -557,7 +553,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @@ -573,6 +568,7 @@ def test_tensorwise_scaling_acceptable_input_dims( use_fast_accum = True device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -615,7 +611,6 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @@ -630,6 +625,7 @@ def test_rowwise_scaling_acceptable_input_dims( use_fast_accum = True device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -674,13 +670,14 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_unacceptable_input_dims(self): # for compiled ops, type checking is in torch/_meta_registrations.py dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) + M, K, N = 64, 15, 2048 # K needs to be a multiple of 16 x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) @@ -714,12 +711,13 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): in str(cm.exception) ) - @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_unacceptable_scale_dims_rowwise_scaling(self): dtype: torch.dtype = torch.bfloat16 device = "cuda" dtype_float8 = torch.float8_e4m3fn + dtype_float8 = _fix_fp8_dtype_for_rocm(dtype_float8, device) + M, K, N = 233, 32, 128 x = torch.randn(M, K, dtype=dtype, device=device) w = torch.randn(N, K, dtype=dtype, device=device) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index a62711196c88..499dcbf4ae47 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -45,7 +45,7 @@ from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu +from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU @@ -672,7 +672,7 @@ def fn(x, number): torch._export.aot_compile(fn, args=inputs) @config.patch(autotune_local_cache=False, autotune_remote_cache=False) - @skipIfRocm + @runOnRocmArch(MI300_ARCH) def test_precompilations(self): def fn(a, b, c): a = (a @ b) @ c diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index acb7fc2e12ec..693292057c96 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -142,6 +142,10 @@ HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines +if TEST_WITH_ROCM: + torch._inductor.config.force_layout_optimization = 1 + os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1" + aten = torch.ops.aten requires_multigpu = functools.partial( @@ -2061,7 +2065,6 @@ def fn(a): self.common(fn, (inp.view(10, -1),), rtol=1e-4, atol=1e-5, check_lowp=False) @skipCUDAIf(not SM80OrLater, "Requires sm80") - @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") @skip_if_gpu_halide # accuracy issue def test_split_cumsum_low_prec(self): if is_cpp_backend(self.device): @@ -2133,7 +2136,6 @@ def fn(a): self.common(fn, (inp,), atol=1e-5, rtol=1e-4, check_lowp=False) @skipCUDAIf(not SM80OrLater, "Requires sm80") - @skipCUDAIf(TEST_WITH_ROCM, "Computation not done in float on ROCm") @skip_if_gpu_halide # accuracy issue def test_split_cumprod_low_prec(self): if is_cpp_backend(self.device): @@ -2172,7 +2174,6 @@ def fn(a, b): self.common(fn, (a, b), atol=1e-5, rtol=1e-5, check_lowp=False) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -2232,7 +2233,6 @@ def fn(a, b, dim): r"triton_.*\.run\(arg[01]_1, arg[12]_1, buf1," ).check_not("run(").run(code[0]) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -2260,7 +2260,6 @@ def argmax_combine(a, b): actual = associative_scan(argmax_combine, (a, idx), 0) self.assertEqual(expect, actual) - @skipCUDAIf(TEST_WITH_ROCM, "associative_scan is not supported on ROCm") @skip_if_halide # scan ops # TODO: support lifted symints when dynamic @torch._dynamo.config.patch( @@ -10446,7 +10445,6 @@ def forward(self, arg0_1, arg1_1): eager_out = eager_mod(*eager_args) self.assertEqual(inductor_out, eager_out) - @skipIfRocm def test_require_stride_expanded(self): def forward(arg6, arg7, arg16): convolution = torch.ops.aten.convolution( diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index ee8e22193f41..c090b7b7846f 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -261,9 +261,6 @@ def run(*ex, **kwargs): ), "test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_custom_op_3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), - "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( - ("cuda", "xpu") if IS_LINUX else ("cpu", "cuda", "xpu") - ), "test_cat_uint8_dynamic_shapes": TestFailure( ("cpu",) ), # cat on uint8 input is using aten fallback on cpu @@ -383,11 +380,12 @@ def run(*ex, **kwargs): **dynamic_shapes_test_failures, } -if TEST_WITH_ROCM: +if not TEST_WITH_ROCM: test_failures.update( { - "test_split_cumsum_low_prec_dynamic_shapes": TestFailure(("cpu", "cuda")), - "test_split_cumprod_low_prec_dynamic_shapes": TestFailure(("cpu", "cuda")), + "test_custom_op_fixed_layout_sequential_dynamic_shapes": TestFailure( + ("cuda", "xpu") if IS_LINUX else ("cpu", "cuda", "xpu") + ), } ) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 3f495042a392..39b220290d03 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -25,13 +25,7 @@ from torch._library import capture_triton from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - parametrize, - skipIfRocm, - skipIfWindows, - skipIfXpu, - TEST_WITH_ROCM, -) +from torch.testing._internal.common_utils import parametrize, skipIfWindows, skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA, HAS_GPU, HAS_XPU from torch.testing._internal.logging_utils import log_settings, logs_to_string @@ -44,23 +38,22 @@ import triton from triton import language as tl - if not TEST_WITH_ROCM: - if HAS_CUDA: - try: - from triton.language.extra.libdevice import ( # @manual - fast_dividef, - fast_dividef as my_fast_dividef, - ) - except ImportError: - from triton.language.extra.cuda.libdevice import ( # @manual - fast_dividef, - fast_dividef as my_fast_dividef, - ) - elif HAS_XPU: - from triton.language.extra.intel.libdevice import ( # @manual + if HAS_CUDA: + try: + from triton.language.extra.libdevice import ( # @manual fast_dividef, fast_dividef as my_fast_dividef, ) + except ImportError: + from triton.language.extra.cuda.libdevice import ( # @manual + fast_dividef, + fast_dividef as my_fast_dividef, + ) + elif HAS_XPU: + from triton.language.extra.intel.libdevice import ( # @manual + fast_dividef, + fast_dividef as my_fast_dividef, + ) def _triton_get_ast_equal_to_str(params): try: @@ -1341,7 +1334,6 @@ def f(x, y): self.assertEqual(compiled_out, eager_out) @requires_gpu - @skipIfRocm def test_triton_kernel_with_imported_symbol(self): @triton.jit def add_kernel_with_imported_symbol( @@ -1373,7 +1365,6 @@ def f(x): self.assertEqual(compiled_out, eager_out) @requires_gpu - @skipIfRocm def test_triton_kernel_with_imported_symbol_with_custom_name(self): @triton.jit def add_kernel_with_imported_symbol( From 4ce0b959ff42c38a51aba9cfe1a4ec16cfbd61d2 Mon Sep 17 00:00:00 2001 From: Joshua Hamilton Date: Tue, 1 Apr 2025 00:42:46 +0000 Subject: [PATCH 025/332] Add a warning when a tensor with requires_grad=True is converted to a scalar (#143261) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #143071 Operations performed on tensors with `requires_grad=True` such as ```python import torch x = torch.tensor(2.0, requires_grad=True) y = x ** 3 ``` and ```python x = torch.tensor(2.0, requires_grad=True) y = torch.pow(x,3) ``` are valid operations. While an operation using `numpy` like ```python import numpy as np x = torch.tensor(2.0, requires_grad=True) y = np.pow(x,3) # > RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. ``` leads to an error. However, an operation that uses `math` like ```python import math x = torch.tensor(2.0, requires_grad=True) y = math.pow(x,3) ``` does not cause an error, and `y` is no longer a tensor with a gradient! This represents a [footgun](https://en.wiktionary.org/wiki/footgun#Noun) for some users, like myself when training small, custom, non-neural network models. To prevent future undesired behavior, I added a warning when converting tensors with `requires_grad=True` to scalars. Now, when using `math.pow` on a `tensor`, we get a single warning with: ```python x = torch.tensor(2.0, requires_grad=True) y = math.pow(x,3) # > UserWarning: Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior. # Consider using tensor.detach() first. ``` Please let me know if you have any questions 👍 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143261 Approved by: https://github.com/malfet Co-authored-by: albanD Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/native/Scalar.cpp | 6 ++++++ test/test_torch.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp index 0053b86c3373..de56c906d004 100644 --- a/aten/src/ATen/native/Scalar.cpp +++ b/aten/src/ATen/native/Scalar.cpp @@ -11,11 +11,17 @@ #include #endif +#include + namespace at::native { Scalar item(const Tensor& self) { auto numel = self.sym_numel(); TORCH_CHECK(numel == 1, "a Tensor with ", numel, " elements cannot be converted to Scalar"); + if (at::GradMode::is_enabled() && self.requires_grad()) { + TORCH_WARN_ONCE("Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.\n" + "Consider using tensor.detach() first."); + } if (self.is_sparse()) { if (self._nnz() == 0) return Scalar(0); if (self.is_coalesced()) return at::_local_scalar_dense(self._values()); diff --git a/test/test_torch.py b/test/test_torch.py index fa96b40774d0..32838e0ddf33 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10829,6 +10829,23 @@ class MyTwoTensor4(TwoTensor): def test_bf16_supported_on_cpu(self): self.assertFalse(torch.cuda.is_bf16_supported()) + def test_tensor_with_grad_to_scalar_warning(self) -> None: + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + x = torch.tensor(2.0, requires_grad=True) + math.pow(x, 3) # calling this results in a warning + + self.assertEqual(len(w), 1) + self.assertTrue(issubclass(w[0].category, UserWarning)) + self.assertIn( + "Converting a tensor with requires_grad=True to a scalar may lead to unexpected behavior.", + str(w[0].message) + ) + + _ = math.pow(x, 3) # calling it again does not result in a second warning + self.assertEqual(len(w), 1) # The following block extends TestTorch with negative dim wrapping tests # FIXME: replace these with OpInfo sample inputs or systemic OpInfo tests From 7ab8532cf1ac31ce47a6813d8508c6c4db031441 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 31 Mar 2025 15:50:31 -0700 Subject: [PATCH 026/332] [BE] Get rid of cross-compile and x86 build options for Mac (#150362) As both cross-compilation and x86 builds has been removed a while back Remove stale TODO about building with OpenMP support Pull Request resolved: https://github.com/pytorch/pytorch/pull/150362 Approved by: https://github.com/atalman, https://github.com/clee2000 --- .ci/pytorch/macos-build.sh | 27 +++------------------------ 1 file changed, 3 insertions(+), 24 deletions(-) diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index 4a2f63a2ed10..4e1c68be9282 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -33,23 +33,6 @@ if which sccache > /dev/null; then export PATH="${tmp_dir}:$PATH" fi -cross_compile_arm64() { - # Cross compilation for arm64 - # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests - # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 - USE_DISTRIBUTED=0 CMAKE_OSX_ARCHITECTURES=arm64 MACOSX_DEPLOYMENT_TARGET=11.0 USE_MKLDNN=OFF USE_QNNPACK=OFF WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel -} - -compile_arm64() { - # Compilation for arm64 - # TODO: Compile with OpenMP support (but this causes CI regressions as cross-compilation were done with OpenMP disabled) - USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel -} - -compile_x86_64() { - USE_DISTRIBUTED=0 WERROR=1 python setup.py bdist_wheel --plat-name=macosx_10_9_x86_64 -} - build_lite_interpreter() { echo "Testing libtorch (lite interpreter)." @@ -71,16 +54,12 @@ build_lite_interpreter() { print_cmake_info if [[ ${BUILD_ENVIRONMENT} = *arm64* ]]; then - if [[ $(uname -m) == "arm64" ]]; then - compile_arm64 - else - cross_compile_arm64 - fi + # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests + # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 + USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel elif [[ ${BUILD_ENVIRONMENT} = *lite-interpreter* ]]; then export BUILD_LITE_INTERPRETER=1 build_lite_interpreter -else - compile_x86_64 fi if which sccache > /dev/null; then From 0f12951fc2005cd5b3ee13a877567215eb5f4425 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Tue, 1 Apr 2025 01:00:11 +0000 Subject: [PATCH 027/332] [Intel gpu] always set deterministic for xpu accuracy test (#149028) On Intel Max 1550, models like Super_SloMo can actually pass accuracy test after set deterministic, because we do not use atomic in upsampling bilinear backward in some cases when running on XPU. Furthermore, I guess the only reason not to set deterministic on these models is just avoiding errors. We should use warn_only = True. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149028 Approved by: https://github.com/guangyey, https://github.com/desertfire Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- benchmarks/dynamo/common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index a354501a75ae..c7785940ddc7 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3542,6 +3542,8 @@ def run(runner, args, original_dir=None): }: # some of the models do not support use_deterministic_algorithms torch.use_deterministic_algorithms(True) + if args.devices == ["xpu"]: + torch.use_deterministic_algorithms(True, warn_only=True) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # TODO(eqy): revisit when cuBLASLt workspace size is bumped # if args.only is not None and args.only in { From 5cb5675f1390474781c0b9cfdeb7bdcc45f89c8e Mon Sep 17 00:00:00 2001 From: "Sun, Jiayi" Date: Sun, 23 Mar 2025 20:13:15 -0700 Subject: [PATCH 028/332] [Inductor] optimize the heuristics of parallel reduction (#149614) Fix https://github.com/pytorch/pytorch/issues/148639. Summary: Optimize the heuristics of parallel reduction: When the number of steps of the first inner loop beyond the maximum parallel depth is much larger than the number of steps of all outer loops within the maximum parallel depth, change the starting depth of parallelism to the first inner loop and recalculate the maximum parallel depth. I ran the Inductor benchmark with this PR on CPU. A timm model poolformer_m36 BF16 has about 25% performance improvement, and no performance regression is seen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149614 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel --- torch/_inductor/codegen/cpp.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 60d151d59b10..23134b7916a7 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -5457,20 +5457,21 @@ def max_parallel_depth(self): start_depth = 0 max_depth = 0 is_reduction = self.loops[0].is_reduction - loop_sizes = sympy.Integer(1) + num_steps = sympy.Integer(1) for loop in self.loops: if loop.is_reduction != is_reduction: break - loop_sizes = loop_sizes * loop.size + num_steps = num_steps * FloorDiv(loop.size, loop.steps) max_depth += 1 - # When the range of the first inner loop is much larger than the range of all outer loops, - # change `start_depth` to the first inner loop and recalculate `max_depth`. + # When the number of steps of the first inner loop is much larger than the number of steps of + # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`. if ( max_depth < len(self.loops) - and isinstance(loop_sizes, sympy.Integer) + and isinstance(num_steps, sympy.Integer) and isinstance(self.loops[max_depth].size, sympy.Integer) - and loop_sizes * 300 < self.loops[max_depth].size + and num_steps * 300 + < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps) ): start_depth = max_depth max_depth = 0 From 6470b373c16017f5cb8f1aa4060bb60632b18160 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 1 Apr 2025 01:33:37 +0000 Subject: [PATCH 029/332] `torch.backends.mkldnn.flags()` CM should not warn (#150358) By returning `None` rather than `False` from `THPModule_allowTF32OneDNN` when USE_XPU is not defined Added regression test Fixes https://github.com/pytorch/pytorch/issues/149829 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150358 Approved by: https://github.com/atalman --- test/test_mkldnn.py | 11 +++++++++++ torch/csrc/Module.cpp | 4 ++++ 2 files changed, 15 insertions(+) diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 93858f10b5c9..19772c4adaa5 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -4,6 +4,7 @@ import itertools import functools import unittest +import warnings from contextlib import nullcontext try: @@ -1612,6 +1613,16 @@ def common(self, shape1, shape2, op, dtype): ]: common(self, shape1, shape2, op, dtype) + def test_mkldnn_setflags_nowarn(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/149829 + with warnings.catch_warnings(record=True) as w: + rc = torch.backends.mkldnn.set_flags() + # torch.backends.mkldnn. returns previously set flags + # That one should be able to set back without cauinsg a warning + torch.backends.mkldnn.set_flags(*rc) + # Above should trigger no warnings regardless of configuration + self.assertEqual(len(w), 0) + instantiate_device_type_tests(TestMkldnn, globals(), only_for=('cpu',)) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index c67953fc45e2..3612e94a0d02 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -965,10 +965,14 @@ static PyObject* THPModule_setAllowTF32OneDNN( static PyObject* THPModule_allowTF32OneDNN( PyObject* _unused, PyObject* noargs) { +#ifdef USE_XPU if (at::globalContext().allowTF32OneDNN()) Py_RETURN_TRUE; else Py_RETURN_FALSE; +#else + Py_RETURN_NONE; +#endif } static PyObject* THPModule_deterministicAlgorithms( From 827b730f4e1cf172c7ba228a2efd268149163d52 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 1 Apr 2025 02:33:43 +0000 Subject: [PATCH 030/332] [CI] Skip test_copy_large_tensor on M2-15 runners (#150377) They have more than 12Gb memory, but may be running this test causes OOM in CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/150377 Approved by: https://github.com/atalman --- test/test_mps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_mps.py b/test/test_mps.py index c10381331d2e..e27a78785be7 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7476,6 +7476,7 @@ def compare_mm(m, n, k, dtype=torch.float): @unittest.skipIf(total_memory < 12_000_000_000, "Needs at least 12Gb RAM to run the test") @unittest.skipIf(MACOS_VERSION < 14.0, "Can't allocate 4Gb tensor on MacOS 13") + @unittest.skipIf(IS_CI, "May be fixes https://github.com/pytorch/pytorch/issues/149999") def test_copy_large(self): """ Test that copy of 4Gb+ tensors works """ x = torch.ones((2**30 + 11,), dtype=torch.float32) From 31634b8c6ac5ae25e792c0c98cb35209d14367eb Mon Sep 17 00:00:00 2001 From: Phillip Liu Date: Tue, 1 Apr 2025 03:07:55 +0000 Subject: [PATCH 031/332] [fr] Added protection against missing stack frames in fr cont. (#150133) Summary: Previously we had D70358287, which didn't fully resolved the issue. Test Plan: # FR `buck2 run @//mode/opt //caffe2/fb/flight_recorder:fr_trace -- --mast_job_id f710320638-TrainingApplication --mast_job_version 0 --mast_job_attempt 0 --bucket tlcm_log_blob --world_size 128 --dump_file_name_offset 0 --allow-incomplete-ranks` Confirm no error # FR analyzer `buck2 run @//mode/opt //investigations/dr_patternson/analyzers/ai_observability:ai_observability-all-analyzers-cli -- flight_recorder_analyzer --mast_job_name f710320638-TrainingApplication --mast_job_version 0 --mast_job_attempt 0` Confirm no error Differential Revision: D71998980 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150133 Approved by: https://github.com/fduwjj --- tools/flight_recorder/components/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index d396551f7cdf..dbad6a93790c 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -224,7 +224,7 @@ def __init__(self, entry: dict[str, Any], expected_ranks: set[int]) -> None: self.input_sizes = entry["input_sizes"] self.output_sizes = entry["output_sizes"] self.collective_state = entry["state"] - self.collective_frames = entry["frames"] + self.collective_frames = entry.get("frames", []) self.expected_ranks = expected_ranks self.missing_ranks: set[int] self.input_numel: int @@ -316,7 +316,7 @@ def to_collective( output_sizes=entry["output_sizes"], expected_ranks=self.expected_ranks, collective_state=entry["state"], - collective_frames=entry["frames"], + collective_frames=entry.get("frames", []), type_of_mismatch=error, ) return Collective( From ce52674b7651921630019de62323ee0bfd69516d Mon Sep 17 00:00:00 2001 From: Stonepia Date: Tue, 1 Apr 2025 04:43:07 +0000 Subject: [PATCH 032/332] [Doc] Update CMAKE_PREFIX_PATH for XPU windows README (#148863) We found that the `pip install cmake` and `conda install cmake` has different behavior. The reason is that the pip installed one doesn't find the corresponding libs under conda env. So we need to set the `CMAKE_PREFIX_PATH` for alignment. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148863 Approved by: https://github.com/CuiYifeng, https://github.com/malfet Co-authored-by: Cui, Yifeng --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index 299cca8e34fd..fcdf761295ae 100644 --- a/README.md +++ b/README.md @@ -355,6 +355,16 @@ Please make sure [the common prerequisites](#prerequisites) as well as [the prer Then PyTorch can be built with the command: ```cmd +:: CMD Commands: +:: Set the CMAKE_PREFIX_PATH to help find corresponding packages +:: %CONDA_PREFIX% only works after `conda activate custom_env` + +if defined CMAKE_PREFIX_PATH ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library;%CMAKE_PREFIX_PATH%" +) else ( + set "CMAKE_PREFIX_PATH=%CONDA_PREFIX%\Library" +) + python setup.py develop ``` From 790d459f85a312781a6ff54d159cb670f611869b Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 31 Mar 2025 15:46:04 -0700 Subject: [PATCH 033/332] [dynamo] add error message for unsupported LOAD_BUILD_CLASS (#150323) Improved error message for https://github.com/pytorch/pytorch/issues/128942 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150323 Approved by: https://github.com/jansel, https://github.com/zou3519 --- test/dynamo/test_error_messages.py | 39 ++++++++++++++++++++++++------ torch/_dynamo/symbolic_convert.py | 11 +++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 8db90f836310..255d62a5c4ff 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -11,6 +11,7 @@ import torch._dynamo.test_case import torch.utils._pytree as python_pytree from torch._dynamo.exc import Unsupported +from torch._dynamo.testing import skipIfNotPy312 from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( IS_FBCODE, @@ -646,18 +647,42 @@ def fn(): """, ) - def test_unsupported_bytecode(self): + def test_load_build_class(self): def fn(): class Foo: pass return Foo + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile(fn, backend="eager", fullgraph=True)(), + """\ +LOAD_BUILD_CLASS bytecode not supported + Explanation: Dynamo does not support tracing classes that are defined in the compiled region. + Hint: Move the class definition out of the compiled region. + Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. + + Developer debug context: + + +from user code: + File "test_error_messages.py", line N, in fn + class Foo:""", + ) + + @skipIfNotPy312 + def test_unsupported_bytecode(self): + async def fn(): + async for i in range(3): + print(i) + return 1 + def post_munge(s): s = re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s) s = re.sub( - r"Instruction\(.*opname='LOAD_BUILD_CLASS'.*\)\n", - "Instruction(LOAD_BUILD_CLASS)", + r"Instruction\(.*opname='GET_AITER'.*\)\n", + "Instruction(GET_AITER)", s, ) return s @@ -667,15 +692,15 @@ def post_munge(s): lambda: torch.compile(fn, backend="eager", fullgraph=True)(), """\ Missing bytecode handler - Explanation: Dynamo does not know how to handle the bytecode instruction `LOAD_BUILD_CLASS`. - Hint: Do not trace code that produces the `LOAD_BUILD_CLASS` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics). + Explanation: Dynamo does not know how to handle the bytecode instruction `GET_AITER`. + Hint: Do not trace code that produces the `GET_AITER` bytecode instruction (see https:/docs.python.org/3/library/dis.html for bytecode semantics). Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: LOAD_BUILD_CLASS with args (, Instruction(LOAD_BUILD_CLASS) + Developer debug context: GET_AITER with args (, Instruction(GET_AITER) from user code: File "test_error_messages.py", line N, in fn - class Foo:""", + async for i in range(3):""", post_munge=post_munge, ) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 2ceb1368f7a7..0d8b37d7ce6c 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2824,6 +2824,17 @@ def MATCH_KEYS(self, inst): def LOAD_ASSERTION_ERROR(self, inst): self.load_builtin_from_argval("AssertionError") + def LOAD_BUILD_CLASS(self, inst): + unimplemented_v2( + gb_type="LOAD_BUILD_CLASS bytecode not supported", + context="", + explanation="Dynamo does not support tracing classes that are defined in the compiled region.", + hints=[ + "Move the class definition out of the compiled region.", + *graph_break_hints.SUPPORTABLE, + ], + ) + UNARY_POSITIVE = stack_op(operator.pos) UNARY_NEGATIVE = stack_op(operator.neg) UNARY_NOT = stack_op(operator.not_) From 7e7e5698cc885583433360e8d64b5f497a9608e2 Mon Sep 17 00:00:00 2001 From: Tugsbayasgalan Manlaibaatar Date: Sun, 30 Mar 2025 19:27:11 -0700 Subject: [PATCH 034/332] Suppress more warnings (#149833) Differential Revision: [D71702307](https://our.internmc.facebook.com/intern/diff/D71702307) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149833 Approved by: https://github.com/malfet, https://github.com/Skylion007 --- torch/_export/passes/lift_constants_pass.py | 13 +++++++++---- torch/export/_unlift.py | 7 ++++++- torch/fx/experimental/proxy_tensor.py | 8 ++++---- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 77255c8d07d7..734f8cd33786 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import collections -import warnings +import logging from typing import Any, Union import torch @@ -19,6 +19,9 @@ from torch.fx.graph_module import _get_attr +log = logging.getLogger(__name__) + + class ConstantAttrMap(collections.abc.MutableMapping): """A mapping class that understands how to use module constants (tensors, ScriptObjects, FakeScriptObjects) as keys. We store tensors and FakeScriptObjects normally, @@ -213,9 +216,11 @@ def lift_constants_pass( elif isinstance(constant_val, torch.Tensor): # Remove the parameterness of constant_val if isinstance(constant_val, torch.nn.Parameter): - warnings.warn( - f"{node.target} created when tracing {node.meta.get('stack_trace', '')} is a parameter. But" - f"it's not registered with register_parameter(). export will treat it as a constant tensor" + log.debug( + "%s created when tracing %s is a parameter. But " + "it's not registered with register_parameter(). export will treat it as a constant tensor", + str(node.target), + str(node.meta.get("stack_trace", "")), ) # We get the real data out of the parameter by disabling the surrounding fake mode. with unset_fake_temporarily(): diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 0caf82160054..e51c12800ad9 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -105,7 +105,12 @@ def _unlift_inputs_as_getattr( else: with gm.graph.inserting_after(input_node): - getattr_node = gm.graph.get_attr(lifted_node) + # It is fine to ignore this warning because + # it is guaranteed that we will populate this + # attr later. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + getattr_node = gm.graph.get_attr(lifted_node) input_node.replace_all_uses_with(getattr_node) metadata = input_node.meta gm.graph.erase_node(input_node) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 45e3309208e9..eb2cb0a81f7a 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -14,7 +14,6 @@ import traceback import typing import typing_extensions -import warnings import weakref from collections import defaultdict, OrderedDict from collections.abc import Generator, Mapping, Sequence @@ -1820,11 +1819,12 @@ def call_module( try: return Tracer.call_module(self, m, forward, args, kwargs) except _ModuleNotInstalledAsSubmoduleError: - warnings.warn( - f"Unable to find the path of the module {m}. " + log.debug( + "Unable to find the path of the module %s. " "This might be because the module was not properly registered " "as a submodule, which is not good practice. We will trace " - "through the module without recording stack information." + "through the module without recording stack information.", + str(m), ) return forward(*args, **kwargs) From 414b9ae016f828c314faeafc4c86a111db414afa Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 1 Apr 2025 05:36:04 +0000 Subject: [PATCH 035/332] enable out variant of 2-shot reduction (#150153) Per title, this version uses symm mem input both as input source and as a work buffer, so input is modified after the end (similar to what fbgemm car reduction does). It is intended to be wrapped in an op that would first copy the real inputs to symm mem buffers that wouldn't be exposed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150153 Approved by: https://github.com/xw285cornell --- test/distributed/test_symmetric_memory.py | 53 +++--- .../c10d/CUDASymmetricMemoryOps.cu | 170 +++++++++++++++--- .../csrc/distributed/c10d/SymmetricMemory.cpp | 4 + 3 files changed, 182 insertions(+), 45 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index 34b8ed5a7b10..8dce15728056 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -1,5 +1,6 @@ # Owner(s): ["module: c10d"] +import itertools import os from unittest import skipIf @@ -881,34 +882,38 @@ def test_one_shot_all_reduce( @skipIfRocm @skip_if_lt_x_gpu(4) - @parametrize("dtype", [torch.float, torch.bfloat16]) - @parametrize("align_bytes", [4, 8, 16]) - @parametrize("size_bytes", [4, 8192, 8196]) - def test_two_shot_all_reduce( - self, dtype: torch.dtype, size_bytes: int, align_bytes: int - ) -> None: + def test_two_shot_all_reduce(self) -> None: self._init_process() group_name = dist.group.WORLD.group_name - t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) - symm_mem.rendezvous(t, group=group_name) - - self.assertTrue(t.data_ptr() % 16 == 0) - self.assertTrue(align_bytes % t.element_size() == 0) - self.assertTrue(size_bytes % t.element_size() == 0) - - shift = align_bytes // t.element_size() - numel = size_bytes // t.element_size() - res = t[shift : shift + numel] - res.normal_() - inp = res.clone() - - torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) + for dtype, size_bytes, align_bytes, inplace in itertools.product( + [torch.float, torch.bfloat16], + [4, 8192, 8196], + [4, 8, 16], + [True, False], + ): + t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) + + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + res = t[shift : shift + numel] + res.normal_().fill_(1) + inp = res.clone() + if not inplace: + out = torch.empty_like(inp) + torch.ops.symm_mem.two_shot_all_reduce_out(res, "sum", group_name, out) + else: + torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) - # Head and tail should not be written - self.assertTrue(t[:shift].eq(0).all().item()) - self.assertTrue(t[shift + numel :].eq(0).all().item()) - self._verify_all_reduce_result(inp, res) + # Head and tail should not be written + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_all_reduce_result(inp, res if inplace else out) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index 438624f4bc07..566cfbe5817f 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -39,6 +39,16 @@ } \ } +#define DISPATCH_WORLD_SIZES_NO_DEFAULT(world_size, ...) \ + switch (world_size) { \ + INT_SWITCH_CASE(k_world_size, 8, __VA_ARGS__); \ + INT_SWITCH_CASE(k_world_size, 4, __VA_ARGS__); \ + INT_SWITCH_CASE(k_world_size, 2, __VA_ARGS__); \ + default: { \ + TORCH_CHECK(false, "Not implemented for world_size=", world_size); \ + } \ + } + #define DISPATCH_ALIGNMENTS_16_8_4(alignment, ...) \ switch (alignment) { \ INT_SWITCH_CASE(k_alignment, 16, __VA_ARGS__); \ @@ -493,6 +503,70 @@ constexpr size_t two_shot_all_reduce_max_num_threads = 512; template static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ void two_shot_all_reduce_kernel( + T** input_ptrs, + T* output_ptr, + size_t input_offset, + size_t numel, + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + static_assert(alignment % sizeof(T) == 0); + constexpr size_t numel_per_thread = alignment / sizeof(T); + + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); + + const size_t numel_per_rank = + at::round_up(numel, alignment * world_size) / world_size; + const size_t start = numel_per_rank * rank; + + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; + auto stride = blockDim.x * gridDim.x * numel_per_thread; + for (size_t i = offset; i < numel_per_rank; i += stride) { + if (start + i >= numel) { + continue; + } + auto vec = load_and_reduce( + input_ptrs, rank, world_size, input_offset + start + i); + // store to local buffer + st_vec(input_ptrs[rank] + input_offset + start + i, vec); + } + + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); + for (size_t i = offset; i < numel_per_rank; i += stride) { + Vec tmp[k_world_size]; +#pragma unroll k_world_size + for (size_t step = 0; step < k_world_size; ++step) { + size_t remote_rank = (rank + step) % k_world_size; + size_t remote_start = numel_per_rank * remote_rank; + if (remote_start + i >= numel) { + continue; + } + tmp[step] = ld_vec( + input_ptrs[remote_rank] + input_offset + remote_start + i); + } +#pragma unroll k_world_size + for (size_t step = 0; step < k_world_size; ++step) { + size_t remote_rank = (rank + step) % k_world_size; + size_t remote_start = numel_per_rank * remote_rank; + if (remote_start + i >= numel) { + continue; + } + st_vec( + output_ptr + remote_start + i, tmp[step]); + } + } + // need to make sure all blocks exit simultaneously so that the data + // is not corrupted by the subsequent kernels + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); +} + +template +static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ + void two_shot_all_reduce_kernel_inplace( T** input_ptrs, size_t input_offset, size_t numel, @@ -528,8 +602,9 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ sync_remote_blocks(signal_pads, rank, world_size); } -at::Tensor two_shot_all_reduce_( +at::Tensor two_shot_all_reduce_impl( at::Tensor input, + std::optional output, std::string reduce_op, std::string group_name) { TORCH_CHECK( @@ -546,6 +621,14 @@ at::Tensor two_shot_all_reduce_( const size_t alignment = get_and_verify_alignment(input, "two_shot_all_reduce"); + if (output.has_value()) { + const size_t output_alignment = + get_and_verify_alignment(*output, "two_shot_all_reduce"); + TORCH_CHECK( + alignment <= output_alignment, + "two_shot_all_reduce: output alignment must be equal to or larger than input."); + } + int num_blocks = 0, num_threads = 0; init_elementwise_launch_config( input.numel(), @@ -557,30 +640,73 @@ at::Tensor two_shot_all_reduce_( num_blocks, num_threads); - AT_DISPATCH_FLOAT_AND_BFLOAT16( - input.scalar_type(), "two_shot_all_reduce", [&]() { - DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { - DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { - two_shot_all_reduce_kernel - <<>>( - reinterpret_cast( - symm_mem->get_buffer_ptrs_dev()), - input.storage_offset(), - input.numel(), - reinterpret_cast( - symm_mem->get_signal_pad_ptrs_dev()), - symm_mem->get_rank(), - symm_mem->get_world_size()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + if (!output.has_value()) { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel_inplace< + scalar_t, + k_alignment, + k_world_size> + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); - }); - return input; + return input; + } else { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + output->data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + return *output; + } +} + +at::Tensor two_shot_all_reduce_( + at::Tensor input, + std::string reduce_op, + std::string group_name) { + return two_shot_all_reduce_impl(input, std::nullopt, reduce_op, group_name); } +at::Tensor two_shot_all_reduce_out( + at::Tensor input, + std::string reduce_op, + std::string group_name, + at::Tensor output) { + return two_shot_all_reduce_impl(input, output, reduce_op, group_name); +} } // namespace #endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 @@ -713,6 +839,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("one_shot_all_reduce", ::one_shot_all_reduce); m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); + m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); + m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm); #endif m.impl("stream_write_value32_", ::stream_write_value32_); diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 0308f2f5c4b2..9d400395e073 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -233,6 +233,10 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); + // note this implementation also modified the input tensor + m.def( + "two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)"); + // An mm that supports consuming asynchronous input. It guarantees the // following rasterization order, and that the corresponding signal arrives // before an input chunk is consumed. From 17005992668f3f6e25761930e6514de435922b13 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 1 Apr 2025 05:36:41 +0000 Subject: [PATCH 036/332] Add one_shot_all_reduce_copy to allow non-symm-mem allocated tensors to be reduced (#150129) Per title, we want to be able to use it even if inputs are not registered. Separate copy would add latency, and one-shot is all about the lowest possible latency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150129 Approved by: https://github.com/xw285cornell --- test/distributed/test_symmetric_memory.py | 36 ++++++--- .../c10d/CUDASymmetricMemoryOps.cu | 79 ++++++++++++++++--- .../csrc/distributed/c10d/SymmetricMemory.cpp | 14 ++++ 3 files changed, 106 insertions(+), 23 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index 8dce15728056..b5e961276f87 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -861,22 +861,32 @@ def test_multimem_one_shot_all_reduce( @skipIfRocm @skip_if_lt_x_gpu(4) - @parametrize("dtype", [torch.float, torch.bfloat16]) - @parametrize("align_bytes", [4, 8, 16]) - @parametrize("size_bytes", [4, 8192, 8196]) - def test_one_shot_all_reduce( - self, dtype: torch.dtype, size_bytes: int, align_bytes: int - ) -> None: + def test_one_shot_all_reduce(self) -> None: self._init_process() group_name = dist.group.WORLD.group_name - inp = symm_mem.empty( - size_bytes // dtype.itemsize, dtype=dtype, device=self.device - ).normal_() - symm_mem.rendezvous(inp, group=group_name) - - res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name) - self._verify_all_reduce_result(inp, res) + for dtype, size_bytes, align_bytes, copy, offset in itertools.product( + [torch.float, torch.bfloat16], + [4, 8192, 8196], + [4, 8, 16], + [True, False], + [0, 16], + ): + inp = symm_mem.empty( + size_bytes // dtype.itemsize + offset, dtype=dtype, device=self.device + ) + symm_mem.rendezvous(inp, group=group_name) + if not copy: + inp.normal_() + res = torch.ops.symm_mem.one_shot_all_reduce( + inp[offset:], "sum", group_name + ) + if copy: + local_inp = torch.randn_like(inp[offset:]) + res = torch.ops.symm_mem.one_shot_all_reduce_copy( + inp[offset:], local_inp, "sum", group_name + ) + self._verify_all_reduce_result(local_inp if copy else inp[offset:], res) dist.destroy_process_group() diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index 566cfbe5817f..02baeb51e51c 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -397,7 +397,7 @@ at::Tensor multimem_all_gather_out( // One-shot all-reduce is register-intensive because it stages values loaded // from peers in registers before performing reduction. Setting the thread // count to 512 to prevent/alleviate register spill. -constexpr size_t one_shot_all_reduce_max_num_blocks = 8; +constexpr size_t one_shot_all_reduce_max_num_blocks = 24; constexpr size_t one_shot_all_reduce_max_num_threads = 512; template @@ -405,6 +405,7 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ void one_shot_all_reduce_kernel( T** input_ptrs, T* output_ptr, + T* input_ptr, size_t input_offset, size_t numel, uint32_t** signal_pads, @@ -412,12 +413,18 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ size_t world_size) { static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - - sync_remote_blocks(signal_pads, rank, world_size); - __syncthreads(); - + // copy input to shared ptr auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; auto stride = blockDim.x * gridDim.x * numel_per_thread; + if (input_ptr) { + for (size_t i = offset; i < numel; i += stride) { + Vec vec_st = ld_vec(input_ptr + i); + st_vec(input_ptrs[rank] + input_offset + i, vec_st); + } + } + // TODO make it sync with one block for no-copy case + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); for (size_t i = offset; i < numel; i += stride) { auto vec = load_and_reduce( @@ -426,11 +433,12 @@ static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ } __syncthreads(); - sync_remote_blocks(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); } -at::Tensor one_shot_all_reduce_out( +at::Tensor one_shot_all_reduce_out_impl( const at::Tensor& input, + const std::optional& local_input, std::string reduce_op, std::string group_name, at::Tensor out) { @@ -440,11 +448,21 @@ at::Tensor one_shot_all_reduce_out( out.is_contiguous(), "one_shot_all_reduce: output must be contiguous."); TORCH_CHECK( out.sizes() == input.sizes(), - "one_shot_all_reduce: input/output size mismatch."); + "one_shot_all_reduce: input/output size mismatch, input.sizes(): ", + input.sizes(), + ", output.sizes(): ", + out.sizes()); TORCH_CHECK( reduce_op == "sum", "one_shot_all_reduce: only sum is supported for now."); - + if (local_input.has_value()) { + TORCH_CHECK( + local_input->is_contiguous(), + "one_shot_all_reduce: local input must be contiguous."); + TORCH_CHECK( + local_input->numel() <= input.numel(), + "one_shot_all_reduce: local input size must be smaller than symm buffer size."); + } auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); TORCH_CHECK( symm_mem != nullptr, @@ -452,6 +470,13 @@ at::Tensor one_shot_all_reduce_out( const size_t alignment = get_and_verify_alignment(input, "one_shot_all_reduce"); + if (local_input.has_value()) { + const size_t local_alignment = + get_and_verify_alignment(*local_input, "one_shot_all_reduce"); + TORCH_CHECK( + alignment == local_alignment, + "one_shot_all_reduce: local input and symm buffer must have the same alignment."); + } int num_blocks = 0, num_threads = 0; init_elementwise_launch_config( @@ -476,6 +501,8 @@ at::Tensor one_shot_all_reduce_out( reinterpret_cast( symm_mem->get_buffer_ptrs_dev()), out.data_ptr(), + local_input.has_value() ? local_input->data_ptr() + : nullptr, input.storage_offset(), input.numel(), reinterpret_cast( @@ -489,12 +516,42 @@ at::Tensor one_shot_all_reduce_out( return out; } +at::Tensor one_shot_all_reduce_out( + const at::Tensor& input, + std::string reduce_op, + std::string group_name, + at::Tensor out) { + return one_shot_all_reduce_out_impl( + input, std::nullopt, reduce_op, group_name, out); +} + +at::Tensor one_shot_all_reduce_copy_out( + const at::Tensor& input, + const at::Tensor& local_input, + std::string reduce_op, + std::string group_name, + at::Tensor out) { + return one_shot_all_reduce_out_impl( + input, local_input, reduce_op, group_name, out); +} + at::Tensor one_shot_all_reduce( const at::Tensor& input, std::string reduce_op, std::string group_name) { auto out = at::empty_like(input); - return one_shot_all_reduce_out(input, reduce_op, group_name, out); + return one_shot_all_reduce_out_impl( + input, std::nullopt, reduce_op, group_name, out); +} + +at::Tensor one_shot_all_reduce_copy( + const at::Tensor& input, + const at::Tensor& local_input, + std::string reduce_op, + std::string group_name) { + auto out = at::empty_like(local_input); + return one_shot_all_reduce_out_impl( + input, local_input, reduce_op, group_name, out); } constexpr size_t two_shot_all_reduce_max_num_blocks = 24; @@ -838,6 +895,8 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("multimem_all_gather_out", ::multimem_all_gather_out); m.impl("one_shot_all_reduce", ::one_shot_all_reduce); m.impl("one_shot_all_reduce_out", ::one_shot_all_reduce_out); + m.impl("one_shot_all_reduce_copy", ::one_shot_all_reduce_copy); + m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 9d400395e073..76eb7205a398 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -217,6 +217,14 @@ at::Tensor one_shot_all_reduce_meta( return at::empty_like(input); } +at::Tensor one_shot_all_reduce_copy_meta( + const at::Tensor& symm_buffer, + const at::Tensor& local_input, + std::string reduce_op, + std::string group_name) { + return at::empty_like(local_input); +} + TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); @@ -230,6 +238,11 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { "one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor"); m.def( "one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "one_shot_all_reduce_copy(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name) -> Tensor"); + m.def( + "one_shot_all_reduce_copy_out(Tensor symm_buffer, Tensor local_input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)"); + m.def( "two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)"); @@ -256,6 +269,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { TORCH_LIBRARY_IMPL(symm_mem, Meta, m) { m.impl("one_shot_all_reduce", one_shot_all_reduce_meta); + m.impl("one_shot_all_reduce_copy", one_shot_all_reduce_copy_meta); } } // namespace From 36f2d0aabacf7953eca2f3b1f6801ba9af663892 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Fri, 21 Mar 2025 10:44:01 +0800 Subject: [PATCH 037/332] Add "xpu" to __all__ for torch/version.py (#149695) As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149695 Approved by: https://github.com/desertfire, https://github.com/guangyey --- tools/generate_torch_version.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/generate_torch_version.py b/tools/generate_torch_version.py index a33ea171edbb..a10d87faf938 100644 --- a/tools/generate_torch_version.py +++ b/tools/generate_torch_version.py @@ -97,7 +97,9 @@ def get_torch_version(sha: str | None = None) -> str: with open(version_path, "w") as f: f.write("from typing import Optional\n\n") - f.write("__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip']\n") + f.write( + "__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n" + ) f.write(f"__version__ = '{version}'\n") # NB: This is not 100% accurate, because you could have built the # library code with DEBUG, but csrc without DEBUG (in which case From 48e9ffc873429c650a996f841a22d692e8ac2956 Mon Sep 17 00:00:00 2001 From: Prajesh Praveen Anchalia Date: Tue, 1 Apr 2025 08:55:51 +0000 Subject: [PATCH 038/332] Unify on dynamo_compile as the overall wait counter (#150293) Summary: dynamo_compile for the most part has been accounting for compile time except autotuning. all_compilation_types had earlier been injected on fx_codegen_and_compile, which was incorrect. Add autotuining to dynamo and deprcate all_compilation_types counter. Differential Revision: D72145447 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150293 Approved by: https://github.com/masnesral, https://github.com/jamesjwu --- torch/_dynamo/convert_frame.py | 5 ++--- .../_aot_autograd/runtime_wrappers.py | 5 ++++- torch/_inductor/compile_fx.py | 10 ---------- torch/_inductor/runtime/triton_heuristics.py | 20 ++++++++++++------- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index a31b1f7e59c3..44d19986707d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -774,9 +774,6 @@ def compile_inner( dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ) ) - stack.enter_context( - _WaitCounter("pytorch.wait_counter.dynamo_compile").guard() - ) stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) stack.enter_context(CompileTimeInstructionCounter.record()) return _compile_inner(code, one_graph, hooks, transform) @@ -957,7 +954,9 @@ def count_args(code: CodeType) -> int: chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), + _WaitCounter("pytorch.wait_counter.entire_forward_compile").guard(), metrics_context, + _WaitCounter("pytorch.wait_counter.dynamo_compile").guard(), ): restart_reasons: set[str] = set() # This is shared across restarts diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index cc7be374e35a..6e1d3a714c12 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -31,6 +31,7 @@ from torch._prims_common import CUDARngStateHelper from torch._subclasses import FakeTensor from torch.fx.experimental._backward_state import BackwardState +from torch.monitor import _WaitCounter from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -2225,7 +2226,9 @@ def _backward_impl(ctx, all_args): dynamo_compile_column_us="backward_cumulative_compile_time_us", log_waitcounter=True, waitcounter_name_override="entire_backward_compile", - ): + ), _WaitCounter( + "pytorch.wait_counter.dynamo_compile" + ).guard(): CompileEventLogger.compilation_metric(is_forward=False) # See Note: [Backward graph lazy lowering] CompiledFunction.compiled_bw = aot_config.bw_compiler( diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index a94d224a2d8a..0e86da9e94d8 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -620,15 +620,6 @@ def compile_fx_inner( dynamo_compile_column_us="inductor_cumulative_compile_time_us", ) ) - # NB: Why is this the dynamo_compile counter? The rule here is that - # if it gets an entry in the dynamo_compile table, we also want to - # tick up the wait counter. We have to displeasingly manually trigger - # the counter here because we may dropped into compile_fx directly - # from lazy backwards compilation. - stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) - stack.enter_context( - _WaitCounter("pytorch.wait_counter.all_compilation_types").guard() - ) if torch._dynamo.callback_handler.prevent_duplicate_callbacks: stack.enter_context(torch._dynamo.callback_handler.install_callbacks()) @@ -691,7 +682,6 @@ def _compile_fx_inner( with ( _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _, - _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(), ): use_cache = ( not config.force_disable_caches diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 4e1a67139f38..be02d43c28f8 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -31,6 +31,7 @@ import torch from torch._prims_common import compute_required_storage_length +from torch.monitor import _WaitCounter from torch.utils._ordered_set import OrderedSet from ..triton_bundler import TritonBundler @@ -815,13 +816,18 @@ def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: return self.maybe_clone_args(OrderedSet(), *args, **kwargs) def benchmark_all_configs(self, *args, **kwargs): - with dynamo_timed( - "CachingAutotuner.benchmark_all_configs", - log_pt2_compile_event=True, - metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, - dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", - compile_id=self.compile_id, - is_backward=self.is_backward, + with ( + dynamo_timed( + "CachingAutotuner.benchmark_all_configs", + log_pt2_compile_event=True, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ), + _WaitCounter("pytorch.wait_counter.dynamo_compile").guard(), ): timings = { launcher: self.bench(launcher, *args, **kwargs) From a10b765bf159a86fb2a0ad693c6b72e0c691e60b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 1 Apr 2025 02:18:46 +0800 Subject: [PATCH 039/332] [pytree] add APIs to determine a class is a namedtuple or PyStructSequence (#113257) Changes in this PR: 1. Add `is_structseq` and `is_structseq_class` functions to determine a object or a class is PyStructSequence. 2. Add a generic class `structseq` which can be used as the registration key for PyStructSequence types like `namedtuple` for Named Tuple types. 3. Change `is_namedtuple` to accept subclasses of namedtuple to be namedtuple. Before this PR, only namedtuple class directly created by `collections.namedtuple` or `typing.NamedTuple` were namedtuple classes while their subclasses were not. This PR makes `is_namedtuple` return true for subclasses of namedtuple class. Resolves #75982. New tests are included in this PR. - #75982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113257 Approved by: https://github.com/zou3519 --- benchmarks/dynamo/common.py | 2 +- test/test_pytree.py | 169 +++++++++++++++-- torch/_dynamo/polyfills/pytree.py | 3 +- torch/_export/serde/serialize.py | 2 +- torch/autograd/forward_ad.py | 12 +- .../testing/_internal/composite_compliance.py | 24 ++- torch/utils/_cxx_pytree.py | 17 +- torch/utils/_pytree.py | 173 +++++++++++++++--- 8 files changed, 345 insertions(+), 57 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index c7785940ddc7..7905a12b1d10 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1397,7 +1397,7 @@ def load(cls, model, example_inputs): # see https://github.com/pytorch/pytorch/issues/113029 example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs) - if pytree._is_namedtuple_instance(example_outputs): + if pytree.is_namedtuple_instance(example_outputs): typ = type(example_outputs) pytree._register_namedtuple( typ, diff --git a/test/test_pytree.py b/test/test_pytree.py index 4560ac6e69ed..99dfba3969ea 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -6,6 +6,7 @@ import re import subprocess import sys +import time import unittest from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict from dataclasses import dataclass @@ -731,6 +732,133 @@ def test_pytree_serialize_bad_input(self, pytree_impl): with self.assertRaises(TypeError): pytree_impl.treespec_dumps("random_blurb") + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_is_namedtuple(self, pytree): + DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) + + class DirectNamedTuple2(NamedTuple): + x: int + y: int + + class IndirectNamedTuple1(DirectNamedTuple1): + pass + + class IndirectNamedTuple2(DirectNamedTuple2): + pass + + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1(0, 1))) + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2(0, 1))) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1(0, 1))) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2(0, 1))) + self.assertFalse(pytree.is_namedtuple(time.gmtime())) + self.assertFalse(pytree.is_namedtuple((0, 1))) + self.assertFalse(pytree.is_namedtuple([0, 1])) + self.assertFalse(pytree.is_namedtuple({0: 1, 1: 2})) + self.assertFalse(pytree.is_namedtuple({0, 1})) + self.assertFalse(pytree.is_namedtuple(1)) + + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple(DirectNamedTuple2)) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple(IndirectNamedTuple2)) + self.assertFalse(pytree.is_namedtuple(time.struct_time)) + self.assertFalse(pytree.is_namedtuple(tuple)) + self.assertFalse(pytree.is_namedtuple(list)) + + self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple_class(DirectNamedTuple2)) + self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple1)) + self.assertTrue(pytree.is_namedtuple_class(IndirectNamedTuple2)) + self.assertFalse(pytree.is_namedtuple_class(time.struct_time)) + self.assertFalse(pytree.is_namedtuple_class(tuple)) + self.assertFalse(pytree.is_namedtuple_class(list)) + + @parametrize( + "pytree", + [ + subtest(py_pytree, name="py"), + subtest(cxx_pytree, name="cxx"), + ], + ) + def test_is_structseq(self, pytree): + class FakeStructSeq(tuple): + n_fields = 2 + n_sequence_fields = 2 + n_unnamed_fields = 0 + + __slots__ = () + __match_args__ = ("x", "y") + + def __new__(cls, sequence): + return super().__new__(cls, sequence) + + @property + def x(self): + return self[0] + + @property + def y(self): + return self[1] + + DirectNamedTuple1 = namedtuple("DirectNamedTuple1", ["x", "y"]) + + class DirectNamedTuple2(NamedTuple): + x: int + y: int + + self.assertFalse(pytree.is_structseq(FakeStructSeq((0, 1)))) + self.assertTrue(pytree.is_structseq(time.gmtime())) + self.assertFalse(pytree.is_structseq(DirectNamedTuple1(0, 1))) + self.assertFalse(pytree.is_structseq(DirectNamedTuple2(0, 1))) + self.assertFalse(pytree.is_structseq((0, 1))) + self.assertFalse(pytree.is_structseq([0, 1])) + self.assertFalse(pytree.is_structseq({0: 1, 1: 2})) + self.assertFalse(pytree.is_structseq({0, 1})) + self.assertFalse(pytree.is_structseq(1)) + + self.assertFalse(pytree.is_structseq(FakeStructSeq)) + self.assertTrue(pytree.is_structseq(time.struct_time)) + self.assertFalse(pytree.is_structseq(DirectNamedTuple1)) + self.assertFalse(pytree.is_structseq(DirectNamedTuple2)) + self.assertFalse(pytree.is_structseq(tuple)) + self.assertFalse(pytree.is_structseq(list)) + + self.assertFalse(pytree.is_structseq_class(FakeStructSeq)) + self.assertTrue( + pytree.is_structseq_class(time.struct_time), + ) + self.assertFalse(pytree.is_structseq_class(DirectNamedTuple1)) + self.assertFalse(pytree.is_structseq_class(DirectNamedTuple2)) + self.assertFalse(pytree.is_structseq_class(tuple)) + self.assertFalse(pytree.is_structseq_class(list)) + + # torch.return_types.* are all PyStructSequence types + for cls in vars(torch.return_types).values(): + if isinstance(cls, type) and issubclass(cls, tuple): + self.assertTrue(pytree.is_structseq(cls)) + self.assertTrue(pytree.is_structseq_class(cls)) + self.assertFalse(pytree.is_namedtuple(cls)) + self.assertFalse(pytree.is_namedtuple_class(cls)) + + inst = cls(range(cls.n_sequence_fields)) + self.assertTrue(pytree.is_structseq(inst)) + self.assertTrue(pytree.is_structseq(type(inst))) + self.assertFalse(pytree.is_structseq_class(inst)) + self.assertTrue(pytree.is_structseq_class(type(inst))) + self.assertFalse(pytree.is_namedtuple(inst)) + self.assertFalse(pytree.is_namedtuple_class(inst)) + else: + self.assertFalse(pytree.is_structseq(cls)) + self.assertFalse(pytree.is_structseq_class(cls)) + self.assertFalse(pytree.is_namedtuple(cls)) + self.assertFalse(pytree.is_namedtuple_class(cls)) + class TestPythonPytree(TestCase): def test_deprecated_register_pytree_node(self): @@ -975,9 +1103,8 @@ def test_pytree_serialize_namedtuple(self): serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", ) - spec = py_pytree.TreeSpec( - namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(Point1(1, 2)) + self.assertIs(spec.type, namedtuple) roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) self.assertEqual(spec, roundtrip_spec) @@ -990,18 +1117,28 @@ class Point2(NamedTuple): serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", ) - spec = py_pytree.TreeSpec( - namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] + spec = py_pytree.tree_structure(Point2(1, 2)) + self.assertIs(spec.type, namedtuple) + roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) + self.assertEqual(spec, roundtrip_spec) + + class Point3(Point2): + pass + + py_pytree._register_namedtuple( + Point3, + serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point3", ) + + spec = py_pytree.tree_structure(Point3(1, 2)) + self.assertIs(spec.type, namedtuple) roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) self.assertEqual(spec, roundtrip_spec) def test_pytree_serialize_namedtuple_bad(self): DummyType = namedtuple("DummyType", ["x", "y"]) - spec = py_pytree.TreeSpec( - namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "Please register using `_register_namedtuple`" @@ -1020,9 +1157,7 @@ def __init__(self, x, y): lambda xs, _: DummyType(*xs), ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( NotImplementedError, "No registered serialization name" ): @@ -1042,9 +1177,7 @@ def __init__(self, x, y): to_dumpable_context=lambda context: "moo", from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) serialized_spec = py_pytree.treespec_dumps(spec, 1) self.assertIn("moo", serialized_spec) roundtrip_spec = py_pytree.treespec_loads(serialized_spec) @@ -1082,9 +1215,7 @@ def __init__(self, x, y): from_dumpable_context=lambda dumpable_context: None, ) - spec = py_pytree.TreeSpec( - DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(DummyType(1, 2)) with self.assertRaisesRegex( TypeError, "Object of type type is not JSON serializable" @@ -1095,9 +1226,7 @@ def test_pytree_serialize_bad_protocol(self): import json Point = namedtuple("Point", ["x", "y"]) - spec = py_pytree.TreeSpec( - namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] - ) + spec = py_pytree.tree_structure(Point(1, 2)) py_pytree._register_namedtuple( Point, serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index c62f19e34406..f007b46800b2 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -56,9 +56,10 @@ def _(*args: Any, **kwargs: Any) -> bool: "structseq_fields", ): __func = getattr(optree, __name) - substitute_in_graph(__func, can_constant_fold_through=True)( + globals()[__name] = substitute_in_graph(__func, can_constant_fold_through=True)( __func.__python_implementation__ ) + __all__ += [__name] # noqa: PLE0604 del __func del __name diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 14cc7d2731bb..26ae80af1c6a 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1243,7 +1243,7 @@ def serialize_treespec(self, treespec): def store_namedtuple_fields(ts): if ts.type is None: return - if ts.type == namedtuple: + if ts.type is namedtuple or pytree.is_namedtuple_class(ts.type): serialized_type_name = pytree.SUPPORTED_SERIALIZED_TYPES[ts.context].serialized_type_name if serialized_type_name in self.treespec_namedtuple_fields: field_names = self.treespec_namedtuple_fields[serialized_type_name].field_names diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index 426523865296..8fcb64beba3b 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import os -from collections import namedtuple -from typing import Any +from typing import Any, NamedTuple, Optional import torch @@ -129,16 +128,15 @@ def make_dual(tensor, tangent, *, level=None): return torch._VF._make_dual(tensor, tangent, level=level) -_UnpackedDualTensor = namedtuple("_UnpackedDualTensor", ["primal", "tangent"]) - - -class UnpackedDualTensor(_UnpackedDualTensor): +class UnpackedDualTensor(NamedTuple): r"""Namedtuple returned by :func:`unpack_dual` containing the primal and tangent components of the dual tensor. See :func:`unpack_dual` for more details. - """ + primal: torch.Tensor + tangent: Optional[torch.Tensor] + def unpack_dual(tensor, *, level=None): r"""Unpack a "dual tensor" to get both its Tensor value and its forward AD gradient. diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index c0ce944c641d..cbdb601af614 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -552,8 +552,16 @@ def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs): expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs) expected = tree_map(fwAD.unpack_dual, expected) - expected_primals = tree_map(lambda x: x.primal, expected) - expected_tangents = tree_map(lambda x: x.tangent, expected) + expected_primals = tree_map( + lambda x: x.primal, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + expected_tangents = tree_map( + lambda x: x.tangent, + expected, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) # Permutations of arg and kwargs in CCT. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode): @@ -586,7 +594,15 @@ def unwrap(e): return e.elem if isinstance(e, CCT) else e actual = tree_map(fwAD.unpack_dual, actual) - actual_primals = tree_map(lambda x: unwrap(x.primal), actual) - actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual) + actual_primals = tree_map( + lambda x: unwrap(x.primal), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) + actual_tangents = tree_map( + lambda x: unwrap(x.tangent), + actual, + is_leaf=lambda x: type(x) is fwAD.UnpackedDualTensor, + ) assert_equal_fn(actual_primals, expected_primals, equal_nan=True) assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index b8d869c1c802..028c21a84bc4 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -23,7 +23,15 @@ from optree import PyTreeSpec as TreeSpec # direct import for type annotations import torch.utils._pytree as python_pytree -from torch.utils._pytree import KeyEntry as KeyEntry +from torch.utils._pytree import ( + is_namedtuple as is_namedtuple, + is_namedtuple_class as is_namedtuple_class, + is_namedtuple_instance as is_namedtuple_instance, + is_structseq as is_structseq, + is_structseq_class as is_structseq_class, + is_structseq_instance as is_structseq_instance, + KeyEntry as KeyEntry, +) __all__ = [ @@ -39,6 +47,7 @@ "keystr", "key_get", "register_pytree_node", + "tree_is_leaf", "tree_flatten", "tree_flatten_with_path", "tree_unflatten", @@ -58,6 +67,12 @@ "treespec_dumps", "treespec_loads", "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", ] diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 857ea1aab080..27941c68066b 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -31,14 +31,17 @@ Any, Callable, cast, + ClassVar, + Final, Generic, + NoReturn, Optional, overload, Protocol, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple +from typing_extensions import deprecated, NamedTuple, Self __all__ = [ @@ -54,6 +57,7 @@ "keystr", "key_get", "register_pytree_node", + "tree_is_leaf", "tree_flatten", "tree_flatten_with_path", "tree_unflatten", @@ -73,6 +77,12 @@ "treespec_dumps", "treespec_loads", "treespec_pprint", + "is_namedtuple", + "is_namedtuple_class", + "is_namedtuple_instance", + "is_structseq", + "is_structseq_class", + "is_structseq_instance", ] @@ -573,6 +583,90 @@ def get(self, obj: Any) -> Any: return getattr(obj, self.name) +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of namedtuple or a subclass of namedtuple.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_namedtuple_class(cls) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_class(cls: type) -> bool: + """Return whether the class is a subclass of namedtuple.""" + return ( + isinstance(cls, type) + and issubclass(cls, tuple) + and isinstance(getattr(cls, "_fields", None), tuple) + and all(type(field) is str for field in cls._fields) # type: ignore[attr-defined] + and callable(getattr(cls, "_make", None)) + and callable(getattr(cls, "_asdict", None)) + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_namedtuple_instance(obj: object) -> bool: + """Return whether the object is an instance of namedtuple.""" + return is_namedtuple_class(type(obj)) + + +_T_co = TypeVar("_T_co", covariant=True) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +class structseq(tuple[_T_co, ...]): + """A generic type stub for CPython's ``PyStructSequence`` type.""" + + __slots__: ClassVar[tuple[()]] = () + + n_fields: Final[int] # type: ignore[misc] + n_sequence_fields: Final[int] # type: ignore[misc] + n_unnamed_fields: Final[int] # type: ignore[misc] + + def __init_subclass__(cls) -> NoReturn: + """Prohibit subclassing.""" + raise TypeError("type 'structseq' is not an acceptable base type") + + def __new__( + cls: type[Self], + sequence: Iterable[_T_co], + dict: dict[str, Any] = ..., + ) -> Self: + raise NotImplementedError + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq(obj: Union[object, type]) -> bool: + """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.""" + cls = obj if isinstance(obj, type) else type(obj) + return is_structseq_class(cls) + + +# Set if the type allows subclassing (see CPython's Include/object.h) +Py_TPFLAGS_BASETYPE: int = 1 << 10 + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_class(cls: type) -> bool: + """Return whether the class is a class of PyStructSequence.""" + return ( + isinstance(cls, type) + # Check direct inheritance from `tuple` rather than `issubclass(cls, tuple)` + and cls.__bases__ == (tuple,) + # Check PyStructSequence members + and isinstance(getattr(cls, "n_fields", None), int) + and isinstance(getattr(cls, "n_sequence_fields", None), int) + and isinstance(getattr(cls, "n_unnamed_fields", None), int) + # Check the type does not allow subclassing + and not bool(cls.__flags__ & Py_TPFLAGS_BASETYPE) # only works for CPython + ) + + +# Reference: https://github.com/metaopt/optree/blob/main/optree/typing.py +def is_structseq_instance(obj: object) -> bool: + """Return whether the object is an instance of PyStructSequence.""" + return is_structseq_class(type(obj)) + + def _tuple_flatten(d: tuple[T, ...]) -> tuple[list[T], Context]: return list(d), None @@ -807,37 +901,72 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: ) -STANDARD_DICT_TYPES: frozenset[type] = frozenset( - {dict, OrderedDict, defaultdict}, -) +STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) BUILTIN_TYPES: frozenset[type] = frozenset( - {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] + { + tuple, + list, + dict, + namedtuple, # type: ignore[arg-type] + OrderedDict, + defaultdict, + deque, + }, ) -# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +@deprecated( + "torch.utils._pytree._is_namedtuple_instance is private and will be removed in a future release. " + "Please use torch.utils._pytree.is_namedtuple_instance instead.", + category=FutureWarning, +) def _is_namedtuple_instance(tree: Any) -> bool: - typ = type(tree) - bases = typ.__bases__ - if len(bases) != 1 or bases[0] != tuple: - return False - fields = getattr(typ, "_fields", None) - if not isinstance(fields, tuple): - return False - return all(type(entry) == str for entry in fields) + return is_namedtuple_instance(tree) def _get_node_type(tree: Any) -> Any: - if _is_namedtuple_instance(tree): + node_type = type(tree) + # All namedtuple types are implicitly registered as pytree nodes. + # XXX: Other parts of the codebase expect namedtuple types always return + # `namedtuple` instead of the actual namedtuple type. Even if the type + # is explicitly registered. + if is_namedtuple_class(node_type): return namedtuple - return type(tree) + return node_type # A leaf is defined as anything that is not a Node. +def tree_is_leaf( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + """Check if a pytree is a leaf. + + >>> tree_is_leaf(1) + True + >>> tree_is_leaf(None) + True + >>> tree_is_leaf([1, 2, 3]) + False + >>> tree_is_leaf((1, 2, 3), is_leaf=lambda x: isinstance(x, tuple)) + True + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': 3}) + False + >>> tree_is_leaf({'a': 1, 'b': 2, 'c': None}) + False + """ + if is_leaf is not None and is_leaf(tree): + return True + return _get_node_type(tree) not in SUPPORTED_NODES + + +@deprecated( + "torch.utils._pytree._is_leaf is private and will be removed in a future release. " + "Please use torch.utils._pytree.tree_is_leaf instead.", + category=FutureWarning, +) def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: - return (is_leaf is not None and is_leaf(tree)) or _get_node_type( - tree - ) not in SUPPORTED_NODES + return tree_is_leaf(tree, is_leaf=is_leaf) # A TreeSpec represents the structure of a pytree. It holds: @@ -1040,7 +1169,7 @@ def tree_flatten( """ def helper(node: PyTree, leaves: list[Any]) -> TreeSpec: - if _is_leaf(node, is_leaf=is_leaf): + if tree_is_leaf(node, is_leaf=is_leaf): leaves.append(node) return _LEAF_SPEC @@ -1074,7 +1203,7 @@ def tree_iter( is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> Iterable[Any]: """Get an iterator over the leaves of a pytree.""" - if _is_leaf(tree, is_leaf=is_leaf): + if tree_is_leaf(tree, is_leaf=is_leaf): yield tree else: node_type = _get_node_type(tree) @@ -1520,7 +1649,7 @@ def _broadcast_to_and_flatten( ) -> Optional[list[Any]]: assert isinstance(treespec, TreeSpec) - if _is_leaf(tree, is_leaf=is_leaf): + if tree_is_leaf(tree, is_leaf=is_leaf): return [tree] * treespec.num_leaves if treespec.is_leaf(): return None From bf4814eb6a2363e38bc8445dbab084aa2074b7d7 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Tue, 1 Apr 2025 11:27:23 +0000 Subject: [PATCH 040/332] [Intel GPU] Allow XPU backend in Quantize operators (#150288) This modification is to support torch.quantize_per_channel() on XPU, otherwise it will cause a segmentation fault. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150288 Approved by: https://github.com/jerryzh168, https://github.com/guangyey --- aten/src/ATen/native/quantized/AffineQuantizer.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aten/src/ATen/native/quantized/AffineQuantizer.cpp b/aten/src/ATen/native/quantized/AffineQuantizer.cpp index 6bd9bfd687aa..dab9e1cf7fc9 100644 --- a/aten/src/ATen/native/quantized/AffineQuantizer.cpp +++ b/aten/src/ATen/native/quantized/AffineQuantizer.cpp @@ -151,6 +151,7 @@ Tensor& quantize_tensor_per_channel_affine( AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); if (qtensor.device().type() != c10::DeviceType::CUDA && + qtensor.device().type() != c10::DeviceType::XPU && qtensor.device().type() != c10::DeviceType::PrivateUse1) { checkZeroPoints(fn_name, zero_points); } // for cuda and privateuse1, this check will occur in the actual device function @@ -242,6 +243,7 @@ Tensor& dequantize_tensor_per_channel_affine( AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() { checkQuantizedTensor(fn_name, qtensor); if(qtensor.device().type() != c10::DeviceType::CUDA && + qtensor.device().type() != c10::DeviceType::XPU && qtensor.device().type() != c10::DeviceType::PrivateUse1){ checkZeroPoints(fn_name, zero_points); } // for cuda and privateuse1, this check will occur in the actual device function From 84c21d2147cdd095915bdc30457f7c1d0b1ed379 Mon Sep 17 00:00:00 2001 From: maajidkhann Date: Tue, 1 Apr 2025 11:54:55 +0000 Subject: [PATCH 041/332] Enable SVE ACLE implementation for tanH Aten op for FP32 dType. (#143741) In deep learning models, the tanh (hyperbolic tangent) function is a widely used activation function, primarily in feedforward networks, recurrent neural networks (RNNs), and various other architectures. Also, the tanh (hyperbolic tangent) function is commonly used in **Physics-Informed Neural Networks (PINNs).** PINNs are a class of machine learning models designed to solve partial differential equations (PDEs) by incorporating the governing physics directly into the loss function, along with data-driven terms. In PINNs, activation functions like tanh are used in the neural network architecture to enable the model to learn complex mappings between inputs (such as spatial and temporal coordinates) and outputs (such as field variables). **Operator: tanh()** **Current Implementation in OSS in ATen Backend:** **SVE Flow:** Uses SVE sleef when available else std implementation. **With this PR :** **SVE Flow:** Uses SVE ACLE implementation. (Faster Implementation) **Here are the performance improvements.** **Single core perf numbers:** ![image](https://github.com/user-attachments/assets/c2f4bcb6-11bc-4af1-b5eb-278a4cc4a69d) **Metric:** CPU time avg time per iteration (In ms) As you can see with both gcc and clang compilers, we see a significant performance gain with SVE ACLE implementation over current OSS Implementation (Sleef) and also Neon. **Hardware:** m7g.8xlarge (Graviton 3 Instance) **Script used in benchmarking:** ```python import os #os.environ["ATEN_CPU_CAPABILITY"] = "default" os.environ["ATEN_CPU_CAPABILITY"] = "sve256" import torch import torch.nn as nn #Set the random seed for reproducibility torch.manual_seed(1) #Create a tensor of shape (8521, 50) x = torch.randn(8521, 50) for i in range(10): output = x.tanh() #Perform the tanh operation 1000 times and profile the performance print("### CPU tanh") with torch.autograd.profiler.profile(record_shapes=True) as prof: for i in range(1000): output = x.tanh() #Print the profiling results sorted by self CPU time print(prof.key_averages().table(sort_by="self_cpu_time_total")) #Optionally print the final output (if needed, uncomment the following line) print(output) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/143741 Approved by: https://github.com/malfet --- aten/src/ATen/cpu/vec/sve/vec_float.h | 80 ++++++++++++++++++++++- aten/src/ATen/test/vec_test_all_types.cpp | 11 ++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h index 6a3dc2bc1c10..dd35787dfb5b 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_float.h +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -85,6 +85,58 @@ template <> class Vectorized { } return b; } + //Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L105 + inline svfloat32_t svexp_f32_z(svbool_t pg, svfloat32_t x) const { + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); // x^1: 0x1.ffffecp-1f + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); // x^2: 0x1.fffdb6p-2f + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); // x^3: 0x1.555e66p-3f + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); // x^4: 0x1.573e2ep-5f + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); // x^5: 0x1.0e4020p-7f + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = svmla_f32_z(pg, shift, x, inv_ln2); + const auto n = svsub_f32_z(pg, z, shift); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = svmla_f32_z(pg, x, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = svmul_f32_z(pg, r, r); + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + auto poly = svmla_f32_z(pg, scale, p12345, scale); + // Handle underflow and overflow. + poly = svsel_f32(svcmplt_f32(pg, x, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, x, max_input), inf, poly); + return poly; + } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) return svld1_f32(ptrue, reinterpret_cast(ptr)); @@ -333,8 +385,34 @@ template <> class Vectorized { Vectorized tan() const { return USE_SLEEF(Vectorized(Sleef_tanfx_u10sve(values)),map(std::tan)); } + //Implementation is picked from https://github.com/ARM-software/ComputeLibrary/blob/v25.01/src/core/NEON/SVEMath.inl#L179 Vectorized tanh() const { - return USE_SLEEF(Vectorized(Sleef_tanhfx_u10sve(values)),map(std::tanh)); + // Constants used for the tanh calculation. + const svfloat32_t CONST_1 = svdup_n_f32(1.f); // Constant 1.0f for the tanh formula. + const svfloat32_t CONST_2 = svdup_n_f32(2.f); // Constant 2.0f for the tanh formula (used in exp(2x)). + const svfloat32_t CONST_MIN_TANH = svdup_n_f32(-10.f); // Minimum threshold for input values to prevent overflow. + const svfloat32_t CONST_MAX_TANH = svdup_n_f32(10.f); // Maximum threshold for input values to prevent overflow. + + // Step 1: Clamp the values within the range [-10, 10] to prevent overflow during exponentiation. + // The tanh function approaches ±1 rapidly as the input grows large, so we limit the input range to avoid numerical instability. + // svmax_f32_z ensures values are greater than -10, and svmin_f32_z ensures they are less than 10. + svfloat32_t x = svmin_f32_z(ptrue, svmax_f32_z(ptrue, values, CONST_MIN_TANH), CONST_MAX_TANH); + + // Step 2: Calculate exp(2 * x), where x is the clamped value. + // svmul_f32_z computes 2 * x, and svexp_f32_z computes the exponential of the result. + svfloat32_t exp2x = svexp_f32_z(ptrue, svmul_f32_z(ptrue, CONST_2, x)); + + // Step 3: Calculate the numerator of the tanh function, which is exp(2x) - 1. + svfloat32_t num = svsub_f32_z(ptrue, exp2x, CONST_1); + + // Step 4: Calculate the denominator of the tanh function, which is exp(2x) + 1. + svfloat32_t den = svadd_f32_z(ptrue, exp2x, CONST_1); + + // Step 5: Calculate the tanh function as the ratio of the numerator and denominator: num / den. + svfloat32_t tanh = svdiv_f32_z(ptrue, num, den); + + // Return the calculated tanh values. + return tanh; } Vectorized trunc() const { return svrintz_f32_x(ptrue, values); diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 4e0780800906..beca3043ce71 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -371,11 +371,22 @@ namespace { } TYPED_TEST(Hyperbolic, Tanh) { using vec = TypeParam; +// NOTE: Because SVE uses ACL logic, the precision changes, hence the adjusted tolerance. +#if defined(CPU_CAPABILITY_SVE) + using UVT = UvalueType; + UVT tolerance = getDefaultTolerance(); + test_unary( + NAME_INFO(tanH), + RESOLVE_OVERLOAD(std::tanh), + [](vec v) { return v.tanh(); }, + createDefaultUnaryTestCase(TestSeed(), tolerance)); +#else test_unary( NAME_INFO(tanH), RESOLVE_OVERLOAD(std::tanh), [](vec v) { return v.tanh(); }, createDefaultUnaryTestCase(TestSeed())); +#endif } TYPED_TEST(Hyperbolic, Sinh) { using vec = TypeParam; From 0d96c38b76b484ad6b66ddca02270018858a4e4e Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Tue, 1 Apr 2025 13:24:21 +0000 Subject: [PATCH 042/332] [AOTI] Skip test_buffer_mutation_and_force_mmap_weights for fbcode (#150340) Summary: Skip due to an older ideep version Differential Revision: D72190746 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150340 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index dd82b5a26f29..ce653436a860 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1806,6 +1806,7 @@ def forward(self, x): @skipCUDAIf(True, "Test for x86 backend") @skipIfXpu + @unittest.skipIf(IS_FBCODE, "Need newer ideep") def test_buffer_mutation_and_force_mmap_weights(self): class Model(nn.Module): def __init__(self): From 1c6e88eb0330a6b3fecf6033b0ffc55a11f181be Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 31 Mar 2025 20:18:17 -0700 Subject: [PATCH 043/332] [MPS] Test bf16 perf of few unary and binary ops (#150382) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150382 Approved by: https://github.com/Skylion007 --- test/bench_mps_ops.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 7dac60ff93cd..009de265bf38 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -72,6 +72,8 @@ def bench_binary( def main() -> None: dtypes = [torch.float16, torch.float32] + if torch.backends.mps.is_macos_or_newer(14, 0): + dtypes.append(torch.bfloat16) # Profile unary ops rc = [] for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes): @@ -83,8 +85,8 @@ def main() -> None: ops = [torch.fmax, torch.add] for op, dtype in itertools.product(ops, dtypes): rc.extend(bench_binary(op, dt_a=dtype)) - for op in ops: - rc.extend(bench_binary(op, dt_b=torch.float16)) + if dtype == torch.float32: + rc.extend(bench_binary(op, dt_b=torch.float16)) Compare(rc).print() From 428234bc285412f5d379831462b982599d1d2f6f Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 1 Apr 2025 07:17:18 -0700 Subject: [PATCH 044/332] [MPSInductor] torch.complex128 is unsupported on MPS (#150386) Same as torch.float64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150386 Approved by: https://github.com/dcci ghstack dependencies: #150382 --- test/inductor/test_mps_basic.py | 1 + test/inductor/test_torchinductor.py | 2 ++ torch/_dynamo/device_interface.py | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 021ab0440492..8376100b91c4 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -162,6 +162,7 @@ def fn(a): # Copy tests for test_name in [ "test_min_max_reduction", + "test_add_complex4", "test_add_const_int", "test_add_inplace_permuted", "test_addmm", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 693292057c96..b7e7d2eb2c0b 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1368,6 +1368,8 @@ def fn(a, b): return c + d for dtype in [torch.complex32, torch.complex64, torch.complex128]: + if not self.is_dtype_supported(dtype): + continue x = torch.tensor( [1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1], dtype=dtype, diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index d8610915ec3a..b24a94ea7cd5 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -376,7 +376,7 @@ def is_bf16_supported(including_emulation: bool = False) -> bool: def is_dtype_supported( cls, dtype: torch.dtype, including_emulation: bool = False ) -> bool: - if dtype == torch.float64: + if dtype in [torch.float64, torch.complex128]: return False return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation) From 7382654ebcadf5a0abd95a381556b7cb447a951b Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 1 Apr 2025 16:30:09 +0000 Subject: [PATCH 045/332] Update ExecuTorch pin to latest viable/strict 3/28/2025 (#150308) From latest viable/strict: https://hud.pytorch.org/hud/pytorch/executorch/viable%2Fstrict/1?per_page=50 Fixes https://github.com/pytorch/pytorch/issues/144480 This commit has important CI stability fixes, such as https://github.com/pytorch/executorch/pull/9561 and https://github.com/pytorch/executorch/pull/9634 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150308 Approved by: https://github.com/jathu, https://github.com/malfet --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- .ci/docker/common/install_executorch.sh | 3 +-- .ci/pytorch/test.sh | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index dc4f8b30fe87..6e9cfe33fe63 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -cedf52aa8e4df879886270a5920da6fe84cbaa67 +ebe8522378c3f9944aaaef44868f5ececdd845fc diff --git a/.ci/docker/common/install_executorch.sh b/.ci/docker/common/install_executorch.sh index a9a558b86f99..e30e0a787bbe 100755 --- a/.ci/docker/common/install_executorch.sh +++ b/.ci/docker/common/install_executorch.sh @@ -50,8 +50,7 @@ setup_executorch() { pushd executorch export PYTHON_EXECUTABLE=python - export EXECUTORCH_BUILD_PYBIND=ON - export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" + export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" as_jenkins .ci/scripts/setup-linux.sh --build-tool cmake || true popd diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 96a160cf618d..69566a244c9b 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1475,8 +1475,7 @@ test_executorch() { pushd /executorch export PYTHON_EXECUTABLE=python - export EXECUTORCH_BUILD_PYBIND=ON - export CMAKE_ARGS="-DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" + export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON" # For llama3 bash examples/models/llama3_2_vision/install_requirements.sh From 35c45a4a315425a0940dcc040d8262b6331804dc Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 31 Mar 2025 23:58:44 -0700 Subject: [PATCH 046/332] [Reland] Launch kernel on current stream & remove `record_stream` entirely (#150398) Relanding #148590 due to merge conflict. This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related): 1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back. - Resolves #147729 - Resolves #146881 - Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user. 2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling. - Resolves #147168 3. Remove tensor life management when async_op=False; only use it when async_op=True. 4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](https://github.com/pytorch/pytorch/issues/147168#issuecomment-2660142460). 5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels. Joint work with @cenzhaometa who wants to remove the event sync overhead. Squashed contents: * [ptd][nccl] use current-stream as nccl-stream under async=False mode (#147820) PTD current workflow: - PTD creates its own dedicated `ncclStream` for comm operation - it will first add a dependency on current-stream (typically the compute stream) to ensure tensors are ready before invoking collective such stream synchronization become expensive in Inference world (cpu overhead: 70us vs GPU kernel time: 160us). This diff: - async=False [default], will use current-stream as nccl-stream and avoid the stream-sync overhead - async=True, will retain existing logic: create new nccl-stream, let it wait on current-stream to ensure tensors are ready - pass down async from c10d down to NCCL-PG this helps shave off 50% CPU overhead **(70us -> 35us)**, which reduce total CPU/GPU from **230us to 195us by 15%** * [PGNCCL] Make avoid-record-stream default * [c10d] Add asyncOp argument to Ops * Change python side wait * Pass asyncOp at ProcessGroup level * Watchdog unstashing tensors as a safety net * Stash tensors for reduce_scatter_v and all_gather_v Pull Request approved: https://github.com/pytorch/pytorch/pull/149753 * [c10d] Move unstashing from watchdog to main thread Pull Request approved: https://github.com/pytorch/pytorch/pull/150079 * [PGNCCL][BE] Merge mutex into TensorShelf for encapsulation Pull Request approved: https://github.com/pytorch/pytorch/pull/150130 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150398 Approved by: https://github.com/atalman --- test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp | 3 + test/distributed/test_c10d_ops_nccl.py | 26 ++ .../check_forward_backward_compatibility.py | 3 + torch/_C/_distributed_c10d.pyi | 8 +- torch/csrc/distributed/c10d/Ops.cpp | 130 +++--- torch/csrc/distributed/c10d/ProcessGroup.hpp | 43 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 393 +++++++++--------- .../distributed/c10d/ProcessGroupNCCL.hpp | 59 ++- torch/csrc/distributed/c10d/Types.hpp | 5 + torch/csrc/distributed/c10d/init.cpp | 18 +- torch/distributed/distributed_c10d.py | 104 ++++- .../_internal/distributed/distributed_test.py | 92 +--- 12 files changed, 521 insertions(+), 363 deletions(-) diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index a2fa2b467c52..533c50a43fe8 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -363,6 +363,9 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter { }; TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { + // Note (kwen2501) 03/07/2025 + // TODO: re-enable + GTEST_SKIP() << "Skipping test as the trace write seems unstable."; int heartBeatIntervalInSec = 2; std::string timeInterval = std::to_string(heartBeatIntervalInSec); ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index 73bad39956c6..4b8aac29e503 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -733,6 +733,32 @@ def reduce_scatter_base(output_t, input_t): # fails the check because the dtype is different reduce_scatter_base(output_t, tensor) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_reduce_scatter_v(self): + device = torch.device("cuda", self.rank_to_GPU[self.rank][0]) + # A list of tensors with different sizes + input_list = [torch.ones(i, device=device) for i in range(self.world_size)] + # The i-th output should have size i + output = torch.zeros(self.rank, device=device) + work = c10d.reduce_scatter(output, input_list, group=self.pg, async_op=True) + expected = torch.ones(self.rank, device=device) * self.world_size + work.wait() + self.assertEqual(expected, output) + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_all_gather_v(self): + device = torch.device("cuda", self.rank_to_GPU[self.rank][0]) + # A list of tensors with different sizes + output_list = [torch.zeros(i, device=device) for i in range(self.world_size)] + # The i-th input has size i, filled with value i + input = torch.ones(self.rank, device=device) * self.rank + work = c10d.all_gather(output_list, input, group=self.pg, async_op=True) + expected = [torch.ones(i, device=device) * i for i in range(self.world_size)] + work.wait() + self.assertEqual(expected, output_list) + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_reduce_scatter_ops(self): diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 03b065a3691a..bfd255c50111 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -126,6 +126,9 @@ ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), ("aten::all_reduce", datetime.date(9999, 1, 30)), + # These ops are defined in torch/csrc/distributed/c10d/Ops.cpp + # TODO: add back restriction when c10d ops can be exported + ("c10d::.*", datetime.date(9999, 1, 1)), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 77a8f9c33e04..6aaaf4b9c5f1 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, overload +from typing import Any, Optional, overload import torch from torch import Tensor @@ -139,6 +139,8 @@ class BroadcastOptions: class AllreduceOptions: reduceOp: ReduceOp timeout: timedelta + asyncOp: bool + sparseIndices: Optional[Tensor] class AllreduceCoalescedOptions(AllreduceOptions): ... @@ -147,6 +149,7 @@ class ReduceOptions: rootRank: int rootTensor: int timeout: timedelta + asyncOp: bool class AllgatherOptions: timeout: timedelta @@ -155,6 +158,7 @@ class AllgatherOptions: class GatherOptions: rootRank: int timeout: timedelta + asyncOp: bool class ScatterOptions: rootRank: int @@ -170,9 +174,11 @@ class BarrierOptions: device_ids: list[int] device: torch.device timeout: timedelta + asyncOp: bool class AllToAllOptions: timeout: timedelta + asyncOp: bool class Store: def set(self, key: str, value: str): ... diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 6251bfa1817d..0480f1b9191d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) { .def("wait", [](const c10::intrusive_ptr& self) { self->wait(); }); m.class_("ReduceOp").def(torch::init<>()); m.def( - "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); + "allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); + "allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)"); m.def( - "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); + "allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); m.def( - "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work"); + "allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)"); + "_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)"); m.def( - "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> __torch__.torch.classes.c10d.Work"); + "reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work"); + "reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int timeout) -> __torch__.torch.classes.c10d.Work"); + "gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); + "alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)"); m.def( - "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, int timeout) -> __torch__.torch.classes.c10d.Work"); + "alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( - "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout) -> __torch__.torch.classes.c10d.Work"); + "barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work"); m.def( "monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()"); m.def( @@ -118,6 +118,7 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) const c10::intrusive_ptr& reduce_op, \ int64_t root_rank, \ int64_t root_tensor, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ @@ -127,7 +128,8 @@ IMPL_RECV_ANY_SOURCE(PrivateUse1) *reduce_op.get(), \ root_rank, \ root_tensor, \ - std::chrono::milliseconds(timeout)}); \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ } IMPL_REDUCE(CPU) @@ -169,12 +171,13 @@ IMPL_BROADCAST(PrivateUse1) const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ const std::optional& sparse_indices, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allreduce( \ tensor_vec, \ AllreduceOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), std::chrono::milliseconds(timeout), asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(tensor_vec), work); \ } @@ -188,11 +191,13 @@ IMPL_ALLREDUCE(PrivateUse1) at::TensorList tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto tensor_vec = tensors.vec(); \ AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; \ opts.reduceOp = *reduce_op.get(); \ opts.timeout = std::chrono::milliseconds(timeout); \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allreduce_coalesced(tensor_vec, opts); \ } @@ -209,12 +214,13 @@ IMPL_ALLREDUCE_COALESCED(PrivateUse1) const std::vector>& output_tensors, \ at::TensorList input_tensors, \ const c10::intrusive_ptr& process_group, \ + bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> allgather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - AllgatherOptions{std::chrono::milliseconds(timeout)}); \ + AllgatherOptions{std::chrono::milliseconds(timeout), asyncOp}); \ return std:: \ tuple>, c10::intrusive_ptr>( \ output_tensors, work); \ @@ -249,12 +255,16 @@ IMPL__ALLGATHER_BASE(PrivateUse1) c10::intrusive_ptr allgather_coalesced_##DEV( \ const std::vector>& output_lists, \ const at::TensorList& input_list, \ - const c10::intrusive_ptr& process_group) { \ + const c10::intrusive_ptr& process_group, \ + bool asyncOp) { \ auto input_list_vec = input_list.vec(); \ + auto opts = AllgatherOptions{}; \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ ->allgather_coalesced( \ const_cast>&>(output_lists), \ - input_list_vec); \ + input_list_vec, \ + opts); \ } IMPL_ALLGATHER_COALESCED(CPU) @@ -265,11 +275,14 @@ IMPL_ALLGATHER_COALESCED(PrivateUse1) c10::intrusive_ptr allgather_into_tensor_coalesced_##DEV( \ at::TensorList outputs, \ at::TensorList inputs, \ - const c10::intrusive_ptr& process_group) { \ + const c10::intrusive_ptr& process_group, \ + bool asyncOp) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ + auto opts = AllgatherOptions{}; \ + opts.asyncOp = asyncOp; \ return process_group->getBackend(c10::DeviceType::DEV) \ - ->allgather_into_tensor_coalesced(output_vec, input_vec); \ + ->allgather_into_tensor_coalesced(output_vec, input_vec, opts); \ } IMPL_ALLGATHER_INTO_TENSOR_COALESCED(CPU) @@ -283,6 +296,7 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) const std::vector>& input_tensors, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto work = \ @@ -290,7 +304,9 @@ IMPL_ALLGATHER_INTO_TENSOR_COALESCED(PrivateUse1) output_tensors_vec, \ const_cast>&>(input_tensors), \ ReduceScatterOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ output_tensors_vec, work); \ } @@ -329,6 +345,7 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) at::TensorList inputs, \ const c10::intrusive_ptr& process_group, \ const c10::intrusive_ptr& reduce_op, \ + bool asyncOp, \ int64_t timeout) { \ auto output_vec = outputs.vec(); \ auto input_vec = inputs.vec(); \ @@ -337,7 +354,9 @@ IMPL__REDUCE_SCATTER_BASE(PrivateUse1) output_vec, \ input_vec, \ ReduceScatterOptions{ \ - *reduce_op.get(), std::chrono::milliseconds(timeout)}); \ + *reduce_op.get(), \ + std::chrono::milliseconds(timeout), \ + asyncOp}); \ } IMPL_REDUCE_SCATTER_TENSOR_COALESCED(CPU) @@ -350,13 +369,15 @@ IMPL_REDUCE_SCATTER_TENSOR_COALESCED(PrivateUse1) const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ int64_t root_rank, \ + bool asyncOp, \ int64_t timeout) { \ auto input_tensors_vec = input_tensors.vec(); \ return process_group->getBackend(c10::DeviceType::DEV) \ ->gather( \ const_cast>&>(output_tensors), \ input_tensors_vec, \ - GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); \ + GatherOptions{ \ + root_rank, std::chrono::milliseconds(timeout), asyncOp}); \ } IMPL_GATHER(CPU) @@ -391,13 +412,14 @@ IMPL_SCATTER(PrivateUse1) const at::TensorList& output_tensors, \ const at::TensorList& input_tensors, \ const c10::intrusive_ptr& process_group, \ + bool asyncOp, \ int64_t timeout) { \ auto output_tensors_vec = output_tensors.vec(); \ auto input_tensors_vec = input_tensors.vec(); \ auto work = process_group->getBackend(c10::DeviceType::DEV) -> alltoall( \ output_tensors_vec, \ input_tensors_vec, \ - AllToAllOptions{std::chrono::milliseconds(timeout)}); \ + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ return std::tuple, c10::intrusive_ptr>( \ std::move(output_tensors_vec), work); \ } @@ -406,21 +428,22 @@ IMPL_ALLTOALL(CPU) IMPL_ALLTOALL(CUDA) IMPL_ALLTOALL(PrivateUse1) -#define IMPL_ALLTOALL_BASE(DEV) \ - c10::intrusive_ptr alltoall_base_##DEV( \ - at::Tensor& output, \ - at::Tensor& input, \ - const c10::intrusive_ptr& process_group, \ - std::vector output_split_sizes, \ - std::vector input_split_sizes, \ - int64_t timeout) { \ - return process_group->getBackend(c10::DeviceType::DEV) \ - ->alltoall_base( \ - output, \ - input, \ - output_split_sizes, \ - input_split_sizes, \ - AllToAllOptions{std::chrono::milliseconds(timeout)}); \ +#define IMPL_ALLTOALL_BASE(DEV) \ + c10::intrusive_ptr alltoall_base_##DEV( \ + at::Tensor& output, \ + at::Tensor& input, \ + const c10::intrusive_ptr& process_group, \ + std::vector output_split_sizes, \ + std::vector input_split_sizes, \ + bool asyncOp, \ + int64_t timeout) { \ + return process_group->getBackend(c10::DeviceType::DEV) \ + ->alltoall_base( \ + output, \ + input, \ + output_split_sizes, \ + input_split_sizes, \ + AllToAllOptions{std::chrono::milliseconds(timeout), asyncOp}); \ } IMPL_ALLTOALL_BASE(CPU) @@ -428,15 +451,18 @@ IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) // NOLINTBEGIN(performance-unnecessary-value-param) -#define IMPL_BARRIER(DEV) \ - c10::intrusive_ptr barrier##DEV( \ - at::Tensor /* unused */, \ - const c10::intrusive_ptr& process_group, \ - const std::vector& device_ids, \ - int64_t timeout) { \ - return process_group->getBackend(c10::DeviceType::DEV) \ - ->barrier( \ - BarrierOptions{device_ids, std::chrono::milliseconds(timeout)}); \ +#define IMPL_BARRIER(DEV) \ + c10::intrusive_ptr barrier##DEV( \ + at::Tensor /* unused */, \ + const c10::intrusive_ptr& process_group, \ + const std::vector& device_ids, \ + bool asyncOp, \ + int64_t timeout) { \ + auto opts = BarrierOptions{}; \ + opts.device_ids = device_ids; \ + opts.timeout = std::chrono::milliseconds(timeout); \ + opts.asyncOp = asyncOp; \ + return process_group->getBackend(c10::DeviceType::DEV)->barrier(opts); \ } IMPL_BARRIER(CPU) @@ -464,6 +490,7 @@ allreduce_sparse_cuda_( const c10::intrusive_ptr& process_group, const c10::intrusive_ptr& reduce_op, const std::optional& sparse_indices, + bool asyncOp, int64_t timeout) { auto tensor_vec = tensors.vec(); auto work = process_group->getBackend(c10::DeviceType::CUDA) @@ -472,6 +499,7 @@ allreduce_sparse_cuda_( AllreduceOptions{ *reduce_op, std::chrono::milliseconds(timeout), + asyncOp, sparse_indices}); return std::tuple, c10::intrusive_ptr>( diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index b3f3d9bdd72d..4ce67c9f5798 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -224,6 +224,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, const std::optional& sparse_indices, + bool, int64_t)>(); auto work = std::get<1>(op.call( @@ -231,6 +232,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -277,6 +281,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t, int64_t, + bool, int64_t)>(); auto work = op.call( tensors, @@ -284,6 +289,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { - static auto op = - c10::Dispatcher::singleton() - .findSchemaOrThrow("c10d::allgather_coalesced_", "") - .typed( - const std::vector>&, - const at::TensorList&, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allgather_coalesced_", "") + .typed( + const std::vector>&, + const at::TensorList&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool)>(); auto work = op.call( outputTensorLists, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensorLists) { @@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { .typed( const at::TensorList, const at::TensorList, - const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool)>(); auto work = op.call( outputTensors, inputTensors, - c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this)); + c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { @@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, + bool, int64_t)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const std::vector>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -546,6 +561,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, + bool, int64_t)>(); auto work = op.call( @@ -553,6 +569,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -577,6 +594,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const c10::intrusive_ptr<::c10d::ProcessGroup>&, std::vector, std::vector, + bool, int64_t)>(); auto work = op.call( outputBuffer, @@ -584,6 +602,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { const at::TensorList&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, + bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), + opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { @@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { at::Tensor, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const std::vector&, + bool, int64_t)>(); auto work = op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, + opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(tensor, work); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 7aa659565ae7..e473912ea62a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -440,6 +440,36 @@ std::ostream& operator<<( return output << workInfo; } +/* Implementation of TensorShelf class */ + +void TensorShelf::stash(std::vector& tensors) { + std::lock_guard lock(mutex_); + tVector_.insert(tVector_.end(), tensors.begin(), tensors.end()); +} + +void TensorShelf::stash(TensorShelf& other) { + std::vector& otherVec = other.get(); + this->stash(otherVec); +} + +void TensorShelf::unstash() { + this->clear(); +} + +bool TensorShelf::empty() { + std::lock_guard lock(mutex_); + return tVector_.empty(); +} + +void TensorShelf::clear() { + std::lock_guard lock(mutex_); + tVector_.clear(); +} + +std::vector& TensorShelf::get() { + return tVector_; +} + ProcessGroupNCCL::WorkNCCL::WorkNCCL( std::string pgUID, std::string pgDesc, @@ -482,6 +512,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( } futureWorkResult_ = c10::make_intrusive(c10::AnyEnumType::get()); + // other functions expect an initialized ptr + stashed_for_allocator_safety_ = std::make_shared(); } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) @@ -503,6 +535,11 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) numelIn_(w.numelIn_), numelOut_(w.numelOut_), store_(w.store_), + // Note: the `work` returned to user and the `work` enqueued to watchdog + // share the pointer to the tensor stash. At least one of them should + // clean the tensor stash, the earlier the better, i.e. user calling + // `work.wait` than watchdog detecting work completion. + stashed_for_allocator_safety_(w.stashed_for_allocator_safety_), futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), @@ -700,10 +737,9 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); // Block the current stream on the NCCL stream ncclEndEvent_->block(currentStream); - - if (avoidRecordStreams_) { - stashed_for_allocator_safety_->clear(); - } + // Unstage the stashed tensors so that CachingAllocator can recycle them + // THIS MUST HAPPEN AFTER THE BLOCKING CALL ABOVE + stashed_for_allocator_safety_->unstash(); } // Same as calling synchronize() when blockingWait_ is false @@ -919,7 +955,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( enableTiming_.store( getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_); #endif // ENABLE_NCCL_ERROR_CHECKING - avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false); + if (getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false)) { + TORCH_WARN_ONCE( + "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); + } #ifdef NCCL_HAS_COMM_REGISTER useTensorRegisterAllocatorHook_ = getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); @@ -2309,6 +2348,23 @@ void ProcessGroupNCCL::watchdogHandler() { // Clean up completed work if (work.isCompleted()) { + // In case user didn't call `work.wait()` with async collectives, + // watchdog would unstage the stashed tensors when detecting completion + // of the collective, to prevent ProcessGroupNCCL from holding reference + // to those tensors forever. + // work.stashed_for_allocator_safety_->unstash(); + // Update: it seems directly unstashing from watchdog thread would cause + // some rare problems. We thus move the unstashing to main thread, + // triggered by a next user call, see `workEnqueue`. But `work` is going + // to be destructed, so we transfer the work's shelf to a shelves + // structure owned by the PG. + if (!work.stashed_for_allocator_safety_->empty()) { + std::lock_guard lock(shelvesMutex_); + // We are just pushing back a shared_ptr here, so the cost should be + // minimal + shelvesToUnstash_.push_back(work.stashed_for_allocator_safety_); + } + // Work status logging for desync debug desyncDebugger_.logWorkEnd(work); @@ -3043,6 +3099,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( enableTiming_.load(), cudaEventCacheEnabled_.load(), dist_debug_level_); + if (record) { bool isP2P = isP2POp(opType); // Ideally record every work that we enqueue, rather than every work we @@ -3122,6 +3179,17 @@ void ProcessGroupNCCL::assignTimeoutToWork( void ProcessGroupNCCL::workEnqueue( const c10::intrusive_ptr& work) { + // We clean up the TensorShelf's in case user hasn't called `work.wait()`. + // This has nothing to do with new work enqueue. We are just using a place + // that would be triggered by a next user call. + { + std::lock_guard lock(shelvesMutex_); + for (auto& shelf : shelvesToUnstash_) { + shelf->unstash(); + } + shelvesToUnstash_.clear(); + } + // in blockingWait_ mode, we don't need watchdog thread, so no need to enqueue // the work if (!terminateProcessGroup_.load() && !blockingWait_) { @@ -3158,6 +3226,7 @@ void ProcessGroupNCCL::startCoalescing() { coalescedDevice_.set_index(-1); coalescedComm_ = nullptr; + coalescedTensors_.clear(); coalescing_state_ |= CoalActive; groupStart(); } @@ -3200,10 +3269,12 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { enqueue); work->ncclComm_ = comm; work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams_; work->store_ = store_; assignTimeoutToWork(work, options_); + // Hand over references to tensors during coalescing to work's stash + work->stashed_for_allocator_safety_->stash(coalescedTensors_); + // Record start before ncclGroupEnd if (work->timingEnabled_) { work->ncclStartEvent_->record(ncclStream); @@ -3219,19 +3290,17 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { // TODO(eqy): is this still necessary if avoidRecordStreams_ is set? work->ncclEndEvent_->record(ncclStream); - if (avoidRecordStreams_) { - // other functions expect an initialized ptr if avoidRecordStreams_ is set - work->stashed_for_allocator_safety_ = - std::make_shared>(); - } - if (enqueue) { workEnqueue(work); } + // Reset coalescing state coalescing_state_ = 0; coalescedComm_ = nullptr; - return work; + coalescedTensors_.clear(); + // If in async mode, return work; otherwise, kernel is enqueued on current + // stream, no need to return work + return coalescedAsync_ ? work : nullptr; } c10::intrusive_ptr ProcessGroupNCCL::endCoalescing() { @@ -3264,11 +3333,10 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; auto device = getDevice(inputs[0]); @@ -3309,13 +3377,17 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + coalescedAsync_ = asyncOp; } - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); - - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // in asyncOp=false [default] mode, we use currentStream as ncclStream + // otherwise, we use separate ncclStream and let it sync on currentStream + auto ncclStream = asyncOp ? ncclStreams_.at(key) + : at::cuda::getCurrentCUDAStream(device.index()); + if (asyncOp) { + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + } bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; @@ -3325,9 +3397,19 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); + // If we are performing sync operations, i.e. equeuing kernel onto "current" + // stream, we don't need to do anything for tensor lifetime management. + // Otherwise, we need to stage the tensors will `work.wait()`. + if (asyncOp) { + // First select which shelf to stash onto: to `work` if single collective; + // to an inflight shelf if coalescing. + if (coalescing_state_) { + coalescedTensors_.stash(inputs); + coalescedTensors_.stash(outputs); + } else { + work->stashed_for_allocator_safety_->stash(inputs); + work->stashed_for_allocator_safety_->stash(outputs); + } } if (nanCheck) { @@ -3353,21 +3435,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // operations where `inputs' and `outputs' are not the same. // // See [Sync Streams]. - if (!avoidRecordStreams) { - for (const auto& input : inputs) { - if (!input.is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - input.values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - input.indices().storage().data_ptr(), ncclStream); - } - } - } // Not all collectives have the same signature, e.g, all-reduce take in a Tensor // as the input and output while all-to-all take in a vector of Tensors as input @@ -3419,7 +3486,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3437,7 +3503,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( workEnqueue(work); } - return work; + return asyncOp ? work : nullptr; } template @@ -3446,11 +3512,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::vector& outputs, Fn fn, OpType opType, - const char* profilingTitle, - bool avoidRecordStreams) { - // Environment setting by the user may add onto collective call's option - avoidRecordStreams |= avoidRecordStreams_; - + bool asyncOp, + const char* profilingTitle) { // Currently, the API permits one scenario where inputs.size() and // outputs.size() are > 0. // 1. If the call was a _coalesced call, all inputs must be on the same @@ -3496,13 +3559,17 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + coalescedAsync_ = asyncOp; } - // Used many times below, so we stash the unordered_map lookup - auto ncclStream = ncclStreams_.at(key); - - // First let NCCL streams wait for input tensors allocation streams - syncStream(device, ncclEvents_[key], ncclStream); + // in asyncOp=false [default] mode, we use currentStream as ncclStream + // otherwise, we use separate ncclStream and let it sync on currentStream + auto ncclStream = asyncOp ? ncclStreams_.at(key) + : at::cuda::getCurrentCUDAStream(device.index()); + if (asyncOp) { + // First let NCCL streams wait for input tensors allocation streams + syncStream(device, ncclEvents_[key], ncclStream); + } auto work = initWork( device, @@ -3517,9 +3584,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); - if (avoidRecordStreams) { - work->stashed_for_allocator_safety_ = - std::make_shared>(inputs); + // If we are performing sync operations, i.e. equeuing kernel onto "current" + // stream, we don't need to do anything for tensor lifetime management. + // Otherwise, we need to stage the tensors will `work.wait()`. + if (asyncOp) { + work->stashed_for_allocator_safety_->stash(inputs); + work->stashed_for_allocator_safety_->stash(outputs); } // Start event should only be recorded before the ncclGroupStart() (which @@ -3545,27 +3615,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( { torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); for (const auto i : c10::irange(inputs.size())) { - // Both `inputs' and `outputs' are created on a worker stream and used in - // different ncclStreams. Hence, both must record the ncclStream to - // prevent being freed before the collective finishes. - // - // We only record `inputs' here, and leave recording `outputs' to `fn' for - // operations where `inputs' and `outputs' are not the same. - // - // See [Sync Streams]. - if (!avoidRecordStreams) { - if (!inputs[i].is_sparse()) { - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].storage().data_ptr(), ncclStream); - } else { - // for sparse input case record streams on both index and value - // tensors - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].values().storage().data_ptr(), ncclStream); - c10::cuda::CUDACachingAllocator::recordStream( - inputs[i].indices().storage().data_ptr(), ncclStream); - } - } #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( fn(inputs[i], outputs[i], comm, ncclStream), @@ -3606,7 +3655,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // Set appropriate work parameters. work->blockingWait_ = blockingWait_; - work->avoidRecordStreams_ = avoidRecordStreams; work->store_ = store_; assignTimeoutToWork(work, options_); // Record size info for debug. We only record the size on the first device as @@ -3637,7 +3685,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( // it, since interactions with it by usercode won't behave normally - they // won't observe work completion, for instance. Will this lead to silent // problems during capture? - return work; + return asyncOp ? work : nullptr; } template @@ -3655,13 +3703,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // to wait() on the returned handle, so ProcessGroupNCCL can't know // when it's safe to release the input back to the allocator, // and the present call has no way to know it's not an isend. - // Therefore, we warn and fall back to the typical recordStream logic: - if (avoidRecordStreams_) { - TORCH_WARN_ONCE( - "TORCH_NCCL_AVOID_RECORD_STREAMS=1 has no effect for point-to-point " - "collectives."); - } - + // Therefore, we warn and fall back to the typical recordStream logic. + // TODO( kwen2501 ): revisit this when we have a better solution. auto device = getDevice(tensor); at::cuda::OptionalCUDAGuard gpuGuard(device); @@ -3716,6 +3759,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } else { TORCH_CHECK(coalescedComm_ == ncclComm, MULTI_DEVICE_ERROR_MSG); } + // For now, P2P ops are always put on internal stream + coalescedAsync_ = true; } // Used many times below, so we stash the unordered_map lookup @@ -3887,8 +3932,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3899,8 +3944,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( pre, post, opType, + asyncOp, profilingTitle, - avoidRecordStreams, nanCheck); } @@ -3910,8 +3955,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( at::Tensor& output, Fn fn, OpType opType, + bool asyncOp, const char* profilingTitle, - bool avoidRecordStreams, bool nanCheck) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; @@ -3924,8 +3969,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, opType, + asyncOp, profilingTitle, - avoidRecordStreams, nanCheck); } @@ -3977,6 +4022,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( auto recvIndices = indices[0] * colSize; // prevent output and recvIndices from being freed + // TODO: not changing the lifetime management of outputs this time, + // revisit later c10::cuda::CUDACachingAllocator::recordStream( output.storage().data_ptr(), stream); c10::cuda::CUDACachingAllocator::recordStream( @@ -4008,6 +4055,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_sparse( } }, OpType::_ALLREDUCE_SPARSE, + opts.asyncOp, "nccl:all_reduce_sparse"); return work; #else @@ -4042,6 +4090,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_impl( stream.stream()); }, OpType::ALLREDUCE, + opts.asyncOp, profilingTitle); } @@ -4142,6 +4191,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:allreduce_coalesced"); } @@ -4173,12 +4223,10 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( globalRankStride_, // globalRankStride_ this->getSize()); // worldSize - // avoidRecordStreams_ note: collective() will stash tensors. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank + opts.rootTensor; bool nanCheck = (root == rank_); + // avoidRecordStreams_ note: collective() will stash tensors. return collective( tensor, tensor, @@ -4195,8 +4243,8 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( stream.stream()); }, OpType::BROADCAST, + opts.asyncOp, "nccl:broadcast", - avoidRecordStreams, nanCheck); } @@ -4235,8 +4283,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_broadcast_oop( stream.stream()); }, OpType::BROADCAST, + opts.asyncOp, "nccl:_broadcast_oop", - /*avoidRecordStreams=*/false, nanCheck); } @@ -4295,6 +4343,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( stream.stream()); }, OpType::REDUCE, + opts.asyncOp, "nccl:reduce"); } @@ -4336,6 +4385,7 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_oop( stream.stream()); }, OpType::REDUCE, + opts.asyncOp, "nccl:_reduce_oop"); } @@ -4379,10 +4429,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } + // See [We actually don't need to stash anything here]. return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -4398,27 +4445,27 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( // - inputTensors is stashed onto work->stashed_for_allocator_safety_ // in collective(). // - outputFlattened is stashed onto work->outputs_ in collective(). - // - User-facing outputTensors should be held by the user until after - // waiting on work_, or the call makes no sense. - // So all participating tensors are accounted for, and won't be - // released back to their allocation streams until after work_ is - // waited on. }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { + // User-facing outputTensors should be held by the user until after + // waiting on work_, or the call makes no sense. We do a stashing here + // in case user doesn't hold the outputTensors in downstream code, + // which can cause an early recyle by the CachingAllocator, which can + // lead to segfault or data corruption. + if (opts.asyncOp) { + work->stashed_for_allocator_safety_->stash(outputTensors_); + } // Copy the flattened output tensors to the outputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(outputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - outputTensors_[j].storage().data_ptr(), ncclStream); - } + // See [We actually don't need to stash anything here]. outputTensors_[j].copy_( outputFlattened[static_cast(j)], true); } }, OpType::ALLGATHER, + opts.asyncOp, "nccl:all_gather"); } else { const auto num_reduces = outputTensors_.size(); @@ -4426,7 +4473,8 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( for (const int64_t i : c10::irange(static_cast(num_reduces))) { auto& output = outputTensors_[i]; auto& input = (i == rank_) ? inputTensor : output; - auto broadcastOpts = BroadcastOptions{i, int64_t(0), opts.timeout}; + auto broadcastOpts = + BroadcastOptions{i, int64_t(0), opts.timeout, opts.asyncOp}; _broadcast_oop(output, input, broadcastOpts); } auto work = endCoalescing(OpType::ALLGATHER); @@ -4482,6 +4530,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:all_gather_into_tensor_coalesced"); } @@ -4527,10 +4576,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } const auto ncclDataType = getNcclDataType(input.scalar_type()); const auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4545,27 +4590,18 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( }, [&](at::cuda::CUDAStream& ncclStream, c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // We only need to stash inputTensors. - // - inputFlattened is stashed onto - // work->stashed_for_allocator_safety_ - // in collective(). - // - User-facing outputTensors is stashed onto work->outputs_ in - // collective(), - // and should also be held by the user until after waiting on - // work_. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors_.begin(), inputTensors_.end()); + // We only need to stash inputTensors. + // - inputFlattened is stashed onto + // work->stashed_for_allocator_safety_ in collective(). + // - User-facing outputTensors is stashed onto work->outputs_ in + // collective(), and should also be held by the user until after + // waiting on work_. + if (opts.asyncOp) { + work->stashed_for_allocator_safety_->stash(inputTensors_); } - // Copy the input tensors to the flattened inputs. at::cuda::CUDAStreamGuard guard(ncclStream); for (const auto j : c10::irange(inputTensors_.size())) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - inputTensors_[j].storage().data_ptr(), ncclStream); - } inputFlattened[static_cast(j)].copy_( inputTensors_[j], true); } @@ -4573,6 +4609,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( [&](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::REDUCE_SCATTER, + opts.asyncOp, "nccl:reduce_scatter"); } else { const auto num_reduces = inputTensors_.size(); @@ -4584,7 +4621,8 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( opts.reduceOp, static_cast(i), static_cast(0), - opts.timeout}; + opts.timeout, + opts.asyncOp}; _reduce_oop(output, input, reduceOpts); } auto work = endCoalescing(OpType::REDUCE_SCATTER); @@ -4638,7 +4676,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `reduce_scatter_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( inputTensor, @@ -4647,10 +4684,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4664,8 +4697,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( stream.stream()); }, OpType::_REDUCE_SCATTER_BASE, - "nccl:_reduce_scatter_base", - avoidRecordStreams); + opts.asyncOp, + "nccl:_reduce_scatter_base"); } c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( @@ -4702,10 +4735,6 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } auto ncclDataType = getNcclDataType(input.scalar_type()); auto ncclReduceOp = getNcclReduceOp(opts.reduceOp, input, ncclDataType, comm); @@ -4719,6 +4748,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( stream.stream()); }, OpType::COALESCED, + opts.asyncOp, "nccl:reduce_scatter_tensor_coalesced"); } @@ -4797,13 +4827,28 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); // All reduce to achieve the barrier - auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier"); + AllreduceOptions arOpts = AllreduceOptions(); + arOpts.asyncOp = opts.asyncOp; + auto work = allreduce_impl(barrierTensor, "nccl:all_reduce_barrier", arOpts); + + if (opts.asyncOp) { + // Work will take over barrierTensors + auto ncclWork = dynamic_cast(work.get()); + // If user specified async, the work should not be nullptr + TORCH_CHECK(ncclWork); + // Put a marker here so that `work.wait()` issue by users does + // barrier-specific thing: CPU sync + ncclWork->isBarrierOp_ = true; + return work; + } - // Work will take over barrierTensors - auto ncclWork = dynamic_cast(work.get()); - TORCH_CHECK(ncclWork); - ncclWork->isBarrierOp_ = true; - return work; + // Otherwise, we are in sync mode, we directly wait here. + // (It is a CPU wait for barrier) + auto currentStream = at::cuda::getCurrentCUDAStream(barDevIdx); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); + // No work to return + return nullptr; } c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( @@ -4811,7 +4856,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, - const AllToAllOptions& /* unused */) { + const AllToAllOptions& opts) { check_gpu_single_tensor(outputTensor); check_gpu_single_tensor(inputTensor); if (outputSplitSizes.empty() && inputSplitSizes.empty()) { @@ -4842,16 +4887,12 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } torch::cuda::nccl::all2all_single_equal_split( input, output, this->getSize(), comm, stream); return ncclSuccess; }, OpType::ALLTOALL_BASE, + opts.asyncOp, "nccl:all_to_all"); } else { c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_); @@ -4893,10 +4934,6 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10d::computeLengthsAndOffsets( outputSplitSizes, output, &recv_lengths, &recv_offsets); // See [Sync Streams]. - if (!avoidRecordStreams_) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } torch::cuda::nccl::all2all_single_unequal_split( input.data_ptr(), send_lengths.data(), @@ -4911,6 +4948,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( return ncclSuccess; }, OpType::ALLTOALL_BASE, + opts.asyncOp, "nccl:all_to_all"); } } @@ -4918,7 +4956,7 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10::intrusive_ptr ProcessGroupNCCL::alltoall( std::vector& outputTensors, std::vector& inputTensors, - const AllToAllOptions& /* unused */) { + const AllToAllOptions& opts) { int64_t input_total_numel = 0; int64_t output_total_numel = 0; @@ -4963,18 +5001,11 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( return ncclSuccess; }, [&](at::cuda::CUDAStream&, - c10::intrusive_ptr& work) { - if (avoidRecordStreams_) { - // inputTensor0 and outputTensor0 are stashed redundantly by - // collective(), but that's ok. - auto& v = work->stashed_for_allocator_safety_; - v->insert(v->end(), inputTensors.begin(), inputTensors.end()); - v->insert(v->end(), outputTensors.begin(), outputTensors.end()); - } - }, + c10::intrusive_ptr& work) {}, [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::ALLTOALL, + opts.asyncOp, "nccl:all_to_all"); } @@ -5172,14 +5203,6 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( ncclComm_t comm, at::cuda::CUDAStream& stream) { const auto root = opts.rootRank; - if (getRank() == root) { - if (!avoidRecordStreams_) { - for (auto const& output : outputs) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } - } - } torch::cuda::nccl::gather( inputTensor, outputs, comm, stream, static_cast(root)); return ncclSuccess; @@ -5189,6 +5212,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::GATHER, + opts.asyncOp, "nccl:gather"); } @@ -5257,8 +5281,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( // avoidRecordStreams_ note: collective() will stash outputTensors and // inputs, which == inputTensors[0] on the root rank where it matters. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); - const auto root = opts.rootRank; bool nanCheck = (rank_ == root); @@ -5270,14 +5292,6 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::Tensor& /* unused */, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (getRank() == root) { - if (!avoidRecordStreams) { - for (auto const& input : inputs) { - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), stream); - } - } - } torch::cuda::nccl::scatter( inputs, outputTensor, comm, stream, static_cast(root)); return ncclSuccess; @@ -5287,8 +5301,8 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( [](at::cuda::CUDAStream&, c10::intrusive_ptr& work) {}, OpType::SCATTER, + opts.asyncOp, "nccl:scatter", - avoidRecordStreams, nanCheck); } @@ -5344,7 +5358,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( // stream so that the caching allocator can reuse memory pool for this stream // in a clever way. This setting is added for libraries like FSDP which uses // `all_gather_into_tensor`. - bool avoidRecordStreams = avoidRecordStreams_ || (!opts.asyncOp); return collective( input_tensor, @@ -5353,10 +5366,6 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( at::Tensor& output, ncclComm_t comm, at::cuda::CUDAStream& stream) { - if (!avoidRecordStreams) { - c10::cuda::CUDACachingAllocator::recordStream( - output.storage().data_ptr(), stream); - } return ncclAllGather( input.data_ptr(), output.data_ptr(), @@ -5366,8 +5375,8 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( stream.stream()); }, OpType::_ALLGATHER_BASE, - "nccl:_all_gather_base", - avoidRecordStreams); + opts.asyncOp, + "nccl:_all_gather_base"); } // Create a memory allocator for NCCL. This allocator is used to allocate memory diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index ca870f702013..9f3ad484e55d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -235,6 +235,34 @@ struct DumpPipe { }; #endif +// A shelf for stashing tensors between op call and `work.wait()`. +// Used in case of async ops. +class TensorShelf { + public: + // Stash tensors so that CachingAllocator cannot recycle them prematurely. + void stash(std::vector& tensors); + // Stash tensors from another shelf. + void stash(TensorShelf& other); + // Unstage the stashed tensors so that CachingAllocator can recycle them. + // Same as `clear()`. + void unstash(); + // Whether shelf is empty. + bool empty(); + // Clear the shelf. + void clear(); + + protected: + // Get the inner tensor vector. Use with caution as it is not protected by + // mutex. + std::vector& get(); + + private: + std::vector tVector_; + // Need a mutex to protect `tVector_` because it can be potentially accessed + // from both main thread and watchdog thread. + std::mutex mutex_; +}; + // ProcessGroupNCCL implements NCCL bindings for c10d. // // All functions of the class are expected to be called in the same order @@ -382,9 +410,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Clone of blockingWait_ from ProcessGroupNCCL. bool blockingWait_{false}; - // Clone of avoidRecordStreams_ from ProcessGroupNCCL. - bool avoidRecordStreams_{false}; - // Clone of opTimeout_ from ProcessGroupNCCL. std::chrono::milliseconds opTimeout_{}; @@ -448,7 +473,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // caching allocator safety without any recordStream calls. // For in-place collectives, some refs stashed here may alias outputs_, // but that doesn't do any harm. - std::shared_ptr> stashed_for_allocator_safety_; + std::shared_ptr stashed_for_allocator_safety_; // The future returned by getFuture. c10::intrusive_ptr future_; @@ -889,8 +914,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { at::Tensor& output, Fn fn, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -901,8 +926,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -913,8 +938,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { PreProcess pre, PostProcess post, OpType opType, + bool asyncOp, const char* profilingTitle = nullptr, - bool avoidRecordStreams = false, bool nanCheck = true); template @@ -923,8 +948,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::vector& output, Fn fn, OpType opType, - const char* profilingTitle = nullptr, - bool avoidRecordStreams = false); + bool asyncOp, + const char* profilingTitle = nullptr); // Helper that encapsulates work shared across point-to-point communication // primitives. It is the same structure as the helper used for collective @@ -1233,6 +1258,22 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Stores communicators for all collectives run inside a coalescing block std::shared_ptr coalescedComm_ = nullptr; + // Whether the coalesced calls are sync or async. + bool coalescedAsync_; + + // keeps track of input and output tensors when coalescing is in flight. Will + // hand over these tensors to WorkNCCL's stash when coalescing is ended. + TensorShelf coalescedTensors_; + + // Some ops may have completed, but user still hasn't called `work.wait()`. + // When watchdog detects this, it transfers the TensorShelf from `work` to + // this `shelves` structure. Next time we execute ProcessGroupNCCL's methods + // on main thread, we clear the `shelves` in one shot. This is mainly because + // watchdog (a side thread) unstashing the shelf directly seems to cause some + // problem. + std::vector> shelvesToUnstash_; + std::mutex shelvesMutex_; + // Whether or not wait() and synchronize() are blocking operations that wait // for the operation to complete. bool blockingWait_ = false; diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 5d15708c953e..8fec5dd0e9e2 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -122,6 +122,7 @@ struct BroadcastOptions { struct AllreduceOptions { ReduceOp reduceOp = ReduceOp::SUM; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; std::optional sparseIndices = std::nullopt; }; @@ -132,6 +133,7 @@ struct ReduceOptions { int64_t rootRank = 0; int64_t rootTensor = 0; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct AllgatherOptions { @@ -142,6 +144,7 @@ struct AllgatherOptions { struct GatherOptions { int64_t rootRank = 0; std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct ScatterOptions { @@ -158,12 +161,14 @@ struct ReduceScatterOptions { struct AllToAllOptions { std::chrono::milliseconds timeout = kUnsetTimeout; + bool asyncOp = true; }; struct BarrierOptions { std::vector device_ids; std::chrono::milliseconds timeout = kUnsetTimeout; std::optional device; + bool asyncOp = true; }; struct DistributedBackendOptions { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index ddd75d234449..0217d2471dc8 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -999,20 +999,23 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::AllreduceOptions>(module, "AllreduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout); + .def_readwrite("timeout", &::c10d::AllreduceOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllreduceOptions::asyncOp); py::class_<::c10d::AllreduceCoalescedOptions>( module, "AllreduceCoalescedOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::AllreduceCoalescedOptions::reduceOp) - .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout); + .def_readwrite("timeout", &::c10d::AllreduceCoalescedOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllreduceCoalescedOptions::asyncOp); py::class_<::c10d::ReduceOptions>(module, "ReduceOptions") .def(py::init<>()) .def_readwrite("reduceOp", &::c10d::ReduceOptions::reduceOp) .def_readwrite("rootRank", &::c10d::ReduceOptions::rootRank) .def_readwrite("rootTensor", &::c10d::ReduceOptions::rootTensor) - .def_readwrite("timeout", &::c10d::ReduceOptions::timeout); + .def_readwrite("timeout", &::c10d::ReduceOptions::timeout) + .def_readwrite("asyncOp", &::c10d::ReduceOptions::asyncOp); py::class_<::c10d::AllgatherOptions>(module, "AllgatherOptions") .def(py::init<>()) @@ -1022,7 +1025,8 @@ This class does not support ``__members__`` property.)"); py::class_<::c10d::GatherOptions>(module, "GatherOptions") .def(py::init<>()) .def_readwrite("rootRank", &::c10d::GatherOptions::rootRank) - .def_readwrite("timeout", &::c10d::GatherOptions::timeout); + .def_readwrite("timeout", &::c10d::GatherOptions::timeout) + .def_readwrite("asyncOp", &::c10d::GatherOptions::asyncOp); py::class_<::c10d::ScatterOptions>(module, "ScatterOptions") .def(py::init<>()) @@ -1040,11 +1044,13 @@ This class does not support ``__members__`` property.)"); .def(py::init<>()) .def_readwrite("device_ids", &::c10d::BarrierOptions::device_ids) .def_readwrite("timeout", &::c10d::BarrierOptions::timeout) - .def_readwrite("device", &::c10d::BarrierOptions::device); + .def_readwrite("device", &::c10d::BarrierOptions::device) + .def_readwrite("asyncOp", &::c10d::BarrierOptions::asyncOp); py::class_<::c10d::AllToAllOptions>(module, "AllToAllOptions") .def(py::init<>()) - .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout); + .def_readwrite("timeout", &::c10d::AllToAllOptions::timeout) + .def_readwrite("asyncOp", &::c10d::AllToAllOptions::asyncOp); py::class_<::c10d::DistributedBackendOptions>( module, "_DistributedBackendOptions") diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 339afeffdc7f..668dbf49a0d0 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -2501,7 +2501,7 @@ class _CoalescingManager: def __init__(self) -> None: self.works: list[Work] = [] - def append(self, work: Work): + def append(self, work: Optional[Work] = None): if work: self.works.append(work) @@ -2514,7 +2514,7 @@ def wait(self): def _coalescing_manager( group: Optional[ProcessGroup] = None, device: Optional[torch.device] = None, - async_ops: Optional[bool] = False, + async_ops: bool = False, ): """ Context manager used to coalesce collectives or P2P operations when possible. @@ -2553,6 +2553,7 @@ def _coalescing_manager( group._start_coalescing(device) cm = _CoalescingManager() yield cm + work = None op_list = _world.pg_coalesce_state.pop(group) if op_list: # Collectives supporting "Fast Path" coalescing are captured. @@ -2566,6 +2567,7 @@ def _coalescing_manager( tensors = [op.tensor for op in op_list] all_reduce_opts = AllreduceCoalescedOptions() all_reduce_opts.reduceOp = not_none(op_list[0].redop) + all_reduce_opts.asyncOp = async_ops work = group.allreduce_coalesced(tensors, all_reduce_opts) elif op0 == all_gather_into_tensor: inputs = [] @@ -2573,6 +2575,8 @@ def _coalescing_manager( for op in op_list: inputs.append(op.tensor) outputs.append(not_none(op.dst_tensor)) + all_gather_opts = AllgatherOptions() + all_gather_opts.asyncOp = async_ops work = group.allgather_into_tensor_coalesced(outputs, inputs) elif op0 == reduce_scatter_tensor: inputs = [] @@ -2582,6 +2586,7 @@ def _coalescing_manager( outputs.append(not_none(op.dst_tensor)) reduce_opts = ReduceScatterOptions() reduce_opts.reduceOp = not_none(op_list[0].redop) + reduce_opts.asyncOp = async_ops work = group.reduce_scatter_tensor_coalesced(outputs, inputs, reduce_opts) else: raise AssertionError( @@ -2594,9 +2599,12 @@ def _coalescing_manager( work = group._end_coalescing(device) if async_ops: - cm.append(work) # type: ignore[possibly-undefined] - else: - work.wait() # type: ignore[possibly-undefined] + cm.append(work) + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level + work.wait() + # Otherwise, the backend has sync'ed at CPP level class _TimeEstimator: @@ -2772,8 +2780,11 @@ def broadcast( work = group.broadcast([tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2853,6 +2864,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceOptions() opts.reduceOp = op + opts.asyncOp = async_op if group is None: group = _get_default_group() @@ -2869,8 +2881,11 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False): if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2929,13 +2944,17 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False): opts = AllreduceCoalescedOptions() opts.reduceOp = op + opts.asyncOp = async_op group = group or _get_default_group() work = group.allreduce_coalesced(tensors, opts) if async_op: return work.get_future() - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -2980,11 +2999,15 @@ def reduce( opts = ReduceOptions() opts.reduceOp = op opts.rootRank = group_dst + opts.asyncOp = async_op work = group.reduce([tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def _object_to_tensor(obj, device, group): @@ -3783,12 +3806,17 @@ def all_gather(tensor_list, tensor, group=None, async_op=False): tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) group = group or _get_default_group() - work = group.allgather([tensor_list], [tensor]) + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather([tensor_list], [tensor], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -3891,8 +3919,11 @@ def all_gather_into_tensor(output_tensor, input_tensor, group=None, async_op=Fal if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4002,12 +4033,17 @@ def all_gather_coalesced( ] group = group or _get_default_group() - work = group.allgather_coalesced(output_tensor_lists, input_tensor_list) + opts = AllgatherOptions() + opts.asyncOp = async_op + work = group.allgather_coalesced(output_tensor_lists, input_tensor_list, opts) if async_op: return work.get_future() - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def _validate_output_list_for_rank(my_rank, dst, gather_list): @@ -4093,12 +4129,16 @@ def gather( opts = GatherOptions() opts.rootRank = group_dst + opts.asyncOp = async_op work = group.gather(output_tensors, input_tensors, opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4199,8 +4239,11 @@ def scatter( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4232,14 +4275,18 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal opts = ReduceScatterOptions() opts.reduceOp = op + opts.asyncOp = async_op group = group or _get_default_group() work = group.reduce_scatter([output], [input_list], opts) if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4336,8 +4383,11 @@ def reduce_scatter_tensor(output, input, op=ReduceOp.SUM, group=None, async_op=F if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @deprecated( @@ -4490,6 +4540,7 @@ def all_to_all_single( return opts = AllToAllOptions() + opts.asyncOp = async_op _check_single_tensor(output, "output") _check_single_tensor(input, "input") _ensure_all_tensors_same_dtype(output, input) @@ -4509,8 +4560,11 @@ def all_to_all_single( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4611,6 +4665,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False return opts = AllToAllOptions() + opts.asyncOp = async_op _check_tensor_list(output_tensor_list, "output_tensor_list") _check_tensor_list(input_tensor_list, "input_tensor_list") _ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list) @@ -4627,8 +4682,11 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level @_exception_logger @@ -4659,6 +4717,7 @@ def barrier( opts = BarrierOptions() opts.device = torch.device(_get_object_coll_device(group)) + opts.asyncOp = async_op if device_ids is not None: if isinstance(device_ids, list): opts.device_ids = device_ids @@ -4672,8 +4731,11 @@ def barrier( if async_op: return work - else: + elif ( + work is not None + ): # Backward compatible with backends that don't sync at CPP level work.wait() + # Otherwise, the backend has sync'ed at CPP level def monitored_barrier( diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 3f4a24a1ffb1..db9f9e70dee1 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -96,7 +96,7 @@ import torchvision HAS_TORCHVISION = True -except ImportError: +except Exception: # Covering both ImportError and RuntimeError HAS_TORCHVISION = False if sys.platform == "win32": @@ -8310,50 +8310,14 @@ def test_compute_bucket_assignment_by_size_sparse_error_without_logger(self): def test_compute_bucket_assignment_by_size_sparse_error_with_logger(self): self._test_compute_bucket_assignment_by_size(use_logger=True) - def _determine_expected_error_verify_model_across_rank( - self, group_to_use, diff_num_params=False - ): - # When running with NCCL backend, we don't expect an error on rank 0, - # rather, it will be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING. When - # running with Gloo or with debug mode wrapper, we expect the error - # to be caught inline. - # All ranks report same error when there is a # of parameter - # mismatch since we use allgather in the impl. - if diff_num_params: - expected_err = "DDP expects same model across all ranks" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - return ctx, expected_err - - is_detail_dbg_mode = dist.get_debug_level() == dist.DebugLevel.DETAIL - if self.rank == 0: - if ( - dist.get_backend(group_to_use) == dist.Backend.NCCL - and not is_detail_dbg_mode - ): - expected_err = "caught collective operation timeout" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - else: - expected_err = None - ctx = self.assertRaises(RuntimeError) - else: - expected_err = "appears not to match" - ctx = self.assertRaisesRegex(RuntimeError, expected_err) - return ctx, expected_err - def _test_verify_model_across_rank(self, use_logger): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=5) ) torch.cuda.set_device(self.rank) - ctx, expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use - ) # Create a valid model. The constructor initializes the logger that we use later. net = EmbeddingNetDifferentParams(0) @@ -8371,7 +8335,8 @@ def _test_verify_model_across_rank(self, use_logger): net.module.lin = nn.Linear(100 if self.rank == 0 else 10, 1) # if we pass a logger we can verify that it was logged - with ctx: + caught = 0 + try: if use_logger: _verify_param_shape_across_processes( net.process_group, list(net.parameters()), net.logger @@ -8380,18 +8345,13 @@ def _test_verify_model_across_rank(self, use_logger): _verify_param_shape_across_processes( net.process_group, list(net.parameters()) ) - # Should only be run by rank 0, and blocking_wait catches and - # reports exception. - dist.barrier(group_to_use) + except Exception: + caught = 1 - # We don't check when self.rank != 0 because the logger doesn't log - # the error "Caught collective operation" as that is not thrown in the reducer. - if use_logger and self.rank != 0: - verify_ddp_error_logged(net, expected_err) - - # Perform gloo-based barrier to ensure one rank doesn't exit test - # early which causes failure with Barrier.sync. - dist.barrier(group_gloo) + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8409,20 +8369,19 @@ def test_verify_model_across_rank_with_logger(self): def test_verify_model_across_rank_without_logger(self): self._test_verify_model_across_rank(use_logger=False) - def _run_test_ddp_model_with_diff_params(self, ctx, net, ddp_group, group_gloo): - with ctx: + def _run_test_ddp_model_with_diff_params(self, net, ddp_group, group_gloo): + caught = 0 + try: net = torch.nn.parallel.DistributedDataParallel( net.to(self.rank), device_ids=[self.rank], process_group=ddp_group ) - # Should only be run by rank 0, and blocking_wait catches and - # reports exception. - dist.barrier(ddp_group) - - # can't use verify_ddp_error_logged here because net was never properly constructed + except Exception: + caught = 1 - # Perform gloo-based barrier to ensure one rank doesn't exit test - # early which causes failure with Barrier.sync. - dist.barrier(group_gloo) + # As long as there is one rank catching the exception + t = torch.Tensor([caught]) + dist.all_reduce(t, group=group_gloo) + self.assertGreater(t, 0) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_but_pass_in_sandcastle_if( @@ -8433,21 +8392,15 @@ def test_ddp_model_diff_shape_across_ranks(self): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use - ) # Creates network with different sized embedding table on different # ranks. This should throw an error during DDP init. net = EmbeddingNetDifferentParams(self.rank) self._run_test_ddp_model_with_diff_params( - ctx, net, group_to_use, group_gloo + net, group_to_use, group_gloo ) @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @@ -8459,16 +8412,10 @@ def test_ddp_model_diff_num_params_across_ranks(self): group_gloo = dist.new_group( timeout=timedelta(seconds=60), backend=dist.Backend.GLOO ) - # Set TORCH_NCCL_BLOCKING_WAIT and use a new NCCL group to improve test - # determinism. - os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" group_to_use = dist.new_group( backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( - group_to_use, diff_num_params=True - ) # Creates network with diff # of param across ranks, reducer should # recognize this and throw appropriate error. @@ -8477,7 +8424,6 @@ def test_ddp_model_diff_num_params_across_ranks(self): ) self._run_test_ddp_model_with_diff_params( - ctx, net, group_to_use, group_gloo, From a19b667bca844f46f1dbfd444407e93407ff1d04 Mon Sep 17 00:00:00 2001 From: Sriram Kumar Date: Tue, 1 Apr 2025 16:49:03 +0000 Subject: [PATCH 047/332] [ROCm] Update CUDAPluggableAllocator.h (#1984) (#150010) Altering the flag to use the correct streamType in CUDAPluggableAllocator class for ROCm gpu. The flag TORCH_HIP_VERSION does not work for ROCm as intended. This flag is replaced with USE_ROCM. This is impacting Distributed Fused Adam in Rocm/APEX when using nccl_ub feature. This has been tested with rocm/apex. See PR https://github.com/ROCm/apex/pull/184 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150010 Approved by: https://github.com/jeffdaily --- torch/csrc/cuda/CUDAPluggableAllocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 140ac95a071a..ade983e708c1 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -37,7 +37,7 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocatorDeleterContext { cudaStream_t stream_{}; }; -#if defined(TORCH_HIP_VERSION) +#if defined(USE_ROCM) using streamType = c10::hip::HIPStream; #else using streamType = c10::cuda::CUDAStream; From ae74ef9d53498840874df0e6c5d448f6b46c967b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 2 Apr 2025 00:17:51 +0800 Subject: [PATCH 048/332] Set proper `LD_LIBRARY_PATH` on Linux in nightly venv in nightly pull tool (#143262) Before this change: ```console $ make setup-env-cuda PYTHON="${HOMEBREW_PREFIX}/bin/python3.12" $ source venv/bin/activate $ python3 -c 'import torch' Traceback (most recent call last): File "", line 1, in File "/home/PanXuehai/Projects/pytorch/torch/__init__.py", line 379, in from torch._C import * # noqa: F403 ^^^^^^^^^^^^^^^^^^^^^^ ImportError: libcudnn.so.9: cannot open shared object file: No such file or directory ``` This PR adds `site-packages/nvidia/**/lib` to `LD_LIBRARY_PATH` in `venv/bin/activate` script to let NVIDIA PyPI packages can be loaded correctly. See also: - #141837 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143262 Approved by: https://github.com/malfet --- tools/nightly.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tools/nightly.py b/tools/nightly.py index 45ca897cbe55..9fa1dcba9f51 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -50,6 +50,7 @@ import subprocess import sys import tempfile +import textwrap import time import uuid from ast import literal_eval @@ -340,6 +341,44 @@ def create(self, *, remove_if_exists: bool = False) -> Path: self.base_python("-m", "venv", str(self.prefix)) assert self.is_venv(), "Failed to create virtual environment." (self.prefix / ".gitignore").write_text("*\n", encoding="utf-8") + + if LINUX: + activate_script = self.activate_script + st_mode = activate_script.stat().st_mode + # The activate script may be read-only and we need to add write permissions + activate_script.chmod(st_mode | 0o200) + with activate_script.open(mode="a", encoding="utf-8") as f: + f.write( + "\n" + + textwrap.dedent( + f""" + # Add NVIDIA PyPI packages to LD_LIBRARY_PATH + export LD_LIBRARY_PATH="$( + {self.executable.name} - < Path: From f94ac263afb849f40728481cd8eff07a9dd63b0d Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 1 Apr 2025 07:17:22 -0700 Subject: [PATCH 049/332] [MPSInductor] Fix neg for unsigned types (#150412) By more-or-less copy-n-pasting the fix from https://github.com/pytorch/pytorch/pull/94035 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150412 Approved by: https://github.com/jansel, https://github.com/dcci ghstack dependencies: #150382, #150386 --- test/inductor/test_mps_basic.py | 1 + torch/_inductor/codegen/mps.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 8376100b91c4..4d16f4301de8 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -208,6 +208,7 @@ def fn(a): "test_multilayer_prime_size", "test_min_max_reduction_nan", "test_nan_to_num", + "test_neg_max_uint8", "test_pow2", "test_prod", "test_randint_int64_mod", diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index a5ea219eb037..ac2218e3e0f5 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -298,6 +298,12 @@ def atan2(x: CSEVariable, y: CSEVariable) -> str: def sqrt(x: CSEVariable) -> str: return f"metal::sqrt({x})" + @staticmethod + def neg(x: CSEVariable) -> str: + # TODO: Does it rely on undefined behavior? + # If so, add special logic for unsigned types + return f"static_cast(-{x})" + @staticmethod def rsqrt(x: CSEVariable) -> str: return f"metal::rsqrt({x})" From a17ee8181a19f2bf8c89a265e6c4d6d95886bfb3 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 1 Apr 2025 17:13:58 +0000 Subject: [PATCH 050/332] [CI] Fix log artifact not containing test logs attempt 2 (#150234) Fixes #ISSUE_NUMBER Take two of https://github.com/pytorch/pytorch/pull/149577 since it didn't work Pull Request resolved: https://github.com/pytorch/pytorch/pull/150234 Approved by: https://github.com/malfet, https://github.com/seemethere --- .github/actions/upload-test-artifacts/action.yml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 5effc5f3689a..fe949516402d 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -48,14 +48,8 @@ runs: run: | # Remove any previous usage logs if they exist rm -f logs-*.zip - # this workflow is also run in bazel build test, but we dont generate usage reports for it - # so check to see if the file exists first - if [ -f 'usage_log.txt' ]; then - zip "logs-${FILE_SUFFIX}.zip" 'usage_log.txt' - fi - if find "test/test-reports" -name "*.log" 2>/dev/null | stdbuf -o0 grep -q .; then - zip -r "logs-${FILE_SUFFIX}.zip" test/test-reports -i '*.log' - fi + zip "logs-${FILE_SUFFIX}.zip" 'usage_log.txt' || true + zip -r "logs-${FILE_SUFFIX}.zip" test/test-reports -i '*.log' || true - name: Zip debugging artifacts for upload if: runner.os != 'Windows' && !inputs.use-gha From 783f045c4f26cb9d11789d80f68a86854dfad9f9 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 1 Apr 2025 17:14:29 +0000 Subject: [PATCH 051/332] [ez] Remove dead lite interpreter CI code (#150424) There are no lite-interpreter build environments in CI I assume every mac build is arm64 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150424 Approved by: https://github.com/seemethere, https://github.com/malfet --- .ci/pytorch/macos-build.sh | 29 +++-------------------------- 1 file changed, 3 insertions(+), 26 deletions(-) diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index 4e1c68be9282..d538581c09a6 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -33,34 +33,11 @@ if which sccache > /dev/null; then export PATH="${tmp_dir}:$PATH" fi -build_lite_interpreter() { - echo "Testing libtorch (lite interpreter)." - - CPP_BUILD="$(pwd)/../cpp_build" - # Ensure the removal of the tmp directory - trap 'rm -rfv ${CPP_BUILD}' EXIT - rm -rf "${CPP_BUILD}" - mkdir -p "${CPP_BUILD}/caffe2" - - # It looks libtorch need to be built in "${CPP_BUILD}/caffe2 folder. - BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py - pushd "${CPP_BUILD}/caffe2" || exit - VERBOSE=1 DEBUG=1 python "${BUILD_LIBTORCH_PY}" - popd || exit - - "${CPP_BUILD}/caffe2/build/bin/test_lite_interpreter_runtime" -} - print_cmake_info -if [[ ${BUILD_ENVIRONMENT} = *arm64* ]]; then - # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests - # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 - USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel -elif [[ ${BUILD_ENVIRONMENT} = *lite-interpreter* ]]; then - export BUILD_LITE_INTERPRETER=1 - build_lite_interpreter -fi +# Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests +# that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 +USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python setup.py bdist_wheel if which sccache > /dev/null; then print_sccache_stats From 3b0cd9b542bc86b6e9f28a61960173371d3522ad Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 1 Apr 2025 03:24:00 -0700 Subject: [PATCH 052/332] [Quant][PT2E] add a lowering pass for x86 backend (#149708) **Summary** This PR adds a lowering pass for x86 backend - Patterns of `dequantize -> conv/linear (-> quantize)` are fused to corresponding quantized onednn ops. - Weights are prepacked ahead of time. - Post ops of conv/linear are fused if supported. - The pass returns a `GraphModule` with the modifications mentioned above. **Test plan** ``` pytest test/quantization/pt2e/test_x86inductor_quantizer.py -k test_lowering_to_x86 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149708 Approved by: https://github.com/jerryzh168, https://github.com/leslie-fang-intel --- docs/source/quantization-support.rst | 14 ++ docs/source/quantization.rst | 1 + .../pt2e/test_x86inductor_quantizer.py | 129 ++++++++++++++++++ torch/ao/quantization/pt2e/lowering.py | 60 ++++++++ 4 files changed, 204 insertions(+) create mode 100644 torch/ao/quantization/pt2e/lowering.py diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 4e4ce90c6055..83ad054514ef 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -146,6 +146,20 @@ torch.ao.quantization.pt2e.export_utils .. currentmodule:: torch.ao.quantization +torch.ao.quantization.pt2e.lowering +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. currentmodule:: torch.ao.quantization.pt2e.lowering + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + lower_pt2e_quantized_to_x86 + +.. currentmodule:: torch.ao.quantization + PT2 Export (pt2e) Numeric Debugger ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autosummary:: diff --git a/docs/source/quantization.rst b/docs/source/quantization.rst index 1b808136ef11..226bb143d322 100644 --- a/docs/source/quantization.rst +++ b/docs/source/quantization.rst @@ -1341,6 +1341,7 @@ Please take a look at `Limitations of Symbolic Tracing torch.fx.GraphModule: + """Lower a PT2E-qantized model to x86 backend. + + Args: + * `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow. + * `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model. + + Return: + A GraphModule lowered to x86 backend. + """ + + def _post_autograd_decomp_table(): # type: ignore[no-untyped-def] + decomp_table = torch.export.default_decompositions() + + # if we are post-autograd, we shouldn't + # decomp prim ops. + for k in list(decomp_table.keys()): + if not torch._export.utils._is_cia_op(k): + del decomp_table[k] + + return decomp_table + + def _node_replace(m): # type: ignore[no-untyped-def] + # Replace aten.t(x) with aten.permute(x, [1, 0]) + aten = torch.ops.aten + g = m.graph + for node in g.nodes: + if node.target == aten.t.default: + with g.inserting_before(node): + x = node.args[0] + dims = [1, 0] + perm_node = g.call_function(aten.permute.default, args=(x, dims)) + node.replace_all_uses_with(perm_node) + g.erase_node(node) + + g.lint() + m.recompile() + + lowered_model = ( + torch.export.export_for_training(model, example_inputs) + .run_decompositions(_post_autograd_decomp_table()) + .module() + ) + _node_replace(lowered_model) + freezing_passes(lowered_model, example_inputs) + constant_fold(lowered_model) + return lowered_model From 48af2cdd270c275acccc4a94b04e4ccdb64d557a Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 1 Apr 2025 17:33:12 +0000 Subject: [PATCH 053/332] [BE] Move all lint runner to 24.04 (#150427) As Ubuntu-20 reached EOL on Apr 1st, see https://github.com/actions/runner-images/issues/11101 This forces older python version to be 3.8 Delete all linux-20.04 runners from the lintrunner.yml Pull Request resolved: https://github.com/pytorch/pytorch/pull/150427 Approved by: https://github.com/seemethere --- .github/actionlint.yaml | 3 --- .github/workflows/lint.yml | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index c33e09d37efc..1c44ba1f888a 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -3,9 +3,6 @@ self-hosted-runner: # GitHub hosted runner that actionlint doesn't recognize because actionlint version (1.6.21) is too old - ubuntu-24.04 # GitHub hosted x86 Linux runners - # TODO: Cleanup mentions of linux.20_04 when upgrade to linux.24_04 is complete - - linux.20_04.4x - - linux.20_04.16x - linux.24_04.4x - linux.24_04.16x # Organization-wide AWS Linux Runners diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 9d72b0c5bbb6..db00515a79b6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -233,10 +233,8 @@ jobs: runner: linux.24_04.4x - test_type: without_torch runner: linux.24_04.4x - # NOTE: The oldest supported version of python for 24.04 is 3.8 - # so this cannot be updated if we want to keep this test at 3.6 - test_type: older_python_version - runner: linux.20_04.4x + runner: linux.24_04.4x steps: # [see note: pytorch repo ref] # deep clone (fetch-depth 0) required, to allow us to use git log @@ -256,7 +254,7 @@ jobs: if: matrix.test_type == 'older_python_version' uses: actions/setup-python@v5 with: - python-version: 3.6 + python-version: 3.8 architecture: x64 check-latest: false cache: pip From b0c560ef2ab05e6158b73bb0ee864a074bb41076 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 31 Mar 2025 14:45:21 -0700 Subject: [PATCH 054/332] [dynamo][hooks] use wrap_top_frame config for functions (#150209) When torch.compile is applied to a module via `mod.compile(...)`, it's equivalent to `torch.compile(mod._call_impl)` which takes a different path than `OptimizedModule`. This PR ensures that the `wrap_top_frame` config can also take effect for the `torch.compile(mod._call_impl)` use case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150209 Approved by: https://github.com/anijain2305 --- test/dynamo/test_hooks.py | 8 ++++++++ torch/_dynamo/eval_frame.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index a75bc7ac1af7..3793db65d73f 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -872,6 +872,7 @@ def forward(self, x): mod = ToyModel() mod.register_forward_pre_hook(lambda mod, input: input[0] + 1) + # Case 1: torch.compile(mod) cnts = torch._dynamo.testing.CompileCounter() compiled_mod = torch.compile(mod, backend=cnts) @@ -881,6 +882,13 @@ def forward(self, x): self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) + # Case 2: mod.compile() + cnts = torch._dynamo.testing.CompileCounter() + mod.compile(backend=cnts) + res = mod(x) + self.assertEqual(ref, res) + self.assertEqual(cnts.frame_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 18450464197b..8527daa7a796 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -597,7 +597,7 @@ def get_compiler_config(): filename = inspect.getsourcefile(fn) except TypeError: filename = None - if ( + if config.wrap_top_frame or ( (filename is None or trace_rules.check(fn)) and ( getattr(fn, "__name__", "") From f04cf13bddd700055069689216fcdc80a80d60cc Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 1 Apr 2025 17:54:28 +0000 Subject: [PATCH 055/332] Revert "Merge Triton ScaledMM as epilogue to MM template (#150045)" This reverts commit 981048854da154eae8ff0bd439e72e1256ae00da. Reverted https://github.com/pytorch/pytorch/pull/150045 on behalf of https://github.com/PaulZhang12 due to Need to add PR 150415 fixes for internal merge ([comment](https://github.com/pytorch/pytorch/pull/150045#issuecomment-2770252452)) --- torch/_inductor/kernel/mm.py | 381 +---------------- torch/_inductor/kernel/mm_common.py | 70 ---- torch/_inductor/kernel/mm_scaled.py | 608 ++++++++++++++++++++++++++++ torch/_inductor/utils.py | 8 +- 4 files changed, 613 insertions(+), 454 deletions(-) create mode 100644 torch/_inductor/kernel/mm_scaled.py diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 2a52d7fc0135..ffa1531efd42 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Any, Optional +from typing import Optional import torch from torch._dynamo.utils import counters @@ -21,16 +21,10 @@ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.wrapper import PythonWrapperCodegen from ..ir import FlexibleLayout, is_triton -from ..lowering import ( - add_layout_constraint, - constrain_to_fx_strides, - lowerings as L, - register_lowering, -) +from ..lowering import register_lowering from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, - realize_inputs, TritonTemplate, ) from ..utils import ( @@ -52,8 +46,6 @@ mm_options, persistent_mm_grid, persistent_mm_options, - scale_mm_epilogue, - scaled_mm_options, should_fallback_to_aten, ) @@ -127,11 +119,7 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -200,10 +188,7 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - {% if USE_FAST_ACCUM %} - acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) - {% else %} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -320,179 +305,6 @@ """, ) -load_scales = r""" -@triton.jit -def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): - if SCALING_ROWWISE: - # For row-wise scaling, we'll return the pointers - return a_scale_ptr, b_scale_ptr - else: - # For per-tensor scaling, we'll load the scalar values - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr) - return a_scale, b_scale -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALING_ROWWISE: - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -device_tma = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALING_ROWWISE: - stride_a_scale_m = 1 - stride_b_scale_n = 1 - else: - stride_a_scale_m = 0 - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - - -scaled_mm_device_tma_template = TritonTemplate( - name="scaled_mm_device_tma", - grid=persistent_mm_grid, - source=device_tma + load_scales + apply_scaling, -) - # prevent duplication registration of extern functions @functools.lru_cache(None) @@ -514,10 +326,6 @@ def lazy_register_extern_choice(fn): has_out_variant=False, ) -aten__fp8_mm = ExternKernelChoice( - torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out -) - def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) @@ -528,16 +336,6 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**13 -@functools.lru_cache -def using_b200() -> bool: - """Returns true if the device is a NVIDIA B200, otherwise returns false.""" - if not torch.cuda.is_available(): - return False - # compute capability 10.0 or 10.0a is NVIDIA B200 - device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) - return device_properties.major == 10 - - def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -549,32 +347,6 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) -def check_supported_striding(mat_a, mat_b) -> None: - def is_row_major(stride) -> bool: - return V.graph.sizevars.statically_known_equals(stride[1], 1) - - def is_col_major(stride) -> bool: - return V.graph.sizevars.statically_known_equals(stride[0], 1) - - def has_zero_dim(size) -> bool: - return bool( - V.graph.sizevars.statically_known_equals(size[0], 0) - or V.graph.sizevars.statically_known_equals(size[1], 0) - ) - - # Check mat_a (self) stride requirements - torch._check( - is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), - lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", - ) - - # Check mat_b stride requirements - torch._check( - is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), - lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", - ) - - aten_bias_addmm = ExternKernelChoice(bias_addmm, None) @@ -974,151 +746,6 @@ def tuned_sparse_semi_structured_mm( ) -add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) - - -@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] -def tuned_scaled_mm( - mat_a, - mat_b, - scale_a, - scale_b, - bias=None, - scale_result=None, - out_dtype=None, - use_fast_accum=False, - layout=None, -): - m, n, k, layout, mat_a, mat_b = mm_args( - mat_a, mat_b, layout=layout, out_dtype=out_dtype - ) - # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 - log.info( - "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", - m, - n, - k, - mat_a.get_dtype(), - mat_b.get_dtype(), - layout, - ) - - device_type = ir.get_device_type(mat_a) - check_supported_striding(mat_a, mat_b) - - scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) - - input_nodes: tuple[Any, ...] - - if not bias: - input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) - else: - bias_real = realize_inputs(bias) - input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) - - aten_choice = aten__fp8_mm.bind( - input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum - ) - - choices = [] - if use_aten_gemm_kernels(): - choices.append(aten_choice) - - _, is_nonzero = _is_static_problem(layout) - - scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) - scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( - device_type - ) - - if is_nonzero and use_triton_template(layout, enable_float8=True): - triton_input_nodes: tuple[Any, ...] - if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: - # Need to unsqueeze bias from [N] -> [1, N] - triton_bias = L[aten.unsqueeze](bias, 0) - else: - triton_bias = bias - - if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: - assert len(scale_a.get_size()) == len(scale_b.get_size()) - # Need to unsqueeze scale from [] -> [1, 1] - triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) - triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) - else: - triton_scale_a = scale_a - triton_scale_b = scale_b - - if bias: - triton_input_nodes = ( - mat_a, - mat_b, - triton_scale_a, - triton_scale_b, - triton_bias, - ) - suffix_args = 3 - else: - triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) - suffix_args = 2 - - # TODO (paulzhan): There is no template that exists for bias and TMA - # Don't run tma template currently if bias exists - if use_triton_tma_template(mat_a, mat_b) and not bias: - for config in scaled_persistent_mm_configs(m, n, k): - kwargs = scaled_mm_options( - config, - m, - n, - k, - layout, - scale_a, - scale_b, - use_fast_accum, - device_tma=True, - ) - scaled_mm_device_tma_template.maybe_append_choice( - choices, - input_nodes=triton_input_nodes, - layout=layout, - workspace_arg=get_tma_workspace_arg( - num_tma_descriptors=2, - device=mat_a.get_device(), - ), - **kwargs, - ) - - for config in scaled_mm_configs(m, n, k): - if k == 16 and config.kwargs["BLOCK_M"] >= 64: - continue # Triton crashes in this case - - # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid - # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape - if using_b200() and k < 32: - continue - - kwargs = scaled_mm_options( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - # possibly appends a TritonTemplateCaller to choices - mm_template.maybe_append_choice( - choices, - input_nodes=triton_input_nodes, - layout=layout, - **kwargs, - suffix_args=suffix_args, - epilogue_fn=scale_mm_epilogue(), - ) - - if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) - - if should_fallback_to_aten(choices): - return aten_choice.output_node() - - return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) - - @functools.lru_cache(None) def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 663e78dc199c..d990536c4362 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -76,7 +76,6 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): GROUP_M=8, EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, - USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=acc_type(layout.dtype), num_stages=config.num_stages, num_warps=config.num_warps, @@ -93,47 +92,6 @@ def persistent_mm_options(mat1, mat2): ) -def scaled_mm_options( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a, - scale_b, - use_fast_accum: bool, - device_tma: bool = False, -) -> dict[str, Any]: - def are_compatible_scales(size_a, size_b) -> bool: - # Same sized scales are compatable - if len(size_a) == len(size_b): - return True - - # Both need to be scalars or len(1) tensors - if len(size_a) <= 1 and len(size_b) <= 1: - return True - - return False - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - - mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) - - mm_template_options["ACC_TYPE"] = "tl.float32" - mm_template_options["USE_FAST_ACCUM"] = use_fast_accum - mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 - - if device_tma: - mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE - mm_template_options["NUM_SMS"] = get_num_sms() - - return mm_template_options - - def mm_args( mat1, mat2, @@ -196,34 +154,6 @@ def epilogue(acc, bias): return epilogue -def scale_mm_epilogue(): - """ - Create an epilogue function that applies scaling to matrix multiplication result - using the given scale factors. - - Args: - dtype: The data type of the output - scale_a: Scale factor for matrix A - scale_b: Scale factor for matrix B - - Returns: - Epilogue function that takes the accumulator and applies scaling - """ - - def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): - # The epilogue function receives the accumulator (result of mat1 @ mat2) - # and applies the scaling factors - # In the original scaled_mm, we use inverse scales, so we multiply by them - mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) - mul_acc = V.ops.mul(acc, mul_scales) - if bias is not None: - return V.ops.add(mul_acc, bias) - else: - return mul_acc - - return epilogue - - def _is_static_problem(layout: Layout) -> tuple[bool, bool]: """ Check if input tensors and output layout have static shapes and non-zero sizes. diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py new file mode 100644 index 000000000000..aa917e120168 --- /dev/null +++ b/torch/_inductor/kernel/mm_scaled.py @@ -0,0 +1,608 @@ +import functools +import logging +from collections.abc import Sequence +from typing import Any, Optional + +import sympy + +import torch +from torch._dynamo.utils import counters +from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate +from torch.utils._triton import has_triton_tma_device + +from ..config import triton as triton_config +from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox +from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering +from ..select_algorithm import ( + autotune_select_algorithm, + ExternKernelChoice, + realize_inputs, + TritonTemplate, +) +from ..utils import ( + get_num_sms, + get_tma_workspace_arg, + TMA_DESCRIPTOR_SIZE, + use_aten_gemm_kernels, + use_ck_gemm_template, + use_triton_template, +) +from ..virtualized import V +from .mm_common import ( + _is_static_problem, + mm_args, + mm_grid, + persistent_mm_grid, + should_fallback_to_aten, +) + + +log = logging.getLogger(__name__) +aten = torch.ops.aten + +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + + +scaled_mm_template = TritonTemplate( + name="scaled_mm", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if USE_FAST_ACCUM: + acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) + else: + acc += tl.dot(a, b, out_dtype=ACC_TYPE) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + if SCALING_ROWWISE: + inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) + inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) + inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] + acc *= inv_scale_row + else: + # for tensor-wise scaling, the scales are scalars + inv_a_scale = tl.load(A_inverse_scale) + inv_b_scale = tl.load(B_inverse_scale) + inv_scale = inv_a_scale * inv_b_scale + acc *= inv_scale + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +# Inductor does not allow optional tensor input arguments currently (pass None as an +# input node to template choices), but since for _scaled_mm there is only one such arg +# (bias), work around by having a second template when bias is provided. +scaled_mm_bias_template = TritonTemplate( + name="scaled_mm_bias", + grid=mm_grid, + source=r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # based on triton.ops.matmul + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + if USE_FAST_ACCUM: + acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) + else: + acc += tl.dot(a, b, out_dtype=ACC_TYPE) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + if SCALING_ROWWISE: + inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) + inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) + inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] + acc *= inv_scale_row + else: + # for tensor-wise scaling, the scales are scalars + inv_a_scale = tl.load(A_inverse_scale) + inv_b_scale = tl.load(B_inverse_scale) + inv_scale = inv_a_scale * inv_b_scale + acc *= inv_scale + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # bias + bias = tl.load(bias_ptr + rn, mask=rn < N) + acc += bias + + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask")}} +""", +) + + +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + + +def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + +def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: + def is_row_major(stride: Sequence[_IntLike]) -> bool: + return stride[1] == 1 + + def is_col_major(stride: Sequence[_IntLike]) -> bool: + return stride[0] == 1 + + def has_zero_dim(size: Sequence[_IntLike]) -> bool: + return bool(size[0] == 0 or size[1] == 0) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + +def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a: StorageBox, + scale_b: StorageBox, + use_fast_accum: bool, +) -> dict[str, Any]: + even_k_symbolic = ( + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ACC_TYPE="tl.float32", + USE_FAST_ACCUM=use_fast_accum, + num_stages=config.num_stages, + num_warps=config.num_warps, + # tensor-wise scaling if scalar scales + SCALING_ROWWISE=len(scale_a.get_size()) == 2, + TMA_SIZE=TMA_DESCRIPTOR_SIZE, + NUM_SMS=get_num_sms(), + **config.kwargs, + ) + + +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a: StorageBox, + scale_b: StorageBox, + use_fast_accum: bool, +) -> dict[str, Any]: + even_k_symbolic = ( + sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] + ) + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + return dict( + GROUP_M=8, + EVEN_K=even_k_symbolic, + ACC_TYPE="tl.float32", + USE_FAST_ACCUM=use_fast_accum, + num_stages=config.num_stages, + num_warps=config.num_warps, + # tensor-wise scaling if scalar scales + SCALING_ROWWISE=len(scale_a.get_size()) == 2, + **config.kwargs, + ) + + +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool: + available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul + # _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes + # When K is 16, BLOCK_K = 16 and is not valid + min_k = k >= 32 + return available and min_k and not has_bias + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a: TensorBox, + mat_b: TensorBox, + scale_a: TensorBox, + scale_b: TensorBox, + bias: Optional[TensorBox] = None, + scale_result: Optional[TensorBox] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, + layout: Optional[Layout] = None, +) -> TensorBox: + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = get_device_type(mat_a) + + check_supported_striding(mat_a, mat_b) + + scale_a, scale_b = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + # workaround for Inductor not supporting optional tensor input arguments + if bias is None: + input_nodes = (mat_a, mat_b, scale_a, scale_b) + triton_template = scaled_mm_template + else: + bias = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) + triton_template = scaled_mm_bias_template + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices: list[ChoiceCaller] = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + if use_persistent_tma(k, bias is not None): + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options_device_tma( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + input_nodes = (mat_a, mat_b, scale_a, scale_b) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + else: + for config in scaled_mm_configs(m, n, k): + if k == 16 and config.kwargs["BLOCK_M"] >= 64: + continue # Triton crashes in this case + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and k < 32: + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + triton_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + if should_fallback_to_aten(choices): + return aten_choice.output_node() + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index bca3f024d134..e93ed88bcbda 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1377,7 +1377,7 @@ def _is_tma_compatible(x: IRNode) -> bool: return False dtype = x.get_dtype() - if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): + if dtype not in (torch.float16, torch.bfloat16): return False layout = x.get_layout() @@ -1388,12 +1388,6 @@ def _is_tma_compatible(x: IRNode) -> bool: inner_dim = layout.size[1] if transposed: inner_dim = layout.size[0] - - if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( - inner_dim, 32 - ): - return False - inner_bytes = inner_dim * dtype.itemsize return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) From 15dbad2115bb21ccd5e9bd3dcfcbbb2aad763e17 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 28 Mar 2025 08:44:12 -0700 Subject: [PATCH 056/332] Update torch.compile issue template (#150192) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150192 Approved by: https://github.com/malfet ghstack dependencies: #149947 --- .github/ISSUE_TEMPLATE/pt2-bug-report.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/pt2-bug-report.yml b/.github/ISSUE_TEMPLATE/pt2-bug-report.yml index be22b1446b4e..2f8ab54a2337 100644 --- a/.github/ISSUE_TEMPLATE/pt2-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/pt2-bug-report.yml @@ -20,7 +20,7 @@ body: - Don't compare indices of max/min etc, because that avoids the above requirement - - If comparing eager and torch.compile at fp16/bf16, you should use fp32 as baseline + - When comparing eager and torch.compile, use a higher precision result as a baseline. `torch._dynamo.utils.same` with fp64_ref will handle this comparison. - Ensure rng state used to compare results is equivalent. Use `torch._inductor.config.fallback_random=True` and reset the torch rng seed between comparisons From 37ebb0b56a3af1a5e8083337b4d670fc70fe23a3 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Tue, 1 Apr 2025 03:24:34 +0000 Subject: [PATCH 057/332] [inductor] Fix inductor windows linker error (#150256) Fixes #149889 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150256 Approved by: https://github.com/anijain2305, https://github.com/eellison --- torch/_inductor/cpp_builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index a8f25056dd52..aeef51ae6cc0 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -838,8 +838,13 @@ def _get_python_related_args() -> tuple[list[str], list[str]]: python_include_dirs.append(python_include_path) if _IS_WINDOWS: - python_path = os.path.dirname(sys.executable) - python_lib_path = [os.path.join(python_path, "libs")] + python_lib_path = [ + str( + ( + Path(sysconfig.get_path("include", scheme="nt")).parent / "libs" + ).absolute() + ) + ] else: python_lib_path = [sysconfig.get_config_var("LIBDIR")] From 78300c82054e7c54dd67aa780fa45b594785a19e Mon Sep 17 00:00:00 2001 From: Ethan Wee Date: Tue, 1 Apr 2025 18:31:21 +0000 Subject: [PATCH 058/332] [ROCm] update test buffer fudge factor for hipblaslt (#150348) The default workspace for hipblaslt is larger than for cublas/cublaslt which requires a slight increase to the buffer needed. Forward-fix for #150227 that broke ROCm distributed tests but wasn't part of initial CI signal. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150348 Approved by: https://github.com/jeffdaily --- test/distributed/_composable/fsdp/test_fully_shard_memory.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index 340fe913c1eb..de6df77479c9 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -117,6 +117,9 @@ def _test_fully_shard_training_memory( # number is kept much smaller than the actual memory usage, which is on # the order of 100-200+ MB) buffer_mb = 16 + # The default workspace for hipblaslt is larger than for cublas/cublaslt + # which requires a slight increase to this buffer value. + buffer_mb = 16 if torch.version.cuda else 18 if reshard_after_forward: # 3x max unsharded block parameters (current all-gather + copy-out # and next all-gather), non-block parameters, and other From a37afd23facbbabf06ee01884b165bfdf0af1db4 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 1 Apr 2025 04:28:05 -0700 Subject: [PATCH 059/332] [custom_ops][perf] Move expensive pytree traversals of tensors to C++ (#148555) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (benchmark for 1 call) Before: ``` └─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py DO_BENCH mutate: 77.72445678710938 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json DO_BENCH no_mutate: 64.61143493652344 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json DO_BENCH direct_mutate: 11.682510375976562 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json DO_BENCH direct_no_mutate: 18.596649169921875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json ``` After: ``` └─ $ python ~/task_custom_ops_perf/test_custom_ops_perf_repro.py DO_BENCH mutate: 47.6837158203125 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/mutate.json DO_BENCH no_mutate: 31.709671020507812 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/no_mutate.json DO_BENCH direct_mutate: 10.967254638671875 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_mutate.json DO_BENCH direct_no_mutate: 10.728836059570312 us PROFILE:/home/ivankobzarev/task_custom_ops_perf/direct_no_mutate.json ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148555 Approved by: https://github.com/zou3519 --- test/test_custom_ops.py | 121 +++++++++++++++++++++++++++++++ torch/_C/__init__.pyi.in | 2 + torch/_library/autograd.py | 4 +- torch/_library/custom_ops.py | 5 +- torch/_library/utils.py | 25 +++++++ torch/csrc/autograd/init.cpp | 135 +++++++++++++++++++++++++++++++++++ 6 files changed, 287 insertions(+), 5 deletions(-) diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index c92edc279f55..a2691d5e1cbf 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3910,6 +3910,127 @@ def fvmap2(info, in_dims, x, y): self.assertTrue(called) self.assertEqual(result, x + y) + @skipIfTorchDynamo("Skip due to sys.refcount") + def test_any_requires_grad(self): + test_fn = torch._C._any_requires_grad + # Regression test on not leaking kwargs + t = torch.randn(2, 2) + t_refcount = sys.getrefcount(t) + test_fn(t, a=t) + self.assertEqual(sys.getrefcount(t), t_refcount) + + self.assertTrue( + test_fn( + torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True) + ) + ) + self.assertFalse(test_fn(torch.ones(1), torch.zeros(1))) + self.assertTrue( + test_fn( + [torch.zeros(1, requires_grad=True), torch.ones(1, requires_grad=True)] + ) + ) + # _C_any_requires_grad supports only List[Tensor] in args, not List[List[Tensor]] + self.assertFalse(test_fn([[torch.zeros(1, requires_grad=True)]], torch.ones(1))) + self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)])) + self.assertTrue(test_fn(torch.zeros(1), a=torch.ones(1, requires_grad=True))) + self.assertFalse(test_fn(torch.zeros(1), a=torch.ones(1))) + self.assertTrue( + test_fn([torch.zeros(1, requires_grad=True), torch.ones(1)], torch.zeros(1)) + ) + self.assertFalse(test_fn([torch.zeros(1), torch.ones(1)], torch.zeros(1))) + + @skipIfTorchDynamo("Skip due to sys.refcount") + def test_any_output_is_alias_to_input_or_output(self): + test_fn = torch._C._any_output_is_alias_to_input_or_output + # Regression test on not leaking kwargs + t = torch.randn(2, 2) + t_refcount = sys.getrefcount(t) + test_fn((t,), {"a": t}, ()) + assert sys.getrefcount(t) == t_refcount + + x = torch.randn(2, 2) + y = torch.randn(2, 2) + self.assertTrue( + test_fn( + (x,), + {}, + (x.t(),), + ) + ) + self.assertFalse(test_fn((x,), None, (2 * x,))) + self.assertTrue( + test_fn( + (), + {"a": x.view(-1)}, + (x,), + ) + ) + self.assertTrue( + test_fn( + (), + {"a": x.view(-1)}, + (x.t(),), + ) + ) + self.assertTrue(test_fn((y,), {}, (y[1:],))) + self.assertFalse( + test_fn( + (x,), + {"a": x}, + (), + ) + ) + self.assertFalse( + test_fn( + (torch.tensor([]),), + {}, + (torch.tensor([]),), + ) + ) + self.assertTrue( + test_fn( + ([x], x + 1), + {}, + (x.t(),), + ) + ) + self.assertTrue( + test_fn( + ([x], x + 1), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([x], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([x, 1], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([[x]], x), + {}, + ([x.t()], x + 1), + ) + ) + self.assertTrue( + test_fn( + ([[1, x], 2], 3), + {}, + ([x.t()], x + 1), + ) + ) + class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 09744f2b043d..893df5db74f8 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1358,6 +1358,8 @@ def _set_grad_enabled(enabled: _bool) -> None: ... def is_grad_enabled() -> _bool: ... def _set_fwd_grad_enabled(enabled: _bool) -> None: ... def _is_fwd_grad_enabled() -> _bool: ... +def _any_requires_grad(*args, **kwargs) -> _bool: ... +def _any_output_is_alias_to_input_or_output(*args, **kwargs) -> _bool: ... def is_inference_mode_enabled() -> _bool: ... @overload def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ... diff --git a/torch/_library/autograd.py b/torch/_library/autograd.py index 5c8c713b6e42..3f3e9295549b 100644 --- a/torch/_library/autograd.py +++ b/torch/_library/autograd.py @@ -105,9 +105,7 @@ def backward(ctx, *grads): # The dispatcher passes any keyword-only-args as kwargs and the # rest of the args (even if specified as kwargs) as args. def autograd_impl(keyset, *args, **keyword_only_args): - if _C.is_grad_enabled() and _pytree.tree_any_only( - Tensor, lambda x: x.requires_grad, args, not_list_of_tensor - ): + if _C.is_grad_enabled() and _C._any_requires_grad(*args): result = Generated.apply(*args, Metadata(keyset, keyword_only_args)) # type: ignore[attr-defined] else: result = forward_no_grad(*args, Metadata(keyset, keyword_only_args)) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 66aeccc58a0c..544bbbf61582 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -347,9 +347,10 @@ def get_module(): fn = self._backend_fns[device_type] return inspect.getmodule(fn) - utils.check_aliasing_constraint( + utils._c_check_aliasing_constraint( self._name, - utils.iter_tensors(args, kwargs), + args, + kwargs, result, get_module, ) diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 8348883cee30..908280ecf292 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -373,6 +373,31 @@ def check_aliasing_constraint(name, prev, result, get_module=lambda: "???"): storages.add(key) +def _c_check_aliasing_constraint(name, args, kwargs, result, get_module=lambda: "???"): + """ + custom operators' outputs must not have any aliases + This version uses C++ implementation for perf. + Only List container is supported. + Tensors in Lists with not only Tensors are checked. + """ + tuple_result = result + if not isinstance(result, tuple): + tuple_result = (result,) + if _C._any_output_is_alias_to_input_or_output(args, kwargs, tuple_result): + raise RuntimeError( + f"{name} (with implementation in {get_module()}): " + f"The output of this custom operator (1) must not " + f"also be an input to this custom operator and " + f"(2) may not alias any inputs to this custom operator " + f"or other returns. " + f"The most common way to trigger this error is if " + f"we have y = custom_op(x) and y and x are the same Tensor. " + f"Please instead return a clone of the offending output " + f"tensor(s) (e.g. return x.clone()) or refactor the custom " + f"operator to not return y." + ) + + class MutationChecker: """ Check if an operator mutated its arguments. diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 6eb3cdcdbdfc..b376c295b77a 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -955,6 +955,133 @@ static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { END_HANDLE_TH_ERRORS } +template +static bool visit( + PyObject* o, + const std::function& visit_tensor) { + if (THPVariable_Check(o)) { + auto t = THPVariable_Unpack(o); + if (visit_tensor(t)) { + return true; + } + } else if (PyList_Check(o)) { + // Check that this List is TensorList + if constexpr (skip_tensors_in_non_tensorlist) { + for (const auto i : c10::irange(PyList_GET_SIZE(o))) { + if (!THPVariable_Check(PyList_GET_ITEM(o, i))) { + return false; + } + } + } + for (const auto i : c10::irange(PyList_GET_SIZE(o))) { + if (visit( + PyList_GET_ITEM(o, i), visit_tensor)) { + return true; + }; + } + } + return false; +} + +// Visiting of tensors in args and kwargs, +// only List container is visited. +// skip_tensors_in_non_tensorlist will skip any List with non-Tensor. +// Lambda returning true means short circuit, traversal stops after that. +template +static void visit_tensors( + PyObject* args, + PyObject* kwargs, + const std::function& visit_tensor) { + if (args && PyTuple_Check(args)) { + for (const auto i : c10::irange(PyTuple_GET_SIZE(args))) { + if (visit( + PyTuple_GET_ITEM(args, i), visit_tensor)) { + return; + } + } + } + if (kwargs && PyDict_Check(kwargs)) { + auto vals = THPObjectPtr{PyDict_Values(kwargs)}; + for (const auto i : c10::irange(PyList_Size(vals))) { + if (visit( + PyList_GetItem(vals, i), visit_tensor)) { + return; + } + } + } +} + +// Returns true if any of the args, kwargs tensor leaves have requires_grad. +// Only List[Tensor] container in args is supported. +static PyObject* any_requires_grad( + PyObject* _unused, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + bool has_requires_grad = false; + visit_tensors(args, kwargs, [&has_requires_grad](at::Tensor& t) { + if (t.requires_grad()) { + has_requires_grad = true; + return true; + } + return false; + }); + if (has_requires_grad) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS +} + +// Checks aliasing constraint for custom ops: +// Returns true if any of outputs is alias to any of inputs or another output +// Args: +// args[0] - inputs args +// args[1] - inputs kwargs +// args[2] - outputs +// Only List container is supported. +// Tensors in Lists that has not only Tensor are checked. +static PyObject* any_output_is_alias_to_input_or_output( + PyObject* _unused, + PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inps = PyTuple_GET_ITEM(args, 0); + PyObject* inps_kwargs = PyTuple_GET_ITEM(args, 1); + PyObject* outs = PyTuple_GET_ITEM(args, 2); + std::unordered_set s; + visit_tensors(inps, inps_kwargs, [&s](at::Tensor& t) { + if (!t.storage()) { + return false; + } + auto* cp = t.storage().data_ptr().get_context(); + if (cp) { + s.insert(cp); + } + return false; + }); + bool ret = false; + visit_tensors(outs, nullptr, [&s, &ret](at::Tensor& t) { + if (!t.storage()) { + return false; + } + auto* cp = t.storage().data_ptr().get_context(); + if (!cp) { + return false; + } + if (s.find(cp) != s.end()) { + ret = true; + return true; + } + s.insert(cp); + return false; + }); + if (ret) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; + END_HANDLE_TH_ERRORS +} + static PyObject* set_multithreading_enabled( PyObject* self, PyObject* args, @@ -1326,6 +1453,14 @@ static PyMethodDef methods[] = { nullptr}, {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr}, + {"_any_requires_grad", + castPyCFunctionWithKeywords(any_requires_grad), + METH_VARARGS | METH_KEYWORDS, + nullptr}, + {"_any_output_is_alias_to_input_or_output", + any_output_is_alias_to_input_or_output, + METH_VARARGS, + nullptr}, {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr}, {"is_inference_mode_enabled", is_inference_mode_enabled, From 5d6ac2dcedef1249385172841f22e29a9598285c Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 1 Apr 2025 19:15:25 +0000 Subject: [PATCH 060/332] [dtensor] add op support for select_backward and slice_backward (#150357) Inheriting and rebasing @awgu 's PR https://github.com/pytorch/pytorch/pull/149071 - fixed an issue for `select_backward` and an issue for `slice_backward` - removed `_experimental_ops.py` as it becomes empty Pull Request resolved: https://github.com/pytorch/pytorch/pull/150357 Approved by: https://github.com/awgu, https://github.com/XilunWu --- test/distributed/tensor/test_tensor_ops.py | 4 +- torch/distributed/tensor/_ops/__init__.py | 1 - .../tensor/_ops/_experimental_ops.py | 27 -------- torch/distributed/tensor/_ops/_tensor_ops.py | 62 ++++++++++++++++++- torch/distributed/tensor/_sharding_prop.py | 2 + 5 files changed, 65 insertions(+), 31 deletions(-) delete mode 100644 torch/distributed/tensor/_ops/_experimental_ops.py diff --git a/test/distributed/tensor/test_tensor_ops.py b/test/distributed/tensor/test_tensor_ops.py index 6d970c379065..ddaee7ab2405 100644 --- a/test/distributed/tensor/test_tensor_ops.py +++ b/test/distributed/tensor/test_tensor_ops.py @@ -649,8 +649,8 @@ def test_slice(self): global_out.backward(gradient=torch.ones_like(global_out)) with comm_mode: - sharded_out_grad = torch.distributed._tensor.ones( - sharded_out.shape, device_mesh=mesh, placements=[Shard(1)] + sharded_out_grad = torch.distributed.tensor.ones( + sharded_out.shape, device_mesh=mesh, placements=shard_spec ) sharded_out.backward(gradient=sharded_out_grad) diff --git a/torch/distributed/tensor/_ops/__init__.py b/torch/distributed/tensor/_ops/__init__.py index dec4665b1c8b..7cfaa668a183 100644 --- a/torch/distributed/tensor/_ops/__init__.py +++ b/torch/distributed/tensor/_ops/__init__.py @@ -1,7 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates from ._conv_ops import * # noqa: F403 from ._embedding_ops import * # noqa: F403 -from ._experimental_ops import * # noqa: F403 from ._math_ops import * # noqa: F403 from ._matrix_ops import * # noqa: F403 from ._pointwise_ops import * # noqa: F403 diff --git a/torch/distributed/tensor/_ops/_experimental_ops.py b/torch/distributed/tensor/_ops/_experimental_ops.py deleted file mode 100644 index 59e907dc5ba1..000000000000 --- a/torch/distributed/tensor/_ops/_experimental_ops.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor - -import torch -from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementStrategy, - StrategyType, -) -from torch.distributed.tensor._ops.utils import register_op_strategy -from torch.distributed.tensor.placement_types import Replicate - - -aten = torch.ops.aten - - -@register_op_strategy(aten.slice_backward.default) -def slice_backward_rules(op_schema: OpSchema) -> StrategyType: - """ - slice_backward is a new_zeros + slice_scatter, we only allow replication - on the input/output for now since new_zeros would produce replication - """ - mesh = op_schema.get_mesh_from_args(validate=False) - replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) - return OpStrategy([PlacementStrategy(replicate_spec)]) diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index d100aaea4ad7..9b73f36d855f 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -20,6 +20,7 @@ from torch.distributed.tensor._ops._embedding_ops import _MaskPartial from torch.distributed.tensor._ops.utils import ( expand_to_full_mesh_op_strategy, + generate_redistribute_costs, is_tensor_dim_sharded, is_tensor_evenly_shardable, is_tensor_partial, @@ -237,7 +238,7 @@ def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: @register_op_strategy(aten.select.int, schema_info=RuntimeSchemaInfo(1)) -def gen_select_strategy(op_schema: OpSchema) -> StrategyType: +def select_int_strategy(op_schema: OpSchema) -> StrategyType: """ In this select op, first determine the input specs, then determine the output specs. - Input specs: @@ -299,6 +300,38 @@ def gen_select_strategy(op_schema: OpSchema) -> StrategyType: return select_strategy +@register_op_strategy( + aten.select_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def select_backward_strategy(op_schema: OpSchema) -> OpStrategy: + # func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + assert isinstance(dim, int) + output_strategies: list[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + input_spec = placement_strategy.output_spec + output_spec_placements: list[Placement] = [] + for placement in input_spec.placements: + if isinstance(placement, Shard): + shard_dim = placement.dim + if shard_dim >= dim: + # NOTE: shard_dim is guaranteed to exist because + # grad_input has one more dim than grad_output + output_spec_placements.append(Shard(shard_dim + 1)) + else: + output_spec_placements.append(Shard(shard_dim)) + else: + output_spec_placements.append(placement) + output_specs = DTensorSpec(input_spec.mesh, tuple(output_spec_placements)) + output_strategies.append( + PlacementStrategy(output_specs=output_specs, input_specs=(input_spec,)) + ) + return OpStrategy(output_strategies) + + @register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: """Forward all shardings except the slice dimension.""" @@ -349,6 +382,33 @@ def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: return slice_strategy +@register_op_strategy( + aten.slice_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: + # func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + output_strategies: list[PlacementStrategy] = [] + for placement_strategy in input_strategy.strategies: + output_spec = placement_strategy.output_spec + new_placements: list[Placement] = [] + for placement in output_spec.placements: + # Redistribute to replicate only if the dim is sharded and matches the slice dim + if isinstance(placement, Shard) and placement.dim == dim: + new_placements.append(Replicate()) + else: + new_placements.append(placement) + new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements)) + redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)] + placement_strategy.redistribute_cost = redistribute_cost + new_strategy = PlacementStrategy(output_specs=new_spec) + output_strategies.append(new_strategy) + return OpStrategy(output_strategies) + + def unshard_tensor_dim( placements: Sequence[Placement], dim: int ) -> tuple[Placement, ...]: diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index c5bb22a92b7d..0e186da56152 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -77,6 +77,8 @@ def __init__(self) -> None: aten.reshape.default: 1, aten.view.default: 1, aten._unsafe_view.default: 1, + aten.select_backward.default: 1, + aten.slice_backward.default: 1, } def register_sharding_prop_rule( From d2ad9aa2f296c83f4c531d486126c6cf8c49720f Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 1 Apr 2025 19:15:40 +0000 Subject: [PATCH 061/332] [dtensor][tp] add a ParallelStyle PrepareModuleInputOutput (#150372) Needed this class for because `parallelize_module` takes a dict, which doesn't allow `PrepareModuleInput` and `PrepareModuleOutput` to be applied at the same time. The `PrepareModuleInputOutput` in this PR initializes two variables `prepare_module_input` and `prepare_module_output` and uses them to process module / inputs / outputs. I had another implementation which put all code in `PrepareModuleInputOutput` and let `PrepareModuleInput` and `PrepareModuleOutput` inherit the monolithic `PrepareModuleInputOutput`. But it is 1. less cleaner 2. conceptually abusing inheritance because `PrepareModuleInput` shouldn't be able to access class methods of `PrepareModuleOutput` and vice versa Pull Request resolved: https://github.com/pytorch/pytorch/pull/150372 Approved by: https://github.com/wanchaol --- docs/source/distributed.tensor.parallel.rst | 4 + .../tensor/parallel/test_parallelize_api.py | 24 ++++ torch/distributed/tensor/parallel/__init__.py | 2 + torch/distributed/tensor/parallel/style.py | 112 ++++++++++++++++++ 4 files changed, 142 insertions(+) diff --git a/docs/source/distributed.tensor.parallel.rst b/docs/source/distributed.tensor.parallel.rst index 694212296e35..75cedd809fdc 100644 --- a/docs/source/distributed.tensor.parallel.rst +++ b/docs/source/distributed.tensor.parallel.rst @@ -46,6 +46,10 @@ the ``parallelize_plan`` when calling ``parallelize_module``: :members: :undoc-members: +.. autoclass:: torch.distributed.tensor.parallel.PrepareModuleInputOutput + :members: + :undoc-members: + .. note:: when using the ``Shard(dim)`` as the input/output layouts for the above ``ParallelStyle`` s, we assume the input/output activation tensors are evenly sharded on the tensor dimension ``dim`` on the ``DeviceMesh`` that TP operates on. For instance, diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 18128366c8db..ae94d8c3ec68 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -9,6 +9,7 @@ from torch.distributed.tensor.parallel.style import ( ColwiseParallel, PrepareModuleInput, + PrepareModuleInputOutput, PrepareModuleOutput, RowwiseParallel, ) @@ -201,6 +202,29 @@ def test_prepare_module_output(self): inp = dtensor.redistribute(device_mesh, [Shard(0)]).to_local() self.assertEqual(inp, output) + @with_comms + def test_prepare_module_input_output(self): + module = DummyModule() + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + parallelize_module( + module, + device_mesh, + PrepareModuleInputOutput( + input_layouts=Shard(0), + desired_input_layouts=Replicate(), + output_layouts=Replicate(), + desired_output_layouts=Shard(1), + ), + ) + inp = torch.rand(5, 7, device=self.device_type) + output = module(inp) + inp = ( + DTensor.from_local(inp, device_mesh, [Shard(0)], run_check=False) + .redistribute(device_mesh, [Shard(1)]) + .to_local() + ) + self.assertEqual(inp, output) + @with_comms def test_parallelize_module_with_star(self): inp_size = [12, 10] diff --git a/torch/distributed/tensor/parallel/__init__.py b/torch/distributed/tensor/parallel/__init__.py index 9fe378c51b0d..5e4881de4387 100644 --- a/torch/distributed/tensor/parallel/__init__.py +++ b/torch/distributed/tensor/parallel/__init__.py @@ -5,6 +5,7 @@ ColwiseParallel, ParallelStyle, PrepareModuleInput, + PrepareModuleInputOutput, PrepareModuleOutput, RowwiseParallel, SequenceParallel, @@ -15,6 +16,7 @@ "ColwiseParallel", "ParallelStyle", "PrepareModuleInput", + "PrepareModuleInputOutput", "PrepareModuleOutput", "RowwiseParallel", "SequenceParallel", diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py index e5ce3371ff96..3580a924d183 100644 --- a/torch/distributed/tensor/parallel/style.py +++ b/torch/distributed/tensor/parallel/style.py @@ -23,6 +23,7 @@ "SequenceParallel", "ColwiseParallel", "PrepareModuleInput", + "PrepareModuleInputOutput", "PrepareModuleOutput", ] @@ -698,3 +699,114 @@ def __repr__(self) -> str: tmpstr += f"use_local_output={self.use_local_output}" tmpstr += ")" return tmpstr + + +class PrepareModuleInputOutput(ParallelStyle): + """ + Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module + to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution + according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of + :class:`PrepareModuleInput` and :class:`PrepareModuleOutput`. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_input (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor + >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated + >>> # as Replicated DTensor and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInputOutput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...), + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0), + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, tuple[Optional[Placement]]] + ] = None, + input_kwarg_layouts: Optional[dict[str, Placement]] = None, + desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + use_local_input: bool = False, + output_layouts: Union[Placement, tuple[Placement]], + desired_output_layouts: Union[Placement, tuple[Placement]], + use_local_output: bool = True, + ): + self.prepare_module_input = PrepareModuleInput( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_input, + ) + self.prepare_module_output = PrepareModuleOutput( + output_layouts=output_layouts, + desired_output_layouts=desired_output_layouts, + use_local_output=use_local_output, + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + self.prepare_module_input._apply(module, device_mesh) + self.prepare_module_output._apply(module, device_mesh) + + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, " + tmpstr += ( + f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, " + ) + tmpstr += ( + f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, " + ) + tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, " + tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, " + tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, " + tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}" + tmpstr += ")" + return tmpstr From 295162ec3abb2a58ab6a54af15196ea6fadf4852 Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 1 Apr 2025 19:18:44 +0000 Subject: [PATCH 062/332] Smoke Test - disable pypi package validation for binaries that package cuda libs (#150194) Smoke Test - disable pypi package validation for binaries that package cuda libs. These binaries do not install packages via pypi. Should Resolve this from `linux-binary-manywheel / manywheel-py3_11-cuda12_6-full-test / test`: ``` Traceback (most recent call last): File "/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 468, in main() File "/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 462, in main smoke_test_cuda( File "/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 274, in smoke_test_cuda compare_pypi_to_torch_versions( File "/pytorch/.ci/pytorch/smoke_test/smoke_test.py", line 220, in compare_pypi_to_torch_versions raise RuntimeError(f"Can't find {package} in PyPI for Torch: {torch_version}") RuntimeError: Can't find cudnn in PyPI for Torch: 9.5.1 ``` Link: https://github.com/pytorch/pytorch/actions/runs/14101221665/job/39505479587#step:15:982 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150194 Approved by: https://github.com/ZainRizvi --- .ci/pytorch/smoke_test/smoke_test.py | 22 ++++++++++++++++++---- .circleci/scripts/binary_linux_test.sh | 13 +++++++++++-- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index c4f41a874774..acc69e36a5a5 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -227,7 +227,10 @@ def compare_pypi_to_torch_versions( def smoke_test_cuda( - package: str, runtime_error_check: str, torch_compile_check: str + package: str, + runtime_error_check: str, + torch_compile_check: str, + pypi_pkg_check: str, ) -> None: if not torch.cuda.is_available() and is_cuda_system: raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.") @@ -268,13 +271,14 @@ def smoke_test_cuda( print(f"cuDNN enabled? {torch.backends.cudnn.enabled}") torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") + torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) + print(f"Torch nccl; version: {torch_nccl_version}") # Pypi dependencies are installed on linux ony and nccl is availbale only on Linux. - if sys.platform in ["linux", "linux2"]: + if pypi_pkg_check == "enabled" and sys.platform in ["linux", "linux2"]: compare_pypi_to_torch_versions( "cudnn", find_pypi_package_version("nvidia-cudnn"), torch_cudnn_version ) - torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) compare_pypi_to_torch_versions( "nccl", find_pypi_package_version("nvidia-nccl"), torch_nccl_version ) @@ -436,6 +440,13 @@ def parse_args(): choices=["enabled", "disabled"], default="enabled", ) + parser.add_argument( + "--pypi-pkg-check", + help="Check pypi package versions cudnn and nccl", + type=str, + choices=["enabled", "disabled"], + default="enabled", + ) return parser.parse_args() @@ -460,7 +471,10 @@ def main() -> None: smoke_test_modules() smoke_test_cuda( - options.package, options.runtime_error_check, options.torch_compile_check + options.package, + options.runtime_error_check, + options.torch_compile_check, + options.pypi_pkg_check, ) diff --git a/.circleci/scripts/binary_linux_test.sh b/.circleci/scripts/binary_linux_test.sh index 3ee84f46d8fa..051b4f16f27a 100755 --- a/.circleci/scripts/binary_linux_test.sh +++ b/.circleci/scripts/binary_linux_test.sh @@ -90,8 +90,17 @@ fi /pytorch/.ci/pytorch/check_binary.sh if [[ "\$GPU_ARCH_TYPE" != *s390x* && "\$GPU_ARCH_TYPE" != *xpu* && "\$GPU_ARCH_TYPE" != *rocm* && "$PACKAGE_TYPE" != libtorch ]]; then - # Exclude s390, xpu, rocm and libtorch builds from smoke testing - python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled + + torch_pkg_size="$(ls -1 /final_pkgs/torch-* | sort |tail -1 |xargs wc -c |cut -d ' ' -f1)" + # todo: implement check for large binaries + # if the package is larger than 1.5GB, we disable the pypi check. + # this package contains all libraries packaged in torch libs folder + # example of such package is https://download.pytorch.org/whl/cu126_full/torch + if [[ "\$torch_pkg_size" -gt 1500000000 ]]; then + python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled --pypi-pkg-check disabled + else + python /pytorch/.ci/pytorch/smoke_test/smoke_test.py --package=torchonly --torch-compile-check disabled $extra_parameters + fi fi # Clean temp files From 99fd96c10b373dca37b7abd22739395f2428d505 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 31 Mar 2025 08:15:18 -0700 Subject: [PATCH 063/332] [Hierarchical Compile] Remove spammy debug log (#150303) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150303 Approved by: https://github.com/williamwen42 --- torch/_dynamo/graph_region_tracker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 1be528a7ed72..824d600c63ba 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -337,7 +337,6 @@ def fully_expand_region_group( debug_log("--------------------") debug_log("considering adding: %s, cur_node: %s", node, current_node) debug_log("previously claimed nodes: %s", node in seen_nodes) - debug_log("%s", seen_nodes) if node: debug_log("is_identical: %s", is_identical_fn(node, current_node)) add_node &= ( From a2300aff94981e055c2a37bda599f375fad95665 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 31 Mar 2025 08:15:22 -0700 Subject: [PATCH 064/332] [Hierarchical Compile] Add cycle detection function for debug (#150304) Remove print Pull Request resolved: https://github.com/pytorch/pytorch/pull/150304 Approved by: https://github.com/anijain2305 ghstack dependencies: #150303 --- test/dynamo/test_graph_deduplication.py | 71 +++++++++++++++++++++++++ torch/_dynamo/graph_utils.py | 43 +++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 torch/_dynamo/graph_utils.py diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 805d8f6be2d0..8cbf11e65f30 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -3,6 +3,7 @@ import torch import torch.fx from torch._dynamo.graph_deduplication import _flatten_args_kwargs +from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.test_case import TestCase from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm @@ -585,6 +586,76 @@ def test_flatten_with_slices(self): str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]""" ) + def test_cycle_detection_no_cycle(self): + def fn(x, y): + x0 = x + 1 + y0 = y + 2 + z = x0.sum() + y0.sum() + return z + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + + _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) + mod = fw_graphs[0] + self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""") + + def test_cycle_detection_simple(self): + def fn(x, y): + x0 = x + 1 + y0 = y + 2 + z = x0.sum() + y0.sum() + return z + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + + _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) + mod = fw_graphs[0] + add_node = next(n for n in mod.graph.nodes if n.name == "add") + add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") + args = add_node.args + add_node.args = (args[0], add_2) + self.assertExpectedInline( + _detect_cycles(mod.graph), + """cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""", + ) + + def test_cycle_detection_complex(self): + def inner_fn(x, y): + x0 = x.view(x.size()) + return x0.view(x.size()) + + def inner_fn2(x, y): + x = x * 2 + y = y * 2 + return x.sum() + y.sum() + + def fn(x, y): + o0 = inner_fn(x, y) + o1 = inner_fn(x, y) + o2 = inner_fn2(x, y) + o3 = inner_fn2(x, y) + return o0 + o1 + o2.sum() + o3.sum() + + x = torch.rand(10, 10, requires_grad=False) + y = torch.rand(10, 20, requires_grad=False) + x_clone = x.clone() + y_clone = y.clone() + + _, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) + mod = fw_graphs[0] + invoke_subgraph_node = next( + n for n in mod.graph.nodes if n.name == "invoke_subgraph" + ) + add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") + args = invoke_subgraph_node.args + invoke_subgraph_node.args = (add_2, args[1]) + self.assertExpectedInline( + _detect_cycles(mod.graph), + """cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""", + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/graph_utils.py b/torch/_dynamo/graph_utils.py new file mode 100644 index 000000000000..6113233df69c --- /dev/null +++ b/torch/_dynamo/graph_utils.py @@ -0,0 +1,43 @@ +from collections import deque + +from torch.fx import Graph, Node + + +def _detect_cycles(graph: Graph) -> str: + current_path: deque[Node] = deque() + current_path_set: set[Node] = set() + pending: deque[tuple[Node, Node]] = deque() + + def add_to_current_path(node: Node) -> None: + current_path.append(node) + current_path_set.add(node) + + def pop_current_path() -> None: + node = current_path.pop() + current_path_set.remove(node) + + def current_path_head() -> Node: + return current_path[-1] + + for origin in graph.find_nodes(op="placeholder"): + current_path.clear() + current_path_set.clear() + add_to_current_path(origin) + for child in origin.users: + pending.append((child, origin)) + + while pending: + cur_node, parent = pending.pop() + + while current_path_head() != parent: + pop_current_path() + + if cur_node in current_path_set: + current_path.append(cur_node) + return f"cycle detected in path: {current_path}" + + add_to_current_path(cur_node) + for child in cur_node.users: + pending.append((child, cur_node)) + + return "no cycle detected" From 8740ffa760a098363863de18e3adff0f0d819fbd Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 31 Mar 2025 08:15:26 -0700 Subject: [PATCH 065/332] [Hierarchical Compile] Add cycle detection to graph region expansion (#150305) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150305 Approved by: https://github.com/anijain2305 ghstack dependencies: #150303, #150304 --- torch/_dynamo/graph_deduplication.py | 26 +------------ torch/_dynamo/graph_region_tracker.py | 54 +++++++++++++++++++++++++-- torch/_dynamo/graph_utils.py | 26 +++++++++++++ 3 files changed, 78 insertions(+), 28 deletions(-) diff --git a/torch/_dynamo/graph_deduplication.py b/torch/_dynamo/graph_deduplication.py index c9ee689e3da5..b1140788cf18 100644 --- a/torch/_dynamo/graph_deduplication.py +++ b/torch/_dynamo/graph_deduplication.py @@ -15,9 +15,9 @@ import torch.fx from torch._dynamo import config from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation -from torch.utils._pytree import tree_flatten from .graph_region_tracker import Node, Region +from .graph_utils import _flatten_args_kwargs log = logging.getLogger(__name__) @@ -87,30 +87,6 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore return output_replacements -# flattens with support for slices -# Note: a better way to do this would -# be register/unregister slices as pytree nodes -# but there is no unregister API in the pytorch -# pytree impl -def _flatten_args_kwargs(args: Any) -> list[Node]: - fully_flattened = [] - - def flatten(args: Any) -> None: - flattened, _ = tree_flatten(args) - for arg in flattened: - if isinstance(arg, slice): - start = arg.start - stop = arg.stop - step = arg.step - flatten((start, stop, step)) - else: - fully_flattened.append(arg) - - flatten(args) - - return fully_flattened - - def _replace_region_with_subgraph( graph: torch.fx.Graph, region: Region, diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 824d600c63ba..272eeff54f44 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -27,6 +27,8 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.utils._pytree import tree_flatten +from .graph_utils import _flatten_args_kwargs + T = TypeVar("T") @@ -253,6 +255,8 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: """ topological_ranking = {node: i for i, node in enumerate(graph.nodes)} region_groups_with_rank = [] + # needed to detect if replacing a region will create cycles + node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph) # Create region groups; a region group is a group # of regions that are all identical. In this initial state @@ -281,7 +285,12 @@ def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]: # overlap. seen_nodes: set[Node] = set() for region_group in region_groups: - fully_expand_region_group(region_group, seen_nodes, self._is_identical) + fully_expand_region_group( + region_group, + seen_nodes, + node_to_recursive_ancestors, + self._is_identical, + ) # sort topologically for region in region_group: region.sort(key=lambda n: topological_ranking[n]) @@ -297,6 +306,7 @@ def __str__(self) -> str: def fully_expand_region_group( regions: list[Region], seen_nodes: set[Node], + node_to_recursive_ancestors: dict[Node, set[Node]], is_identical_fn: Callable[[Node, Node], bool], ) -> None: debug_log("--------------------------------------------------") @@ -327,11 +337,14 @@ def fully_expand_region_group( # regions are only expanded if the node to add is valid # for ALL regions while current_node: - add_node = True + add_node = not _will_create_cycle( + current_node, regions[0], node_to_recursive_ancestors + ) nodes_to_add.clear() nodes_to_add.append(current_node) nodes_to_add_set = set(nodes_to_add) - for region_it in region_iters[1:]: + for ind, region_it in enumerate(region_iters[1:]): + ind += 1 # compensate for the 0th region node = region_it.next() debug_log("--------------------") @@ -344,6 +357,9 @@ def fully_expand_region_group( and node not in nodes_to_add_set and node.op != "placeholder" and is_identical_fn(node, current_node) + and not _will_create_cycle( + node, regions[ind], node_to_recursive_ancestors + ) ) nodes_to_add.append(node) nodes_to_add_set.add(node) @@ -368,3 +384,35 @@ def fully_expand_region_group( debug_log("end expand new region group: %s", regions) debug_log("--------------------------------------------------") + + +def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[Node]]: + node_to_recursive_ancestors: dict[Node, set[Node]] = {} + for node in graph.nodes: + node_to_recursive_ancestors[node] = set() + for node in graph.nodes: + all_args = _flatten_args_kwargs((node.args, node.kwargs)) + for arg in all_args: + if isinstance(arg, Node): + node_to_recursive_ancestors[node].update( + node_to_recursive_ancestors[arg] + ) + node_to_recursive_ancestors[node].add(node) + return node_to_recursive_ancestors + + +def _will_create_cycle( + node_to_add: Node, + region: Region, + node_to_recursive_ancestors: dict[Node, set[Node]], +) -> bool: + region_set: set[Node] = set(region) + region_ancestors: set[Node] = set( + tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0] + ) + external_users = [user for user in node_to_add.users if user not in region_set] + for user in external_users: + if user in region_ancestors: + return True + + return False diff --git a/torch/_dynamo/graph_utils.py b/torch/_dynamo/graph_utils.py index 6113233df69c..cde627f244e8 100644 --- a/torch/_dynamo/graph_utils.py +++ b/torch/_dynamo/graph_utils.py @@ -1,6 +1,32 @@ from collections import deque +from typing import Any from torch.fx import Graph, Node +from torch.utils._pytree import tree_flatten + + +# flattens with support for slices +# Note: a better way to do this would +# be register/unregister slices as pytree nodes +# but there is no unregister API in the pytorch +# pytree impl +def _flatten_args_kwargs(args: Any) -> list[Node]: + fully_flattened = [] + + def flatten(args: Any) -> None: + flattened, _ = tree_flatten(args) + for arg in flattened: + if isinstance(arg, slice): + start = arg.start + stop = arg.stop + step = arg.step + flatten((start, stop, step)) + else: + fully_flattened.append(arg) + + flatten(args) + + return fully_flattened def _detect_cycles(graph: Graph) -> str: From 0d44a8aea1e97e80a446a0a27193bba4546433f8 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Mon, 31 Mar 2025 08:15:30 -0700 Subject: [PATCH 066/332] [Hierarchical Compile] Apply deduplication after output node creation (#150306) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150306 Approved by: https://github.com/anijain2305 ghstack dependencies: #150303, #150304, #150305 --- test/dynamo/test_graph_deduplication.py | 91 ++++++++++++------------- torch/_dynamo/graph_deduplication.py | 60 +++++++++------- torch/_dynamo/output_graph.py | 23 +++---- 3 files changed, 86 insertions(+), 88 deletions(-) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 8cbf11e65f30..99ed2f5a8dd2 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -60,18 +60,15 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None o1: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, o1)); o1 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None @@ -266,31 +263,27 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): y0: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \ -(x0, y0)); invoke_subgraph_3 = None - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, l_y_)) + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)) getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None o1: "f32[]" = torch.sin(getitem); getitem = None - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, y0)) + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0)) getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \ -(x0, y0)); subgraph_1 = x0 = y0 = None - - getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None - mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None + + getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None return (add_13,) @@ -329,27 +322,29 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): ___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph = None - getitem_1: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None + getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None - sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1) + sin_1: "f32[]" = torch.ops.aten.sin.default(getitem) ___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_post_graph_1 = None - getitem_2: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None - ___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph - invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None - getitem_19: "f32[]" = invoke_subgraph_11[3] - getitem_18: "f32[10, 20]" = invoke_subgraph_11[2] - getitem_17: "f32[10, 10]" = invoke_subgraph_11[1] - getitem_3: "f32[10, 10]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None + getitem_1: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None + + mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None + ___forward_subgraph_0_post_graph_2 = self.___forward_subgraph_0_post_graph - invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None - getitem_4: "f32[]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None + invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None + getitem_2: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None + ___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph + invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None + getitem_19: "f32[]" = invoke_subgraph_12[3] + getitem_18: "f32[10, 20]" = invoke_subgraph_12[2] + getitem_17: "f32[10, 10]" = invoke_subgraph_12[1] + getitem_4: "f32[10, 10]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None - mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None - mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_3); mul = None - add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_4); mul_1 = getitem_4 = None - return (add, primals_1, primals_2, getitem_1, getitem_2, getitem_19, getitem_18, getitem_17, getitem_3) + mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_4); mul = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_2); mul_1 = getitem_2 = None + return (add, primals_1, primals_2, getitem, getitem_1, getitem_19, getitem_18, getitem_17, getitem_4) class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): @@ -476,12 +471,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None - repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (add_2, add_3)); repeated_subgraph0 = None - getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None - - clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None + clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2) clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3) add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1) @@ -492,9 +482,11 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None + getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None @@ -552,18 +544,19 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None + repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None - repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None - getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None + + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None + getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) diff --git a/torch/_dynamo/graph_deduplication.py b/torch/_dynamo/graph_deduplication.py index b1140788cf18..3a3f7e65491a 100644 --- a/torch/_dynamo/graph_deduplication.py +++ b/torch/_dynamo/graph_deduplication.py @@ -12,18 +12,19 @@ from collections.abc import Iterable from typing import Any +import torch import torch.fx from torch._dynamo import config from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation from .graph_region_tracker import Node, Region -from .graph_utils import _flatten_args_kwargs +from .graph_utils import _detect_cycles, _flatten_args_kwargs log = logging.getLogger(__name__) -def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore[no-untyped-def] +def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def] """ This is the main entry point for applying the graph deduplication pass. \ Deduplication occurs in two phases: @@ -50,15 +51,14 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore Returns a mapping of nodes to their subgraph output replacement node to remap outputs when they are created in output_graph. """ + from torch._inductor.pattern_matcher import stable_topological_sort + duplicated_region_groups = output_graph.region_tracker.get_identical_regions( output_graph.graph ) - # Used to track which nodes were replaced with subgraph outputs - # today, we have to register the new subgraph submodules before the - # graph outputs have been created, so we pass the replacement mapping - # back to output graph to do the replacements at the site of output creation - output_replacements: dict[Node, Node] = {} + sub_gms: dict[str, torch.fx.GraphModule] = {} + for region_group in duplicated_region_groups: inds_with_external_users = _get_all_output_indices(region_group) region = region_group[0] @@ -66,8 +66,14 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore subgraph, node_ind_arg_inds, ) = _create_subgraph(region, inds_with_external_users) + + # Ignore regions with no args for now, could they possibly be evaluated at compile time? + if not list(node_ind_arg_inds): + continue + sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph) subgraph_name = output_graph.install_subgraph("subgraph", sub_gm) + sub_gms[subgraph_name] = sub_gm with output_graph.graph.inserting_before(): get_subgraph_node = output_graph.graph.create_node( "get_attr", subgraph_name, (), {} @@ -81,10 +87,10 @@ def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore inds_with_external_users, sub_gm, subgraph_name, - output_replacements, ) - return output_replacements + stable_topological_sort(output_graph.graph) + return sub_gms def _replace_region_with_subgraph( @@ -95,7 +101,6 @@ def _replace_region_with_subgraph( inds_with_external_users: list[int], sub_gm: torch.fx.GraphModule, subgraph_name: str, - output_replacements: dict[Node, Node], ) -> None: sub_args = [] for node_ind, arg_ind in node_ind_arg_ind: @@ -113,23 +118,26 @@ def _replace_region_with_subgraph( ) return - latest_region_node = region[-1] - with graph.inserting_after(latest_region_node): - invoke_subgraph_node = graph.create_node( - "call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {} + from torch._inductor.pattern_matcher import stable_topological_sort + + invoke_subgraph_node = graph.create_node( + "call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {} + ) + for ind, external_user_ind in enumerate(inds_with_external_users): + node = region[external_user_ind] + subgraph_output = graph.create_node( + "call_function", operator.getitem, (invoke_subgraph_node, ind), {} ) - with graph.inserting_after(invoke_subgraph_node): - for ind, external_user_ind in enumerate(inds_with_external_users): - node = region[external_user_ind] - subgraph_output = graph.create_node( - "call_function", operator.getitem, (invoke_subgraph_node, ind), {} - ) - output_replacements[node] = subgraph_output - node.replace_all_uses_with(subgraph_output, propagate_meta=True) - - # Erase in reverse topological order - for node in reversed(region): - graph.erase_node(node) + node.replace_all_uses_with(subgraph_output, propagate_meta=True) + + # Erase in reverse topological order + for node in reversed(region): + graph.erase_node(node) + + if config.graph_deduplication_lint: + _detect_cycles(graph) + stable_topological_sort(graph) + graph.lint() if config.graph_deduplication_lint: graph.lint() diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ba3dea42864d..69d1ac475790 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -240,6 +240,10 @@ def __init__(self, nn_modules: dict[str, torch.nn.Module]): def __repr__(self) -> str: return "FakeRootModule(...)" + def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]): + for k, v in nn_modules.items(): + setattr(self, k, v) + class WrapperBackend: def __init__(self, backend: CompilerFn): @@ -1070,8 +1074,6 @@ def append_prefix_insts(): for value in stack_values: value.realize() - output_replacements = self.dedup_pass() - # Use nn.Module "proxies" in the constructed GraphModule so that # the resulting GM does not hold additional strong references to the original modules. # This prevents a strong ref cycle where Dynamo created code holds on to references @@ -1155,9 +1157,7 @@ def append_prefix_insts(): append_prefix_insts() # optimization to generate better code in a common case self.add_output_instructions( - self.compile_and_call_fx_graph( - tx, list(reversed(stack_values)), root, output_replacements - ) + self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) + [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] ) # restore all the live local vars @@ -1190,9 +1190,7 @@ def append_prefix_insts(): output = [] if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: output.extend( - self.compile_and_call_fx_graph( - tx, pass2.graph_output_vars(), root, output_replacements - ) + self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) ) if len(pass2.graph_outputs) != 0: @@ -1356,7 +1354,7 @@ def run_compiler_collective(self, tx): tx.speculation_log.clear() raise exc.CompileCollectiveRestartAnalysis - def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): + def compile_and_call_fx_graph(self, tx, rv, root): """ Generate code from self.graph and return the Instruction()s to call that generated code. @@ -1379,9 +1377,8 @@ def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs): (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), {}, ) - - for old_node, new_node in replaced_outputs.items(): - old_node.replace_all_uses_with(new_node) + sub_gms = self.dedup_pass() + root.add_nn_modules(sub_gms) tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node) if not config.do_not_emit_runtime_asserts: @@ -1576,7 +1573,7 @@ def dedup_pass(self): if torch._dynamo.config.use_graph_deduplication: return apply_graph_deduplication(self) else: - return dict() + return {} def install_subgraph(self, name, sub_gm): next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True) From b70d105c7792a0014b69e248f7f3c9e0fdbe13fd Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Tue, 1 Apr 2025 21:13:39 +0000 Subject: [PATCH 067/332] infer dynamic shapes through additional inputs (#150144) Summary: Instead of explicitly specifying dynamic shapes, it is possible to infer them from additional example inputs. Together with the example inputs provided to export, we can basically make any varying dim dynamic and keep any fixed dim static. This should be useful for prod scenarios that have access to tests and/or profiling data, yet are somewhat removed from the model authoring process. However this alone is not satisfactory: the exported program by design has only one graph, representing one path through the model, and we cannot necessarily guarantee that this graph works for the additional example inputs because different guards might have been created if we had exported with them instead (corresponding to different traced paths). However, checking that the additional example inputs satisfy the guards created by the original export should be sufficient for generalization. Now, while we don't preserve all guards in the exported program, we do check a subset of them as part of input matching. So we add a verification step at the end of export when such additional example inputs are provided. This should be enough for now. Test Plan: added test (positive and negative cases) Differential Revision: D72001771 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150144 Approved by: https://github.com/bobrenjc93 --- docs/source/export.rst | 6 +++ test/export/test_export.py | 56 ++++++++++++++++++++++++ torch/export/__init__.py | 3 +- torch/export/_trace.py | 13 +++++- torch/export/dynamic_shapes.py | 79 ++++++++++++++++++++++++++++++++++ 5 files changed, 154 insertions(+), 3 deletions(-) diff --git a/docs/source/export.rst b/docs/source/export.rst index 3f947e54e568..8db7fb118334 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -797,6 +797,12 @@ API Reference .. automethod:: dynamic_shapes +.. autoclass:: torch.export.dynamic_shapes.AdditionalInputs + + .. automethod:: add + .. automethod:: dynamic_shapes + .. automethod:: verify + .. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes .. autoclass:: Constraint .. autoclass:: ExportedProgram diff --git a/test/export/test_export.py b/test/export/test_export.py index 5e7d9a436e3d..4507bf93b9d6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3888,6 +3888,62 @@ def forward(self, inp: Inp1): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") + def test_dynamic_shapes_inferred_basic(self): + class M(torch.nn.Module): + def forward(self, x, y, z): + # x and y[0] must have same dynamic shape (say `dim`) >= 3 + tmp = (x + y[0])[:3] + # z["k"] must have static shape = 3 + return tmp * z["k"] + + m = M() + args = (torch.randn(4), [torch.randn(4)], {"k": torch.randn(3)}) + + additional_inputs = torch.export.AdditionalInputs() + # 4->5, 4->5, 3->3 + good_args = (torch.randn(5), [torch.randn(5)], {"k": torch.randn(3)}) + additional_inputs.add(good_args) + + ep = export(m, args, dynamic_shapes=additional_inputs) + got_shapes = [ + str(tuple(node.meta["val"].shape)) + for node in ep.graph.find_nodes(op="placeholder") + ] + dim = next(iter(ep.range_constraints.keys())) + expected_shapes = [f"({dim},)", f"({dim},)", "(3,)"] + self.assertEqual(got_shapes, expected_shapes) + + def expect_error(bad_args, run_time_msg, compile_time_msg): + with self.assertRaisesRegex(RuntimeError, run_time_msg): + ep.module()(*bad_args) + + additional_inputs = torch.export.AdditionalInputs() + additional_inputs.add(bad_args) + + with self.assertRaisesRegex(RuntimeError, compile_time_msg): + export(m, args, dynamic_shapes=additional_inputs) + + expect_error( + # 4->2, 4->2, 3->3 + bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}), + run_time_msg="Expected input.*to be >= 3, but got 2", + compile_time_msg="Expected input.*to be >= 3, but got 2", + ) + + expect_error( + # 4->6, 4->7, 3->3 + bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}), + run_time_msg="Expected input.*to be equal to 6, but got 7", + compile_time_msg="Expected input.*to be equal to 6, but got 7", + ) + + expect_error( + # 4->5, 4->5, 3->4 + bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}), + run_time_msg="Expected input.*to be equal to 3, but got 4", + compile_time_msg=r"Constraints violated.*\n.*was inferred to be a constant \(3\)", + ) + def test_mismatched_dynamic_shapes(self): AUTO, STATIC = Dim.AUTO, Dim.STATIC diff --git a/torch/export/__init__.py b/torch/export/__init__.py index f3cd894185e6..41b9421be641 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -51,13 +51,14 @@ "unflatten", "FlatArgsAdapter", "UnflattenedModule", + "AdditionalInputs", ] # To make sure export specific custom ops are loaded import torch.export.custom_ops from .decomp_utils import CustomDecompTable -from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection +from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection from .exported_program import ( default_decompositions, ExportedProgram, diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 830710f44e7d..63c7472f0a7c 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1155,10 +1155,15 @@ def _process_export_inputs(mod, args, kwargs, dynamic_shapes): kwargs = kwargs if kwargs is not None else {} _, original_in_spec = pytree.tree_flatten((args, kwargs)) - if isinstance(dynamic_shapes, torch.export.ShapesCollection): + if isinstance(dynamic_shapes, torch.export.AdditionalInputs): + verify_additional_inputs = dynamic_shapes.verify dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + else: + verify_additional_inputs = lambda ep: None # noqa: E731 + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) - return args, kwargs, original_in_spec, dynamic_shapes + return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs def _get_module_call_graph( @@ -1971,6 +1976,7 @@ def _export_for_training( kwargs, orig_in_spec, dynamic_shapes, + verify_additional_inputs, ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) original_state_dict = _get_original_state_dict(mod) @@ -2033,6 +2039,7 @@ def _export_for_training( verifiers=[TrainingIRVerifier], ) + verify_additional_inputs(exported_program) return exported_program @@ -2132,6 +2139,7 @@ def _export( kwargs, original_in_spec, dynamic_shapes, + verify_additional_inputs, ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) original_state_dict = _get_original_state_dict(mod) @@ -2205,4 +2213,5 @@ def _export( dtrace_structured("exported_program", payload_fn=lambda: str(exported_program)) + verify_additional_inputs(exported_program) return exported_program diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 3b0ce63d134c..50682a948aaf 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -34,6 +34,7 @@ "Dim", "dims", "refine_dynamic_shapes_from_suggested_fixes", + "AdditionalInputs", ] @@ -713,6 +714,84 @@ def find_shape(path, t): return dynamic_shapes +class AdditionalInputs: + """ + Infers dynamic_shapes based on additional inputs. + + This is useful particularly for deployment engineers who, on the one hand, may + have access to ample testing or profiling data that can provide a fair sense of + representative inputs for a model, but on the other hand, may not know enough + about the model to guess which input shapes should be dynamic. + + Input shapes that are different than the original are considered dynamic; conversely, + those that are the same as the original are considered static. Moreover, we verify + that the additional inputs are valid for the exported program. This guarantees that + tracing with them instead of the original would have generated the same graph. + + Example:: + + args0, kwargs0 = ... # example inputs for export + + # other representative inputs that the exported program will run on + dynamic_shapes = torch.export.AdditionalInputs() + dynamic_shapes.add(args1, kwargs1) + ... + dynamic_shapes.add(argsN, kwargsN) + + torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._examples = [] + + def add(self, args, kwargs=None): + """ + Additional input :func:`args` and :func:`kwargs`. + """ + + assert type(args) is tuple, f"Representative args {args} must be a tuple" + assert ( + kwargs is None or type(kwargs) is dict + ), f"Representative kwargs {kwargs} must be None or a dict" + self._examples.append((args, kwargs)) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the + original input :func:`args` and :func:`kwargs` and of each additional input + args and kwargs. + """ + + dynamic_shapes, *other_dynamic_shapes = [ + _tree_map_with_path( + lambda path, t: tuple(t.shape), _combine_args(m, args, kwargs) + ) + for args, kwargs in [(args, kwargs), *self._examples] + ] + + return tree_map_with_path( + lambda path, dim, *other_dims: ( + dim + if all(other_dim == dim for other_dim in other_dims) + else Dim.DYNAMIC + ), + dynamic_shapes, + *other_dynamic_shapes, + is_leaf=lambda i: type(i) is int, + ) + + def verify(self, ep): + """ + Verifies that an exported program is valid for each additional input. + """ + + epm = ep.module() + for args, kwargs in self._examples: + torch.export._unlift._check_input_constraints_pre_hook( + epm, args, kwargs or {} + ) + + def _warn_on_None_dynamic_shape_dimension(): msg = ( "Using None as a dynamic shape dimension is deprecated. " From 629c1bd2dd9448126c45ac0b104a948e53ac01a1 Mon Sep 17 00:00:00 2001 From: henrylhtsang Date: Tue, 1 Apr 2025 01:30:23 +0000 Subject: [PATCH 068/332] [ez][inductor][tests] Skip triton backend only for CPU tests (#150343) Motivation: to unblock https://github.com/pytorch/pytorch/pull/148622 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150343 Approved by: https://github.com/chenyang78 --- test/inductor/test_aot_inductor.py | 15 +++++++++++++++ test/inductor/test_torchinductor.py | 3 +++ 2 files changed, 18 insertions(+) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ce653436a860..12443a3bee89 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -434,6 +434,9 @@ def forward(self, y): self.check_model(model, example_inputs) def test_linear_dynamic_maxautotune(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -550,6 +553,9 @@ def forward(self, x, y): @skip("Test was marked as expected failure, but does not fail always anymore.") def test_dynamic_smem_above_default_limit(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def forward(self, x, y): return x @ y @@ -870,6 +876,9 @@ def forward(self, x, y): ) def test_addmm_multiple_dynamic(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self, n, k, device): super().__init__() @@ -907,6 +916,9 @@ def forward(self, a): ) def test_bmm_multiple_dynamic(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2954,6 +2966,9 @@ def forward(self, x): self.check_model(Model(), inputs) def test_convolution(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b7e7d2eb2c0b..a2357af8ee84 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -3782,6 +3782,9 @@ def forward(self, x): } ) def test_linear_dynamic_maxautotune(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + @torch.compile(dynamic=True) class Model(torch.nn.Module): def __init__(self) -> None: From 76e1b3ba4c8e0b79c093296fb8b420d7b6dcd356 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 1 Apr 2025 22:31:13 +0000 Subject: [PATCH 069/332] Revert "[ROCm] use correct workspace for hipblaslt, silence warning (#150227)" This reverts commit c158eac0de2afe38d68952ca401888ed5777f6b0. Reverted https://github.com/pytorch/pytorch/pull/150227 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/150227#issuecomment-2770827563)) --- aten/src/ATen/cuda/CUDABlas.cpp | 24 +++++------------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index a374ee3c8b7c..d39fe4be31c9 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -222,35 +222,19 @@ static size_t _getWorkspaceSize() { return workspace_size; } -static at::DataPtr _getNewWorkspace() { - return c10::cuda::CUDACachingAllocator::get()->allocate(_getWorkspaceSize()); -} - -// See Note [hipblaslt handles]. -// ROCm's hipblas and hipblaslt do not share handles, unlike with CUDA. -// Using getCurrentCUDABlasLtHandle is on purpose. For CUDA it's the same as -// getCurrentCUDABlasHandle, but for ROCm it's a unique handle. void* _getWorkspaceWithoutHandle() { - cublasLtHandle_t handle = at::cuda::getCurrentCUDABlasLtHandle(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); -#ifdef USE_ROCM - // The first call to _getWorkspaceWithoutHandle could be empty, so allocate and store. - if (workspace_it == at::cuda::cublas_handle_stream_to_workspace().end()) { - workspace_it = at::cuda::cublas_handle_stream_to_workspace().insert(workspace_it, {key, _getNewWorkspace()}); - } -#else TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); -#endif return workspace_it->second.mutable_get(); } void* _getWorkspace(size_t& workspaceSize) { +// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2)) workspaceSize = _getWorkspaceSize(); -#ifndef USE_ROCM - // See Note [hipblaslt handles]. auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); if (cublasWorkspaceSize < workspaceSize) { TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize, @@ -261,7 +245,9 @@ void* _getWorkspace(size_t& workspaceSize) { " size will be limited to the CUBLAS workspace size."); workspaceSize = cublasWorkspaceSize; } -#endif +// #else +// workspaceSize = at::cuda::getChosenWorkspaceSize(); +// #endif auto workspace_ptr = _getWorkspaceWithoutHandle(); return workspace_ptr; } From 9458460211a93911181b9f28cfb3245d58b0a12b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 1 Apr 2025 22:52:22 +0000 Subject: [PATCH 070/332] Revert "if blaslt fails, fall back to blas (#150147)" This reverts commit 65139eb050817329ac8e541c377b2be3bb5ffe14. Reverted https://github.com/pytorch/pytorch/pull/150147 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/150147#issuecomment-2770847320)) --- aten/src/ATen/cuda/CUDABlas.cpp | 91 +++++++++++------------------- aten/src/ATen/cuda/CUDABlas.h | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 18 +----- 3 files changed, 37 insertions(+), 74 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index d39fe4be31c9..ad92af61ff96 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -366,7 +366,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< template -static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { +static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { cudaDataType_t abcType = CUDA_R_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; @@ -454,7 +454,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); #endif - cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -469,10 +468,10 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { &heuristicResult, &returnedResult)); if (returnedResult == 0) { - cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; + TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); } - else { - cublasStatus = cublasLtMatmul( + + cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), alpha_ptr, @@ -489,10 +488,9 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { workspace_ptr, workspaceSize, at::cuda::getCurrentCUDAStream()); - } - if (cublasStatus != CUBLAS_STATUS_SUCCESS) { - TORCH_WARN( - "bgemm_internal_cublaslt error: ", + TORCH_CHECK( + cublasStatus == CUBLAS_STATUS_SUCCESS, + "CUDA error: ", at::cuda::blas::_cublasGetErrorEnum(cublasStatus), " when calling cublasLtMatmul with transpose_mat1 ", (opa == CUBLAS_OP_T), @@ -515,11 +513,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { " computeType ", computeType, " scaleType ", - scaleType, - ". Will attempt to recover by calling cublas instead."); - return false; - } - return true; + scaleType); } @@ -680,9 +674,7 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) // hipblaslt does not support double gemm yet bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(double)); #else - if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(double))) { - bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(double)); - } + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(double)); #endif } else { @@ -694,9 +686,7 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float))) { - bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); - } + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float)); } else { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); @@ -711,9 +701,7 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex gemm yet bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); #else - if (!bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex))) { - bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); - } + bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex)); #endif } else { @@ -729,9 +717,7 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex gemm yet bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); #else - if (!bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex))) { - bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); - } + bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex)); #endif } else { @@ -743,9 +729,7 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half))) { - bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); - } + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half)); } else { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); @@ -756,9 +740,7 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16))) { - bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); - } + bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } #if defined(USE_ROCM) && !defined(_MSC_VER) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { @@ -881,11 +863,18 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { } } +template +inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // forward to bgemm implementation but set strides and batches to 0 + bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); +} + template inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented"); } + template <> void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] @@ -1095,14 +1084,6 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } -template -inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - // forward to bgemm implementation but set strides and batches to 0 - if (!bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0)) { - gemm_internal_cublas(CUDABLAS_GEMM_ARGS(Dtype)); - } -} - template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) { @@ -1319,7 +1300,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { template -bool gemm_and_bias( +void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1434,12 +1415,11 @@ bool gemm_and_bias( 1, &heuristicResult, &returnedResult)); - cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; if (returnedResult == 0) { - cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; + TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); } - else { - cublasStatus = cublasLtMatmul( + + cublasStatus_t cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), alpha_ptr, @@ -1456,10 +1436,9 @@ bool gemm_and_bias( workspace_ptr, workspaceSize, stream); - } - if (cublasStatus != CUBLAS_STATUS_SUCCESS) { - TORCH_WARN( - "gemm_and_bias error: ", + TORCH_CHECK( + cublasStatus == CUBLAS_STATUS_SUCCESS, + "CUDA error: ", at::cuda::blas::_cublasGetErrorEnum(cublasStatus), " when calling cublasLtMatmul with transpose_mat1 ", transpose_mat1, @@ -1482,14 +1461,10 @@ bool gemm_and_bias( " computeType ", computeType, " scaleType ", - scaleType, - ". Will attempt to recover by calling unfused cublas path."); - return false; - } - return true; + scaleType); } -template bool gemm_and_bias( +template void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1505,7 +1480,7 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template bool gemm_and_bias( +template void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1521,7 +1496,7 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template bool gemm_and_bias( +template void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1537,7 +1512,7 @@ template bool gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template bool gemm_and_bias( +template void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index b65a7c79ee10..637b48c797fa 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -91,7 +91,7 @@ enum GEMMAndBiasActivationEpilogue { // NOTE: GELU activation is not supported prior to CUDA 11.4 and will // do nothing if passed in that case. template -bool gemm_and_bias( +void gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 50043e3e8534..eaa90de69570 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -326,7 +326,7 @@ static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha } } -Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) { +Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call // expand(). @@ -352,8 +352,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma #else static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt(); #endif - // if lt path fails, we recurse back into this function here and force the lt path to off - disable_addmm_cuda_lt |= disable_addmm_cuda_lt_override; at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned self_; if (&result != &self) { @@ -448,7 +446,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma if (useLtInterface) { #if defined(USE_ROCM) - bool okay = true; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -464,7 +461,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_to_gemm_and_blas_arg(activation)); } else { - okay = at::cuda::blas::gemm_and_bias( + at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', args.m, @@ -483,10 +480,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_to_gemm_and_blas_arg(activation) ); }}); - if (!okay) { - // lt path failed; recurse but disable lt path - return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); - } #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -498,7 +491,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None; #endif - bool okay = true; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -514,7 +506,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue); } else { - okay = at::cuda::blas::gemm_and_bias( + at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', args.m, @@ -531,10 +523,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue ); }}); - if (!okay) { - // lt path failed; recurse but disable lt path - return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); - } #endif } else { From 80ab2337863b3a7fad74838a20aa40872507aaa4 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Tue, 1 Apr 2025 10:55:23 -0700 Subject: [PATCH 071/332] [Inductor] Hide reinplace_fsdp_all_gather pass behind skip_fsdp_hooks config (#150436) The `reinplace_fsdp_all_gather` pass is currently only for Traceable FSDP2 and doesn't work together with SimpleFSDP. We should hide the pass behind `skip_fsdp_hooks` config which makes it only apply to Traceable FSDP2. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150436 Approved by: https://github.com/BoyuanFeng --- torch/_inductor/fx_passes/post_grad.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index f2ab19dd720f..327e15cce92c 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -197,9 +197,10 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass( decompose_auto_functionalized ) - GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( - comms.reinplace_fsdp_all_gather - ) + if not torch._dynamo.config.skip_fsdp_hooks: + GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass( + comms.reinplace_fsdp_all_gather + ) GraphTransformObserver(gm, "lower_scan_to_while_loop").apply_gm_pass( lower_scan_to_while_loop ) From 203a27e0cecce5b9050218c9d6a56c8cd2ebd0a5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 1 Apr 2025 23:07:28 +0000 Subject: [PATCH 072/332] Revert "[cuBLAS][cuBLASLt] Unify `cuBLASLt` workspaces with `cuBLAS` workspaces (#145130)" This reverts commit 8f7fbe3d7d2cd301df48fcbe8a14f8aa1a9c1e48. Reverted https://github.com/pytorch/pytorch/pull/145130 on behalf of https://github.com/clee2000 due to reverted internally by D72140190 ([comment](https://github.com/pytorch/pytorch/pull/145130#issuecomment-2770874244)) --- aten/src/ATen/cuda/CUDABlas.cpp | 68 ++++++++----------------- aten/src/ATen/cuda/CUDAContextLight.h | 3 -- aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++-- benchmarks/dynamo/common.py | 10 ---- 4 files changed, 26 insertions(+), 65 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index ad92af61ff96..52aee1378c0e 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -3,7 +3,6 @@ */ #include -#include #include #include #include @@ -222,36 +221,6 @@ static size_t _getWorkspaceSize() { return workspace_size; } -void* _getWorkspaceWithoutHandle() { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - auto stream = c10::cuda::getCurrentCUDAStream(); - cudaStream_t _stream = stream; - auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); - auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key); - TORCH_INTERNAL_ASSERT(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end()); - return workspace_it->second.mutable_get(); -} - -void* _getWorkspace(size_t& workspaceSize) { -// #ifdef (defined(USE_ROCM) || defined(FBCODE_CAFFE2)) - workspaceSize = _getWorkspaceSize(); - auto cublasWorkspaceSize = at::cuda::getChosenWorkspaceSize(); - if (cublasWorkspaceSize < workspaceSize) { - TORCH_WARN_ONCE("Requested CUBLASLT workspace size of ", workspaceSize, - " bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize, - " bytes. Please increase CUBLAS workspace size", - " via CUBLAS_WORKSPACE_CONFIG or decrease requested" - " CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace" - " size will be limited to the CUBLAS workspace size."); - workspaceSize = cublasWorkspaceSize; - } -// #else -// workspaceSize = at::cuda::getChosenWorkspaceSize(); -// #endif - auto workspace_ptr = _getWorkspaceWithoutHandle(); - return workspace_ptr; -} - } // anonymous namespace namespace at::cuda::blas { @@ -441,8 +410,9 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { } CuBlasLtMatmulPreference preference; - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); #ifndef USE_ROCM @@ -454,6 +424,8 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); #endif + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -485,7 +457,7 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { c, Cdesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream()); TORCH_CHECK( @@ -1385,8 +1357,9 @@ void gemm_and_bias( CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld); CuBlasLtMatmulPreference preference; - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); #ifndef USE_ROCM @@ -1400,7 +1373,8 @@ void gemm_and_bias( preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment); #endif - auto stream = c10::cuda::getCurrentCUDAStream(); + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); @@ -1433,9 +1407,9 @@ void gemm_and_bias( result_ptr, Cdesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1621,9 +1595,9 @@ void scaled_gemm( #endif // if CUDA_VERSION >= 12080 } - auto stream = c10::cuda::getCurrentCUDAStream(); - size_t workspaceSize = 0; - auto workspace_ptr = _getWorkspace(workspaceSize); + size_t workspaceSize = _getWorkspaceSize(); + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + CuBlasLtMatmulPreference preference; preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); cublasLtMatmulHeuristicResult_t heuristicResult = {}; @@ -1706,9 +1680,9 @@ void scaled_gemm( result_ptr, Ddesc.descriptor(), &heuristicResult.algo, - workspace_ptr, + workspace.mutable_data_ptr(), workspaceSize, - stream); + at::cuda::getCurrentCUDAStream()); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1784,8 +1758,8 @@ void int8_gemm( CuBlasLtMatmulPreference preference; size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - auto workspace = allocator.allocate(workspaceSize); + auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -1823,7 +1797,7 @@ void int8_gemm( nullptr, // Heuristics don't seem to work for int8 #endif #ifdef USE_ROCM - workspace.mutable_get(), + workspace.mutable_data_ptr(), #else nullptr, // Non-zero workspace doesn't seem to work. #endif diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 65019bb6097c..dc33cb541370 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -2,7 +2,6 @@ // Light-weight version of CUDAContext.h with fewer transitive includes #include -#include #include #include @@ -88,8 +87,6 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); TORCH_CUDA_CPP_API void clearCublasWorkspaces(); -TORCH_CUDA_CPP_API std::map, at::DataPtr>& cublas_handle_stream_to_workspace(); -TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); #if defined(CUDART_VERSION) || defined(USE_ROCM) TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 6f7f0536437c..9b183848503e 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -83,6 +83,11 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v #endif +std::map, at::DataPtr>& cublas_handle_stream_to_workspace() { + static auto& instance = *new std::map, at::DataPtr>; + return instance; +} + void createCublasHandle(cublasHandle_t *handle) { TORCH_CUDABLAS_CHECK(cublasCreate(handle)); } @@ -104,11 +109,6 @@ using CuBlasPoolType = DeviceThreadHandlePool, at::DataPtr>& cublas_handle_stream_to_workspace() { - static auto& instance = *new std::map, at::DataPtr>; - return instance; -} - void clearCublasWorkspaces() { cublas_handle_stream_to_workspace().clear(); } diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 7905a12b1d10..d23c528c9de9 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3545,16 +3545,6 @@ def run(runner, args, original_dir=None): if args.devices == ["xpu"]: torch.use_deterministic_algorithms(True, warn_only=True) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # TODO(eqy): revisit when cuBLASLt workspace size is bumped - # if args.only is not None and args.only in { - # "DebertaForQuestionAnswering", - # "RobertaForQuestionAnswering", - # "nvidia_deeprecommender", - # "volo_d1_224", - # }: - # # These seem unhappy with numerics of larger cuBLASLt workspace - # # sizes following #145130 (due to enabling split-k?) - # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.allow_tf32 = False torch.backends.cudnn.benchmark = False From 60fe0922f61773f31eacc1bfc7b861e1bba3d5c5 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 1 Apr 2025 23:28:20 +0000 Subject: [PATCH 073/332] [pytree] Register normal class to register_dataclass (#147752) Fixes https://github.com/pytorch/pytorch/pull/147532#discussion_r1964365330 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147752 Approved by: https://github.com/zou3519 --- test/export/test_export.py | 3 -- test/test_pytree.py | 59 +++++++++++++++++++---- torch/export/__init__.py | 7 +-- torch/utils/_pytree.py | 96 ++++++++++++++++++++++++++++++++++---- 4 files changed, 138 insertions(+), 27 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 4507bf93b9d6..4fc5515c7665 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3859,7 +3859,6 @@ def forward(self, x, y, z): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") - @testing.expectedFailureRetraceability def test_dynamic_shapes_builder_pytree(self): torch.export.register_dataclass( Inp1, @@ -5097,7 +5096,6 @@ def forward(self, x): ): self.assertTrue("source_fn_stack" in node.meta) - @testing.expectedFailureRetraceability def test_dynamic_shapes_dataclass(self): torch.export.register_dataclass( Inp2, @@ -7144,7 +7142,6 @@ def forward(self): ep = export(m, ()) self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"]) - @testing.expectedFailureRetraceability def test_preserve_shape_dynamism_for_unused_inputs(self): torch.export.register_dataclass( Inp3, diff --git a/test/test_pytree.py b/test/test_pytree.py index 99dfba3969ea..82665854c2b1 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -9,9 +9,9 @@ import time import unittest from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import auto -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional import torch import torch.utils._pytree as py_pytree @@ -1297,16 +1297,55 @@ def test_tree_map_with_path(self): def test_dataclass(self): @dataclass - class Point: - x: torch.Tensor - y: torch.Tensor + class Data: + a: torch.Tensor + b: str = "moo" + c: Optional[str] = None + d: str = field(init=False, default="") + + py_pytree.register_dataclass(Data) + old_data = Data(torch.tensor(3), "b", "c") + old_data.d = "d" + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "b") + self.assertEqual(new_data.c, "c") + self.assertEqual(new_data.d, "") + py_pytree._deregister_pytree_node(Data) + + with self.assertRaisesRegex(ValueError, "Missing fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "e"]) + + with self.assertRaisesRegex(ValueError, "Unexpected fields"): + py_pytree.register_dataclass(Data, field_names=["a", "b", "c", "d"]) + + py_pytree.register_dataclass( + Data, field_names=["a"], drop_field_names=["b", "c"] + ) + old_data = Data(torch.tensor(3), "b", "c") + new_data = py_pytree.tree_unflatten(*py_pytree.tree_flatten(old_data)) + self.assertEqual(new_data.a, torch.tensor(3)) + self.assertEqual(new_data.b, "moo") + self.assertEqual(new_data.c, None) + py_pytree._deregister_pytree_node(Data) + + def test_register_dataclass_class(self): + class CustomClass: + def __init__(self, x, y): + self.x = x + self.y = y - py_pytree.register_dataclass(Point) + with self.assertRaisesRegex(ValueError, "field_names must be specified"): + py_pytree.register_dataclass(CustomClass) - point = Point(torch.tensor(0), torch.tensor(1)) - point = py_pytree.tree_map(lambda x: x + 1, point) - self.assertEqual(point.x, torch.tensor(1)) - self.assertEqual(point.y, torch.tensor(2)) + py_pytree.register_dataclass(CustomClass, field_names=["x", "y"]) + c = CustomClass(torch.tensor(0), torch.tensor(1)) + mapped = py_pytree.tree_map(lambda x: x + 1, c) + self.assertEqual(mapped.x, torch.tensor(1)) + self.assertEqual(mapped.y, torch.tensor(2)) def test_constant(self): # Either use `frozen=True` or `unsafe_hash=True` so we have a diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 41b9421be641..e95ac3f3a1df 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -523,9 +523,4 @@ def forward(self, x: InputDataClass) -> OutputDataClass: print(ep) """ - - from torch._export.utils import register_dataclass_as_pytree_node - - return register_dataclass_as_pytree_node( - cls, serialized_type_name=serialized_type_name - ) + pytree.register_dataclass(cls, serialized_type_name=serialized_type_name) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 27941c68066b..9b5d472321e5 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -205,6 +205,10 @@ def register_pytree_node( ) -> None: """Register a container-like type as pytree node. + Note: + :func:`register_dataclass` is a simpler way of registering a container-like + type as a pytree node. + Args: cls: the type to register flatten_fn: A callable that takes a pytree and returns a flattened @@ -265,14 +269,34 @@ def register_pytree_node( _cxx_pytree_pending_imports.append((args, kwargs)) -def register_dataclass(cls: type[Any]) -> None: - """Registers a ``dataclasses.dataclass`` type as a pytree node. +def register_dataclass( + cls: type[Any], + *, + field_names: Optional[list[str]] = None, + drop_field_names: Optional[list[str]] = None, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a type that has the semantics of a ``dataclasses.dataclass`` type + as a pytree node. This is a simpler API than :func:`register_pytree_node` for registering - a dataclass. + a dataclass or a custom class with the semantics of a dataclass. Args: - cls: the dataclass type to register + cls: The python type to register. The class must have the semantics of a + dataclass; in particular, it must be constructed by passing the fields + in. + field_names (Optional[List[str]]): A list of field names that correspond + to the **non-constant data** in this class. This list must contain + all the fields that are used to initialize the class. This argument + is optional if ``cls`` is a dataclass, in which case the fields will + be taken from ``dataclasses.fields()``. + drop_field_names (Optional[List[str]]): A list of field names that + should not be included in the pytree. + serialized_type_name: A keyword argument used to specify the fully + qualified name used when serializing the tree spec. This is only + needed for serializing the treespec in torch.export. Example: @@ -293,11 +317,67 @@ def register_dataclass(cls: type[Any]) -> None: >>> assert torch.allclose(point.y, torch.tensor(2)) """ - import torch.export + drop_field_names = drop_field_names or [] + + if not dataclasses.is_dataclass(cls): + if field_names is None: + raise ValueError( + "field_names must be specified with a list of all fields used to " + f"initialize {cls}, as it is not a dataclass." + ) + elif field_names is None: + field_names = [f.name for f in dataclasses.fields(cls) if f.init] + else: + dataclass_init_fields = {f.name for f in dataclasses.fields(cls) if f.init} + dataclass_init_fields.difference_update(drop_field_names) + + if dataclass_init_fields != set(field_names): + error_msg = "field_names does not include all dataclass fields.\n" + + if missing := dataclass_init_fields - set(field_names): + error_msg += ( + f"Missing fields in `field_names`: {missing}. If you want " + "to include these fields in the pytree, please add them " + "to `field_names`, otherwise please add them to " + "`drop_field_names`.\n" + ) + + if unexpected := set(field_names) - dataclass_init_fields: + error_msg += ( + f"Unexpected fields in `field_names`: {unexpected}. " + "Please remove these fields, or add them to `drop_field_names`.\n" + ) + + raise ValueError(error_msg) + + def _flatten_fn(obj: Any) -> tuple[list[Any], Context]: + flattened = [] + flat_names = [] + none_names = [] + for name in field_names: + val = getattr(obj, name) + if val is not None: + flattened.append(val) + flat_names.append(name) + else: + none_names.append(name) + return flattened, [flat_names, none_names] + + def _unflatten_fn(values: Iterable[Any], context: Context) -> Any: + flat_names, none_names = context + return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names)) - # Eventually we should move the export code here. It is not specific to export, - # aside from the serialization pieces. - torch.export.register_dataclass(cls) + def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]: + flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc] + return [(MappingKey(k), v) for k, v in zip(flat_names, flattened)], flat_names + + _private_register_pytree_node( + cls, + _flatten_fn, + _unflatten_fn, + serialized_type_name=serialized_type_name, + flatten_with_keys_fn=_flatten_fn_with_keys, + ) CONSTANT_NODES: set[type] = set() From 4934a8334726fa804653581e32603df0669e99b6 Mon Sep 17 00:00:00 2001 From: Nick Riasanovsky Date: Tue, 1 Apr 2025 23:29:35 +0000 Subject: [PATCH 074/332] [AMD] [TRITON] [INDUCTOR] Add tl.assume to enable bufferops on AMD (#150373) Summary: Update the GEMM template to include the necessary `tl.assume` annotations to enable bufferops with AMD. Test Plan: Tested manually with a simple matmul run with torch.complie(f, mode="max-autotune") the environment variables TRITON_ALWAYS_COMPILE=1 AMDGCN_ENABLE_DUMP=1 AMDGCN_USE_BUFFER_OPS=1. Inspecting the generated AMDGCN all loads/stores use bufferops. Note: Since inductor is loading constants for many of the shape values assumes are generally not needed for the stride/shape information, but pid calculations are generally a gap in Triton's inference capability. Differential Revision: D71922698 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150373 Approved by: https://github.com/eellison --- torch/_inductor/kernel/bmm.py | 2 ++ torch/_inductor/kernel/mm.py | 4 ++++ torch/_inductor/kernel/mm_plus_mm.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index c3886111cb02..cd074e2c36d4 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -74,6 +74,8 @@ def _is_large_block_for_cpu(m, n, k): group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index ffa1531efd42..e4389ce9e78c 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -90,6 +90,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -159,6 +161,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) diff --git a/torch/_inductor/kernel/mm_plus_mm.py b/torch/_inductor/kernel/mm_plus_mm.py index ac6bbee6c75a..2e190595c0d1 100644 --- a/torch/_inductor/kernel/mm_plus_mm.py +++ b/torch/_inductor/kernel/mm_plus_mm.py @@ -53,6 +53,8 @@ group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) From 6aea4d90fb1b147e8e244abdfb93153bc06ff6c7 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 1 Apr 2025 23:37:25 +0000 Subject: [PATCH 075/332] gloo: use shared Stores (#150230) Summary: X-link: https://github.com/facebookincubator/gloo/pull/423 This modifies `connectFullMesh` to take in a shared_ptr instead of a reference. This is an API breaking change but fairly easy to work around. To have backwards compatibility in PyTorch during the commit phase we add a new ifdef `GLOO_SHARED_STORE` which can provide backwards compatibility until we update the pinned Gloo version in pytorch OSS repo. This also adds a new `wait_get` method to `IStore` which will allow us to do a more efficient operation in PyTorch TCPStore. PyTorch's `Store::get` automatically waits so we want to make sure we can avoid waiting twice to reduce network traffic. This change will land simultaneously in PyTorch and Gloo repos. Test Plan: ``` buck2 test //gloo/... //caffe2/caffe2/contrib/gloo: ``` Differential Revision: D72084111 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150230 Approved by: https://github.com/fduwjj --- .../distributed/c10d/ProcessGroupGloo.cpp | 19 +++++++++++++++++-- .../distributed/c10d/ProcessGroupGloo.hpp | 4 ++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 345b2741dc97..3c5644eeab68 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -785,10 +785,25 @@ ProcessGroupGloo::ProcessGroupGloo( contexts_.reserve(options_->devices.size()); for (const auto i : c10::irange(options_->devices.size())) { auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_); - auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_); + +#ifdef GLOO_SHARED_STORE + auto underlyingStore = store_; +#else + auto& underlyingStore = *store_; +#endif + + auto store = std::make_shared<::gloo::rendezvous::PrefixStore>( + std::to_string(i), underlyingStore); + +#ifdef GLOO_SHARED_STORE + auto connectStore = store; +#else + auto& connectStore = *store; +#endif + context->setTimeout(options_->timeout); try { - context->connectFullMesh(store, options_->devices[i]); + context->connectFullMesh(connectStore, options_->devices[i]); } catch (const std::runtime_error& e) { auto err = e.what(); // TORCH_CHECK to print the cpp stacktrace. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index b44cba9f35a4..059ba8a4ee3f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -367,7 +367,7 @@ class TORCH_API ProcessGroupGloo : public Backend { void enableCollectivesTiming() override; - const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const { + const std::shared_ptr<::gloo::rendezvous::Store>& _getStore() const { return store_; } @@ -393,7 +393,7 @@ class TORCH_API ProcessGroupGloo : public Backend { } protected: - std::unique_ptr<::gloo::rendezvous::Store> store_; + std::shared_ptr<::gloo::rendezvous::Store> store_; const c10::intrusive_ptr options_; // Every Gloo context represents a set of connections to its peers. From d22e3d5efe42cfccd26a2c48e243a78d690a4f8a Mon Sep 17 00:00:00 2001 From: "Junjie Wang (PyTorch)" Date: Tue, 1 Apr 2025 23:54:07 +0000 Subject: [PATCH 076/332] [fr] Add logger config for flight record in PGNCCL (#150356) Summary: We want to move from a scuba based direct logging to a logger config based logging. Mostly changes are internal but we need to change the exception to exception_msg. Test Plan: Following https://www.internalfb.com/wiki/Server_Logging/Getting_Started_with_Logging/Onboarding_Existing_Scribe-Based_Logging_(Alpha)/ to test it. Differential Revision: D72198171 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150356 Approved by: https://github.com/fegin --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index e473912ea62a..734705a93cc9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1320,7 +1320,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( e.what()); debugLog.strings["status"] = "EXCEPTION"; - debugLog.strings["exception"] = e.what(); + debugLog.strings["exception_msg"] = e.what(); LOG(ERROR) << errorMsg; } catch (...) { errorMsg = c10::str( @@ -1328,7 +1328,7 @@ bool ProcessGroupNCCL::waitForFutureOrTimeout( "Unknown exception thrown when waiting for future ", futDescription); debugLog.strings["status"] = "EXCEPTION"; - debugLog.strings["exception"] = "Unknown exception"; + debugLog.strings["exception_msg"] = "Unknown exception"; LOG(ERROR) << errorMsg; } } else { From db32093192c9dd9a37f6066ac540228de7ed3855 Mon Sep 17 00:00:00 2001 From: tvukovic-amd <127323445+tvukovic-amd@users.noreply.github.com> Date: Wed, 2 Apr 2025 00:35:43 +0000 Subject: [PATCH 077/332] [ROCm][Windows] Fix torchvision build with ROCm 6.4 on windows (#150180) Since with HIP SDK 6.4 hipcc files and calls and restructured, the case for calling hipcc.exe is added in case of building torchvision with HIP SDK 6.4 on Windows Pull Request resolved: https://github.com/pytorch/pytorch/pull/150180 Approved by: https://github.com/malfet, https://github.com/jeffdaily Co-authored-by: Jeff Daily --- torch/utils/cpp_extension.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 4d4e115f67b2..197eba777930 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -2139,7 +2139,9 @@ def _jit_compile(name, def _get_hipcc_path(): if IS_WINDOWS: - return _join_rocm_home('bin', 'hipcc.bat') + # mypy thinks ROCM_VERSION is None but it will never be None here + hipcc_exe = 'hipcc.exe' if ROCM_VERSION >= (6, 4) else 'hipcc.bat' # type: ignore[operator] + return _join_rocm_home('bin', hipcc_exe) else: return _join_rocm_home('bin', 'hipcc') From ee9729996107cbc7bc94cf705505fb8615eb230d Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 1 Apr 2025 12:33:11 -0700 Subject: [PATCH 078/332] [MPS][Testing] Benchmark reduction ops (#150452) That compares eager vs compile On my M4Pro mini I'm getting the following now ``` [--------------------------------------------------------------------------------------------- --------------------------------------------------------------------------------------------] | eager-512x512 | compile-512x512 | eager-1024x1024 | compile-1024x1024 | eager-2048x2048 | compile-2048x2048 | eager-4096x4096 | compile-4096x4096 1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- sum (torch.float32) | 121.0 | 201.5 | 130.3 | 772.3 | 179.4 | 1470.5 | 476.1 | 2980.0 max (torch.float32) | 154.1 | 165.9 | 198.7 | 211.6 | 344.2 | 386.9 | 1326.6 | 1345.6 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150452 Approved by: https://github.com/dcci, https://github.com/manuelcandales --- test/bench_mps_ops.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/bench_mps_ops.py b/test/bench_mps_ops.py index 009de265bf38..319c8eb9ef40 100644 --- a/test/bench_mps_ops.py +++ b/test/bench_mps_ops.py @@ -3,6 +3,7 @@ # Useful as reference tool when migrating ops from MPS to Metal import itertools import timeit +import warnings from typing import Optional import torch @@ -70,16 +71,50 @@ def bench_binary( return rc +def bench_reduction( + reduction_func, device: str = "mps", dtype: torch.dtype = torch.float32 +) -> list[Measurement]: + rc = [] + + # Bench 2D with reduction over dim=0 + def f(t): + return reduction_func(t, dim=0) + + f.__name__ = reduction_func.__name__ + f_c = torch.compile(f, dynamic=False) + + for size in (512, 1024, 2048, 4096): + x = torch.testing.make_tensor(size, size, device=device, dtype=dtype) + rc_c, rc_e = f(x), f_c(x) + rc_c, rc_e = (rc_c[0], rc_e[0]) if isinstance(rc_c, tuple) else (rc_c, rc_e) + if not torch.allclose(rc_c, rc_e): + mdiff = (rc_c - rc_e).abs().max() + warnings.warn( + f"Eager and compile reduction do not match for {reduction_func.__name__} and {dtype} max_diff={mdiff}", + stacklevel=2, + ) + rc.append(bench_unary_op(f, x, f"eager-{size}x{size}")) + rc.append(bench_unary_op(f_c, x, f"compile-{size}x{size}")) + return rc + + def main() -> None: dtypes = [torch.float16, torch.float32] if torch.backends.mps.is_macos_or_newer(14, 0): dtypes.append(torch.bfloat16) + # Profile unary ops rc = [] for op, dtype in itertools.product([torch.sqrt, torch.sin], dtypes): rc.extend(bench_unary(op, dtype=dtype)) Compare(rc).print() + # Profile reduction ops + rc = [] + for op in [torch.sum, torch.max]: + rc.extend(bench_reduction(op)) + Compare(rc).print() + # Profile binary ops rc = [] ops = [torch.fmax, torch.add] From c974b5322a1c5139800a6422e95368a7b89c5388 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 1 Apr 2025 13:19:59 -0700 Subject: [PATCH 079/332] enable torch.compile for torch._scaled_mm nvfp4 recipe (#150462) Summary: Updates the meta registration for `torch._scaled_mm` to work for the nvfp4 recipe. Test Plan: ```bash pytest test/test_matmul_cuda.py -s -k test_blockwise_nvfp4 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/150462 Approved by: https://github.com/eellison --- test/test_matmul_cuda.py | 29 +++++++++++++++++++++++++++++ torch/_meta_registrations.py | 31 +++++++++++++++++++++++-------- torch/fx/graph.py | 1 + 3 files changed, 53 insertions(+), 8 deletions(-) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 64f9ee7ad2df..49da165ca20e 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1397,6 +1397,35 @@ def test_blockwise_mxfp8_compile(self) -> None: ) torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) + def test_blockwise_nvfp4_compile(self) -> None: + + device = "cuda" + M, K, N = 128, 128, 128 + BLOCK_SIZE = 16 + + A_ref = torch.eye(M, device=device, dtype=torch.bfloat16) + B_ref = torch.eye(M, device=device, dtype=torch.bfloat16) + + A = _bfloat16_to_float4_e2m1fn_x2(A_ref) + B = _bfloat16_to_float4_e2m1fn_x2(B_ref) + + A_scale = torch.full((M, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + B_scale = torch.full((N, ceil_div(K, BLOCK_SIZE)), 1.0, device=device, dtype=torch.float8_e4m3fn) + C_ref = A_ref @ B_ref.t() + + compiled_scaled_mm = torch.compile(torch._scaled_mm, backend="inductor") + # C = torch._scaled_mm( + C = compiled_scaled_mm( + A, + B.t(), + A_scale, + B_scale, + out_dtype=torch.bfloat16, + use_fast_accum=False, + ) + torch.testing.assert_close(C, C_ref, atol=0, rtol=0) + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 1141484db6aa..dab0e92558fc 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -6182,12 +6182,13 @@ def meta_scaled_mm( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ): - def is_fp8_type(dtype): + def is_fp8_or_fp4_type(dtype): return dtype in ( torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e4m3fnuz, torch.float8_e5m2fnuz, + torch.float4_e2m1fn_x2, ) torch._check( @@ -6195,8 +6196,8 @@ def is_fp8_type(dtype): lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}", ) torch._check( - is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype), - lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", + is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype), + lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) if device_hint(self) == "cuda": @@ -6232,18 +6233,32 @@ def has_zero_dim(tensor_2d): m, _k = self.shape n = mat2.size(1) + is_blockwise_scaling = ( + scale_a.dtype == torch.float8_e8m0fnu + and scale_b.dtype == torch.float8_e8m0fnu + ) or ( + scale_a.dtype == torch.float8_e4m3fn + and scale_b.dtype == torch.float8_e4m3fn + ) + if scale_a.numel() == 1 and scale_b.numel() == 1: # tensorwise scaling torch._check( scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32, lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.", ) - elif ( - scale_a.dtype == torch.float8_e8m0fnu - and scale_b.dtype == torch.float8_e8m0fnu - ): + elif is_blockwise_scaling: # blockwise scaling - block_size_k = 32 + + if scale_a.dtype == torch.float8_e4m3fn: + # NVIDIA's nvfp4 recipe: + # * block size is 16 elements packed (32 unpacked) + # * _k needs to be translated to the unpacked version + block_size_k = 16 + _k = _k * 2 + else: + block_size_k = 32 + block_size_mn = 128 def ceil_div(a, b): diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 0e483f19c866..541a76942739 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -222,6 +222,7 @@ def _rename_object(self, obj: Any, name: str): torch.float8_e4m3fnuz: "f8e4m3fnuz", torch.float8_e5m2fnuz: "f8e5m2fnuz", torch.float8_e8m0fnu: "f8e8m0fnu", + torch.float4_e2m1fn_x2: "f4e2m1fnx2", torch.complex32: "c32", torch.complex64: "c64", torch.complex128: "c128", From e872c38eb3d17b4439b00aec3e360fe25846e2db Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 2 Apr 2025 01:33:20 +0000 Subject: [PATCH 080/332] Remove cppcoreguidelines-pro-type-member-init_fix suppression (#148638) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/148638 Approved by: https://github.com/zou3519 --- .../src/ATen/native/cpu/group_norm_kernel.cpp | 6 +- torch/csrc/dynamo/guards.cpp | 58 ++++++------------- torch/csrc/jit/codegen/fuser/arg_spec.h | 1 - torch/csrc/jit/ir/attributes.h | 2 - torch/csrc/jit/ir/ir.h | 5 -- torch/csrc/jit/runtime/graph_executor.h | 1 - torch/csrc/jit/runtime/interpreter.h | 1 - torch/csrc/jit/runtime/operator.h | 2 - torch/csrc/jit/runtime/static/impl.cpp | 3 +- torch/csrc/jit/runtime/static/impl.h | 14 ++--- torch/csrc/jit/runtime/static/ops.cpp | 1 - 11 files changed, 29 insertions(+), 65 deletions(-) diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index 8c1000f8de47..4807a689e8c2 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -570,10 +570,8 @@ ComputeInternalGradients( at::parallel_for(0, N * C, 1, [=](int64_t start, int64_t end) { constexpr int64_t K = Vec::size(); const int64_t inner_size = HxW / K * K; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - std::array ds_arr; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) - std::array db_arr; + std::array ds_arr{}; + std::array db_arr{}; for (const auto i : c10::irange(start, end)) { const T* dY_ptr = dY + i * HxW; const T* X_ptr = X + i * HxW; diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 6795857ed9f5..0b8dec86f98d 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -525,11 +525,11 @@ static PyTypeObject TensorGuardsType = { PyVarObject_HEAD_INIT(nullptr, 0) struct AutocastState { static constexpr auto& DEVICES = at::autocast::_AUTOCAST_SUPPORTED_DEVICES; - std::array enabled; - std::array dtype; + std::array enabled{}; + std::array dtype{}; bool cache_enabled; - AutocastState() : enabled{}, dtype{} { + AutocastState() { for (size_t i = 0; i < DEVICES.size(); i++) { enabled[i] = at::autocast::is_autocast_enabled(DEVICES[i]); dtype[i] = at::autocast::get_autocast_dtype(DEVICES[i]); @@ -1977,8 +1977,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { py::object py_addr_keep_alive, py::object verbose_code_parts) : RelationalGuard(std::move(verbose_code_parts)), - _py_addr_keep_alive(std::move(py_addr_keep_alive)), - _args_seen{0} { + _py_addr_keep_alive(std::move(py_addr_keep_alive)) { _nargs_int = PyLong_AsSize_t(nargs_int.ptr()); _nargs_float = PyLong_AsSize_t(nargs_float.ptr()); _nargs = _nargs_int + _nargs_float; @@ -2072,7 +2071,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { private: py::object _py_addr_keep_alive; - size_t _args_seen, _nargs_float, _nargs_int, _nargs; + size_t _args_seen{0}, _nargs_float, _nargs_int, _nargs; std::vector _args_int; std::vector _args_float; std::function _guard_check_fn; @@ -3496,7 +3495,6 @@ class GetAttrGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -3515,7 +3513,7 @@ class GetAttrGuardAccessor : public GuardAccessor { private: // no need of py::object here because the attr_name is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _attr_name; + PyObject* _attr_name{nullptr}; }; /** @@ -3571,7 +3569,6 @@ class GetGenericDictGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetGenericDictGuardAccessor( GuardManager* guard_manager, GetGenericDictGuardAccessor* from) @@ -3639,7 +3636,6 @@ class GetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GetItemGuardAccessor(GuardManager* guard_manager, GetItemGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -3658,7 +3654,7 @@ class GetItemGuardAccessor : public GuardAccessor { private: // no need of py::object here because the attr_name is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _attr_name; + PyObject* _attr_name{nullptr}; }; /** @@ -3757,7 +3753,6 @@ class FrameLocalsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FrameLocalsGuardAccessor( GuardManager* guard_manager, FrameLocalsGuardAccessor* from) @@ -3778,12 +3773,12 @@ class FrameLocalsGuardAccessor : public GuardAccessor { } private: - PyObject* _key; - int _framelocals_idx; + PyObject* _key{nullptr}; + int _framelocals_idx{-1}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. - bool _is_immutable_object; + bool _is_immutable_object{false}; }; /** @@ -3847,7 +3842,6 @@ class DictGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) DictGetItemGuardAccessor( GuardManager* guard_manager, DictGetItemGuardAccessor* from) @@ -3867,11 +3861,11 @@ class DictGetItemGuardAccessor : public GuardAccessor { } private: - PyObject* _key; + PyObject* _key{nullptr}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. - bool _is_immutable_object; + bool _is_immutable_object{false}; }; /** @@ -3924,7 +3918,6 @@ class ListGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ListGetItemGuardAccessor( GuardManager* guard_manager, ListGetItemGuardAccessor* from) @@ -3943,7 +3936,7 @@ class ListGetItemGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -3996,7 +3989,6 @@ class TupleGetItemGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TupleGetItemGuardAccessor( GuardManager* guard_manager, TupleGetItemGuardAccessor* from) @@ -4016,7 +4008,7 @@ class TupleGetItemGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; enum class TensorProperty { @@ -4143,7 +4135,6 @@ class TensorPropertyGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorPropertyGuardAccessor( GuardManager* guard_manager, TensorPropertyGuardAccessor<_prop>* from) @@ -4163,7 +4154,7 @@ class TensorPropertyGuardAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -4210,7 +4201,6 @@ class IndexedGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) IndexedGuardAccessor(GuardManager* guard_manager, IndexedGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4227,7 +4217,7 @@ class IndexedGuardAccessor : public GuardAccessor { } private: - py::int_ _index; + py::int_ _index{-1}; }; /** @@ -4287,7 +4277,6 @@ class GradGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GradGuardAccessor(GuardManager* guard_manager, GradGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4361,7 +4350,6 @@ class FuncDefaultsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FuncDefaultsGuardAccessor( GuardManager* guard_manager, FuncDefaultsGuardAccessor* from) @@ -4437,7 +4425,6 @@ class FuncKwDefaultsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) FuncKwDefaultsGuardAccessor( GuardManager* guard_manager, FuncKwDefaultsGuardAccessor* from) @@ -4494,7 +4481,6 @@ class GlobalsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GlobalsGuardAccessor(GuardManager* guard_manager, GlobalsGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4513,7 +4499,7 @@ class GlobalsGuardAccessor : public GuardAccessor { private: // no need of py::object here because the globals_dict is already passed on to // the base class as accessor_key which is a py::object. - PyObject* _globals_dict; + PyObject* _globals_dict{nullptr}; }; /** @@ -4554,7 +4540,6 @@ class TypeGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TypeGuardAccessor(GuardManager* guard_manager, TypeGuardAccessor* from) : GuardAccessor(guard_manager, from) { from->clone_visitor(this); @@ -4623,7 +4608,6 @@ class TupleIteratorGetItemAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TupleIteratorGetItemAccessor( GuardManager* guard_manager, TupleIteratorGetItemAccessor* from) @@ -4643,7 +4627,7 @@ class TupleIteratorGetItemAccessor : public GuardAccessor { } private: - Py_ssize_t _index; + Py_ssize_t _index{-1}; }; /** @@ -4739,7 +4723,6 @@ class GlobalWeakRefGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GlobalWeakRefGuardAccessor( GuardManager* guard_manager, GlobalWeakRefGuardAccessor* from) @@ -4758,7 +4741,7 @@ class GlobalWeakRefGuardAccessor : public GuardAccessor { } private: - PyObject* _global_name; + PyObject* _global_name{nullptr}; }; /** @@ -4830,7 +4813,6 @@ class WeakRefCallGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) WeakRefCallGuardAccessor( GuardManager* guard_manager, WeakRefCallGuardAccessor* from) @@ -4910,7 +4892,6 @@ class CallFunctionNoArgsGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) CallFunctionNoArgsGuardAccessor( GuardManager* guard_manager, CallFunctionNoArgsGuardAccessor* from) @@ -4982,7 +4963,6 @@ class PythonLambdaGuardAccessor : public GuardAccessor { } public: // cloning functions - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) PythonLambdaGuardAccessor( GuardManager* guard_manager, PythonLambdaGuardAccessor* from) diff --git a/torch/csrc/jit/codegen/fuser/arg_spec.h b/torch/csrc/jit/codegen/fuser/arg_spec.h index 7239e0391b8f..923aa324aa7a 100644 --- a/torch/csrc/jit/codegen/fuser/arg_spec.h +++ b/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -16,7 +16,6 @@ namespace torch::jit::fuser { // Note: the device to run on is included in the arg spec because kernels // are compiled per-device. struct TORCH_API ArgSpec { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ArgSpec(at::TensorList inputs, const int _device) : descs_{c10::fmap(inputs)}, hash_code_{c10::get_hash(_device, inputs.size(), descs_)}, diff --git a/torch/csrc/jit/ir/attributes.h b/torch/csrc/jit/ir/attributes.h index fb2c44350d2d..f6e8f2148078 100644 --- a/torch/csrc/jit/ir/attributes.h +++ b/torch/csrc/jit/ir/attributes.h @@ -86,7 +86,6 @@ template struct VectorAttributeValue : public AttributeValue { using ConstructorType = std::vector; using ValueType = std::vector; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { @@ -144,7 +143,6 @@ struct TORCH_API GraphAttr : public AttributeValue { struct TORCH_API GraphsAttr : public AttributeValue { using ConstructorType = std::vector>; using ValueType = std::vector>; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) GraphsAttr(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {} ValueType& value() { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index 44087074e891..fc780c26c3dd 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -1490,7 +1490,6 @@ struct WithCurrentScope { ScopePtr prev_scope_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) inline Value::Value(Node* node_, size_t offset_) : node_(node_), offset_(offset_), @@ -1651,7 +1650,6 @@ struct TORCH_API OperatorSet { }; template -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct OperatorMap { // Type aliasing using OpMapType = typename std::pair, T>; @@ -1659,12 +1657,10 @@ struct OperatorMap { using MapType = std::unordered_map; OperatorMap() = default; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit OperatorMap( std::initializer_list, T>> init) { insert(init); } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit OperatorMap(std::initializer_list> init) { insert(init); } @@ -1760,7 +1756,6 @@ struct OperatorMap { }; template -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct FunctionSchemaMap { // Type aliasing using FuncSchemaMapType = typename std::pair; diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 8295b9d6c378..d1039216de3e 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -43,7 +43,6 @@ struct ExecutionPlan { // They are only valid only right after you call getDebugState() and should // never be used again once another GraphExecutor function is called. -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct GraphExecutorState { const Graph* graph = nullptr; ExecutionPlan fallback; // XXX: members of this field are optional diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index e6a71dc0a0b9..6ae9f52a0cda 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -111,7 +111,6 @@ struct Suspend : public std::exception { return "Suspend"; } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) explicit Suspend(c10::intrusive_ptr future_) : future(std::move(future_)) {} diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 2e609f18ecc0..bde3825f5ea3 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -60,7 +60,6 @@ const std::array kJitOnlyOperatorTags = { // the concrete operator nature. struct TORCH_API Operator { private: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct C10Operator final { c10::OperatorHandle handle_; Operation op_; @@ -69,7 +68,6 @@ struct TORCH_API Operator { std::string schema_string_; mutable std::optional alias_analysis_; }; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct JitOnlyOperator final { // The only valid transition for schema_ is from right->left, i.e. // when the schema gets parsed. diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index ec736d006be0..0e2a89544b56 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -38,7 +38,6 @@ #include #include #include -#include #ifdef FBCODE_CAFFE2 #include @@ -953,11 +952,11 @@ BlockRunner::BlockRunner( } } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) BlockRunner::BlockRunner(BlockRunner&&) noexcept = default; BlockRunner::~BlockRunner() = default; +// NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) void BlockRunner::set_arg(const size_t idx, std::vector&& args) { DCHECK(idx < args.size()); Input(idx + first_input_is_self_) = std::move(args[idx]); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index 04a0862f9795..e8a3bdbc42ff 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -815,10 +815,8 @@ class TORCH_API BlockRunner { std::vector nodes_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API StaticNodeInfo { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) StaticNodeInfo( Node* n, ProcessedFunction* fn, @@ -873,6 +871,9 @@ class TORCH_API ProcessedNodeMetadata { // if the contained type (BlockRunner) is not copyable ProcessedNodeMetadata(const ProcessedNodeMetadata&) = delete; ProcessedNodeMetadata& operator=(const ProcessedNodeMetadata&) = delete; + ProcessedNodeMetadata(ProcessedNodeMetadata&&) = delete; + ProcessedNodeMetadata&& operator=(ProcessedNodeMetadata&&) = delete; + ~ProcessedNodeMetadata() = default; std::vector& block_runners() { return block_runners_; @@ -895,10 +896,8 @@ class TORCH_API ProcessedNodeMetadata { torch::jit::TaskLauncher* launcher_; }; -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) class TORCH_API ProcessedNode { public: - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) ProcessedNode() = default; ProcessedNode(const StaticNodeInfo& other, IValue* values) @@ -917,6 +916,7 @@ class TORCH_API ProcessedNode { ProcessedNode(const ProcessedNode&) = delete; ProcessedNode& operator=(const ProcessedNode& other) = delete; ProcessedNode& operator=(ProcessedNode&&) = default; + ~ProcessedNode() = default; void run(); @@ -1025,10 +1025,10 @@ class TORCH_API ProcessedNode { [[nodiscard]] bool verify_inputs_dont_overlap_outputs(bool force_check) const; - Node* node_; - const ProcessedFunction* fn_; + Node* node_{nullptr}; + const ProcessedFunction* fn_{nullptr}; ProcessedNodeInputs inputs_; - uint16_t outputs_offset_; + uint16_t outputs_offset_{0}; bool overlap_detected_{false}; IValue* values_ = nullptr; // unowned // Metadata for ProcessedNode. diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 60fca2f87066..d5586a5b9cd7 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -1344,7 +1344,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::pow, aten_pow, [](Node* n) -> SROperator { namespace { -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) struct ToArgs { std::optional dtype; c10::Layout layout; From 0ae75ca2de0551406c3e91849ead32eed981b581 Mon Sep 17 00:00:00 2001 From: Rithesh Baradi Date: Wed, 2 Apr 2025 01:54:35 +0000 Subject: [PATCH 081/332] assert on all_reduce_event only if it's not CPU device. (#150316) Summary: For CPU based runs, `all_reduce_event` would be None since this is the result of the `all_reduce_stream.record_event()`, which does not do much other than returning None when device type is CPU. Test Plan: CI Differential Revision: D72176406 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150316 Approved by: https://github.com/kwen2501, https://github.com/weifengpy, https://github.com/mori360 --- .../fsdp/_fully_shard/_fsdp_param_group.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index e149005ffc2c..c9c36654e882 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -95,17 +95,17 @@ def get_all_gather_streams( # See [Note: Overlapping all-gather copy-in and all-gather] class AllGatherState(NamedTuple): all_gather_result: AllGatherResult - event: torch.Event # all-gather copy-out + event: Optional[torch.Event] # all-gather copy-out class ReduceScatterState(NamedTuple): reduce_scatter_input: torch.Tensor - event: torch.Event # reduce-scatter event + event: Optional[torch.Event] # reduce-scatter event class AllReduceState(NamedTuple): all_reduce_input: torch.Tensor - event: torch.Event # all-reduce event + event: Optional[torch.Event] # all-reduce event class FSDPParamGroup: @@ -310,11 +310,11 @@ def wait_for_unshard(self): self._wait_all_gather_streams_on_event(all_gather_copy_out_event) self._all_gather_result = None # free unless saved in `all_gather_state` - def _wait_all_gather_streams_on_event(self, event: torch.Event): + def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): # Calling `unshard` before lazy init means streams are not initialized - if hasattr(self.comm_ctx, "all_gather_copy_in_stream"): + if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None: self.comm_ctx.all_gather_copy_in_stream.wait_event(event) - if hasattr(self.comm_ctx, "all_gather_stream"): + if hasattr(self.comm_ctx, "all_gather_stream") and event is not None: self.comm_ctx.all_gather_stream.wait_event(event) def reshard(self): @@ -414,11 +414,14 @@ def post_backward(self, *unused: Any): if len(fsdp_params_with_grad) == 0: return with record_function(self._with_fqn("FSDP::post_backward_reduce")): - if self.comm_ctx.reduce_scatter_state is not None: + if ( + self.comm_ctx.reduce_scatter_state is not None + and self.comm_ctx.reduce_scatter_state.event is not None + ): self.device_handle.current_stream().wait_event( self.comm_ctx.reduce_scatter_state.event ) - self.comm_ctx.reduce_scatter_state = None + self.comm_ctx.reduce_scatter_state = None all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None all_reduce_stream: torch.cuda.Stream if all_reduce_pg is None and self._all_reduce_hook_stream is not None: @@ -458,7 +461,8 @@ def post_backward(self, *unused: Any): reduce_scatter_input, reduce_scatter_event ) if all_reduce_input is not None: - assert all_reduce_event is not None + if self.device.type != "cpu": + assert all_reduce_event is not None self._all_reduce_state = AllReduceState( all_reduce_input, all_reduce_event ) @@ -484,9 +488,12 @@ def _wait_for_post_backward(self): if self._post_reduce_event is not None: self.device_handle.current_stream().wait_event(self._post_reduce_event) self._post_reduce_event = None - if self._all_reduce_state is not None: + if ( + self._all_reduce_state is not None + and self._all_reduce_state.event is not None + ): self.device_handle.current_stream().wait_event(self._all_reduce_state.event) - self._all_reduce_state = None + self._all_reduce_state = None def _backward_prefetch(self) -> None: if self._training_state == TrainingState.PRE_BACKWARD: From b060fedfa8ce8aabc3bab97119489d16dac31348 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 1 Apr 2025 11:42:45 -0700 Subject: [PATCH 082/332] [invoke_subgraph] Support None in the fwd output (#150082) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150082 Approved by: https://github.com/zou3519 --- test/higher_order_ops/test_invoke_subgraph.py | 103 ++++++++++++++++++ torch/_higher_order_ops/base_hop.py | 2 +- torch/_higher_order_ops/invoke_subgraph.py | 53 +++++++-- 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 287957d9f7a6..c508539d708c 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -819,6 +819,109 @@ def run(x, train=True): r1.sum().backward() weight.grad.clone() + def test_return_none_from_fwd(self): + @mark_compile_region + def gn(x): + return x * 2, None, x * 3 + + def fn(x): + ys = gn(x) + return ys[0] + ys[2] + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + x = torch.randn(8, 8, requires_grad=True) + x_clone = x.detach().clone().requires_grad_(True) + + ref = fn(x) + res = opt_fn(x_clone) + + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + + backend = AotEagerAndRecordGraphs() + + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + + x = torch.randn(8, 8, requires_grad=True) + res = opt_fn(x_clone) + res.sum().backward() + + self.assertEqual(len(backend.graphs), 1) + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + self.count_unique_get_attr_nodes(backend.graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.fw_graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.bw_graphs[0], [], 1) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8, 8]"): + l_x_ = L_x_ + + invoke_subgraph_0 = self.invoke_subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None + getitem: "f32[8, 8]" = invoke_subgraph[0] + getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None + + add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None + return (add,) + + class invoke_subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8, 8]"): + child: "f32[8, 8]" = l_x_ * 2 + child_1: "f32[8, 8]" = l_x_ * 3; l_x_ = None + return (child, None, child_1) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[8, 8]"): + ___forward_invoke_subgraph_0_post_graph = self.___forward_invoke_subgraph_0_post_graph + + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_post_graph, '___forward_invoke_subgraph_0_post_graph', (primals_1,)); ___forward_invoke_subgraph_0_post_graph = primals_1 = None + getitem: "f32[8, 8]" = invoke_subgraph_2[0] + getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None + + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(getitem, getitem_2); getitem = getitem_2 = None + return (add,) + + class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module): + def forward(self, primals_0: "f32[8, 8]"): + mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 2) + mul_1: "f32[8, 8]" = torch.ops.aten.mul.Tensor(primals_0, 3); primals_0 = None + return (mul, None, mul_1) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[8, 8]"): + ___backward_invoke_subgraph_0_post_graph = self.___backward_invoke_subgraph_0_post_graph + + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_post_graph, '___backward_invoke_subgraph_0_post_graph', (tangents_1, tangents_1)); ___backward_invoke_subgraph_0_post_graph = tangents_1 = None + getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None + return (getitem_3,) + + class ___backward_invoke_subgraph_0_post_graph(torch.nn.Module): + def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): + mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) + mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + return (add,) +""", + ) + def test_dynamic(self): @mark_compile_region def gn(x): diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index 02eee4b2c07b..5f634f0c6436 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -151,7 +151,7 @@ def backward(ctx, *grad_outputs): from .utils import _from_fun fw_inputs = pytree.tree_map(_from_fun, operands) - _, joint_graph, _ = create_fw_bw_graph( + _, joint_graph, _, _ = create_fw_bw_graph( subgraph, fw_inputs, grad_outputs ) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 16819b44c6f6..1e2c2ce95a30 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -195,6 +195,19 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): with context: grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + num_fw_outs = len(grad_outputs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + none_indexes_in_fwd_out = set() + + for idx, grad in enumerate(grad_outputs): + if grad is None: + none_indexes_in_fwd_out.add(idx) + + grad_outputs = [grad for grad in grad_outputs if grad is not None] + if any( not isinstance(out, torch.Tensor) for out in grad_outputs @@ -214,7 +227,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): fw_inputs, grad_outputs, ) - return fw_graph, bw_graph, len(grad_outputs) + return fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out class InvokeSubgraphAutogradOp(torch.autograd.Function): @@ -224,11 +237,20 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands): + def forward( + ctx, + fw_graph, + bw_graph, + identifier, + num_fw_outs, + none_indexes_in_fwd_out, + *operands, + ): ctx._fw_graph = fw_graph ctx._bw_graph = bw_graph ctx._identifier = identifier ctx._num_fw_outs = num_fw_outs + ctx._none_indexes_in_fwd_out = none_indexes_in_fwd_out with torch._C._AutoDispatchBelowAutograd(): out = invoke_subgraph( @@ -238,6 +260,12 @@ def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands): ) save_tensors_and_symints_for_backward(ctx, operands) + + # Check that None is at expected indexes. + for idx, o in enumerate(out): + if o is None: + assert idx in none_indexes_in_fwd_out + return out @staticmethod @@ -246,10 +274,19 @@ def backward(ctx, *grad_outs): identifier = ctx._identifier primals = saved_tensors_and_symints(ctx) num_fw_outs = ctx._num_fw_outs + none_indexes_in_fwd_out = ctx._none_indexes_in_fwd_out # While tracing we made the assumption that tangents are contiguous. So, - # force the grad_outs to be contiguous. - contiguous_grad_outs = tuple([o.contiguous() for o in grad_outs]) + # force the grad_outs to be contiguous. Some of the grads can be None, + # because the forward outs could be None. Filter them out. + contiguous_grad_outs = [] + for idx, o in enumerate(grad_outs): + if o is not None: + contiguous_grad_outs.append(o.contiguous()) + else: + # Check that None is at expected indexes. + assert idx in none_indexes_in_fwd_out + contiguous_grad_outs = tuple(contiguous_grad_outs) # bw_graph is a joint graph with signature (*primals_and_tangents) and # returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs @@ -258,7 +295,7 @@ def backward(ctx, *grad_outs): grads = invoke_subgraph( bw_graph, f"___backward_{identifier}", primals_and_tangents )[:-num_fw_outs] - return None, None, None, None, *grads + return None, None, None, None, None, *grads @invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) @@ -294,11 +331,13 @@ def _(subgraph, identifier, operands): ): return saved_autograd_fn(*operands) - fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands) + fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out = create_fw_bw_graph( + subgraph, operands + ) def autograd_fn_callable(*args): return InvokeSubgraphAutogradOp.apply( - fw_graph, bw_graph, identifier, num_fw_outs, *args + fw_graph, bw_graph, identifier, num_fw_outs, none_indexes_in_fwd_out, *args ) # Save the autograd_fn_callable in the dispatch set cache. From 61ebe999ccdc8b591019caaa959357ff5de4a374 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 1 Apr 2025 12:05:31 -0700 Subject: [PATCH 083/332] [invoke_subgraph] Do not cache fake tensors for AOTDispatcher first pass (#150450) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150450 Approved by: https://github.com/zou3519 ghstack dependencies: #150082 --- test/higher_order_ops/test_invoke_subgraph.py | 2 +- torch/_subclasses/fake_tensor.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index c508539d708c..aa6ecdd15928 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -956,6 +956,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + @torch._dynamo.config.patch(capture_scalar_outputs=True) def test_unbacked(self): @mark_compile_region def gn(x, y): @@ -970,7 +971,6 @@ def fn(x, y): x = torch.tensor(4) y = torch.randn(8) ref = fn(x, y) - torch._dynamo.config.capture_scalar_outputs = True opt_fn = torch.compile( fn, backend="eager", fullgraph=True ) # Inductor fails with assertion error when lowering aten.sym_constrain_range_for_size.default diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 000949475bc4..1328d5233d36 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1515,16 +1515,15 @@ def _validate_cache_key( for node in subgraph_mod.graph.nodes: if node.op == "call_function": op = node.target - # Dynamo graphs can have operator.add type of operations. For these operations, it is safe to cache. - if ( - callable(op) - and getattr(op, "__module__", None) - in {"_operator", "operator"} - and not op.__name__.startswith("i") - ): - continue - if op in (torch._check, torch._check_is_size): - continue + + # AOTDispatcher first pass does not run make_fx on + # dynamo graphs. As a result, it can have non OpOverload + # ops. + if not isinstance(op, torch._ops.OpOverload): + raise _BypassDispatchCache( + f"{func.name()} hop with a non OpOverload input" + ) + try: self._validate_cache_key(op, [], {}) except _BypassDispatchCache as e: From f09513e515c34e7cd1e5540ebe784ed44f5c30bf Mon Sep 17 00:00:00 2001 From: eqy Date: Wed, 2 Apr 2025 02:41:07 +0000 Subject: [PATCH 084/332] [CUDA]][SymmetricMemory] Interpret empty string as `std::nullopt` in `rendezvous` (#149793) this is a "temporary" fix as current internal API requires strings at some interfaces instead of `std::optional` and empty strings are presumably used in-lieu of `nullopt`. e.g., https://github.com/pytorch/pytorch/blob/9d02b3993f7dae7fa3379d5190ac88291ecd4dce/torch/csrc/distributed/c10d/intra_node_comm.cu#L49 this currently breaks `test_intra_node_comm_all_reduce` Pull Request resolved: https://github.com/pytorch/pytorch/pull/149793 Approved by: https://github.com/kwen2501, https://github.com/cyyever --- torch/csrc/distributed/c10d/CUDASymmetricMemory.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu index 721d2c815875..08f61c80b1bb 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -786,7 +786,7 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::rendezvous( std::string group_name_; // Treat empty string and std::nullopt the same as empty string seems to be // implicitly used that way - if (group_name != "") { + if (group_name.has_value() && group_name != "") { group_name_ = *group_name; } else { if (!block->default_group_name.has_value()) { From 5734909f343ab1de44ed5ab23311d43a9c6afaed Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Wed, 2 Apr 2025 02:44:47 +0000 Subject: [PATCH 085/332] [Profiler] Fix Empty C Call Queue (#150370) Summary: My commandeer of https://github.com/pytorch/pytorch/pull/150102 Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. It adds some unnecessary events at worst but prevents segfaults/failures. My PR just cleans up some refcount impl and logging. Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues. Differential Revision: D72207570 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150370 Approved by: https://github.com/aaronenyeshi --- torch/csrc/autograd/profiler_python.cpp | 39 +++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a98d1a8b7934..02ab02856864 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -709,6 +709,8 @@ class PythonTracer final : public python_tracer::PythonTracerBase { const std::vector interpreterThreads() const; + PyObject* get_callable_from_frame(PyFrameObject* frame); + std::atomic active_lock_{false}; bool active_{false}; @@ -787,6 +789,13 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { recordPyCall(thread_local_results_.back(), it->get(), true); + PyFrameObject* frame = it->get(); + PyObject* callable = get_callable_from_frame(frame); + if (callable) { + // Call recordCCall with the callable and the frame + recordCCall(thread_local_results_.back(), it->get(), callable); + } + auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -901,6 +910,26 @@ void PythonTracer::recordCCall( queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); } +PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) { + if (frame == nullptr) { + return nullptr; + } + // Get the code object associated with the frame + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + if (code == nullptr) { + return nullptr; + } + // Get the function name (if needed) + auto name = THPUtils_unpackStringView(code->co_name).data(); + // To get the function object, you will need to look in the globals or the + // frame's f_globals + PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name); + if (func) { + Py_INCREF(func); // Make sure the returned function has a reference + } + return func; // Returns a PyObject* (the function) +} + // ============================================================================ // == Post processing ========================================================= // ============================================================================ @@ -983,9 +1012,13 @@ class PostProcess { using stack_t = std::vector>; const auto initial_size = out.size(); auto pop = [](stack_t& stack, c10::time_t t) { - TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty."); - std::get>(stack.back()->extra_fields_).end_time_ns_ = t; - stack.pop_back(); + if (!stack.empty()) { + std::get>(stack.back()->extra_fields_).end_time_ns_ = t; + stack.pop_back(); + } else { + TORCH_WARN_ONCE( + "Python replay stack is empty during pop operation! May result in incorrect stack tracing."); + } }; ska::flat_hash_map stacks; From 063ea5d66995e2e27b869f9a2af4dae641488479 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Tue, 1 Apr 2025 13:42:21 -0700 Subject: [PATCH 086/332] [AOTInductor] Modify test for Memory tracking for memory-related (#150269) operations Summary: Fix the test for memory tracking. This PR does: (1) Add tracking before and after for all memory-related operations. Make sure the operation do indeed captures memory both in CUDA and torch's CUDACachAllocator Make sure the operation do indeed captures consumed memory both in CUDA and torch's CUDACachAllocator. (2) Keep track of memory being reserved by CUDACacheAllocator in torch and it's relationship with global CUDA memory consumption. Test Plan: This PR is adding tests. Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/150269 Approved by: https://github.com/jingsh, https://github.com/chenyang78, https://github.com/desertfire --- test/cpp/aoti_inference/test.cpp | 70 +++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/test/cpp/aoti_inference/test.cpp b/test/cpp/aoti_inference/test.cpp index 9861fd6bdead..1bf6ecc1ecfe 100644 --- a/test/cpp/aoti_inference/test.cpp +++ b/test/cpp/aoti_inference/test.cpp @@ -460,7 +460,6 @@ void test_aoti_double_buffering_with_tensor_constants() { void test_aoti_free_buffer(bool use_runtime_constant_folding) { torch::NoGradGuard no_grad; - size_t allocated, reserved, active; std::string data_path = (std::filesystem::path( @@ -511,7 +510,11 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } c10::cuda::CUDACachingAllocator::DeviceStats stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + size_t initTorchActive = stats.active_bytes[0].current; + size_t initTorchReserved = stats.reserved_bytes[0].current; // This should contain one set of weight (128MB) loaded from .so + size_t torchActive1, torchActive2; + size_t torchReserved1, torchReserved2; size_t initMemory = 0; size_t totalMemory = 0; cudaStatus = cudaMemGetInfo(&initMemory, &totalMemory); @@ -532,18 +535,30 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { // (64MB). if (use_runtime_constant_folding) { runner->run_const_fold(/* use_inactive = */ true); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; size_t constFoldMemory = 0; cudaStatus = cudaMemGetInfo(&constFoldMemory, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - DATASIZE - FOLDEDDATASIZE, constFoldMemory); + ASSERT_EQ( + initMemory - DATASIZE - (torchReserved1 - initTorchReserved), + constFoldMemory); + ASSERT_EQ(torchActive1 - initTorchActive, FOLDEDDATASIZE); } // We swap and free the inactive buffer. (Use #2 and free #1) - // Note that buffer #1 do not include folded-const + // Note that buffer #1 does not include folded-const + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; runner->swap_constant_buffer(); runner->free_inactive_constant_buffer(); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; size_t postFreeMemory = 0; cudaStatus = cudaMemGetInfo(&postFreeMemory, &totalMemory); if (cudaStatus != cudaSuccess) { @@ -551,60 +566,77 @@ void test_aoti_free_buffer(bool use_runtime_constant_folding) { } // We should only have one set of buffer (#2), available memory should equal // initial memory minus the folded constants. - ASSERT_EQ(initMemory - FOLDEDDATASIZE, postFreeMemory); + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), postFreeMemory); + // Buffer #1 does not include folded-consts + ASSERT_EQ(torchActive2 - torchActive1, 0); // We update random weights to buffer #1 and run const fold. // We will have 2 full set of data plus 2 set of const-folded data. runner->update_inactive_constant_buffer(rand_map); runner->run_const_fold(/* use_inactive = */ true); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; size_t updateMemory1 = 0; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1); + ASSERT_EQ( + initMemory - DATASIZE - (torchReserved1 - initTorchReserved), + updateMemory1); + ASSERT_EQ(torchActive1 - initTorchActive, 2 * FOLDEDDATASIZE); // We directly free the buffer #1. This would free the DATASIZE weight. // If folded constant exists, it will not directly free the cudaMalloc, but // decrease the active buffer in CachingAllocator instead. - size_t active1, active2; - size_t allocated1, allocated2; stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active1 = stats.active_bytes[0].current; - allocated1 = stats.allocated_bytes[0].current; + torchActive1 = stats.active_bytes[0].current; runner->free_inactive_constant_buffer(); cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active2 = stats.active_bytes[0].current; - allocated2 = stats.allocated_bytes[0].current; - ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(FOLDEDDATASIZE, active1 - active2); + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), updateMemory1); + ASSERT_EQ(FOLDEDDATASIZE, torchActive1 - torchActive2); // Free buffer #1 again, since #1 is freed, nothing should change. + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; runner->free_inactive_constant_buffer(); + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive2 = stats.active_bytes[0].current; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(FOLDEDDATASIZE, active1 - active2); + ASSERT_EQ(initMemory - (torchReserved2 - initTorchReserved), updateMemory1); + ASSERT_EQ(torchActive1 - torchActive2, 0); // Swap and free #2, no data should exist in memory now. - // However, the folded constants still occupies the CUDA memory in + // However, the folded constants might still occupies the CUDA memory in // CachedAllocator. + stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); + torchActive1 = stats.active_bytes[0].current; + torchReserved1 = stats.reserved_bytes[0].current; runner->swap_constant_buffer(); runner->free_inactive_constant_buffer(); stats = c10::cuda::CUDACachingAllocator::getDeviceStats(device_idx); - active2 = stats.active_bytes[0].current; + torchActive2 = stats.active_bytes[0].current; + torchReserved2 = stats.reserved_bytes[0].current; cudaStatus = cudaMemGetInfo(&updateMemory1, &totalMemory); if (cudaStatus != cudaSuccess) { throw std::runtime_error("cudaMemGetInfo failed!"); } - ASSERT_EQ(initMemory + DATASIZE - 2 * FOLDEDDATASIZE, updateMemory1); - ASSERT_EQ(2 * FOLDEDDATASIZE, active1 - active2); + + ASSERT_EQ( + initMemory + DATASIZE - (torchReserved2 - initTorchReserved), + updateMemory1); + ASSERT_EQ(FOLDEDDATASIZE, torchActive1 - torchActive2); + ASSERT_EQ(0, torchActive2 - initTorchActive); for (auto& pair : rand_map) { delete pair.second; From 25eff6e991e27350c4d7494cea79d58f48c90417 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 1 Apr 2025 14:34:14 -0700 Subject: [PATCH 087/332] [dynamo] add reason field to torch.compiler.disable (#150341) Implements https://github.com/pytorch/pytorch/issues/146445 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150341 Approved by: https://github.com/zou3519, https://github.com/jansel --- test/dynamo/test_error_messages.py | 107 +++++++++++++++++++-------- torch/_dynamo/decorators.py | 9 ++- torch/_dynamo/eval_frame.py | 4 +- torch/_dynamo/symbolic_convert.py | 6 +- torch/_dynamo/variables/builder.py | 8 +- torch/_dynamo/variables/functions.py | 4 +- torch/_dynamo/variables/misc.py | 20 ++++- torch/_dynamo/variables/tensor.py | 3 +- torch/compiler/__init__.py | 5 +- 9 files changed, 123 insertions(+), 43 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 255d62a5c4ff..3793ade26738 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -308,38 +308,6 @@ def post_munge(s): post_munge=post_munge, ) - def test_disable(self): - @torch.compiler.disable - def inner(): - return 1 - - def fn(): - return inner() - - def post_munge(s): - return re.sub( - r"\.inner at 0x[0-9A-Fa-f]+>", - "", - s, - ) - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)(), - """\ -Skip calling `torch.compiler.disable()`d function - Explanation: Skip calling function `` since it was wrapped with `torch.compiler.disable` - Hint: Remove the `torch.compiler.disable` call - - Developer debug context: - - -from user code: - File "test_error_messages.py", line N, in fn - return inner()""", - post_munge=post_munge, - ) - def test_dynamo_graph_break_fn(self): def fn(): torch._dynamo.graph_break() @@ -1115,6 +1083,81 @@ def f3(x): """, ) + def test_disable_message(self): + @torch.compile(backend="eager", fullgraph=True) + def outer(fn, x): + return fn(x) + + @torch.compiler.disable + def f(x): + return x + 1 + + def post_munge(s): + return re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s) + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(f, torch.randn(3)), + """\ +Skip calling `torch.compiler.disable()`d function + Explanation: Skip calling function `.f at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: None) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .f at 0xmem_addr> + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + + @torch.compiler.disable(reason="test message") + def g(x): + return x + 2 + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(g, torch.randn(3)), + """\ +Skip calling `torch.compiler.disable()`d function + Explanation: Skip calling function `.g at 0xmem_addr>` since it was wrapped with `torch.compiler.disable` (reason: test message) + Hint: Remove the `torch.compiler.disable` call + + Developer debug context: .g at 0xmem_addr> + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + + class Mod(torch.nn.Module): + def forward(self, x): + return x + 3 + + mod = Mod() + mod.compile() + mod = torch.compiler.disable(mod, reason="test message 2") + + self.assertExpectedInlineMunged( + Unsupported, + lambda: outer(mod, torch.randn(3)), + """\ +Unsupported function call (delayed) + Explanation: Dynamo determined that a graph break should occur when calling `L['fn']`. Reason: Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: test message 2) + + + Developer debug context: source: LocalSource(local_name='fn', is_input=True, dynamism=None, is_derefed_cell_contents=False) + + +from user code: + File "test_error_messages.py", line N, in outer + return fn(x)""", + post_munge=post_munge, + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index def6c5fd2919..5d966c5d1f64 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -65,7 +65,7 @@ def run(fn=None): return RunOnlyContext() -def disable(fn=None, recursive=True): +def disable(fn=None, recursive=True, *, reason=None): """ Decorator to disable TorchDynamo @@ -74,13 +74,15 @@ def disable(fn=None, recursive=True): If recursive=False, Dynamo skips frames associated with the function code, but still process recursively invoked frames. + + If reason is provided, it will be printed when Dynamo attempts to trace the disabled function. """ if recursive: if fn is not None: fn = innermost_fn(fn) assert callable(fn) - return DisableContext()(fn) - return DisableContext() + return DisableContext(msg=reason)(fn) + return DisableContext(msg=reason) else: def wrap(fn): @@ -89,6 +91,7 @@ def wrap(fn): nonrecursive_disable_wrapper = get_nonrecursive_disable_wrapper(fn) nonrecursive_disable_wrapper._torchdynamo_disable = True # type: ignore[attr-defined] + nonrecursive_disable_wrapper._torchdynamo_disable_msg = reason # type: ignore[attr-defined] nonrecursive_disable_wrapper._torchdynamo_orig_callable = fn # type: ignore[attr-defined] return nonrecursive_disable_wrapper diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 8527daa7a796..b86870f4ca27 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -805,8 +805,9 @@ def __reduce__(self): class DisableContext(_TorchDynamoContext): - def __init__(self) -> None: + def __init__(self, msg: Optional[str] = None) -> None: super().__init__(callback=None) + self.msg = msg def __call__(self, fn): # Earlier this code was in the base class _TorchDynamoContext. But we @@ -854,6 +855,7 @@ def _fn(*args, **kwargs): _maybe_set_eval_frame(prior) _fn._torchdynamo_disable = True # type: ignore[attr-defined] + _fn._torchdynamo_disable_msg = self.msg # type: ignore[attr-defined] # Save the function pointer to find the original callable while nesting # of decorators. diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 0d8b37d7ce6c..80960d6eb94a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3784,10 +3784,14 @@ def check_inlineable(func): if isinstance(func, UserFunctionVariable) and inspect.getattr_static( func.get_function(), "_torchdynamo_disable", False ): + msg = inspect.getattr_static( + func.get_function(), "_torchdynamo_disable_msg", None + ) unimplemented_v2( gb_type="Skip inlining `torch.compiler.disable()`d function", context=str(func.get_function()), - explanation=f"Skip inlining function {func.get_function()} since it was wrapped with `torch.compiler.disable`", + explanation=f"Skip inlining function {func.get_function()} since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", hints=[ "Remove the `torch.compiler.disable` call", ], diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d00ea5edc90a..49d0c162d68a 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -1539,7 +1539,13 @@ def wrap_module(self, value: torch.nn.Module): # we graph break here, Dynamo does not know how to create # continuation functions for such bytecodes. So, we delay the # graph break to CALL_FUNCTION. - return DelayGraphBreakVariable(source=self.source) + msg = inspect.getattr_static( + value.forward, "_torchdynamo_disable_msg", None + ) + return DelayGraphBreakVariable( + source=self.source, + msg=f"Optimized `nn.Module` is wrapped with `torch.compiler.disable` (reason: {msg})", + ) self.install_guards(GuardBuilder.TYPE_MATCH) self.source = AttrSource(self.source, "_orig_mod") diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index f30b69e44b6b..fc20350dc943 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1198,10 +1198,12 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": if inspect.getattr_static(self.value, "_torchdynamo_disable", False): + msg = inspect.getattr_static(self.value, "_torchdynamo_disable_msg", None) unimplemented_v2( gb_type="Skip calling `torch.compiler.disable()`d function", context=str(self.value), - explanation=f"Skip calling function `{self.value}` since it was wrapped with `torch.compiler.disable`", + explanation=f"Skip calling function `{self.value}` since it was wrapped " + f"with `torch.compiler.disable` (reason: {msg})", hints=[ "Remove the `torch.compiler.disable` call", ], diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 241bfb2c808b..7eaa01c2a5da 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -35,7 +35,7 @@ from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import raise_observed_exception, unimplemented +from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import AttrSource, GetItemSource, TypeSource, WeakRefCallSource @@ -396,6 +396,24 @@ class DelayGraphBreakVariable(UnknownVariable): Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION. """ + def __init__(self, msg=None, **kwargs): + super().__init__(**kwargs) + self.msg = msg + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + unimplemented_v2( + gb_type="Unsupported function call (delayed)", + context=f"source: {self.source}", + explanation="Dynamo determined that a graph break should occur " + f"when calling `{self.source.name()}`. Reason: {self.msg}", + hints=[], + ) + class ComptimeVariable(VariableTracker): """ diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 3470adfa2c7e..5b10a643ad94 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -453,7 +453,8 @@ def var_getattr(self, tx: "InstructionTranslator", name): ): # Delay the graph break to the actual call of unsqueeze_/resize_/resize_as_ etc. return variables.misc.DelayGraphBreakVariable( - source=AttrSource(self.source, name) + source=AttrSource(self.source, name), + msg="Getting an inplace view on a graph input is not supported", ) # For attributes (not methods) that were not caught in the special handling above, diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 321ededbb24a..aa6a27a3dcc3 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -228,7 +228,7 @@ def assume_constant_result(fn): return torch._dynamo.assume_constant_result(fn) -def disable(fn=None, recursive=True): +def disable(fn=None, recursive=True, *, reason=None): """ This function provides a decorator to disable compilation on a function. It also provides the option of recursively disabling called functions. @@ -236,10 +236,11 @@ def disable(fn=None, recursive=True): Args: fn (optional): The function to disable recursive (optional): A boolean value indicating whether the disabling should be recursive. + reason (optional): A string value indicating the reason for disabling the function. """ import torch._dynamo - return torch._dynamo.disable(fn, recursive) + return torch._dynamo.disable(fn, recursive, reason=reason) def set_stance( From 3ac5a499ddac701f607a9f7206f9bec8871e1cbb Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 1 Apr 2025 15:08:04 -0700 Subject: [PATCH 088/332] [dynamo] add dynamo disable reasons to codebase (#150440) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150440 Approved by: https://github.com/jansel, https://github.com/zou3519 ghstack dependencies: #150341 --- torch/_dynamo/backends/common.py | 9 ++++-- torch/_dynamo/eval_frame.py | 32 +++++++++++++++---- torch/_dynamo/output_graph.py | 9 ++++-- torch/_dynamo/utils.py | 12 +++++-- torch/_higher_order_ops/cond.py | 2 +- .../fsdp/_fully_shard/_fsdp_state.py | 6 +++- torch/export/unflatten.py | 5 ++- torch/fx/graph_module.py | 5 ++- 8 files changed, 62 insertions(+), 18 deletions(-) diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index f92d16bf2b30..246596bcbcab 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -69,7 +69,12 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs): def wrap_bw_compiler(bw_compiler_fn): def _wrapped_bw_compiler(*args, **kwargs): # stop TorchDynamo from trying to compile our generated backwards pass - return disable(disable(bw_compiler_fn)(*args, **kwargs)) + return disable( + disable( + bw_compiler_fn, reason="do not trace backward compiler function" + )(*args, **kwargs), + reason="do not trace generated backwards pass", + ) return _wrapped_bw_compiler @@ -100,7 +105,7 @@ def _wrapped_bw_compiler(*args, **kwargs): with enable_aot_logging(), patch_config: cg = aot_module_simplified(gm, example_inputs, **self.kwargs) counters["aot_autograd"]["ok"] += 1 - return disable(cg) + return disable(cg, reason="do not trace AOT-compiled graph") except TensorifyScalarRestartAnalysis: raise except Exception: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index b86870f4ca27..1bcad8ef5b5f 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1901,11 +1901,20 @@ def patch(): # with torch.deploy internally. from .decorators import disable - torch.jit.trace = disable(torch.jit.trace) - torch.jit.trace_module = disable(torch.jit.trace_module) - torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) + torch.jit.trace = disable( + torch.jit.trace, reason="tracing into TorchScript not fully supported" + ) + torch.jit.trace_module = disable( + torch.jit.trace_module, + reason="tracing into TorchScript not fully supported", + ) + torch.jit._get_trace_graph = disable( + torch.jit._get_trace_graph, + reason="tracing into TorchScript not fully supported", + ) torch.fx._symbolic_trace.Tracer.trace = disable( - torch.fx._symbolic_trace.Tracer.trace + torch.fx._symbolic_trace.Tracer.trace, + reason="tracing into FX not fully supported", ) torch.distributions.Distribution.set_default_validate_args(False) @@ -1947,7 +1956,12 @@ def patch(): if hasattr(opt_mod, fused_fn_name): setattr( - opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) + opt_mod, + fused_fn_name, + disable( + getattr(opt_mod, fused_fn_name), + reason="don't trace into fused optimizer", + ), ) optimizer_classes = [ @@ -1964,10 +1978,14 @@ def patch(): for opt in optimizer_classes: if opt in excluded_optimizer_classes: - opt.step = disable(opt.step) + opt.step = disable( + opt.step, reason=f"optimizer {opt} step not supported" + ) if hasattr(opt, "_init_group"): - opt._init_group = disable(opt._init_group) + opt._init_group = disable( + opt._init_group, reason=f"optimizer {opt} _init_group not supported" + ) @staticmethod def suppress_torch_distributed_warnings(fn): diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 69d1ac475790..c11e6deccc7d 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -1122,7 +1122,10 @@ def append_prefix_insts(): append_prefix_insts() random_calls_instructions = [] self.random_values_var = self.new_var("random_values") - rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) + rand_fn = disable( + _get_gen_rand_values_fn(self.random_calls), + reason="do not trace into Dynamo rng recovery function", + ) rand_fn_name = self.install_global("__gen_rand_values", rand_fn) codegen = PyCodegen(tx, root, overridden_sources=overridden_sources) random_calls_instructions.extend( @@ -1470,7 +1473,9 @@ def compile_and_call_fx_graph(self, tx, rv, root): # replace compiled_fn with the real forward method compiled_fn = lazy_gm.forward - compiled_fn = disable(compiled_fn) + compiled_fn = disable( + compiled_fn, reason="do not trace Dynamo-compiled graph" + ) counters["stats"]["unique_graphs"] += 1 # This is safe because we pre-process name to be unique diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 89e47e823cdd..8ee9289633b1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4371,7 +4371,9 @@ def does_not_override_dict_iter_methods(user_cls): # compiled bytecode # They will be skipped which is the desired result def call_size(x, i): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x, i): return x.size(i) @@ -4379,7 +4381,9 @@ def fn(x, i): def call_stride(x, i): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x, i): return x.stride(i) @@ -4387,7 +4391,9 @@ def fn(x, i): def call_storage_offset(x): - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable( + recursive=True, reason="__torch_function__ tracing helper function" + ) def fn(x): return x.storage_offset() diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 6501ca6ad1ca..31846752e3db 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -204,7 +204,7 @@ def materialize_as_graph( exclude_key_set: torch._C.DispatchKeySet, force_enable_grad=False, ) -> torch.fx.GraphModule: - @torch._dynamo.disable(recursive=True) + @torch._dynamo.disable(recursive=True, reason=None) def _materialize_as_graph_inner(): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 5d11f0359f1f..601a77185e40 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -59,7 +59,11 @@ def disable_if_config_true(func): @functools.wraps(func) def fsdp_hook_wrapper(*args, **kwargs): if torch._dynamo.config.skip_fsdp_hooks: - return torch._dynamo.disable(func, recursive=True)(*args, **kwargs) + return torch._dynamo.disable( + func, + recursive=True, + reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set", + )(*args, **kwargs) else: return func(*args, **kwargs) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 833961170001..85a01ea13ee7 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -589,7 +589,10 @@ def process_forward_inputs(self, *args, **kwargs): return flat_args def forward(self, *args, **kwargs): - flat_args = torch._dynamo.disable(self.process_forward_inputs)(*args, **kwargs) + flat_args = torch._dynamo.disable( + self.process_forward_inputs, + reason="do not trace into preprocessing the inputs", + )(*args, **kwargs) signature = self.module_call_graph[0].signature if is_fx_tracing(): diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 3910020cfad9..57f65acef9b2 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -372,7 +372,10 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: all_src_lines = linecache.getlines(frame_summary.filename) # constituent substrings of the error message - tb_repr = torch._dynamo.disable(traceback.format_exc)() + tb_repr = torch._dynamo.disable( + traceback.format_exc, + reason="do not trace into traceback.format_exc when generating error message", + )() custom_msg = ( "Call using an FX-traced Module, " f"line {err_lineno} of the traced Module's " From dee016ceb7451d2b9815ec1735afb08a6c4732e4 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 1 Apr 2025 12:33:15 -0700 Subject: [PATCH 089/332] [MPSInductor] Add `store_reduce` method (#150457) That restrict the store operation to 0th thread, which should be much better, shouldn't it (Though I don't observe it in the benchmark) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150457 Approved by: https://github.com/jansel, https://github.com/dcci ghstack dependencies: #150452 --- torch/_inductor/codegen/mps.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index ac2218e3e0f5..b600721e1a30 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -493,6 +493,16 @@ def store( else: self.stores.writeline(DeferredLine(name, line)) + def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> None: + var = self.args.output(name) + index = self.prepare_indexing(index) + dtype_str = self.dtype_to_str(V.graph.get_dtype(name)) + reduction_dim = next(t for t in self.range_trees if t.is_reduction) + # Only one thread in the reduction group needs to store the results + line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});" + line = f"if ({reduction_dim.name} == 0) {line}" + self.stores.writeline(DeferredLine(name, line)) + def _new_accvar( self, dtype: torch.dtype, From c65de03196ae3dbeb67ef38d43c4639b85a60ce4 Mon Sep 17 00:00:00 2001 From: Rebecca Chen Date: Wed, 2 Apr 2025 05:25:03 +0000 Subject: [PATCH 090/332] Add `Any` return annotation to `__getattr__` methods that return a union of types. (#150204) Adds an `Any` return type annotation to `__getattr__` methods in `torch/_ops.py` that return a union of types. Attribute access returning a union of types can cause issues downstream because consumers would need to handle all of the possible types to make the type checker happy. This doesn't seem to matter today for mypy, presumably because `Any` is always inferred when a return type annotation is missing, but it still makes explicit what mypy is already doing implicitly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150204 Approved by: https://github.com/malfet --- torch/_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_ops.py b/torch/_ops.py index c6f5be583e41..0842f57fbff7 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -1086,7 +1086,7 @@ def _schemas(self): for overload_name in self._overload_names } - def __getattr__(self, key): + def __getattr__(self, key) -> Any: # It is not a valid op_name when __file__ is passed in if key == "__file__": return "torch.ops" @@ -1246,7 +1246,7 @@ def __init__(self, name): def __iter__(self): return iter(self._dir) - def __getattr__(self, op_name): + def __getattr__(self, op_name) -> Any: # It is not a valid op_name when __file__ is passed in if op_name == "__file__": return "torch.ops" From 0da8127f77f9bf05ba204ea7659cb15ec85e88a7 Mon Sep 17 00:00:00 2001 From: Sukchul Cho Date: Wed, 2 Apr 2025 06:06:02 +0000 Subject: [PATCH 091/332] Compare device name of profiler dynamically (#150396) Compare self.use_device of torch.autograd.profiler.profiler with _get_privateuse1_backend_name(), since privateuse1 backend can be renamed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150396 Approved by: https://github.com/sraikund16 --- torch/autograd/profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 50745586ca63..0a8b9d1e29d3 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -629,7 +629,7 @@ def _device_memory_usage(mem_record): ) max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: - if self.use_device == "privateuseone": + if self.use_device == _get_privateuse1_backend_name(): privateuse1_time = kineto_event.privateuse1_elapsed_us() if privateuse1_time > 0: fe.append_kernel(fe.name, fe.device_index, privateuse1_time) From 3f54b14c753f61ac52c5e1c0a6cf1567b6eced2d Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Wed, 2 Apr 2025 07:21:46 +0000 Subject: [PATCH 092/332] [CUDAGraph] support meta tensor (#150478) Previously, cudagraph is skipped if the graph contains any meta tensor. However, we should not skip since meta tensor does not have actual computation. This PR fixes the issue. ### Example ```python import torch def foobar(x, y): return x * 2, y * 3 foo_c = torch.compile(mode="reduce-overhead")(foobar) t = torch.empty((1, 16, 128, 128), device="meta") y = torch.rand([64], device="cuda") eager_out = foobar(t, y) for _ in range(3): compiled_out = foo_c(t, y) ``` Prior to this PR, above code leads to ``` skipping cudagraphs due to multiple devices: device(type='cuda', index=0), device(type='meta') ``` With this PR, we don't skip. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150478 Approved by: https://github.com/eellison --- test/inductor/test_cudagraph_trees.py | 17 +++++++++++++++++ torch/_inductor/cudagraph_utils.py | 3 +++ 2 files changed, 20 insertions(+) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 742000347a35..a536aa7ab74f 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -3048,6 +3048,23 @@ def run(shape_x, shape_y): self.assertEqual(self.get_manager().new_graph_id().id, 3) + def test_meta_tensor(self): + def foobar(x, y): + return x * 2, y * 3 + + foo_c = torch.compile(mode="reduce-overhead")(foobar) + t = torch.empty((1, 16, 128, 128), device="meta") + y = torch.rand([64], device="cuda") + + eager_out = foobar(t, y) + + for _ in range(3): + compiled_out = foo_c(t, y) + + compiled_out = foo_c(t, y) + self.assertEqual(eager_out, compiled_out) + self.assertEqual(self.get_manager().new_graph_id().id, 1) + class TestSAC(TestCase): def _make_observer_mode(self): class ObserverMode(TorchDispatchMode): diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 68ea4a010e6e..f6ce7e43ad95 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -167,6 +167,9 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]: def check_multiple_devices_or_any_cpu_nodes( device_node_mapping: dict[torch.device, torch.fx.Node], ) -> Optional[str]: + # meta tensors are supported since there is no compute + device_node_mapping.pop(torch.device("meta"), None) + if torch._inductor.config.graph_partition: # graph partition supports splitting on cpu op. So we can ignore cpu nodes. device_node_mapping.pop(torch.device("cpu"), None) From 75f38dfd4e8daa38a705d677eca5648743fab6bd Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Wed, 2 Apr 2025 03:54:39 +0000 Subject: [PATCH 093/332] cpp_wrapper: precompile a few more commonly used headers, and improve RAIIPyObject interface (#149350) Add includes for torch.device, torch.dtype, torch.layout, and torch.memory_format to the cpp_wrapper common header, so that they get precompiled. Additionally, add move constructors and operator bool to RAIIPyObject. Closes #142005. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149350 Approved by: https://github.com/desertfire --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 24 +++------------ torch/csrc/inductor/cpp_wrapper/common.h | 35 +++++++++++++++++++--- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 1ea1459659ae..fc30b0f3e437 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -2041,11 +2041,11 @@ def load_custom_op_wrapper(self): lines = """ RAIIPyObject codecache_module(PyImport_ImportModule("torch._inductor.codecache")); -if (codecache_module.get() == NULL) { +if (!codecache_module) { throw std::runtime_error("Failed to load torch._inductor.codecache"); } custom_op_wrapper = PyObject_GetAttrString(codecache_module, "custom_op_wrapper"); -if (custom_op_wrapper.get() == NULL) { +if (!custom_op_wrapper) { throw std::runtime_error("Failed to load torch._inductor.codecache.custom_op_wrapper"); }""" @@ -2070,11 +2070,6 @@ def generate_float_value(self, val): def generate_py_arg(self, py_args_var, idx, raw_arg, arg_type): def generate_py_arg_inner(lines, raw_arg, arg_type): - def add_py_newref(): - if sys.version_info < (3, 10): - # Py_NewRef is only available since Python 3.10 - self.include_extra_header("torch/csrc/utils/pythoncapi_compat.h") - def handle_scalar(scalar): if isinstance(scalar, int): return f"PyLong_FromLongLong({scalar})" @@ -2135,24 +2130,13 @@ def handle_scalar(scalar): # torch/_prims_common/__init__.py return handle_scalar(raw_arg) elif isinstance(raw_arg, torch.device): - # device - self.include_extra_header("torch/csrc/Device.h") device_str, device_index = self.codegen_device(raw_arg).split(", ") return f"THPDevice_New(c10::Device(static_cast({device_str}), {device_index}))" elif isinstance(raw_arg, torch.dtype): - # dtype - add_py_newref() - self.include_extra_header("torch/csrc/DynamicTypes.h") return f"Py_NewRef(torch::getTHPDtype(static_cast({self.codegen_dtype(raw_arg)})))" elif isinstance(raw_arg, torch.layout): - # memory layout - add_py_newref() - self.include_extra_header("torch/csrc/DynamicTypes.h") return f"Py_NewRef(torch::getTHPLayout(static_cast({self.codegen_layout(raw_arg)})))" elif isinstance(raw_arg, torch.memory_format): - # memory_format - add_py_newref() - self.include_extra_header("torch/csrc/utils/tensor_memoryformats.h") return ( "Py_NewRef(torch::utils::getTHPMemoryFormat(static_cast(" f"{self.codegen_memory_format(raw_arg)})))" @@ -2204,7 +2188,7 @@ def generate_fallback_kernel_with_runtime_lookup_jit( lines = textwrap.dedent( f""" RAIIPyObject {py_args_var}(PyTuple_New({num_args + 1})); - if ({py_args_var}.get() == NULL) {{ + if (!{py_args_var}) {{ throw std::runtime_error("PyTuple_New {py_args_var} failed"); }} PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); @@ -2224,7 +2208,7 @@ def generate_fallback_kernel_with_runtime_lookup_jit( f""" // Call the custom op in Python RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); - if (py_{buf_name}.get() == NULL) {{ + if (!py_{buf_name}) {{ if (PyErr_Occurred()) {{ return; }} diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h index 3f77347f5274..2b59855cbc6e 100644 --- a/torch/csrc/inductor/cpp_wrapper/common.h +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -3,16 +3,33 @@ #include #include #include +#include #include #define PYBIND11_SIMPLE_GIL_MANAGEMENT #include -namespace py = pybind11; + +// Include some often-used cpp_wrapper headers, for precompiling. +#include +#include +#include +#include +#include + +namespace py = pybind11; // NOLINT(misc-unused-alias-decls) class RAIIPyObject { public: - RAIIPyObject() : obj_(nullptr) {} - RAIIPyObject(PyObject* obj) : obj_(obj) {} + RAIIPyObject() = default; + // steals a reference to a PyObject + RAIIPyObject(PyObject* obj) : obj_{obj} {} + RAIIPyObject(const RAIIPyObject& other) : obj_{other.obj_} { + Py_XINCREF(obj_); + } + RAIIPyObject(RAIIPyObject&& other) noexcept { + // refcount doesn't change, and obj_ is currently nullptr + std::swap(obj_, other.obj_); + } ~RAIIPyObject() { Py_XDECREF(obj_); } @@ -24,6 +41,16 @@ class RAIIPyObject { } return *this; } + RAIIPyObject& operator=(RAIIPyObject&& other) noexcept { + // refcount to the current object decreases, but refcount to other.obj_ is + // the same + Py_XDECREF(obj_); + obj_ = std::exchange(other.obj_, nullptr); + return *this; + } + operator bool() const noexcept { + return obj_; + } operator PyObject*() { return obj_; } @@ -32,7 +59,7 @@ class RAIIPyObject { } private: - PyObject* obj_; + PyObject* obj_{nullptr}; }; #include From 03138733ba7b7afb2821ed4516eea87a6215798f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sun, 30 Mar 2025 16:32:52 -0700 Subject: [PATCH 094/332] [AOTI] Emit Triton kernels as comment (#150188) Summary: Emit the corresponding Triton kernel code as comment in each call_triton_ wrapper function, for easier debugging. Differential Revision: [D72178907](https://our.internmc.facebook.com/intern/diff/D72178907) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150188 Approved by: https://github.com/yushangdi --- test/inductor/test_aot_inductor.py | 4 ++++ torch/_inductor/codecache.py | 4 ++++ torch/_inductor/codegen/cpp_wrapper_gpu.py | 13 ++++++++++++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 12443a3bee89..ce661145d9c7 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -4060,6 +4060,10 @@ def forward(self, a, b, c): AOTIRunnerUtil.compile, model, example_inputs ) self.assertEqual("aoti_torch_print_tensor_handle" in code, True) + + # check if the triton kernel is printed as comment + self.assertEqual("def triton_" in code, True) + # check the codegen for debug printing around aoti model inputs is expected for kernel_call, count in kernel_calls: FileCheck().check_count( diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index e331badcff34..aaad5a53486a 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1451,6 +1451,10 @@ def compile( extra=cpp_command, specified_dir=specified_output_path, ) + kernel_code = ( + f"// Triton kernels are embedded as comments in {wrapper_path}\n" + + kernel_code + ) _, kernel_path = write( kernel_code, "kernel.cpp", diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 56f0941715a9..77364ae48734 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -54,6 +54,7 @@ class DeferredTritonCallWrapper: wrapper_name: str kernel_name: str + kernel_name_to_body: dict[str, str] arg_types: list[Any] def generate(self, wrapper: CppWrapperGpu): @@ -122,6 +123,11 @@ def generate(self, wrapper: CppWrapperGpu): ) prefix.writeline("){") with prefix.indent(): + if V.graph.aot_mode: + # Emit the original Triton kernel for debugging purposes + prefix.writeline("/*") + prefix.splice(self.kernel_name_to_body[self.kernel_name]) + prefix.writeline("*/") self.generate_grid(prefix, inductor_meta, params) self.generate_load_kernel(prefix, kernel_var_name, params) self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) @@ -205,6 +211,7 @@ def __init__(self) -> None: self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() + self._kernel_name_to_body: dict[str, str] = {} self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" @@ -296,6 +303,7 @@ def define_kernel( cpp_definition: Optional[str] = None, ): if gpu: + self._kernel_name_to_body[kernel_name] = kernel_body if config.triton.autotune_at_compile_time: # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.define_kernel( @@ -502,7 +510,10 @@ def generate_kernel_call( wrapper_name = f"call_{kernel_name}" if wrapper_name not in self._triton_call_wrappers: self._triton_call_wrappers[wrapper_name] = DeferredTritonCallWrapper( - wrapper_name, kernel_name, arg_types + wrapper_name, + kernel_name, + self._kernel_name_to_body, + arg_types, ) call_args.append(stream) if V.graph.aot_mode: From c41fbb4f783fd4910d7776a32f608d74dbb0fba9 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 1 Apr 2025 21:07:25 -0700 Subject: [PATCH 095/332] Change arg_kwarg_vals propagation strategy (#148046) Instead of always propagating arg_kwarg_vals in _COPY_META_FIELDS, we special-case the pattern matcher to propagate arg_kwarg_vals when it sees triton_kernel_wrapper_functional. The strategy is: 1) trace out the replacement graph with arg_kwarg_vals (which have accurate eager-mode metadata) 2) trace out the replacement graph with vals (which have the accurate Inductor metadata) 3) Propagate the arg_kwarg_vals from the first graph to the second. 4) Use the second graph as the replacement graph. The strategy is this because we want to extend this to handle auto_functionalized later up in the stack. Test Plan: - existing tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/148046 Approved by: https://github.com/eellison --- torch/_inductor/fx_passes/reinplace.py | 10 +++- torch/_inductor/pattern_matcher.py | 70 ++++++++++++++++++++++++-- torch/fx/proxy.py | 1 - 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index a4d6f482e25d..48541bcc5e34 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -21,7 +21,7 @@ ) from torch._inductor.virtualized import V from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode -from torch.fx.immutable_collections import immutable_dict +from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.passes.reinplace import _is_view_op from torch.utils import _pytree as pytree from torch.utils._ordered_set import OrderedSet @@ -720,6 +720,14 @@ def tensor_with_same_storage_already_reinplaced(arg): kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone node.kwargs = immutable_dict(kwargs) + if "arg_kwarg_vals" in node.meta: + # We changed the kwargs, so we need to update arg_kwarg_vals + # to something sane. + args, kwargs = node.meta["arg_kwarg_vals"] + new_kwargs = {**kwargs} + new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) + new_kwargs = immutable_dict(new_kwargs) + node.meta["arg_kwarg_vals"] = (args, new_kwargs) elif ( inplaceable_op := inplaceable_foreach_ops.get(node.target, None) ) is not None: diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 1891c7d15dca..d337dcae4c99 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -251,14 +251,73 @@ def replace_by_example( else contextlib.nullcontext() ) + def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool: + if len(nodes) != 1: + return False + node = nodes[0] + if "arg_kwarg_vals" not in node.meta: + return False + return node.target in OrderedSet( + [ + torch.ops.higher_order.triton_kernel_wrapper_functional, + ] + ) + with context: if trace_fn is None: trace_fn = functools.partial( fwd_only, run_functional_passes=run_functional_passes ) - replacement = trace_fn( - replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) - ) + + if should_propagate_arg_kwarg_vals(self.nodes): + # Our strategy is: + # 1) trace out the graph with arg_kwarg_vals (which have accurate eager-mode metadata) + # 2) trace out the graph with vals (which have the accurate Inductor metadata) + # 3) Propagate the arg_kwarg_vals from the first graph to the second. + # 4) Use the second graph as the replacement graph. + + # Construct a map of node -> FakeTensor val in arg_kwarg_vals + node_to_val = {} + + fake_args, fake_kwargs = self.nodes[0].meta["arg_kwarg_vals"] + fake_kwargs = {**fake_kwargs} + match_args, match_kwargs = tuple(self.args), self.kwargs + + def record(node: torch.fx.Node, val: Any) -> None: + if isinstance(node, torch.fx.Node): + node_to_val[node] = val + + torch.utils._pytree.tree_map( + record, (match_args, match_kwargs), (fake_args, fake_kwargs) + ) + # map args to their FakeTensor val in arg_kwarg_vals + example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg]) + + # first graph + graph_with_eager_vals = trace_fn(replacement_fn, example_vals) + + # second graph + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(graph_with_eager_vals, example_vals) + + # propagate metadata from first graph to second + # NB: This assertion might not be true in general, but it is true for + # the two use cases we have + # (triton_kernel_wrapper_functional, auto_functionalized) + assert len(graph_with_eager_vals.graph.nodes) == len( + replacement.graph.nodes + ) + for old_node, new_node in zip( + graph_with_eager_vals.graph.nodes, replacement.graph.nodes + ): + if "arg_kwarg_vals" in old_node.meta: + new_node.meta["arg_kwarg_vals"] = old_node.meta[ + "arg_kwarg_vals" + ] + + else: + example_vals = torch.fx.map_arg(args, lambda arg: arg.meta["val"]) + replacement = trace_fn(replacement_fn, example_vals) if len(self.nodes) == 1: for n in replacement.graph.nodes: _transfer_meta( @@ -1083,6 +1142,11 @@ def run_node(self, node: torch.fx.Node) -> Any: old_node=node, pass_name="Interpreter_Replacer", ) + # This function copy-pastes the replacement graph into + # the graph. If the replacement graph had any arg_kwarg_vals, + # or val/tensor_meta, we propagate those over. + if "arg_kwarg_vals" in node.meta: + result.meta["arg_kwarg_vals"] = node.meta["arg_kwarg_vals"] if "val" in node.meta and "val" not in result.meta: result.meta["val"] = node.meta["val"] if isinstance(node.meta["val"], torch.Tensor): diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index ce1814dd7f29..e40cb13d5558 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -116,7 +116,6 @@ def __exit__(self, *args): "_numeric_debug_handle", # TODO deprecated "custom", "partitioner_tag", - "arg_kwarg_vals", ] From c69c3c885e82becda528330d22d913b7055b33d7 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 1 Apr 2025 21:07:25 -0700 Subject: [PATCH 096/332] Add needs_exact_strides operator tag for Inductor to force exact strides (#148063) Inductor will force exact strides on a custom operator tagged with needs_exact_strides. I'll make this the default in a follow-up PR. Test Plan: - tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/148063 Approved by: https://github.com/eellison ghstack dependencies: #148046 --- aten/src/ATen/native/tags.yaml | 14 ++++-- test/inductor/test_triton_kernels.py | 62 +++++++++++++++++++++------ torch/_inductor/graph.py | 33 +++++++++++++- torch/fx/experimental/proxy_tensor.py | 23 ++++++---- 4 files changed, 104 insertions(+), 28 deletions(-) diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index ff4a7730fcc5..948cbe0f4028 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -42,19 +42,25 @@ desc: | This tag indicates if an operator doesn't guarantee bitwise equivalence across different runs of an operator with identical inputs. +- tag: needs_exact_strides + desc: | + This tag indicates that the operator should be passed Tensors following + the same strides as observed in eager when compiled in inductor. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. - tag: needs_fixed_stride_order desc: | This tag indicates that the operator should be passed Tensors following the same stride permutation as observed in eager when compiled in inductor. - Only one of {needs_fixed_stride_order, flexible_layout} can apply; if - multiple are assigned then we assume the most restrictive one. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. - tag: flexible_layout desc: | This tag indicates that the custom operator can accept inputs with varying strides/storage_offset and that when compiled, Inductor is allowed to change the strides/storage_offset of inputs to the custom operator. - Only one of {needs_fixed_stride_order, flexible_layout} can apply; if - multiple are assigned then we assume the most restrictive one. + Only one of {needs_exact_strides, needs_fixed_stride_order, flexible_layout} + can apply; if multiple are assigned then we assume the most restrictive one. # NOTE [Core ATen Ops] - tag: core diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 39b220290d03..6e02a928bb73 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3376,7 +3376,8 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertEqual(z, (x + y) * 2) @requires_gpu - def test_preserves_strides(self): + @common_utils.parametrize("variant", ["triton_kernel", "custom_op"]) + def test_preserves_strides(self, variant): import triton import triton.language as tl @@ -3400,12 +3401,10 @@ def add_kernel( x = torch.randn(4, 4, 2, 2, device=GPU_TYPE) other = torch.randn(4, 4, 2, 2, device=GPU_TYPE) - def f(x, other): - y = x.transpose(2, 3).contiguous().transpose(2, 3) - z = y.sin().transpose(2, 3) + def add_triton(y, z): grid = (z.numel(),) - out = torch.empty_like(other) - add_kernel[grid](z, other, out, z.numel(), BLOCK_SIZE=16) + out = torch.empty_like(z, memory_format=torch.contiguous_format) + add_kernel[grid](y, z, out, z.numel(), BLOCK_SIZE=16) return out class _CustomPass(PatternMatcherPass): @@ -3427,8 +3426,8 @@ def _(match, *args, **kwargs): def decomp(*flat_args): args, kwargs = pytree.tree_unflatten(flat_args, spec) - return torch.ops.aten.permute(*args, **kwargs).clone( - memory_format=torch.channels_last + return torch.ops.mylib.force_channels_last( + torch.ops.aten.permute(*args, **kwargs) ) nonlocal called @@ -3437,12 +3436,47 @@ def decomp(*flat_args): from torch._inductor import config - with config.patch( - post_grad_custom_post_pass=g, - ): - f_compile = torch.compile(f) - self.assertEqual(f(x, other), f_compile(x, other)) - self.assertTrue(called) + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define( + "force_channels_last(Tensor x) -> Tensor", + tags=[torch._C.Tag.flexible_layout], + ) + + def impl2(x): + return x.clone(memory_format=torch.channels_last) + + lib.impl("force_channels_last", impl2, "CompositeExplicitAutograd") + + lib.define( + "add_op(Tensor x, Tensor y) -> Tensor", + tags=[torch._C.Tag.needs_exact_strides], + ) + + def impl(x, y): + return add_triton(x, y) + + def meta(x, y): + return torch.empty_like(y, memory_format=torch.contiguous_format) + + lib.impl("add_op", impl, "CompositeExplicitAutograd") + lib.impl("add_op", meta, "Meta") + + def f(x, other): + y = x.transpose(2, 3).contiguous().transpose(2, 3) + z = y.sin().transpose(2, 3) + if variant == "triton_kernel": + return add_triton(y, z) + elif variant == "custom_op": + return torch.ops.mylib.add_op.default(y, z) + else: + raise AssertionError("should not be hit") + + with config.patch( + post_grad_custom_post_pass=g, + ): + f_compile = torch.compile(f) + self.assertEqual(f(x, other), f_compile(x, other)) + self.assertTrue(called) @requires_gpu @common_utils.parametrize("dynamic", [False, True]) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index a2e668d698b9..cfbc01d4f51e 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1151,7 +1151,9 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> # use contiguous unless the (custom) op asks something else # explicitly - if torch._C.Tag.needs_fixed_stride_order in target.tags: + if torch._C.Tag.needs_exact_strides in target.tags: + decided_constraint = constrain_to_fake_tensors # type: ignore[assignment] + elif torch._C.Tag.needs_fixed_stride_order in target.tags: decided_constraint = constrain_to_fx_strides # type: ignore[assignment] elif torch._C.Tag.flexible_layout in target.tags: decided_constraint = None # type: ignore[assignment] @@ -1192,7 +1194,34 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> layout_constraints = maybe_layout_constraints(target) if layout_constraints: old_args, old_kwargs = args, kwargs - args, kwargs = layout_constraints(n, *args, **kwargs) + if layout_constraints is constrain_to_fake_tensors: + # only constrain_to_fake_tensor if this exists. + # otherwise, no constraints at all: the implication is + # that this operator was inserted by a custom pass + # so we'll give them the freedom. + if "arg_kwarg_vals" in n.meta: + fake_args, fake_kwargs = n.meta["arg_kwarg_vals"] + + # (fake_args, fake_kwargs) might not align with (args, kwargs). + # we need to normalize them based on the schema + assert isinstance(target, torch._ops.OpOverload) + + def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]: + result = torch.fx.operator_schemas.normalize_function( + target, args, kwargs + ) + assert result is not None + return result[0], result[1] + + fake_args, fake_kwargs = normalize(fake_args, fake_kwargs) + args, kwargs = normalize(args, kwargs) + old_args, old_kwargs = normalize(old_args, old_kwargs) + + args, kwargs = constrain_to_fake_tensors( + args, kwargs, fake_args, fake_kwargs + ) + else: + args, kwargs = layout_constraints(n, *args, **kwargs) out = lowerings[target](*args, **kwargs) # type: ignore[index] diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index eb2cb0a81f7a..bd20cb242b41 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1124,20 +1124,27 @@ def map_fn(v: Any) -> Optional[_ExtractValType]: return None return extract_val(v.meta["val"]) - # TODO: opt-in mechanism ? - if isinstance( - target, - ( - torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, - torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, - ), - ): + if _should_save_arg_kwarg_vals(target): arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp) return node +def _should_save_arg_kwarg_vals(target: Any) -> bool: + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + ), + ): + return True + if isinstance(target, torch._ops.OpOverload): + return torch._C.Tag.needs_exact_strides in target.tags + return False + + def _make_temp_remove_mode_context_manager( mode_ty: type[TorchFunctionMode], ) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: From 4d121d2b02dade3f6f909ddcfcdccbbe05f547d5 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 1 Apr 2025 21:07:26 -0700 Subject: [PATCH 097/332] Implement needs_exact_strides for mutable custom operators (#148091) Mutable custom operators get wrapped into an auto_functionalized HOP, so we need to store the arg_kwarg_vals on the auto_functionalized HOP itself. When Inductor does the re-inplacing, it'll use the pattern matcher to decompose the auto_functionalized HOP back into the original op (and 0+ other view or clone operations). The pattern matcher uses the arg_kwarg_vals to trace the subgraph to do the decomposition, so it ultimately sets arg_kwarg_vals on the original op's node correctly. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/148091 Approved by: https://github.com/eellison ghstack dependencies: #148046, #148063 --- test/inductor/test_triton_kernels.py | 22 ++++++++++++++++++++-- torch/_inductor/pattern_matcher.py | 2 ++ torch/fx/experimental/proxy_tensor.py | 27 +++++++++++++++++++++++++-- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 6e02a928bb73..4966821120c5 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3376,7 +3376,9 @@ def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: self.assertEqual(z, (x + y) * 2) @requires_gpu - @common_utils.parametrize("variant", ["triton_kernel", "custom_op"]) + @common_utils.parametrize( + "variant", ["triton_kernel", "custom_op", "mutable_custom_op"] + ) def test_preserves_strides(self, variant): import triton import triton.language as tl @@ -3461,6 +3463,18 @@ def meta(x, y): lib.impl("add_op", impl, "CompositeExplicitAutograd") lib.impl("add_op", meta, "Meta") + lib.define( + "add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()", + tags=[torch._C.Tag.needs_exact_strides], + ) + + def impl_out(x, y, out): + grid = (y.numel(),) + add_kernel[grid](x, y, out, y.numel(), BLOCK_SIZE=16) + + lib.impl("add_out_op", impl_out, "CompositeExplicitAutograd") + lib.impl("add_out_op", lambda x, y, out: None, "Meta") + def f(x, other): y = x.transpose(2, 3).contiguous().transpose(2, 3) z = y.sin().transpose(2, 3) @@ -3468,13 +3482,17 @@ def f(x, other): return add_triton(y, z) elif variant == "custom_op": return torch.ops.mylib.add_op.default(y, z) + elif variant == "mutable_custom_op": + out = torch.empty_like(y, memory_format=torch.contiguous_format) + torch.ops.mylib.add_out_op(y, z, out) + return out else: raise AssertionError("should not be hit") with config.patch( post_grad_custom_post_pass=g, ): - f_compile = torch.compile(f) + f_compile = torch.compile(f, fullgraph=True) self.assertEqual(f(x, other), f_compile(x, other)) self.assertTrue(called) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index d337dcae4c99..01ff810e03db 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -260,6 +260,8 @@ def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool: return node.target in OrderedSet( [ torch.ops.higher_order.triton_kernel_wrapper_functional, + torch.ops.higher_order.auto_functionalized, + torch.ops.higher_order.auto_functionalized_v2, ] ) diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index bd20cb242b41..739d37349953 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1124,14 +1124,19 @@ def map_fn(v: Any) -> Optional[_ExtractValType]: return None return extract_val(v.meta["val"]) - if _should_save_arg_kwarg_vals(target): + if _should_save_arg_kwarg_vals(target, (args, kwargs)): arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp) return node -def _should_save_arg_kwarg_vals(target: Any) -> bool: +def _should_save_arg_kwarg_vals( + target: Any, + args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, +) -> bool: + if not callable(target): + return False if isinstance( target, ( @@ -1140,6 +1145,24 @@ def _should_save_arg_kwarg_vals(target: Any) -> bool: ), ): return True + if args_kwargs is not None and ( + target is torch.ops.higher_order.auto_functionalized + or target is torch.ops.higher_order.auto_functionalized_v2 + ): + args = args_kwargs[0] + assert isinstance(args[0], torch._ops.OpOverload) + return _should_save_arg_kwarg_vals(args[0], None) + if target is torch.ops.higher_order.with_effects: + # TODO: inductor lowering for with_effects needs to be updated to propagate + # the arg_kwarg_vals + return False + if isinstance(target, torch._ops.HigherOrderOperator): + if pytree.tree_any(_should_save_arg_kwarg_vals, args_kwargs): + raise RuntimeError( + f"NYI: The HOP {target} has an input that is an OpOverload that " + f"needs exact strides. We probably need special logic to " + f"propagate the FakeTensor vals. Please file an issue." + ) if isinstance(target, torch._ops.OpOverload): return torch._C.Tag.needs_exact_strides in target.tags return False From aae36929edf892292677d5a22d1cd4f1301b3d61 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 1 Apr 2025 21:07:26 -0700 Subject: [PATCH 098/332] Rename node.meta["arg_kwarg_vals"] to node.meta["eager_input_vals"] (#148092) And added a comment about it. Otherwise it might be confusing Test Plan: - wait for CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/148092 Approved by: https://github.com/eellison ghstack dependencies: #148046, #148063, #148091 --- torch/_inductor/fx_passes/reinplace.py | 8 ++++---- torch/_inductor/graph.py | 10 ++++----- torch/_inductor/pattern_matcher.py | 28 +++++++++++++------------- torch/fx/experimental/proxy_tensor.py | 15 +++++++++----- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 48541bcc5e34..ee258dfd4158 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -720,14 +720,14 @@ def tensor_with_same_storage_already_reinplaced(arg): kwargs = dict(node.kwargs) kwargs["tensors_to_clone"] = tensors_to_clone node.kwargs = immutable_dict(kwargs) - if "arg_kwarg_vals" in node.meta: - # We changed the kwargs, so we need to update arg_kwarg_vals + if "eager_input_vals" in node.meta: + # We changed the kwargs, so we need to update eager_input_vals # to something sane. - args, kwargs = node.meta["arg_kwarg_vals"] + args, kwargs = node.meta["eager_input_vals"] new_kwargs = {**kwargs} new_kwargs["tensors_to_clone"] = immutable_list(tensors_to_clone) new_kwargs = immutable_dict(new_kwargs) - node.meta["arg_kwarg_vals"] = (args, new_kwargs) + node.meta["eager_input_vals"] = (args, new_kwargs) elif ( inplaceable_op := inplaceable_foreach_ops.get(node.target, None) ) is not None: diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index cfbc01d4f51e..38f359c5f255 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1199,8 +1199,8 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> # otherwise, no constraints at all: the implication is # that this operator was inserted by a custom pass # so we'll give them the freedom. - if "arg_kwarg_vals" in n.meta: - fake_args, fake_kwargs = n.meta["arg_kwarg_vals"] + if "eager_input_vals" in n.meta: + fake_args, fake_kwargs = n.meta["eager_input_vals"] # (fake_args, fake_kwargs) might not align with (args, kwargs). # we need to normalize them based on the schema @@ -1535,9 +1535,9 @@ def debug(msg: str) -> None: old_args = args # type: ignore[possibly-undefined] old_kwargs = kwargs # type: ignore[possibly-undefined] - if arg_kwarg_vals := n.meta.get("arg_kwarg_vals"): - inp_args = arg_kwarg_vals[0] - inp_kwargs = arg_kwarg_vals[1] + if eager_input_vals := n.meta.get("eager_input_vals"): + inp_args = eager_input_vals[0] + inp_kwargs = eager_input_vals[1] args, kwargs = constrain_to_fake_tensors( args, kwargs, inp_args, inp_kwargs ) diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 01ff810e03db..792a6b4385a2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -251,11 +251,11 @@ def replace_by_example( else contextlib.nullcontext() ) - def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool: + def should_propagate_eager_input_vals(nodes: list[torch.fx.Node]) -> bool: if len(nodes) != 1: return False node = nodes[0] - if "arg_kwarg_vals" not in node.meta: + if "eager_input_vals" not in node.meta: return False return node.target in OrderedSet( [ @@ -271,17 +271,17 @@ def should_propagate_arg_kwarg_vals(nodes: list[torch.fx.Node]) -> bool: fwd_only, run_functional_passes=run_functional_passes ) - if should_propagate_arg_kwarg_vals(self.nodes): + if should_propagate_eager_input_vals(self.nodes): # Our strategy is: - # 1) trace out the graph with arg_kwarg_vals (which have accurate eager-mode metadata) + # 1) trace out the graph with eager_input_vals (which have accurate eager-mode metadata) # 2) trace out the graph with vals (which have the accurate Inductor metadata) - # 3) Propagate the arg_kwarg_vals from the first graph to the second. + # 3) Propagate the eager_input_vals from the first graph to the second. # 4) Use the second graph as the replacement graph. - # Construct a map of node -> FakeTensor val in arg_kwarg_vals + # Construct a map of node -> FakeTensor val in eager_input_vals node_to_val = {} - fake_args, fake_kwargs = self.nodes[0].meta["arg_kwarg_vals"] + fake_args, fake_kwargs = self.nodes[0].meta["eager_input_vals"] fake_kwargs = {**fake_kwargs} match_args, match_kwargs = tuple(self.args), self.kwargs @@ -292,7 +292,7 @@ def record(node: torch.fx.Node, val: Any) -> None: torch.utils._pytree.tree_map( record, (match_args, match_kwargs), (fake_args, fake_kwargs) ) - # map args to their FakeTensor val in arg_kwarg_vals + # map args to their FakeTensor val in eager_input_vals example_vals = torch.fx.map_arg(args, lambda arg: node_to_val[arg]) # first graph @@ -312,9 +312,9 @@ def record(node: torch.fx.Node, val: Any) -> None: for old_node, new_node in zip( graph_with_eager_vals.graph.nodes, replacement.graph.nodes ): - if "arg_kwarg_vals" in old_node.meta: - new_node.meta["arg_kwarg_vals"] = old_node.meta[ - "arg_kwarg_vals" + if "eager_input_vals" in old_node.meta: + new_node.meta["eager_input_vals"] = old_node.meta[ + "eager_input_vals" ] else: @@ -1145,10 +1145,10 @@ def run_node(self, node: torch.fx.Node) -> Any: pass_name="Interpreter_Replacer", ) # This function copy-pastes the replacement graph into - # the graph. If the replacement graph had any arg_kwarg_vals, + # the graph. If the replacement graph had any eager_input_vals, # or val/tensor_meta, we propagate those over. - if "arg_kwarg_vals" in node.meta: - result.meta["arg_kwarg_vals"] = node.meta["arg_kwarg_vals"] + if "eager_input_vals" in node.meta: + result.meta["eager_input_vals"] = node.meta["eager_input_vals"] if "val" in node.meta and "val" not in result.meta: result.meta["val"] = node.meta["val"] if isinstance(node.meta["val"], torch.Tensor): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 739d37349953..4193606d849d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1124,14 +1124,19 @@ def map_fn(v: Any) -> Optional[_ExtractValType]: return None return extract_val(v.meta["val"]) - if _should_save_arg_kwarg_vals(target, (args, kwargs)): + if _should_save_eager_input_vals(target, (args, kwargs)): + # NOTE "eager_input_vals" + # We save the original (args, kwargs) FakeTensor values for nodes + # that have exact stride requirements. This is useful downstream. + # We use this information inside Inductor to ensure that inputs to + # stride-sensitive operators have the correct strides. arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] - node.meta["arg_kwarg_vals"] = (arg_inp, kwarg_inp) + node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) return node -def _should_save_arg_kwarg_vals( +def _should_save_eager_input_vals( target: Any, args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, ) -> bool: @@ -1151,13 +1156,13 @@ def _should_save_arg_kwarg_vals( ): args = args_kwargs[0] assert isinstance(args[0], torch._ops.OpOverload) - return _should_save_arg_kwarg_vals(args[0], None) + return _should_save_eager_input_vals(args[0], None) if target is torch.ops.higher_order.with_effects: # TODO: inductor lowering for with_effects needs to be updated to propagate # the arg_kwarg_vals return False if isinstance(target, torch._ops.HigherOrderOperator): - if pytree.tree_any(_should_save_arg_kwarg_vals, args_kwargs): + if pytree.tree_any(_should_save_eager_input_vals, args_kwargs): raise RuntimeError( f"NYI: The HOP {target} has an input that is an OpOverload that " f"needs exact strides. We probably need special logic to " From 5f62d07ec67bd1fae504e095e28098574b1fa24d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 3 Mar 2025 20:33:01 +0000 Subject: [PATCH 099/332] Fix log2, PowByNatural printing (#147592) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147592 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 11 +++++++++++ torch/_inductor/codegen/halide.py | 27 ++++++++++++++++++--------- torch/_inductor/codegen/triton.py | 11 ++++++++++- torch/utils/_sympy/printers.py | 11 +++++++++++ 4 files changed, 50 insertions(+), 10 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a2357af8ee84..01f64a479f20 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6569,6 +6569,7 @@ def fn(a, b): ), ) + @skip_if_halide # log2 not implemented for halide def test_log2(self): def fn(x): return torch.log2(x), torch.log2(x + 1) - 2 @@ -6587,6 +6588,7 @@ def fn(x): (torch.randn([8, 8]) + 10,), ) + @skip_if_halide # log2 not implemented for halide def test_log_fp64(self): def fn(x): return torch.log(x), torch.log2(x) @@ -10340,6 +10342,15 @@ def fn(arg0_1): [x], ) + @skip_if_halide # log2 not yet implemented + @skip_if_triton_cpu # log2 implemented only in Dec 2024 + def test_pow_by_natural_log2_dynamic_shapes(self): + @torch.compile(dynamic=True) + def fn(x): + return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1)) + + self.common(fn, [torch.randn(5)]) + def test_setitem_with_int_parameter(self): x = torch.zeros(7, device=self.device) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 9f4469698207..28dbbfb446ba 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -96,6 +96,8 @@ def _print_floor(self, expr): assert len(expr.args) == 1 return self.cast_index(f"hl.floor({self._print(expr.args[0])})") + _print_FloorToInt = _print_floor + def _print_Trunc(self, expr): assert len(expr.args) == 1 return self.cast_index(f"hl.trunc({self._print(expr.args[0])})") @@ -140,39 +142,42 @@ def _print_Abs(self, expr): def _print_OpaqueUnaryFn_cos(self, expr): assert len(expr.args) == 1 - return f"hl.cos(({self._print(expr.args[0])})" + return f"hl.cos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_cosh(self, expr): assert len(expr.args) == 1 - return f"hl.cosh(({self._print(expr.args[0])})" + return f"hl.cosh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_acos(self, expr): assert len(expr.args) == 1 - return f"hl.acos(({self._print(expr.args[0])})" + return f"hl.acos({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sin(self, expr): assert len(expr.args) == 1 - return f"hl.sin(({self._print(expr.args[0])})" + return f"hl.sin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_sinh(self, expr): assert len(expr.args) == 1 - return f"hl.sinh(({self._print(expr.args[0])})" + return f"hl.sinh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_asin(self, expr): assert len(expr.args) == 1 - return f"hl.asin(({self._print(expr.args[0])})" + return f"hl.asin({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tan(self, expr): assert len(expr.args) == 1 - return f"hl.tan(({self._print(expr.args[0])})" + return f"hl.tan({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_tanh(self, expr): assert len(expr.args) == 1 - return f"hl.tanh(({self._print(expr.args[0])})" + return f"hl.tanh({self._print(expr.args[0])})" def _print_OpaqueUnaryFn_atan(self, expr): assert len(expr.args) == 1 - return f"hl.atan(({self._print(expr.args[0])})" + return f"hl.atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_log2(self, expr): + raise NotImplementedError("log2") def _print_FloorDiv(self, expr): if expr.is_integer: @@ -453,6 +458,10 @@ def pow(a, b): def log(x): return f"hl.log({x})" # hl.fast_log fails accuracy + @staticmethod + def log2(x): + raise NotImplementedError("log2") + @staticmethod def isinf(x): # workaround https://github.com/halide/Halide/issues/8309 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 2fe5c304c24a..8f7353500f10 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -605,7 +605,12 @@ def _print_FloatPow(self, expr: sympy.Expr) -> str: f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" ) - _print_PowByNatural = _print_FloatPow + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + if expr.args[0].is_Integer: + return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})" + return ( + f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})" + ) def _print_Where(self, expr: sympy.Expr) -> str: c = self.doprint(expr.args[0]) @@ -678,6 +683,10 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return ( diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 33b4e6e0652d..60e6b37f1340 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -264,6 +264,10 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.log2({self._print(expr.args[0])})" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -351,6 +355,10 @@ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT # use std::pow, that operates on floats def _print_PowByNatural(self, expr: sympy.Expr) -> str: + # Implement the special-case of 2**x for now + base, exp = expr.args + if base == 2: + return f"(1 << ({self._print(exp)}))" raise NotImplementedError( f"_print_PowByNatural not implemented for {type(self)}" ) @@ -465,6 +473,9 @@ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: return f"std::sqrt({self._print(expr.args[0])})" + def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str: + return f"std::log2({self._print(expr.args[0])})" + def _print_RoundToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 # TODO: dispatch to llrint depending on index type From 82ceebce584e4bd19cd7a1064a549e6293085789 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 1 Apr 2025 18:11:48 +0000 Subject: [PATCH 100/332] [inductor] Lowerings for max_pool3d (#148210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/148210 Approved by: https://github.com/eellison --- test/inductor/test_mps_basic.py | 4 +- test/inductor/test_torchinductor.py | 23 +-- torch/_functorch/partitioners.py | 2 +- torch/_inductor/codegen/triton.py | 4 +- torch/_inductor/decomposition.py | 69 ++++++--- torch/_inductor/fx_passes/quantization.py | 2 +- torch/_inductor/inductor_prims.py | 120 +++++++------- torch/_inductor/lowering.py | 181 +++++++++++++--------- 8 files changed, 244 insertions(+), 161 deletions(-) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 4d16f4301de8..ee2cc4e4fbba 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -201,8 +201,8 @@ def fn(a): "test_lgamma", "test_linear_float64", "test_log_fp64", - "test_low_memory_max_pool_dilation_1", - "test_low_memory_max_pool_dilation_2", + "test_low_memory_max_pool_dilation_1_dim_2", + "test_low_memory_max_pool_dilation_2_dim_2", "test_max_min", "test_max_pool2d2", "test_multilayer_prime_size", diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 01f64a479f20..30aafa062068 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -4276,34 +4276,35 @@ def fn2(a): ) @parametrize("dilation", (1, 2)) - def test_low_memory_max_pool(self, dilation: int): + @parametrize("dim", (2, 3)) + def test_low_memory_max_pool(self, dilation: int, dim: int): prims = torch.ops.prims def fn(x): - kernel_size = [3, 3] - stride = [2, 2] - padding = [1, 1] + kernel_size = [3, 3] if dim == 2 else [3, 3, 2] + stride = [2] * dim + padding = [1] * dim ceil_mode = False - vals, offsets = prims._low_memory_max_pool2d_with_offsets( + vals, offsets = prims._low_memory_max_pool_with_offsets( x, kernel_size, stride, padding, - [dilation] * 2, + [dilation] * dim, ceil_mode, ) - indices = prims._low_memory_max_pool2d_offsets_to_indices( + indices = prims._low_memory_max_pool_offsets_to_indices( offsets, - kernel_size[1], - x.size(-1), + kernel_size, + x.shape[-dim:], stride, padding, - dilation=[dilation] * 2, + dilation=[dilation] * dim, ) return vals, indices, offsets - self.common(fn, (torch.randn(1, 3, 10, 10),)) + self.common(fn, (torch.randn(1, 3, *[10] * dim),)) @xfail_if_mps def test_to_dtype(self): diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 51292fb00985..97b53b6c9a88 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1538,7 +1538,7 @@ def get_default_op_list() -> OpTypes: aten.argmax, aten.maximum, prims.iota, - prims._low_memory_max_pool2d_offsets_to_indices, + prims._low_memory_max_pool_offsets_to_indices, ] # noqa: E501,B950 # Natalia said that we should allow recomputing indexing :) default_recomputable_ops += [aten.index, aten.gather] diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 8f7353500f10..5aaab1ed47ed 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -2552,17 +2552,19 @@ def _mask_value(value, default) -> CSEVariable: masked_value = _mask_value(value, default) if reduction_type in ("argmax", "argmin"): + accumulator_dtype = V.kernel.get_index_dtype_as_torch_dtype() accumulator_index = str( self.cse.generate( self.compute, f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", - dtype=V.kernel.get_index_dtype_as_torch_dtype(), + dtype=accumulator_dtype, ) ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] final_argreduce( self.compute, result_var, masked_value, accumulator_index ) + result_var.dtype = accumulator_dtype elif reduction_type == "welford_reduce": if self.cooperative_reduction: # cooperative reductions require full welford for correctness diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index ddf044ebf1ff..82d0eb6a1320 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -2,6 +2,7 @@ import functools import logging import math +import operator import sys import typing from typing import Any, Callable, Optional, TypeVar, Union @@ -963,38 +964,40 @@ def index_reduce( ) -@register_decomposition(aten.max_pool2d_with_indices) -def max_pool2d_with_indices( +def _max_pool_with_indices( x: torch.Tensor, kernel_size: list[int], - stride: Optional[Union[int, list[int]]] = None, - padding: Union[int, list[int]] = 0, - dilation: Union[int, list[int]] = 1, - ceil_mode: bool = False, + stride: Optional[Union[int, list[int]]], + padding: Union[int, list[int]], + dilation: Union[int, list[int]], + ceil_mode: bool, + dim: int, ) -> tuple[torch.Tensor, torch.Tensor]: if dilation == 1: - dilation = [1, 1] + dilation = [1] * dim if padding == 0: - padding = [0, 0] + padding = [0] * dim if not stride: stride = kernel_size - kernel_size = pad_listlike(kernel_size, 2) - dilation = pad_listlike(dilation, 2) - padding = pad_listlike(padding, 2) - stride = pad_listlike(stride, 2) + kernel_size = pad_listlike(kernel_size, dim) + dilation = pad_listlike(dilation, dim) + padding = pad_listlike(padding, dim) + stride = pad_listlike(stride, dim) - window_size = kernel_size[0] * kernel_size[1] - # We fallback when the window size is too large + window_size = functools.reduce(operator.mul, kernel_size) + # We fallback when using non-default dilation or when the window size is too large if ( - torch._inductor.lowering.should_fallback_max_pool2d_with_indices(kernel_size) + torch._inductor.lowering.should_fallback_max_pool_with_indices( + kernel_size, n_dim=dim + ) or window_size > torch.iinfo(torch.int8).max ): return NotImplemented - vals, offsets = prims._low_memory_max_pool2d_with_offsets( + vals, offsets = prims._low_memory_max_pool_with_offsets( x, kernel_size, stride, @@ -1002,10 +1005,10 @@ def max_pool2d_with_indices( dilation, ceil_mode, ) - indices = prims._low_memory_max_pool2d_offsets_to_indices( + indices = prims._low_memory_max_pool_offsets_to_indices( offsets, - kernel_size[1], - x.size(-1), + kernel_size, + x.shape[-dim:], stride, padding, dilation, @@ -1013,6 +1016,34 @@ def max_pool2d_with_indices( return vals, indices +@register_decomposition(aten.max_pool2d_with_indices) +def max_pool2d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=2 + ) + + +@register_decomposition(aten.max_pool3d_with_indices) +def max_pool3d_with_indices( + x: torch.Tensor, + kernel_size: list[int], + stride: Optional[Union[int, list[int]]] = None, + padding: Union[int, list[int]] = 0, + dilation: Union[int, list[int]] = 1, + ceil_mode: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, dim=3 + ) + + @register_decomposition(aten.adaptive_max_pool2d) def adaptive_max_pool2d( x: torch.Tensor, output_size: list[int] diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index e1dff0162cb5..65eacb32dff4 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -940,7 +940,7 @@ def _register_quantization_maxpool2d(): *max_pool2d_args, ) dequantize_lowmem_maxpool2d_pattern = CallFunction( - prims._low_memory_max_pool2d_with_offsets.default, + prims._low_memory_max_pool_with_offsets.default, get_dequantize_per_tensor_activation_pattern(), KeywordArg("kernel_size"), *max_pool2d_args, diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 170c2f00d44a..d764744d857a 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs from __future__ import annotations +import functools import logging +import operator from typing import Optional, TYPE_CHECKING import torch @@ -119,7 +121,28 @@ def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: ) -def _low_memory_max_pool2d_with_offsets_aten( +def _flattened_index_to_nd(indices, width): + dim = len(width) + + if dim == 1: + return [indices] + elif dim >= 2: + m = functools.reduce(operator.mul, width[1:]) + ih = indices // m + indices_new = indices - (ih * m) + return [ih, *_flattened_index_to_nd(indices_new, width[1:])] + else: + raise ValueError(f"Unknown dim: {dim}") + + +def _flatten_index(indices, width): + result = indices[0] + for d in range(1, len(indices)): + result = width[d] * result + indices[d] + return result + + +def _low_memory_max_pool_with_offsets_aten( self, kernel_size, stride, @@ -127,80 +150,69 @@ def _low_memory_max_pool2d_with_offsets_aten( dilation, ceil_mode, ): - vals, indices = torch.ops.aten.max_pool2d_with_indices( - self, kernel_size, stride, padding, dilation, ceil_mode - ) - - input_width = self.shape[-1] - kernel_width = kernel_size[1] - - bh_shape = [1] * self.ndim - bh_shape[-2] = -1 - bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view( - bh_shape - ) - - bw_shape = [1] * self.ndim - bw_shape[-1] = -1 - bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view( - bw_shape - ) + dim = len(kernel_size) + if dim == 2: + vals, indices = torch.ops.aten.max_pool2d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + else: + vals, indices = torch.ops.aten.max_pool3d_with_indices( + self, kernel_size, stride, padding, dilation, ceil_mode + ) - hbase = bh * stride[0] - padding[0] - wbase = bw * stride[1] - padding[1] + idhw = _flattened_index_to_nd(indices, self.shape[-dim:]) - ih = indices // input_width - iw = indices - (ih * input_width) + dhw_inc = [] - h_inc = (ih - hbase) // dilation[0] - w_inc = (iw - wbase) // dilation[1] + for d in range(dim): + bh_shape = [1] * self.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + indices.shape[-dim + d], dtype=torch.int64, device=self.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + h_inc = (idhw[d] - hbase) // dilation[d] + dhw_inc.append(h_inc) - offsets = h_inc * kernel_width + w_inc + offsets = _flatten_index(dhw_inc, kernel_size) return vals, offsets.to(torch.int8) -def _low_memory_max_pool2d_offsets_to_indices_aten( +def _low_memory_max_pool_offsets_to_indices_aten( offsets, - kernel_width, - input_width, + kernel_size, + input_size, stride, padding, dilation, ): + dim = len(kernel_size) offsets = offsets.to(torch.int64) - h_inc = offsets // kernel_width - w_inc = offsets - (h_inc * kernel_width) - - bh_shape = [1] * offsets.ndim - bh_shape[-2] = -1 - bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view( - bh_shape - ) - - bw_shape = [1] * offsets.ndim - bw_shape[-1] = -1 - bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view( - bw_shape - ) + dhw_inc = _flattened_index_to_nd(offsets, kernel_size) - hbase = bh * stride[0] - padding[0] - wbase = bw * stride[1] - padding[1] + idhw = [] + for d in range(dim): + bh_shape = [1] * offsets.ndim + bh_shape[-dim + d] = -1 + bh = torch.arange( + offsets.shape[-dim + d], dtype=torch.int64, device=offsets.device + ).view(bh_shape) + hbase = bh * stride[d] - padding[d] + idhw.append(hbase + dhw_inc[d] * dilation[d]) - ih = hbase + h_inc * dilation[0] - iw = wbase + w_inc * dilation[1] - return ih * input_width + iw + return _flatten_index(idhw, input_size) -_low_memory_max_pool2d_with_offsets = make_prim( - "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 - _low_memory_max_pool2d_with_offsets_aten, +_low_memory_max_pool_with_offsets = make_prim( + "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + _low_memory_max_pool_with_offsets_aten, return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), doc="Instead of returning indices, returns indices offsets.", ) -_low_memory_max_pool2d_offsets_to_indices = make_prim( - "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", # noqa: B950 - _low_memory_max_pool2d_offsets_to_indices_aten, +_low_memory_max_pool_offsets_to_indices = make_prim( + "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", # noqa: B950 + _low_memory_max_pool_offsets_to_indices_aten, doc="Convert small int offsets to regular indices.", ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 64b505d5cdac..9996857b29d2 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -222,6 +222,7 @@ def add_layout_constraint(fn, constraint): aten.convolution, aten.convolution_backward, aten.max_pool2d_with_indices, + aten.max_pool3d_with_indices, aten.max_pool2d_with_indices_backward, aten.mm, aten.upsample_nearest2d, @@ -2616,7 +2617,6 @@ def is_aligned(x): make_fallback(aten._adaptive_avg_pool3d) # @isuruf make_fallback(aten.adaptive_max_pool3d) # @isuruf make_fallback(aten.fractional_max_pool3d) # @isuruf -make_fallback(aten.max_pool3d_with_indices) # @isuruf (can this one be implemented?) # 1) Easy @@ -4348,57 +4348,62 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None return x_out, ceil_mode -def should_fallback_max_pool2d_with_indices(kernel_size): - kernel_size = pad_listlike(kernel_size, 2) - window_size = kernel_size[0] * kernel_size[1] +def should_fallback_max_pool_with_indices(kernel_size, *, n_dim): + kernel_size = pad_listlike(kernel_size, n_dim) + window_size = functools.reduce(operator.mul, kernel_size) return window_size > 25 -def max_pool2d_checks( - x, kernel_size, stride, padding, dilation, *, assert_fallback=None +def max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None ): if padding == 0: - padding = [0, 0] + padding = [0] * n_dim if dilation == 1: - dilation = [1, 1] + dilation = [1] * n_dim if not stride: stride = kernel_size - kernel_size = pad_listlike(kernel_size, 2) - stride = pad_listlike(stride, 2) - padding = pad_listlike(padding, 2) - dilation = pad_listlike(dilation, 2) + kernel_size = pad_listlike(kernel_size, n_dim) + stride = pad_listlike(stride, n_dim) + padding = pad_listlike(padding, n_dim) + dilation = pad_listlike(dilation, n_dim) assert isinstance(x, TensorBox) - assert len(kernel_size) == 2 - assert len(stride) == 2 - assert len(padding) == 2 - assert len(dilation) == 2 - assert len(x.get_size()) in (3, 4) + assert len(kernel_size) == n_dim + assert len(stride) == n_dim + assert len(padding) == n_dim + assert len(dilation) == n_dim + assert len(x.get_size()) in (n_dim + 1, n_dim + 2) - use_fallback = should_fallback_max_pool2d_with_indices(kernel_size) + use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim) if assert_fallback is not None: assert use_fallback == assert_fallback return kernel_size, stride, padding, dilation, use_fallback -def _max_pool2d_with_offsets( +def _max_pool_with_offsets( x, kernel_size, stride, padding, dilation, - ceil_mode=False, + ceil_mode, + *, + n_dim, ): x.realize_hint() - *batch, h, w = x.get_size() + batch = x.shape[:-n_dim] + dhw = x.shape[-n_dim:] - h_out, ceil_mode1 = pooling_size( - h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation - ) - w_out, ceil_mode2 = pooling_size( - w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation + dhw_out, ceil_mode = zip( + *[ + pooling_size( + dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation + ) + for d in range(n_dim) + ] ) dtype = x.dtype @@ -4408,27 +4413,18 @@ def _max_pool2d_with_offsets( else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min) ) - new_size = list(batch) + [h_out, w_out] - if ( - padding[0] - or padding[1] - or ceil_mode1 - or ceil_mode2 - or (dilation[0] > 1) - or (dilation[1] > 1) - ): - x_loader = constant_boundary_condition(x, min_value, dim=2) + new_size = list(batch) + list(dhw_out) + if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation): + x_loader = constant_boundary_condition(x, min_value, dim=n_dim) else: x_loader = x.make_loader() - dim = 2 - def fn_inner(idx, reduction_idx): - prefix = idx[:-dim] - bh = idx[-dim:] + prefix = idx[:-n_dim] + bh = idx[-n_dim:] ih = [ (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] - for i in range(dim) + for i in range(n_dim) ] return x_loader([*prefix, *ih]) @@ -4462,8 +4458,8 @@ def fn_inner(idx, reduction_idx): return result, offsets -@register_lowering(prims._low_memory_max_pool2d_with_offsets, type_promotion_kind=None) -def _low_memory_max_pool2d_with_offsets( +@register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None) +def _low_memory_max_pool_with_offsets( x, kernel_size, stride, @@ -4471,53 +4467,60 @@ def _low_memory_max_pool2d_with_offsets( dilation, ceil_mode=False, ): + n_dim = len(kernel_size) + # assert we are not on a fallback path, the inductor decomp should have guaranteed this - kernel_size, stride, padding, dilation, _ = max_pool2d_checks( + kernel_size, stride, padding, dilation, _ = max_pool_checks( x, kernel_size, stride, padding, dilation, + n_dim, assert_fallback=False, ) with config.patch(unroll_reductions_threshold=25): - result, offsets = _max_pool2d_with_offsets( + result, offsets = _max_pool_with_offsets( x, kernel_size, stride, padding, dilation, ceil_mode, + n_dim=n_dim, ) return result, to_dtype(offsets, torch.int8) @register_lowering( - prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None + prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None ) -def _low_memory_max_pool2d_offsets_to_indices( - offsets, kernel_width, input_width, stride, padding, dilation +def _low_memory_max_pool_offsets_to_indices( + offsets, kernel_size, input_size, stride, padding, dilation ): - # TODO: Generalize to other max pooling flavors, and arbitrary dim - + # TODO: Generalize to other max pooling flavors + n_dim = len(kernel_size) offsets_loader = offsets.make_loader() - def increments_to_index(h_inc, w_inc, bh, bw): - w_in = ops.index_expr(input_width, torch.int64) - hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) - wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) - ih = hbase + h_inc * ops.constant(dilation[0], torch.int64) - iw = wbase + w_inc * ops.constant(dilation[1], torch.int64) - return ih * w_in + iw + def increments_to_index(dhw_inc, bh): + w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)] + hbase = [ + ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64) + for d in range(n_dim) + ] + idhw = [ + hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64) + for d in range(n_dim) + ] + return inductor_prims._flatten_index(idhw, w_in) def offsets_to_indices(idx): - *prefix, bh, bw = idx - offset = offsets_loader([*prefix, bh, bw]) - kw_const = ops.constant(kernel_width, torch.int32) - h_inc = offset // kw_const - w_inc = offset - (h_inc * kw_const) - return increments_to_index(h_inc, w_inc, bh, bw) + bh = idx[-n_dim:] + offset = offsets_loader(idx) + k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)] + dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const) + return increments_to_index(dhw_inc, bh) indices = Pointwise.create( device=offsets.get_device(), @@ -4528,6 +4531,35 @@ def offsets_to_indices(idx): return indices +def _max_pool_with_indices( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + n_dim, +): + kernel_size, stride, padding, dilation, _ = max_pool_checks( + x, kernel_size, stride, padding, dilation, n_dim=n_dim + ) + + out, offsets = _max_pool_with_offsets( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim + ) + + indices = _low_memory_max_pool_offsets_to_indices( + offsets, + kernel_size, + x.shape[-n_dim:], + stride, + padding, + dilation, + ) + + return out, indices + + # Fallback when we do not decompose to the low-memory path. @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) def max_pool2d_with_indices( @@ -4538,20 +4570,25 @@ def max_pool2d_with_indices( dilation=1, ceil_mode=False, ): - kernel_size, stride, padding, dilation, _ = max_pool2d_checks( - x, kernel_size, stride, padding, dilation + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2 ) - out, offsets = _max_pool2d_with_offsets( - x, kernel_size, stride, padding, dilation, ceil_mode - ) - indices = _low_memory_max_pool2d_offsets_to_indices( - offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation +# Fallback when we do not decompose to the low-memory path. +@register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None) +def max_pool3d_with_indices( + x, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + return _max_pool_with_indices( + x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3 ) - return out, indices - fallback_max_pool2d_with_indices_backward = fallback_handler( aten.max_pool2d_with_indices_backward.default, From 42c7c7f15f5e38c871f119ea12e27655c3d2dfba Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 1 Apr 2025 21:57:25 -0700 Subject: [PATCH 101/332] [invoke_subgraph] Filter out grad_out where fw_out requires_grad is False (#150486) I am not sure if this is the right way. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150486 Approved by: https://github.com/zou3519 ghstack dependencies: #150082, #150450 --- test/higher_order_ops/test_invoke_subgraph.py | 14 ++-- torch/_higher_order_ops/invoke_subgraph.py | 72 ++++++++++++------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index aa6ecdd15928..cd6c97fcf878 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -559,23 +559,27 @@ def test_simple_module(self): @mark_compile_region def gn(x): - return mod(x) + return torch.cos(x), mod(x) def fn(x): - return gn(x) + out = gn(x) + return out[0] + out[1] opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) # requires_grad is False deliberately to force None the joint_graph # outputs x = torch.randn(8, 8, requires_grad=False) + x_clone = x.detach().clone().requires_grad_(False) - ref = mod(x) - res = opt_fn(x) - self.assertEqual(ref, res) + ref = fn(x) + res = opt_fn(x_clone) ref.sum().backward() res.sum().backward() + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + def test_fail_with_direct_invoke_subgraph(self): from torch._higher_order_ops import invoke_subgraph diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 1e2c2ce95a30..840e32e29dfc 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -2,6 +2,7 @@ from contextlib import nullcontext +from dataclasses import dataclass, field from typing import Optional, Union import torch @@ -37,6 +38,15 @@ invoke_subgraph_counter = 0 +# During the tracing of the joint graph, we construct this information. This is +# used to filter out grad_outs/tangents in the `backward` method of +# InvokeSubgraphAutogradOp. +@dataclass +class FilterTangentInfo: + indexes_with_none: set[int] = field(default_factory=set) + indexes_with_no_grad: set[int] = field(default_factory=set) + + class InvokeSubgraphHOP(HigherOrderOperator): def __init__(self) -> None: super().__init__("invoke_subgraph") @@ -189,24 +199,34 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): else fake_mode.shape_env.ignore_fresh_unbacked_symbols() ) - if grad_outputs is None: - # Infer grad_outputs to be the same properties as the fw_outputs - # if they're not passed in - with context: - grad_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) - num_fw_outs = len(grad_outputs) + num_fw_outs = len(fw_outs) # Collect the indexes of none in the output to check that the grad # is None at the corresponding index in the backward. This check is # performed in the autograd.Function - InvokeSubgraphAutogradOp. - none_indexes_in_fwd_out = set() + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + filter_tangent_info = FilterTangentInfo() - for idx, grad in enumerate(grad_outputs): - if grad is None: - none_indexes_in_fwd_out.add(idx) + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + filter_tangent_info.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + filter_tangent_info.indexes_with_no_grad.add(idx) - grad_outputs = [grad for grad in grad_outputs if grad is not None] + if grad_outputs is None: + # Infer grad_outputs to be the same properties as the fw_outputs + # if they're not passed in + # Although fw_outs are equivalent to grad_outputs for tracing + # purposes, we have to carefully handle the None and fw_out that do + # not have require_grad. At those indexes, we will have None in the + # backward graph. + grad_outputs = fw_outs + grad_outputs = [grad for grad in grad_outputs if grad is not None] + grad_outputs = [grad for grad in grad_outputs if grad.requires_grad] if any( not isinstance(out, torch.Tensor) @@ -227,7 +247,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): fw_inputs, grad_outputs, ) - return fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out + return fw_graph, bw_graph, num_fw_outs, filter_tangent_info class InvokeSubgraphAutogradOp(torch.autograd.Function): @@ -243,14 +263,14 @@ def forward( bw_graph, identifier, num_fw_outs, - none_indexes_in_fwd_out, + filter_tangent_info, *operands, ): ctx._fw_graph = fw_graph ctx._bw_graph = bw_graph ctx._identifier = identifier ctx._num_fw_outs = num_fw_outs - ctx._none_indexes_in_fwd_out = none_indexes_in_fwd_out + ctx._filter_tangent_info = filter_tangent_info with torch._C._AutoDispatchBelowAutograd(): out = invoke_subgraph( @@ -264,7 +284,7 @@ def forward( # Check that None is at expected indexes. for idx, o in enumerate(out): if o is None: - assert idx in none_indexes_in_fwd_out + assert idx in filter_tangent_info.indexes_with_none return out @@ -274,18 +294,22 @@ def backward(ctx, *grad_outs): identifier = ctx._identifier primals = saved_tensors_and_symints(ctx) num_fw_outs = ctx._num_fw_outs - none_indexes_in_fwd_out = ctx._none_indexes_in_fwd_out + filter_tangent_info = ctx._filter_tangent_info # While tracing we made the assumption that tangents are contiguous. So, - # force the grad_outs to be contiguous. Some of the grads can be None, - # because the forward outs could be None. Filter them out. + # force the grad_outs to be contiguous. + # Also filter out grads that are None or do not require_grad. This was + # the assumption we made during the tracing of joint_graph. contiguous_grad_outs = [] for idx, o in enumerate(grad_outs): - if o is not None: - contiguous_grad_outs.append(o.contiguous()) + if o is None: + assert idx in filter_tangent_info.indexes_with_none + elif idx in filter_tangent_info.indexes_with_no_grad: + # Deliberately skip over the grad_outs which we know should be + # None because the corresponding fwd_out does not require_grad. + pass else: - # Check that None is at expected indexes. - assert idx in none_indexes_in_fwd_out + contiguous_grad_outs.append(o.contiguous()) contiguous_grad_outs = tuple(contiguous_grad_outs) # bw_graph is a joint graph with signature (*primals_and_tangents) and @@ -331,13 +355,13 @@ def _(subgraph, identifier, operands): ): return saved_autograd_fn(*operands) - fw_graph, bw_graph, num_fw_outs, none_indexes_in_fwd_out = create_fw_bw_graph( + fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph( subgraph, operands ) def autograd_fn_callable(*args): return InvokeSubgraphAutogradOp.apply( - fw_graph, bw_graph, identifier, num_fw_outs, none_indexes_in_fwd_out, *args + fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args ) # Save the autograd_fn_callable in the dispatch set cache. From 8102272d8c5b5a3063446ec67877eea495e6d323 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Wed, 2 Apr 2025 15:48:11 +0000 Subject: [PATCH 102/332] [BE] Fix triton windows build (#150512) Fixes #150480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150512 Approved by: https://github.com/atalman Co-authored-by: Andrey Talman --- .github/scripts/windows/build_triton.bat | 3 ++- .github/workflows/build-triton-wheel.yml | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/scripts/windows/build_triton.bat b/.github/scripts/windows/build_triton.bat index 245740c66cdb..97cd535a4988 100644 --- a/.github/scripts/windows/build_triton.bat +++ b/.github/scripts/windows/build_triton.bat @@ -9,7 +9,8 @@ if "%PY_VERS%" == "3.13t" ( ) else ( call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS% ) -call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake setuptools==72.1.0 ninja +:: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480 +call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja dir "%VC_INSTALL_PATH%" diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 9921b018fcc3..b4e9ec34f3da 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -12,6 +12,8 @@ on: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt + - .github/scripts/windows/install_vs2022.ps1 + - .github/scripts/windows/build_triton.bat - .ci/docker/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton-xpu.txt workflow_dispatch: @@ -20,6 +22,8 @@ on: - .github/workflows/build-triton-wheel.yml - .github/scripts/build_triton_wheel.py - .github/ci_commit_pins/triton.txt + - .github/scripts/windows/install_vs2022.ps1 + - .github/scripts/windows/build_triton.bat - .ci/docker/ci_commit_pins/triton.txt - .ci/docker/ci_commit_pins/triton-xpu.txt @@ -244,7 +248,6 @@ jobs: .github/scripts/windows/build_triton.bat mkdir -p "${RUNNER_TEMP}/artifacts/" mv ./*.whl "${RUNNER_TEMP}/artifacts/" - - uses: actions/upload-artifact@v4.4.0 with: name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }} From f38566dfe43f7b63f795b609741d04404bdf8d67 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Wed, 2 Apr 2025 16:07:18 +0000 Subject: [PATCH 103/332] [MPSInductor] Disable mm/bmm decompositions (#150541) Disables mm/bmm decompositions. torch.compile on MPS was speeding up stories15M (~4x) but it was making stories110M much slower. Self-contained reproducer to demonstrate the difference (before the change, after it should be identical) ```python import torch import timeit def bench_mm(f, x, y): from torch.utils.benchmark import Timer return Timer(stmt="f(x, y); torch.mps.synchronize()", globals={"x": x, "y": y, "f": f}, language="python", timer=timeit.default_timer).blocked_autorange() x = torch.rand(1024, 512, device='mps') y = torch.rand(512, 1, device='mps') mm_c = torch.compile(torch.mm, options={"coordinate_descent_tuning": False}) mm_c_cdt = torch.compile(torch.mm, options={"coordinate_descent_tuning": True}) print(f"Compiled torch.mm perf (with cdt disabled) for 1024x512 and 512x1 matrices are {bench_mm(mm_c, x, y).median}") print(f"Compiled torch.mm perf (with cdt enabled) for 1024x512 and 512x1 matrices are {bench_mm(mm_c_cdt, x, y).median}") ``` Disabling the inductor mm decomposition, speeds up stories15M further (~6x) and speeds up stories110M (~7x) The table below show average tokens/sec across 5 runs on M1 Pro for stories15M and stories110M: | | stories15M | stories110M | |------------------------|------------|-------------| | without compile | 99.40 | 53.11 | | compile before change | 367.68 | 19.43 | | compile after change | 582.96 | 355.07 | stories110M (without compile) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps [...] Average tokens/sec: 53.11 ``` stories110M (compile before change) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps --compile [...] Average tokens/sec: 19.43 ``` stories110M (compile after change) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps --compile [...] Average tokens/sec: 355.07 ``` stories15M (without compile) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps [...] Average tokens/sec: 99.40 ``` stories15M (compile before change) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps --compile [...] Average tokens/sec: 367.68 ``` stories15M (compile after change) ``` (gptfast) mcandales@mcandales-mbp gpt-fast % python generate.py --checkpoint_path checkpoints/stories110M/stories110M.pt --prompt "Once upon a time" --device mps --compile [...] Average tokens/sec: 582.96 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150541 Approved by: https://github.com/malfet --- torch/_inductor/decomposition.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 82d0eb6a1320..2dd8a47feb4a 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -262,7 +262,9 @@ def bmm( self: torch.Tensor, batch2: torch.Tensor, ) -> torch.Tensor: - if config.coordinate_descent_tuning and self.device.type != "cpu": + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious( batch2.shape[2] == 1 ): @@ -316,7 +318,10 @@ def mm( ) -> torch.Tensor: # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning. # todo: Look into why and fix it (hopefully) - if config.coordinate_descent_tuning and self.device.type != "cpu": + + # TODO: Re-enable for mps once our reductions are performant enough + # (https://github.com/pytorch/pytorch/issues/150121) + if config.coordinate_descent_tuning and self.device.type not in ["cpu", "mps"]: if guard_size_oblivious(self.shape[0] == 1) or guard_size_oblivious( input2.shape[1] == 1 ): From 532530be34bf396cc632b5c8d95253953e3f7717 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 16:40:54 +0000 Subject: [PATCH 104/332] Revert "[Profiler] Fix Empty C Call Queue (#150370)" This reverts commit 5734909f343ab1de44ed5ab23311d43a9c6afaed. Reverted https://github.com/pytorch/pytorch/pull/150370 on behalf of https://github.com/clee2000 due to broke some profiler tests when building with debug asserts profiler/test_memory_profiler.py::TestMemoryProfiler::test_config_check [GH job link](https://github.com/pytorch/pytorch/actions/runs/14211763078/job/39822158330) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/3ac5a499ddac701f607a9f7206f9bec8871e1cbb) ([comment](https://github.com/pytorch/pytorch/pull/150370#issuecomment-2773146070)) --- torch/csrc/autograd/profiler_python.cpp | 39 ++----------------------- 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 02ab02856864..a98d1a8b7934 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -709,8 +709,6 @@ class PythonTracer final : public python_tracer::PythonTracerBase { const std::vector interpreterThreads() const; - PyObject* get_callable_from_frame(PyFrameObject* frame); - std::atomic active_lock_{false}; bool active_{false}; @@ -789,13 +787,6 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { recordPyCall(thread_local_results_.back(), it->get(), true); - PyFrameObject* frame = it->get(); - PyObject* callable = get_callable_from_frame(frame); - if (callable) { - // Call recordCCall with the callable and the frame - recordCCall(thread_local_results_.back(), it->get(), callable); - } - auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -910,26 +901,6 @@ void PythonTracer::recordCCall( queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); } -PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) { - if (frame == nullptr) { - return nullptr; - } - // Get the code object associated with the frame - auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); - if (code == nullptr) { - return nullptr; - } - // Get the function name (if needed) - auto name = THPUtils_unpackStringView(code->co_name).data(); - // To get the function object, you will need to look in the globals or the - // frame's f_globals - PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name); - if (func) { - Py_INCREF(func); // Make sure the returned function has a reference - } - return func; // Returns a PyObject* (the function) -} - // ============================================================================ // == Post processing ========================================================= // ============================================================================ @@ -1012,13 +983,9 @@ class PostProcess { using stack_t = std::vector>; const auto initial_size = out.size(); auto pop = [](stack_t& stack, c10::time_t t) { - if (!stack.empty()) { - std::get>(stack.back()->extra_fields_).end_time_ns_ = t; - stack.pop_back(); - } else { - TORCH_WARN_ONCE( - "Python replay stack is empty during pop operation! May result in incorrect stack tracing."); - } + TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty."); + std::get>(stack.back()->extra_fields_).end_time_ns_ = t; + stack.pop_back(); }; ska::flat_hash_map stacks; From 98453c135a7778d12ff881d8b0a717257be9fc38 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:39 -0700 Subject: [PATCH 105/332] [dynamo] Support Tensor subclass that has dynamic attributes or calls `Parameter.__torch_function__` (#149482) This fixes most of https://github.com/huggingface/diffusers/issues/10795, except for `torch.Tensor._make_subclass`, which will be fixed in a subsequent patch. The relevant tensor subclass from the aforementioned issue is defined here: https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435. There are two things to note about the tensor subclass: 1. it calls `super().__torch_function__`, which is `torch._C._disabled_torch_function_impl`, so this patch updates `SuperVariable.call_method` to handle it (we can't do a simpler polyfill due to some bug with `var_getattr` raising `NotImplementedError`, which forgot to restore symbolic context). 2. it sets and reads attributes (`quant_type`), and defines new methods (`as_data`), so this patch adds support for those. 3. it has a `__init__`, which Dynamo needs to trace through in `TensorSubclassVariable.call_function`. Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482 Approved by: https://github.com/jansel, https://github.com/mlazos --- test/dynamo/test_misc.py | 16 +++ test/dynamo/test_subclasses.py | 134 ++++++++++++++++++ .../TestTorch.test_tensor_ressurecting_clear | 1 + ...edding_swap_True_set_grad_True_cpu_float32 | 0 ..._PReLU_swap_True_set_grad_True_cpu_float32 | 0 ...MSNorm_swap_True_set_grad_True_cpu_float32 | 0 ...dding_swap_True_set_grad_True_cuda_float32 | 0 ...PReLU_swap_True_set_grad_True_cuda_float32 | 0 ...SNorm_swap_True_set_grad_True_cuda_float32 | 0 torch/_dynamo/side_effects.py | 66 ++++++--- torch/_dynamo/trace_rules.py | 1 - torch/_dynamo/variables/builder.py | 89 ++++++++---- torch/_dynamo/variables/builtin.py | 14 ++ torch/_dynamo/variables/misc.py | 31 ++++ torch/_dynamo/variables/tensor.py | 44 ++++-- torch/_dynamo/variables/torch_function.py | 64 +++++---- 16 files changed, 373 insertions(+), 87 deletions(-) create mode 100644 test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a6b3a29eb4d3..8579ee8e1b2e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -586,6 +586,22 @@ def f(x): ref = f(x) self.assertEqual(res, ref) + def test_newly_constructed_tensor_attr_mutation(self): + def f(x): + y = x + 10 + y.grad = x + y.foo = 42 + return y + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.ones(5) + + res = opt_f(x) + ref = f(x) + self.assertEqual(res, ref) + self.assertEqual(res.grad, ref.grad) + self.assertEqual(res.foo, ref.foo) + def test_closure_recompiles(self): cnt = CompileCounter() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ef2acadac89d..7fefc281089b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,6 +954,140 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): + # This is a slight variation of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + # which basically + # 1. uses tensor subclass to attach quantization metadata onto tensors + # 2. preserve them across torch ops + # 3. use the metadata to dequantize the tensor + # 4. convert it to a regular tensor. + # + # The test is meant to make sure Dynamo won't graph break over it. + class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + return self + + def __init__(self, *args, quant_type=None, **kwargs): + self.quant_type = quant_type + + def as_tensor(self): + return torch.Tensor(self.data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + quant_type = None + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [ + cls(x, quant_type=quant_type) + if isinstance(x, torch.Tensor) + else x + for x in result + ] + return type(result)(wrapped) + else: + return result + + def f(x): + tmp = x * 2 + tmp = tmp + tmp.quant_type + tmp = tmp.as_tensor() + return tmp * 3 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = GGUFParameter(torch.ones(2), quant_type=42) + with traceable_subclass(GGUFParameter): + res = f(x) + ref = opt_f(x) + self.assertEqual(res, ref) + + def test_newly_constructed_tensor_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from constructor call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = MySubclass(torch.ones(2)) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_as_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from as_subclass call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = torch.ones(2).as_subclass(MySubclass) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_tensor_subclass_attr_codegen_tos(self): + # This repros a very subtle interaction between + # `TensorWithTFOverrideVariable` attribute mutation codegen and + # `PyCodegen.top_of_stack`. It was uncovered from + # `test_tensor_subclass_deepcopy`. + class MySubclass(torch.Tensor): + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_subclass(cls, torch.ones(0)) + r.elem = elem + return r + + def f(t): + return MySubclass(t.elem.clone()) + + opt_f = compile_full_eager(f) + + t = MySubclass(torch.ones(2)) + with traceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + # TODO uncomment once we trace into `__new__`. + # self.assertEqual(res, ref) + # self.assertEqual(res.elem, ref.elem) + self.assertEqual(type(res), type(ref)) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear new file mode 100644 index 000000000000..276a4f74bbca --- /dev/null +++ b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149881 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 4c85d98cfd16..1deb09e2cc1e 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -295,9 +295,7 @@ def _track_obj( variable: VariableTracker, mutation_type_cls=ValueMutationExisting, ): - """Start tracking a new variable for mutation""" - assert variable.source is not None - + """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( f"{variable} is already tracked for mutation. This could be " @@ -576,12 +574,18 @@ def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): - # Make sure we codegen these modified VT to their source by default, so - # that mutation and aliasing are properly accounted for. + # We must codegen modified VT to their source by default, so that + # mutation and aliasing are properly accounted for. + # + # Since newly constructed objects don't have a source, we manually + # codegen their construction and store them to a newly assigned local + # source. Note that `ValueMutationNew` isn't tracked by SideEffects. for var in self._get_modified_vars(): - if isinstance(var.mutation_type, AttributeMutationNew) and isinstance( - var, variables.CellVariable - ): + if not isinstance(var.mutation_type, AttributeMutationNew): + assert var.source is not None + continue + + if isinstance(var, variables.CellVariable): # Cells created in the root frame are created either by # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit # `make_cell` for the non-root-frame cells here. @@ -595,18 +599,38 @@ def codegen_save_tempvars(self, cg: PyCodegen): var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: var.source = LocalCellSource(var.local_name) - elif isinstance(var.mutation_type, AttributeMutationNew): - if isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented_v2( - gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", - context="", - explanation="We cannot reconstruct a torch.autograd.Function's context object.", - hints=[], - ) - + elif isinstance(var, variables.TensorVariable): + # NOTE: for historical reasons we never assigned local sources + # to newly constructed tensor object, so we keep it that way. + # They are always loaded from output of the fx graph, so one can + # think of it as having a "OutputGraphSource" for codegen + # purposes. + # + # However, tensor subclass objects are different, because the + # reconstruction logic in `PyCodegen` loads the data tensor from + # graph output and then calls `as_subclass`, meaning we must + # assign a source to it to ensure we only reconstruct one + # subclass instance. + if isinstance( + var, variables.torch_function.TensorWithTFOverrideVariable + ): + # Don't codegen from temp source assigned from the 1st pass. + cg(var, allow_cache=False) + cg.add_cache(var) + # `add_cache` generates STORE and consumes TOS, but we never + # cleared it. TODO move this call into `add_cache` + cg.clear_tos() + var.source = LocalSource(cg.tempvars[var]) + elif isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) + else: # Reconstruct the bytecode for # base_cls.__new__(user_cls, *args) - if isinstance(var, variables.UserDefinedObjectVariable): def load_new_method(): @@ -630,10 +654,6 @@ def load_new_method(): cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) - else: - # The remaning cases here are `AttributeMutationExisting` and - # `MutableSideEffects`, which have sources already. - assert var.source is not None for ctx, args in self.save_for_backward: cg(ctx.source) @@ -993,7 +1013,7 @@ def codegen_update_mutated(self, cg: PyCodegen): else: cg.tx.output.update_co_names(name) cg(value) - cg(var.source) + cg(var) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.ListIteratorVariable): for _ in range(var.index): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 05739259dc5b..4b3eb10d09e7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -510,7 +510,6 @@ "torch._C._debug_set_fusion_group_inlining", "torch._C._demangle", "torch._C._disabled_torch_dispatch_impl", - "torch._C._disabled_torch_function_impl", "torch._C._dispatch_call_boxed", "torch._C._dispatch_check_all_invariants", "torch._C._dispatch_check_invariants", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 49d0c162d68a..1cd8001e4c34 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -140,6 +140,7 @@ wrap_fake_exception, ) from .base import ( + AttributeMutationNew, typestr, ValueMutationExisting, ValueMutationNew, @@ -2470,7 +2471,9 @@ def _wrap_fx_preexisting_tensor( f"wrapped by this instance of Dynamo. Found: {tensor}" ) - return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls) + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) # This is 2 in the above comment (wrapping the output of a traced op) @@ -2504,36 +2507,23 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe import torch._utils if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - is_buffer = isinstance(example_value, torch.nn.Buffer) - - # NB: In most (all?) cases, this does not actually do a clone. - # (WARNING: this means that if we mutate metadata on the fake - # tensor, the stored example value will update too!) - example_value = _clone_input(example_value, tx.fake_mode) - set_example_value(proxy.node, example_value) - # We bind the unbacked symints in sizes/trdies of tensor lazily. - # So that subgraphs can access the unbacked symbol's proxy in parent graph - # when lifting unbacked symbols of input tensors to subgraph inputs. - # We do it lazily because the tensor may not be used in subgraphs. - tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) - specialized_props = target_cls.specialize(example_value) - # TODO: not sure about this fake mode test - if ( - isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) - and example_value.fake_mode is tx.fake_mode - ): - tensor_type = subclass_type if subclass_type else torch.Tensor - specialized_props["class_type"] = ( - torch.nn.Parameter - if is_parameter - else torch.nn.Buffer - if is_buffer - else tensor_type - ) - - options.update(specialized_props) - return target_cls(proxy, **options) + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # speically here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" @@ -2702,6 +2692,43 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ) +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + options.update(specialized_props) + return target_cls(proxy, **options) + + def get_automatic_dynamic_shapes_mark_as(): if config.automatic_dynamic_shapes_mark_as == "dynamic": return DimDynamic.DYNAMIC diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c66c369876b9..9c11423162d3 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1933,6 +1933,20 @@ def call_setattr( "the middle of the graph, which aot_autograd does not currently know how to handle. " ) elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented_v2( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause # tracked fakes to produce incorrect guards. This is sound because the TensorVariable diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7eaa01c2a5da..c6a2124c871a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -161,6 +161,14 @@ def call_method( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -266,6 +274,29 @@ def call_method( source = self.source and AttrSource(self.source, attr_name) return VariableTracker.build(tx, attr_value, source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + output_old = tx.output.torch_function_enabled + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.output.torch_function_enabled = False + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.output.torch_function_enabled = output_old + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5b10a643ad94..99bd1f3eb552 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -67,7 +67,7 @@ set_example_value, tensortype_to_dtype, ) -from .base import VariableTracker +from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable @@ -789,9 +789,14 @@ def method_as_subclass(self, cls): tx = InstructionTranslator.current_tx() py_cls = cls.as_python_constant() - return TensorWithTFOverrideVariable.from_tensor_var( + var = TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, cls.source ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var def method_get_device(self): if isinstance(self.device, torch.device): @@ -1443,14 +1448,37 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable + # Handle `Subclass(existing_tensor)` calls. + def impl(): + if len(args) == 1 and isinstance(args[0], TensorVariable): + from .torch_function import TensorWithTFOverrideVariable + + # This simulates `__new__` and _assumes_ it doesn't have + # side-effects that matters to Dynamo tracing. TODO trace through + # `__new__`. + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, args[0], self.value, self.source + ) - return TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source - ) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + cls_kwargs = kwargs or {} + VariableTracker.build(tx, init_func).call_function( + tx, [var], cls_kwargs + ) + return var + + return super().call_function(tx, args, kwargs) - return super().call_function(tx, args, kwargs) + var = impl() + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index e51f6ccd6c9d..3946dccd8dc7 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -62,6 +62,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -592,12 +593,9 @@ def __init__(self, *args, **kwargs) -> None: def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. - # This isn't sound if a __torch_function__ tensor subclass defines a - # constructor, but if only a __torch_function__ impl is defined, this is - # okay to call. It is up to the user whether this is correct behavior - # or not. import torch + # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) assert kwargs.pop("class_type") is torch.Tensor, ( "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" @@ -640,30 +638,48 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + if hasattr(torch.Tensor, name): + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overridden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) - if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): - if self.source: - install_guard( - AttrSource(AttrSource(self.source, "__class__"), name).make_guard( - GuardBuilder.FUNCTION_MATCH + if tx.output.torch_function_enabled: + if self.source: + install_guard( + AttrSource( + AttrSource(self.source, "__class__"), name + ).make_guard(GuardBuilder.FUNCTION_MATCH) ) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + [self], + {}, ) - get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) - - return self.call_torch_function( - tx, - get_fn, - TupleVariable([self.class_type_var(tx)]), - [self], - {}, - ) else: - return super().var_getattr(tx, name) + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + if isinstance(attr, types.FunctionType): + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + func_source = AttrSource(cls_source, name) + install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return UserMethodVariable(attr, self) + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( From 203e1d681d1a4eb7794dfaeaebfa497242dde17d Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:40 -0700 Subject: [PATCH 106/332] [dynamo] Support `torch.Tensor._make_subclass` and tracing through tensor subclass `__new__` (#149483) This builds off the previous patch in the stack, and fully fixes https://github.com/huggingface/diffusers/issues/10795. Essentially, tensor subclass in the issue uses `torch.Tensor._make_subclass`, which has a pretty simple shallow-copy plus type change semantics, as far as Dynamo is concerned. So this patch adds a polyfill for it. As a result, this allows us to trace through many user-defined `__new__` in tensor subclass (it's similar to how we trace through user-defined `__new__` for `UserDefinedClassVariable`), so this patch also faithfully trace through these `__new__` methods. Differential Revision: [D71906139](https://our.internmc.facebook.com/intern/diff/D71906139) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149483 Approved by: https://github.com/zou3519, https://github.com/mlazos ghstack dependencies: #149482 --- test/dynamo/test_subclasses.py | 37 ++++++++-- .../TestGradNewOnesOverride.test_newones | 0 .../TestIterator.test_iterator | 0 .../TestNamedTuple.test_max | 0 .../TestPickle.test_pickle | 0 torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/tensor.py | 37 ++++++++++ torch/_dynamo/variables/tensor.py | 67 +++++++++++++------ 8 files changed, 116 insertions(+), 26 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones delete mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator delete mode 100644 test/dynamo_expected_failures/TestNamedTuple.test_max delete mode 100644 test/dynamo_expected_failures/TestPickle.test_pickle create mode 100644 torch/_dynamo/polyfills/tensor.py diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7fefc281089b..99b7ab9784ae 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,6 +954,34 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) + def test_make_subclass(self): + # Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo + # models its aliasing relationships correctly. + class MySubclass(torch.Tensor): + pass + + def fn(x): + # Downcast then upcast + y = torch.Tensor._make_subclass(MySubclass, x) + z = torch.Tensor._make_subclass(torch.Tensor, x) + # Now `x, y, z` should have the same underlying data. + x += 1 + y += 2 + z += 3 + res = x * y + z + return res + + with traceable_subclass(MySubclass): + x0 = torch.randn(2, 2) + x1 = x0.clone() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 @@ -974,7 +1002,9 @@ def __init__(self, *args, quant_type=None, **kwargs): self.quant_type = quant_type def as_tensor(self): - return torch.Tensor(self.data) + return torch.Tensor._make_subclass( + torch.Tensor, self, self.requires_grad + ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1083,9 +1113,8 @@ def f(t): res = f(t) ref = opt_f(t) - # TODO uncomment once we trace into `__new__`. - # self.assertEqual(res, ref) - # self.assertEqual(res.elem, ref.elem) + self.assertEqual(res, ref) + self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) def test_compile_with_fake_tensor_dynamic_dim(self): diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNamedTuple.test_max b/test/dynamo_expected_failures/TestNamedTuple.test_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPickle.test_pickle b/test/dynamo_expected_failures/TestPickle.test_pickle deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d9be4e9febc9..f60aa57a5d40 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -21,6 +21,7 @@ "pytree", "sys", "fx", + "tensor", ) POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/tensor.py b/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 000000000000..002ccf5d1d4f --- /dev/null +++ b/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,37 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, "_make_subclass only supports requires_grad as keyword arg" + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 99bd1f3eb552..44b3ffc27689 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -797,6 +797,15 @@ def method_as_subclass(self, cls): object(), var, mutation_type_cls=AttributeMutationNew ) return var + unimplemented_v2( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) def method_get_device(self): if isinstance(self.device, torch.device): @@ -1448,32 +1457,46 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - # Handle `Subclass(existing_tensor)` calls. - def impl(): - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable - - # This simulates `__new__` and _assumes_ it doesn't have - # side-effects that matters to Dynamo tracing. TODO trace through - # `__new__`. + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable + + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if ( + len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? var = TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source + tx, data, self.value, self.source ) + else: + unimplemented_v2( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs + ) - # Let Dynamo trace through custom `__init__` - init_func = self.value.__init__ - # TODO builder should be able to handle `torch.Tensor.__init__`, - # which is `object.__init__`, so that we can remove this check. - if init_func is not torch.Tensor.__init__: - cls_kwargs = kwargs or {} - VariableTracker.build(tx, init_func).call_function( - tx, [var], cls_kwargs - ) - return var - - return super().call_function(tx, args, kwargs) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) - var = impl() # See NOTE [Side effect tracking for newly constructed tensor] tx.output.side_effects._track_obj( object(), var, mutation_type_cls=AttributeMutationNew From 7e53c58687482d58461e1dd8e09f59a9daf8f7b3 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:40 -0700 Subject: [PATCH 107/332] [dynamo] Support tensor subclass with overriden tensor methods and properties (#149484) This fixes most of the "torch.compile X tensor-subclass" issues encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The relevant tensor subclass definition is here: https://github.com/city96/ComfyUI-GGUF/blob/298192ed60f8ca821c6fe5f8030cae23424cada5/ops.py#L18-L65. A few things to note about the tensor subclass: 1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`, `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr` to support that. 2. it overrides the `shape` property, so this patch updates `TensorWithTFOverrideVariable.var_getattr` to support property as well. 3. it has calls to `torch.Tensor.size`, which returns `torch.Size`, which gets reconstructed in `torch.Tensor.__torch_function__`, so this patch adds support for calling `torch.Size(...)` on non-constant inputs. Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: #149482, #149483 --- test/dynamo/test_subclasses.py | 131 +++++++++++++++++----- torch/_dynamo/variables/misc.py | 15 ++- torch/_dynamo/variables/torch_function.py | 33 ++++-- torch/_dynamo/variables/user_defined.py | 5 + 4 files changed, 137 insertions(+), 47 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 99b7ab9784ae..df6397df8257 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -757,26 +757,22 @@ def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None - @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() - msg = ( - "Accessing overridden method/attribute sigmoid on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @@ -792,10 +788,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def fn(x): return x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + msg = "Currently only support accessing overridden attributes that are functions or properties, but got " with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): @@ -804,13 +797,11 @@ def fn(x): def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): - def __init__(self) -> None: + def __init__(self, *args, **kwargs) -> None: self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property @@ -821,19 +812,17 @@ def ndim(self): def ndim(self, value): self._ndim = value - @torch.compile(backend="eager", fullgraph=True) def fn(x): - return x.ndim + return x + x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = LocalSubclass(torch.ones(2, 2)) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @@ -982,6 +971,88 @@ def fn(x): self.assertEqual(res_exp, res_act) self.assertEqual(x0, x1) + def test_subclass_override_shape_and_to(self): + # This is a slight variabtion of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + class MySubclass(torch.Tensor): + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + return new + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + def fn(x): + x_shape = x.shape + y = x.to("cpu") + return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape + + with traceable_subclass(MySubclass): + x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + self.assertEqual(x0.tensor_shape, x1.tensor_shape) + + def test_subclass_dont_invoke_torch_function_on_overriden_method(self): + # We shouldn't fire `__torch_function__` for overriden tensor methods. + class MySubclass(torch.Tensor): + def to(self, device): + return self * len(device) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func is torch.Tensor.to: + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x.to("cpu") + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_subclass_dont_invoke_torch_function_on_overriden_attr(self): + from types import MethodWrapperType + + # We shouldn't fire `__torch_function__` for overriden tensor attrs. + class MySubclass(torch.Tensor): + def ndim(self): + return 42 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if type(func) is MethodWrapperType and func.__name__ == "ndim": + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x + x.ndim() + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c6a2124c871a..2c92599a8b28 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -32,7 +32,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, variables +from .. import config, trace_rules, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 @@ -297,6 +297,14 @@ def call_method( tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( tx_old ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) unimplemented(f"non-function or method super: {inner_fn}") @@ -669,11 +677,10 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ): - from ..trace_rules import is_callable_allowed from .builder import wrap_fx_proxy if name == "apply": - if is_callable_allowed(self.fn_cls): + if trace_rules.is_callable_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) @@ -691,8 +698,6 @@ def call_method( elif name == "backward": return self.call_backward(tx, args, kwargs) else: - from .. import trace_rules - source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.fn_cls, name) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 3946dccd8dc7..9f24f669e398 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -597,8 +597,9 @@ def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) - assert kwargs.pop("class_type") is torch.Tensor, ( - "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" ) torch_fn_var = build_torch_function_fn(tx, class_type, cls_source) var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs) @@ -638,13 +639,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if hasattr(torch.Tensor, name): - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - + # Handle non-overriden attributes inherited from `torch.Tensor`. + attr_is_overriden = _is_attr_overidden(tx, self, name) + if hasattr(torch.Tensor, name) and not attr_is_overriden: if tx.output.torch_function_enabled: if self.source: install_guard( @@ -674,11 +671,23 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: import types + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) if isinstance(attr, types.FunctionType): - cls_source = GlobalSource(self.global_mangled_class_name(tx)) - func_source = AttrSource(cls_source, name) - install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH)) return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = UserMethodVariable(getter, self, source=getter_source) + return getter_var.call_function(tx, [], {}) + + elif attr_is_overriden: + unimplemented( + f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950 + ) + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index b842a552649f..2d22e0d35805 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,7 @@ ) from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker from .dicts import DefaultDictVariable +from .lists import SizeVariable try: @@ -579,6 +580,10 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) items = list(args) From 238109ad3245c5485f9e83b4b02d258b09329042 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:41 -0700 Subject: [PATCH 108/332] [dynamo] Always trace into tensor subclass `__torch_function__` (#149792) This patch effectively ignores traceable_tensor_subclasses, allowing Dynamo to always try tracing into the `__torch_function__` of tensor subclass. This helps us with 2 things: 1. allowing users to directly benefit from better compilation of tensor subclass, by just upgrading pytorch, without having to change legacy library code (see earlier patches in the stack for examples). 2. potentially exposing more issues in compiling tensor subclass, so we can get signals and improve them. As a consequence, it exposed and fixes 2 subtle bugs: 1. In `build_torch_function_fn`, we could get `torch._C._disabled_torch_function_impl` because we have a `Parameter` subclass without `__torch_function__` override or if we have a tensor subclass with `__torch_dispatch__` override. We graph break on this for now, and plan to add support -- the logic for simulating `torch._C._disabled_torch_function_impl` is already in `SuperVariable`, we just need to reuse it. 2. Sometimes we create `SyntheticLocalSource` and need to remove all the guards installed on it, but we only removed the ones whose source _is_ the created synthetic source `s`, but forgot about chained source like `s.foo`, this showed up as `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`. Differential Revision: [D71906141](https://our.internmc.facebook.com/intern/diff/D71906141) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149792 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: #149482, #149483, #149484 --- test/dynamo/test_subclasses.py | 28 ++++++ ..._preserve_torch_function_when_return_as_is | 10 ++ .../TestGradNewOnesOverride.test_newones | 1 + .../TestIterator.test_iterator | 1 + .../TestLazyModules.test_lazy_module_buffer | 1 + ...estLazyModules.test_lazy_module_jit_buffer | 1 + .../TestTorchFunctionMode.test_subclass_hash | 10 ++ ..._on_invalid_torch_function_tensor_subclass | 3 + test/profiler/test_profiler_tree.py | 1 + torch/_dynamo/config.py | 26 ++--- torch/_dynamo/source.py | 6 ++ torch/_dynamo/utils.py | 1 + torch/_dynamo/variables/builder.py | 99 ++++++++++++------- torch/_dynamo/variables/tensor.py | 11 +-- torch/_dynamo/variables/torch.py | 5 +- torch/_dynamo/variables/torch_function.py | 3 - torch/_guards.py | 8 +- torch/nn/attention/bias.py | 8 +- 18 files changed, 151 insertions(+), 72 deletions(-) create mode 100644 test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is create mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones create mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator create mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer create mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer create mode 100644 test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash create mode 100644 test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index df6397df8257..0e7d54c28448 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -40,6 +40,10 @@ def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) +def nontraceable_subclass(c): + return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c}) + + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) @@ -1188,6 +1192,30 @@ def f(t): self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) + def test_nontraceable_tensor_subclass(self): + # This will error if Dynamo tries to wrap it as a tensor variable, + # because that involves calling certain methods to inspect the tensor + # property, which will blow up in the overriden `__torch_function__`. + class MySubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("one shall not pass") + + def f(t): + return t.foo + torch.ones(10) + + opt_f = torch.compile(f, backend="eager", fullgraph=False) + + t = MySubclass(torch.ones(2)) + t.foo = 42 + # Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo + # from wrapping `t`. + with nontraceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is new file mode 100644 index 000000000000..f243ff1904b0 --- /dev/null +++ b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is @@ -0,0 +1,10 @@ +- Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._data = data` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones new file mode 100644 index 000000000000..24f34ca8e8e6 --- /dev/null +++ b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149975 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator new file mode 100644 index 000000000000..880a24b122bb --- /dev/null +++ b/test/dynamo_expected_failures/TestIterator.test_iterator @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/150005 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash new file mode 100644 index 000000000000..beb4bf5d003a --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash @@ -0,0 +1,10 @@ +Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._diag = _diag` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass new file mode 100644 index 000000000000..c2ddc08d1e40 --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass @@ -0,0 +1,3 @@ +Dynamo cannot query properties of the tensor subclass object when wrapping it +into a VT, because it has a `__torch_function__` that only allows limited +torch ops. diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 7dac5fb70905..48bbbf01727f 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -690,6 +690,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ...""", ) + @skipIfTorchDynamo("segfaults in 3.13+") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5d58efdeed09..b59e1c49e607 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -152,26 +152,16 @@ # Non-Inductor backends can use this list for graph freezing. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1" - -# This feature doesn't really work. We offer this flag for experimental -# purposes / if you want to help us build out support. -# -# torchdynamo has limited support for tensor subclasses that implement -# __torch_function__ see [Note: __torch_function__] in torch_function.py. -# Our current support is limited to tensor subclasses -# that DO NOT store metadata on the tensor (in general, dynamo does not -# support Python code that stores extra attributes on tensors at present). -# If your tensor subclass purely changes function call behavior via -# __torch_function__, you can allow torchdynamo to trace into it by -# adding it to traceable_tensor_subclasses. We don't do any safety checks, -# so it is up to you to ensure that your subclass is well behaved. See also -# https://github.com/pytorch/torchdynamo/issues/1948 -# -# We do NOT currently support __torch_dispatch__. The implementation is -# currently buggy, the main show stopper for nontrivial use is -# https://github.com/pytorch/torchdynamo/issues/1952 +# NOTE this has been deprecated, it does nothing now. traceable_tensor_subclasses: set[type[Any]] = set() +# If a tensor subclass is put into this set, Dynamo will model its instasnces in +# a very conservative and limited way (most likely causing lots of graph breaks +# if one apply tensor ops on these instances). This is useful if you encounter +# internal compiler errors from Dynamo which are caused by tensor subclasses, +# and you are willing to tolerate potential graph breaks rather than hard error. +nontraceable_tensor_subclasses: set[type[Any]] = set() + # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index e01c166c97d2..4116f110b21d 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -842,6 +842,12 @@ def is_from_local_source(source: Source, *, only_allow_input=False): return True +def is_from_source(source: Source, target: Source): + if isinstance(source, ChainedSource): + return is_from_source(source.base, target) + return source == target + + def is_from_unspecialized_param_buffer_source(source: Source): if isinstance(source, UnspecializedParamBufferSource): return True diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8ee9289633b1..8fa038ce7116 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1408,6 +1408,7 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: "reorderable_logging_functions", "ignore_logger_methods", "traceable_tensor_subclasses", + "nontraceable_tensor_subclasses", "_custom_ops_profile", } diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 1cd8001e4c34..d5cea823b7f6 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -68,7 +68,11 @@ SymbolicContext, ) from torch.fx.immutable_collections import immutable_dict, immutable_list -from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) from torch.utils._sympy.value_ranges import ValueRanges from torch.utils.weak import TensorWeakRef @@ -612,11 +616,30 @@ def create_2d_tma_descriptor(): return id_dispatch(self, value) # Everything else (NB: order matters!) - if is_traceable_wrapper_subclass(value) or istype( - value, config.traceable_tensor_subclasses + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses ): - return self.wrap_tensor(value) - elif is_namedtuple(value): + if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: + # This case it's either tensor or subclass with default + # torch_dispatch (they might override torch_function or not), + # and we can always trace into them. + return self.wrap_tensor(value) + elif is_traceable_wrapper_subclass(value): + # For non-default torch_dispatch, we have more requirements. + return self.wrap_tensor(value) + + if is_namedtuple(value): self.install_guards(GuardBuilder.SEQUENCE_LENGTH) output = [ LazyVariableTracker.create( @@ -930,11 +953,6 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif ( - isinstance(value, torch._C._TensorMeta) - and value in config.traceable_tensor_subclasses - ): - return TensorSubclassVariable(value, source=self.source) elif ( istype(value, contextlib.nullcontext) and inspect.getattr_static(value, "enter_result", None) is None @@ -1187,6 +1205,20 @@ def build_key_value(i, k, v): if value is torch.autograd._unsafe_preserve_version_counter: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) # This is a userdefined class, so install an ID_MATCH even if its a # global variable. self.install_guards(GuardBuilder.ID_MATCH) @@ -1729,7 +1761,22 @@ def wrap_tensor(self, value: torch.Tensor): # Guards are added inside register_attr_or_module ) - if type(value) in config.traceable_tensor_subclasses: + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary @@ -1747,24 +1794,13 @@ def wrap_tensor(self, value: torch.Tensor): # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. - subclass_type = type(value) - else: - assert type(value) in ( - torch.Tensor, - torch.nn.Parameter, - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ) or is_traceable_wrapper_subclass(value), type(value) subclass_type = None - - # NB: this just says we accessed a tensor from the same source again - # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). - # This is distinct from two distinct sources mapping to the same - # Tensor (per id())! No guard is necessary here. See below for the - # other case. - is_duplicate_tensor = source in self.tx.output.input_source_to_var - if is_duplicate_tensor: - return self.tx.output.input_source_to_var[source] + else: + subclass_type = type(value) + options["torch_function_fn"] = build_torch_function_fn( + self.tx, value, self.source + ) + self.install_guards(GuardBuilder.TYPE_MATCH) if get_static_address_type(value) == "guarded": self.install_guards(GuardBuilder.ID_MATCH) @@ -1772,13 +1808,6 @@ def wrap_tensor(self, value: torch.Tensor): # By this point, we should have deduplicated all tensors self.assert_not_wrapped_by_this_graph(value) - options = {} - if type(value) in config.traceable_tensor_subclasses: - options["torch_function_fn"] = build_torch_function_fn( - self.tx, value, self.source - ) - self.install_guards(GuardBuilder.TYPE_MATCH) - if ( isinstance(value, torch.Tensor) and value.is_nested diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 44b3ffc27689..c477979fa9e3 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -70,6 +70,7 @@ from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable +from .user_defined import UserDefinedClassVariable try: @@ -410,8 +411,6 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): return ConstantVariable(ret_val) def var_getattr(self, tx: "InstructionTranslator", name): - from . import UserDefinedClassVariable - if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): unimplemented( @@ -614,7 +613,7 @@ def call_method( """ # This is seen in inspect signature where we check if the value is a default value - if name == "__eq__" and isinstance(args[0], variables.UserDefinedClassVariable): + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): return variables.ConstantVariable(False) try: @@ -1446,11 +1445,7 @@ def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) -class TensorSubclassVariable(VariableTracker): - def __init__(self, value, *args, **kwargs) -> None: - self.value = value - super().__init__(*args, **kwargs) - +class TensorSubclassVariable(UserDefinedClassVariable): def call_function( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1e7a9baf9494..40821a16e5e5 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -76,6 +76,7 @@ from .torch_function import ( can_dispatch_torch_function, dispatch_torch_function, + TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) @@ -1350,7 +1351,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if is_traceable_wrapper_subclass_type(data.class_type): + if isinstance( + data, TensorWithTFOverrideVariable + ) or is_traceable_wrapper_subclass_type(data.class_type): unimplemented("Parameter constructor with tensor subclass NYI") if not can_convert_to_tracable_parameter(): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 9f24f669e398..330faf9bf902 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -24,9 +24,6 @@ See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w for more information on the design. - -To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses -in torch/_dynamo/config.py """ import collections diff --git a/torch/_guards.py b/torch/_guards.py index ad5f4a7b130a..b6b36f637101 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -631,8 +631,12 @@ def update(self, *others: set[Guard]): self.add(g, skip=1) def remove_guards_with_source(self, source): - """Delete all guards with a given source""" - self.inner = {g for g in self.inner if g.originating_source != source} + """Delete all guards that contains a given source""" + from ._dynamo.source import is_from_source + + self.inner = { + g for g in self.inner if not is_from_source(g.originating_source, source) + } class GuardsContext(Checkpointable[GuardsCheckpointState]): diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index da7acb957d96..36c0a18cdd12 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -283,11 +283,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" if kwargs is None: kwargs = {} - if func != torch.nn.functional.scaled_dot_product_attention: - raise NotImplementedError( - "CausalBias only supports scaled_dot_product_attention" - ) - return cls._dispatch(*args, **kwargs) + if func is torch.nn.functional.scaled_dot_product_attention: + return cls._dispatch(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) def __repr__(self): # type:ignore[override] return self._materialize().__repr__() From e62d958f02afd74e4d4c023fa97d6c7966e00e26 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Wed, 2 Apr 2025 17:49:32 +0000 Subject: [PATCH 109/332] [Inductor] Reland Merge Triton ScaledMM as epilogue to MM template #150045 (#150441) Merges https://github.com/pytorch/pytorch/pull/150438 and https://github.com/pytorch/pytorch/pull/150045. https://github.com/pytorch/pytorch/pull/150045 was already landed, but did not include a change that makes it unable to land internally. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150441 Approved by: https://github.com/clee2000 --- torch/_inductor/kernel/mm.py | 407 ++++++++++++++++++- torch/_inductor/kernel/mm_common.py | 70 ++++ torch/_inductor/kernel/mm_scaled.py | 608 ---------------------------- torch/_inductor/utils.py | 8 +- 4 files changed, 469 insertions(+), 624 deletions(-) delete mode 100644 torch/_inductor/kernel/mm_scaled.py diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index e4389ce9e78c..3a7d87fc8596 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Optional +from typing import Any, Optional import torch from torch._dynamo.utils import counters @@ -21,10 +21,16 @@ from ..codegen.rocm.ck_universal_gemm_template import CKGemmTemplate from ..codegen.wrapper import PythonWrapperCodegen from ..ir import FlexibleLayout, is_triton -from ..lowering import register_lowering +from ..lowering import ( + add_layout_constraint, + constrain_to_fx_strides, + lowerings as L, + register_lowering, +) from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, + realize_inputs, TritonTemplate, ) from ..utils import ( @@ -46,6 +52,8 @@ mm_options, persistent_mm_grid, persistent_mm_options, + scale_mm_epilogue, + scaled_mm_options, should_fallback_to_aten, ) @@ -121,7 +129,12 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -192,7 +205,11 @@ idx_m = b_k_idx_vals idx_n = offs_b_n[None, :] {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", indent_width=8)}} - acc += tl.dot(a, b, allow_tf32=ALLOW_TF32) + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -295,20 +312,195 @@ allow_tf32=ALLOW_TF32, ) - if ki == k_tiles - 1: - # rematerialize rm and rn to save registers - rcm = rm + tl.arange(0, BLOCK_M) - rcn = rn + tl.arange(0, BLOCK_N) - idx_m = rcm[:, None] - idx_n = rcn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + {% if ki == k_tiles - 1 %} + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + {% endif %} """, ) +load_scales = r""" +@triton.jit +def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): + if SCALING_ROWWISE: + # For row-wise scaling, we'll return the pointers + return a_scale_ptr, b_scale_ptr + else: + # For per-tensor scaling, we'll load the scalar values + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr) + return a_scale, b_scale +""" + + +apply_scaling = r""" +@triton.jit +def apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE: tl.constexpr, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, +): + if SCALING_ROWWISE: + # For row-wise scaling, we need to load the scales for each row/column + a_scales = tl.load( + a_scale + (offs_cm * stride_a_scale_m), + mask=offs_cm < M, + other=0.0, + ) + b_scales = tl.load( + b_scale + (offs_cn * stride_b_scale_n), + mask=offs_cn < N, + other=0.0, + ) + acc_scale = a_scales[:, None] * b_scales[None, :] + else: + # For per-tensor scaling, we can directly use the loaded scalar values + acc_scale = a_scale * b_scale + + return accumulator * acc_scale +""" + + +device_tma = r""" +{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + if SCALING_ROWWISE: + stride_a_scale_m = 1 + stride_b_scale_n = 1 + else: + stride_a_scale_m = 0 + stride_b_scale_n = 0 + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE + a_desc_ptr = workspace_base + b_desc_ptr = workspace_base + TMA_SIZE + + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=a_desc_ptr, + global_address=A, + load_size=[BLOCK_M, BLOCK_K], + global_size=[M, K], + element_ty=A.dtype.element_ty, + ) + triton.language.extra.cuda.experimental_device_tensormap_create2d( + desc_ptr=b_desc_ptr, + global_address=B, + load_size=[BLOCK_N, BLOCK_K], + global_size=[N, K], + element_ty=B.dtype.element_ty, + ) + + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) + tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) + + tiles_per_SM = num_tiles // NUM_SMS + if start_pid < num_tiles % NUM_SMS: + tiles_per_SM += 1 + + tile_id = start_pid - NUM_SMS + ki = -1 + + pid_m = 0 + pid_n = 0 + offs_am = 0 + offs_bn = 0 + + num_pid_in_group = GROUP_M * num_pid_n + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) + + for _ in range(0, k_tiles * tiles_per_SM): + ki = tl.where(ki == k_tiles - 1, 0, ki + 1) + if ki == 0: + tile_id += NUM_SMS + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + offs_k = ki * BLOCK_K + + a = tl._experimental_descriptor_load( + a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + ) + b = tl._experimental_descriptor_load( + b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + ) + if USE_FAST_ACCUM: + accumulator = tl.dot(a, b.T, accumulator) + else: + accumulator += tl.dot(a, b.T) + + {% if ki == k_tiles - 1 %} + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) + + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + {% endif %} +""" + + +scaled_mm_device_tma_template = TritonTemplate( + name="scaled_mm_device_tma", + grid=persistent_mm_grid, + source=device_tma + load_scales + apply_scaling, +) + # prevent duplication registration of extern functions @functools.lru_cache(None) @@ -330,6 +522,10 @@ def lazy_register_extern_choice(fn): has_out_variant=False, ) +aten__fp8_mm = ExternKernelChoice( + torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out +) + def _is_int8_mat(mat): return mat.get_dtype() in (torch.int8, torch.uint8) @@ -340,6 +536,16 @@ def _is_large_block_for_cpu(m, n, k): return m * n > 2**13 +@functools.lru_cache +def using_b200() -> bool: + """Returns true if the device is a NVIDIA B200, otherwise returns false.""" + if not torch.cuda.is_available(): + return False + # compute capability 10.0 or 10.0a is NVIDIA B200 + device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) + return device_properties.major == 10 + + def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): """ Giving torch.addmm a 1D tensor calls a different (faster) cublasLt @@ -351,6 +557,32 @@ def bias_addmm(inp, mat1, mat2, *, out=None, alpha=1, beta=1): return torch.addmm(inp, mat1, mat2, out=out, alpha=alpha, beta=beta) +def check_supported_striding(mat_a, mat_b) -> None: + def is_row_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[1], 1) + + def is_col_major(stride) -> bool: + return V.graph.sizevars.statically_known_equals(stride[0], 1) + + def has_zero_dim(size) -> bool: + return bool( + V.graph.sizevars.statically_known_equals(size[0], 0) + or V.graph.sizevars.statically_known_equals(size[1], 0) + ) + + # Check mat_a (self) stride requirements + torch._check( + is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), + lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", + ) + + # Check mat_b stride requirements + torch._check( + is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), + lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", + ) + + aten_bias_addmm = ExternKernelChoice(bias_addmm, None) @@ -750,6 +982,151 @@ def tuned_sparse_semi_structured_mm( ) +add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) + + +@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] +def tuned_scaled_mm( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + layout=None, +): + m, n, k, layout, mat_a, mat_b = mm_args( + mat_a, mat_b, layout=layout, out_dtype=out_dtype + ) + # below is for getting an overview logging info of inductor mms + counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 + log.info( + "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", + m, + n, + k, + mat_a.get_dtype(), + mat_b.get_dtype(), + layout, + ) + + device_type = ir.get_device_type(mat_a) + check_supported_striding(mat_a, mat_b) + + scale_a_real, scale_b_real = realize_inputs(scale_a, scale_b) + + input_nodes: tuple[Any, ...] + + if not bias: + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real) + else: + bias_real = realize_inputs(bias) + input_nodes = (mat_a, mat_b, scale_a_real, scale_b_real, bias_real) + + aten_choice = aten__fp8_mm.bind( + input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum + ) + + choices = [] + if use_aten_gemm_kernels(): + choices.append(aten_choice) + + _, is_nonzero = _is_static_problem(layout) + + scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) + scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( + device_type + ) + + if is_nonzero and use_triton_template(layout, enable_float8=True): + triton_input_nodes: tuple[Any, ...] + if bias and len(mat_b.get_size()) == len(bias.get_size()) + 1: + # Need to unsqueeze bias from [N] -> [1, N] + triton_bias = L[aten.unsqueeze](bias, 0) + else: + triton_bias = bias + + if len(scale_a.get_size()) == 0 or len(scale_b.get_size()) == 0: + assert len(scale_a.get_size()) == len(scale_b.get_size()) + # Need to unsqueeze scale from [] -> [1, 1] + triton_scale_a = L[aten.unsqueeze](L[aten.unsqueeze](scale_a, 0), 1) + triton_scale_b = L[aten.unsqueeze](L[aten.unsqueeze](scale_b, 0), 1) + else: + triton_scale_a = scale_a + triton_scale_b = scale_b + + if bias: + triton_input_nodes = ( + mat_a, + mat_b, + triton_scale_a, + triton_scale_b, + triton_bias, + ) + suffix_args = 3 + else: + triton_input_nodes = (mat_a, mat_b, triton_scale_a, triton_scale_b) + suffix_args = 2 + + # TODO (paulzhan): There is no template that exists for bias and TMA + # Don't run tma template currently if bias exists + if use_triton_tma_template(mat_a, mat_b) and not bias: + for config in scaled_persistent_mm_configs(m, n, k): + kwargs = scaled_mm_options( + config, + m, + n, + k, + layout, + scale_a, + scale_b, + use_fast_accum, + device_tma=True, + ) + scaled_mm_device_tma_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + workspace_arg=get_tma_workspace_arg( + num_tma_descriptors=2, + device=mat_a.get_device(), + ), + **kwargs, + ) + + for config in scaled_mm_configs(m, n, k): + if k == 16 and config.kwargs["BLOCK_M"] >= 64: + continue # Triton crashes in this case + + # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid + # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape + if using_b200() and k < 32: + continue + + kwargs = scaled_mm_options( + config, m, n, k, layout, scale_a, scale_b, use_fast_accum + ) + # possibly appends a TritonTemplateCaller to choices + mm_template.maybe_append_choice( + choices, + input_nodes=triton_input_nodes, + layout=layout, + **kwargs, + suffix_args=suffix_args, + epilogue_fn=scale_mm_epilogue(), + ) + + if is_nonzero and use_ck_gemm_template(layout, m, n, k): + CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) + + if should_fallback_to_aten(choices): + return aten_choice.output_node() + + return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) + + @functools.lru_cache(None) def _is_sm7x_or_older_gpu(index: Optional[int]) -> bool: props = torch.cuda.get_device_properties(index or 0) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index d990536c4362..663e78dc199c 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -76,6 +76,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): GROUP_M=8, EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, + USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=acc_type(layout.dtype), num_stages=config.num_stages, num_warps=config.num_warps, @@ -92,6 +93,47 @@ def persistent_mm_options(mat1, mat2): ) +def scaled_mm_options( # type: ignore[no-untyped-def] + config, # triton.Config + sym_m: sympy.core.numbers.Integer, + sym_n: sympy.core.numbers.Integer, + sym_k: sympy.core.numbers.Integer, + layout: Layout, + scale_a, + scale_b, + use_fast_accum: bool, + device_tma: bool = False, +) -> dict[str, Any]: + def are_compatible_scales(size_a, size_b) -> bool: + # Same sized scales are compatable + if len(size_a) == len(size_b): + return True + + # Both need to be scalars or len(1) tensors + if len(size_a) <= 1 and len(size_b) <= 1: + return True + + return False + + size_a, size_b = scale_a.get_size(), scale_b.get_size() + assert are_compatible_scales(size_a, size_b), ( + "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " + f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." + ) + + mm_template_options = mm_options(config, sym_m, sym_n, sym_k, layout) + + mm_template_options["ACC_TYPE"] = "tl.float32" + mm_template_options["USE_FAST_ACCUM"] = use_fast_accum + mm_template_options["SCALING_ROWWISE"] = len(size_a) == 2 + + if device_tma: + mm_template_options["TMA_SIZE"] = TMA_DESCRIPTOR_SIZE + mm_template_options["NUM_SMS"] = get_num_sms() + + return mm_template_options + + def mm_args( mat1, mat2, @@ -154,6 +196,34 @@ def epilogue(acc, bias): return epilogue +def scale_mm_epilogue(): + """ + Create an epilogue function that applies scaling to matrix multiplication result + using the given scale factors. + + Args: + dtype: The data type of the output + scale_a: Scale factor for matrix A + scale_b: Scale factor for matrix B + + Returns: + Epilogue function that takes the accumulator and applies scaling + """ + + def epilogue(acc, inv_a_scale, inv_b_scale, bias=None): + # The epilogue function receives the accumulator (result of mat1 @ mat2) + # and applies the scaling factors + # In the original scaled_mm, we use inverse scales, so we multiply by them + mul_scales = V.ops.mul(inv_a_scale, inv_b_scale) + mul_acc = V.ops.mul(acc, mul_scales) + if bias is not None: + return V.ops.add(mul_acc, bias) + else: + return mul_acc + + return epilogue + + def _is_static_problem(layout: Layout) -> tuple[bool, bool]: """ Check if input tensors and output layout have static shapes and non-zero sizes. diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py deleted file mode 100644 index aa917e120168..000000000000 --- a/torch/_inductor/kernel/mm_scaled.py +++ /dev/null @@ -1,608 +0,0 @@ -import functools -import logging -from collections.abc import Sequence -from typing import Any, Optional - -import sympy - -import torch -from torch._dynamo.utils import counters -from torch._inductor.codegen.rocm.ck_universal_gemm_template import CKGemmTemplate -from torch.utils._triton import has_triton_tma_device - -from ..config import triton as triton_config -from ..ir import _IntLike, ChoiceCaller, get_device_type, Layout, StorageBox, TensorBox -from ..lowering import add_layout_constraint, constrain_to_fx_strides, register_lowering -from ..select_algorithm import ( - autotune_select_algorithm, - ExternKernelChoice, - realize_inputs, - TritonTemplate, -) -from ..utils import ( - get_num_sms, - get_tma_workspace_arg, - TMA_DESCRIPTOR_SIZE, - use_aten_gemm_kernels, - use_ck_gemm_template, - use_triton_template, -) -from ..virtualized import V -from .mm_common import ( - _is_static_problem, - mm_args, - mm_grid, - persistent_mm_grid, - should_fallback_to_aten, -) - - -log = logging.getLogger(__name__) -aten = torch.ops.aten - -load_scales = r""" -@triton.jit -def load_scales(a_scale_ptr, b_scale_ptr, SCALING_ROWWISE: tl.constexpr): - if SCALING_ROWWISE: - # For row-wise scaling, we'll return the pointers - return a_scale_ptr, b_scale_ptr - else: - # For per-tensor scaling, we'll load the scalar values - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr) - return a_scale, b_scale -""" - - -apply_scaling = r""" -@triton.jit -def apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE: tl.constexpr, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, -): - if SCALING_ROWWISE: - # For row-wise scaling, we need to load the scales for each row/column - a_scales = tl.load( - a_scale + (offs_cm * stride_a_scale_m), - mask=offs_cm < M, - other=0.0, - ) - b_scales = tl.load( - b_scale + (offs_cn * stride_b_scale_n), - mask=offs_cn < N, - other=0.0, - ) - acc_scale = a_scales[:, None] * b_scales[None, :] - else: - # For per-tensor scaling, we can directly use the loaded scalar values - acc_scale = a_scale * b_scale - - return accumulator * acc_scale -""" - - -device_tma = r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - if SCALING_ROWWISE: - stride_a_scale_m = 1 - stride_b_scale_n = 1 - else: - stride_a_scale_m = 0 - stride_b_scale_n = 0 - - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - k_tiles = tl.cdiv(K, BLOCK_K) - num_tiles = num_pid_m * num_pid_n - - workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE - a_desc_ptr = workspace_base - b_desc_ptr = workspace_base + TMA_SIZE - - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=a_desc_ptr, - global_address=A, - load_size=[BLOCK_M, BLOCK_K], - global_size=[M, K], - element_ty=A.dtype.element_ty, - ) - triton.language.extra.cuda.experimental_device_tensormap_create2d( - desc_ptr=b_desc_ptr, - global_address=B, - load_size=[BLOCK_N, BLOCK_K], - global_size=[N, K], - element_ty=B.dtype.element_ty, - ) - - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr) - tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr) - - tiles_per_SM = num_tiles // NUM_SMS - if start_pid < num_tiles % NUM_SMS: - tiles_per_SM += 1 - - tile_id = start_pid - NUM_SMS - ki = -1 - - pid_m = 0 - pid_n = 0 - offs_am = 0 - offs_bn = 0 - - num_pid_in_group = GROUP_M * num_pid_n - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - a_scale, b_scale = load_scales(A_inverse_scale, B_inverse_scale, SCALING_ROWWISE) - - for _ in range(0, k_tiles * tiles_per_SM): - ki = tl.where(ki == k_tiles - 1, 0, ki + 1) - if ki == 0: - tile_id += NUM_SMS - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - offs_k = ki * BLOCK_K - - a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty - ) - b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty - ) - if USE_FAST_ACCUM: - accumulator = tl.dot(a, b.T, accumulator) - else: - accumulator += tl.dot(a, b.T) - - if ki == k_tiles - 1: - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) - - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) -""" - - -scaled_mm_device_tma_template = TritonTemplate( - name="scaled_mm_device_tma", - grid=persistent_mm_grid, - source=device_tma + load_scales + apply_scaling, -) - - -scaled_mm_template = TritonTemplate( - name="scaled_mm", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -# Inductor does not allow optional tensor input arguments currently (pass None as an -# input node to template choices), but since for _scaled_mm there is only one such arg -# (bias), work around by having a second template when bias is provided. -scaled_mm_bias_template = TritonTemplate( - name="scaled_mm_bias", - grid=mm_grid, - source=r""" -{{def_kernel("A", "B", "A_inverse_scale", "B_inverse_scale", "bias_ptr")}} - M = {{size("A", 0)}} - N = {{size("B", 1)}} - K = {{size("A", 1)}} - if M * N == 0: - # early exit due to zero-size input(s) - return - stride_am = {{stride("A", 0)}} - stride_ak = {{stride("A", 1)}} - stride_bk = {{stride("B", 0)}} - stride_bn = {{stride("B", 1)}} - - # based on triton.ops.matmul - pid = tl.program_id(0) - grid_m = (M + BLOCK_M - 1) // BLOCK_M - grid_n = (N + BLOCK_N - 1) // BLOCK_N - - # re-order program ID for better L2 performance - width = GROUP_M * grid_n - group_id = pid // width - group_size = min(grid_m - group_id * GROUP_M, GROUP_M) - pid_m = group_id * GROUP_M + (pid % group_size) - pid_n = (pid % width) // (group_size) - - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) - rk = tl.arange(0, BLOCK_K) - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - for k in range(K, 0, -BLOCK_K): - if EVEN_K: - a = tl.load(A) - b = tl.load(B) - else: - a = tl.load(A, mask=rk[None, :] < k, other=0.) - b = tl.load(B, mask=rk[:, None] < k, other=0.) - if USE_FAST_ACCUM: - acc = tl.dot(a, b, acc, out_dtype=ACC_TYPE) - else: - acc += tl.dot(a, b, out_dtype=ACC_TYPE) - A += BLOCK_K * stride_ak - B += BLOCK_K * stride_bk - - if SCALING_ROWWISE: - inv_a_scale_row = tl.load(A_inverse_scale + rm, mask=rm < M) - inv_b_scale_row = tl.load(B_inverse_scale + rn, mask=rn < N) - inv_scale_row = inv_a_scale_row[:, None] * inv_b_scale_row[None, :] - acc *= inv_scale_row - else: - # for tensor-wise scaling, the scales are scalars - inv_a_scale = tl.load(A_inverse_scale) - inv_b_scale = tl.load(B_inverse_scale) - inv_scale = inv_a_scale * inv_b_scale - acc *= inv_scale - - # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - - # bias - bias = tl.load(bias_ptr + rn, mask=rn < N) - acc += bias - - idx_m = rm[:, None] - idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask")}} -""", -) - - -aten__fp8_mm = ExternKernelChoice( - torch._scaled_mm, "at::_scaled_mm_out", op_overload=aten._scaled_mm.out -) - - -def are_compatible_scales(size_a: Sequence[int], size_b: Sequence[int]) -> bool: - # Same sized scales are compatable - if len(size_a) == len(size_b): - return True - - # Both need to be scalars or len(1) tensors - if len(size_a) <= 1 and len(size_b) <= 1: - return True - - return False - - -def check_supported_striding(mat_a: TensorBox, mat_b: TensorBox) -> None: - def is_row_major(stride: Sequence[_IntLike]) -> bool: - return stride[1] == 1 - - def is_col_major(stride: Sequence[_IntLike]) -> bool: - return stride[0] == 1 - - def has_zero_dim(size: Sequence[_IntLike]) -> bool: - return bool(size[0] == 0 or size[1] == 0) - - # Check mat_a (self) stride requirements - torch._check( - is_row_major(mat_a.get_stride()) or has_zero_dim(mat_a.get_size()), - lambda: f"mat_a must be row_major, got stride {mat_a.get_stride()}", - ) - - # Check mat_b stride requirements - torch._check( - is_col_major(mat_b.get_stride()) or has_zero_dim(mat_b.get_size()), - lambda: f"mat_b must be col_major, got stride {mat_b.get_stride()}", - ) - - -def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - TMA_SIZE=TMA_DESCRIPTOR_SIZE, - NUM_SMS=get_num_sms(), - **config.kwargs, - ) - - -def scaled_mm_options( # type: ignore[no-untyped-def] - config, # triton.Config - sym_m: sympy.core.numbers.Integer, - sym_n: sympy.core.numbers.Integer, - sym_k: sympy.core.numbers.Integer, - layout: Layout, - scale_a: StorageBox, - scale_b: StorageBox, - use_fast_accum: bool, -) -> dict[str, Any]: - even_k_symbolic = ( - sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] - ) - - size_a, size_b = scale_a.get_size(), scale_b.get_size() - assert are_compatible_scales(size_a, size_b), ( - "Expect scale_a and scale_b to be either both scalars (including single-element tensors) " - f"or 1-dimensional tensors with the same size. Got scale_a: {len(size_a)} and scale_b: {len(size_b)}." - ) - return dict( - GROUP_M=8, - EVEN_K=even_k_symbolic, - ACC_TYPE="tl.float32", - USE_FAST_ACCUM=use_fast_accum, - num_stages=config.num_stages, - num_warps=config.num_warps, - # tensor-wise scaling if scalar scales - SCALING_ROWWISE=len(scale_a.get_size()) == 2, - **config.kwargs, - ) - - -add_layout_constraint(aten._scaled_mm.default, constrain_to_fx_strides) - - -def use_persistent_tma(k: sympy.core.numbers.Integer, has_bias: bool) -> bool: - available = has_triton_tma_device() and triton_config.enable_persistent_tma_matmul - # _determine_swizzle_mode_2d requires BLOCK_K to be at least 32 contiguous bytes - # When K is 16, BLOCK_K = 16 and is not valid - min_k = k >= 32 - return available and min_k and not has_bias - - -@register_lowering(aten._scaled_mm.default, type_promotion_kind=None) # type: ignore[misc] -def tuned_scaled_mm( - mat_a: TensorBox, - mat_b: TensorBox, - scale_a: TensorBox, - scale_b: TensorBox, - bias: Optional[TensorBox] = None, - scale_result: Optional[TensorBox] = None, - out_dtype: Optional[torch.dtype] = None, - use_fast_accum: bool = False, - layout: Optional[Layout] = None, -) -> TensorBox: - m, n, k, layout, mat_a, mat_b = mm_args( - mat_a, mat_b, layout=layout, out_dtype=out_dtype - ) - - # below is for getting an overview logging info of inductor mms - counters["aten_mm_info"][f"aten._scaled_mm.default_{m}_{n}_{k}"] += 1 - log.info( - "Tuned aten._scaled_mm.default: m=%s, n=%s, k=%s, mat1_dtype=%s, mat2_dtype=%s, output_layout=%s", - m, - n, - k, - mat_a.get_dtype(), - mat_b.get_dtype(), - layout, - ) - - device_type = get_device_type(mat_a) - - check_supported_striding(mat_a, mat_b) - - scale_a, scale_b = realize_inputs(scale_a, scale_b) - - input_nodes: tuple[Any, ...] - # workaround for Inductor not supporting optional tensor input arguments - if bias is None: - input_nodes = (mat_a, mat_b, scale_a, scale_b) - triton_template = scaled_mm_template - else: - bias = realize_inputs(bias) - input_nodes = (mat_a, mat_b, scale_a, scale_b, bias) - triton_template = scaled_mm_bias_template - - aten_choice = aten__fp8_mm.bind( - input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum - ) - - choices: list[ChoiceCaller] = [] - if use_aten_gemm_kernels(): - choices.append(aten_choice) - - _, is_nonzero = _is_static_problem(layout) - - scaled_mm_configs = V.choices.get_scaled_mm_configs(device_type) - scaled_persistent_mm_configs = V.choices.get_scaled_persistent_mm_configs( - device_type - ) - - if is_nonzero and use_triton_template(layout, enable_float8=True): - if use_persistent_tma(k, bias is not None): - for config in scaled_persistent_mm_configs(m, n, k): - kwargs = scaled_mm_options_device_tma( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - input_nodes = (mat_a, mat_b, scale_a, scale_b) - scaled_mm_device_tma_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - workspace_arg=get_tma_workspace_arg( - num_tma_descriptors=2, - device=mat_a.get_device(), - ), - **kwargs, - ) - else: - for config in scaled_mm_configs(m, n, k): - if k == 16 and config.kwargs["BLOCK_M"] >= 64: - continue # Triton crashes in this case - - # On NVIDIA B200 GPUs, K dim must be >= 32 for tcgen05.mma.kind::f8f6f4.* PTX instruction to be valid - # source: https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-matrix-shape - if using_b200() and k < 32: - continue - - kwargs = scaled_mm_options( - config, m, n, k, layout, scale_a, scale_b, use_fast_accum - ) - # possibly appends a TritonTemplateCaller to choices - triton_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - ) - - if is_nonzero and use_ck_gemm_template(layout, m, n, k): - CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) - - if should_fallback_to_aten(choices): - return aten_choice.output_node() - - return autotune_select_algorithm("scaled_mm", choices, input_nodes, layout) - - -@functools.lru_cache -def using_b200() -> bool: - """Returns true if the device is a NVIDIA B200, otherwise returns false.""" - if not torch.cuda.is_available(): - return False - # compute capability 10.0 or 10.0a is NVIDIA B200 - device_properties = torch.cuda.get_device_properties(torch.cuda.current_device()) - return device_properties.major == 10 diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index e93ed88bcbda..bca3f024d134 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1377,7 +1377,7 @@ def _is_tma_compatible(x: IRNode) -> bool: return False dtype = x.get_dtype() - if dtype not in (torch.float16, torch.bfloat16): + if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False layout = x.get_layout() @@ -1388,6 +1388,12 @@ def _is_tma_compatible(x: IRNode) -> bool: inner_dim = layout.size[1] if transposed: inner_dim = layout.size[0] + + if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + inner_dim, 32 + ): + return False + inner_bytes = inner_dim * dtype.itemsize return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) From cb4cd6166e5f6d7c0180dcca27454691ef910625 Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 2 Apr 2025 19:13:44 +0000 Subject: [PATCH 110/332] Address Cmake update issue in windows magma builds (#150549) 1. Fixes Cmake update error: https://github.com/pytorch/pytorch/actions/runs/14223930697/job/39858632864 ``` CMake Error at CMakeLists.txt:1 (cmake_minimum_required): Compatibility with CMake < 3.5 has been removed from CMake. Update the VERSION argument value. Or, use the ... syntax to tell CMake that the project requires at least but has been updated to work with policies introduced by or earlier. Or, add -DCMAKE_POLICY_VERSION_MINIMUM=3.5 to try configuring anyway. ``` 2. Removes deprecated CUDA 12.4 build Pull Request resolved: https://github.com/pytorch/pytorch/pull/150549 Approved by: https://github.com/clee2000 --- .github/scripts/windows/build_magma.bat | 3 ++- .github/workflows/build-magma-windows.yml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/scripts/windows/build_magma.bat b/.github/scripts/windows/build_magma.bat index beabb0070554..b8701ddde3fc 100644 --- a/.github/scripts/windows/build_magma.bat +++ b/.github/scripts/windows/build_magma.bat @@ -54,7 +54,8 @@ cmake .. -DGPU_TARGET="%GPU_TARGET%" ^ -DCMAKE_BUILD_TYPE=%CONFIG% ^ -DCMAKE_GENERATOR=Ninja ^ -DCMAKE_INSTALL_PREFIX=..\install\ ^ - -DCUDA_ARCH_LIST="%CUDA_ARCH_LIST%" + -DCUDA_ARCH_LIST="%CUDA_ARCH_LIST%" ^ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 if errorlevel 1 exit /b 1 cmake --build . --target install --config %CONFIG% -- -j%NUMBER_OF_PROCESSORS% diff --git a/.github/workflows/build-magma-windows.yml b/.github/workflows/build-magma-windows.yml index 85f2884e5351..4a3fb9855a06 100644 --- a/.github/workflows/build-magma-windows.yml +++ b/.github/workflows/build-magma-windows.yml @@ -22,7 +22,7 @@ jobs: runs-on: windows-2019 strategy: matrix: - cuda_version: ["128", "126", "124", "118"] + cuda_version: ["128", "126", "118"] config: ["Release", "Debug"] env: CUDA_VERSION: ${{ matrix.cuda_version }} From d4298f2136d06264cbbef9806ad46b14e707a60a Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 2 Apr 2025 19:42:43 +0000 Subject: [PATCH 111/332] [CI] Use system nccl in build (#150226) Install nccl in the docker image (which is already being done in some docker images), and use USE_SYSTEM_NCCL=1 in CI builds It takes some time to build nccl and doesn't happen in parallel, so theres less benefit in switching to a bigger runner and using more processes The other changes in this PR are because there is an install_cuda script and an install_cuda_aarch64 script and they both build nccl from source and define their own pins for the nccl version. There is also a .ci/docker/nccl-cu11.txt and cu12.txt that define the pins, and this is an attempt to unify them. Unfortunately this leads to a lot of files needing to be copied to the docker build Generally seems to increase docker pull times by <1 min, P1768456379 but its hard to tell what the real increase is 15761 mib -> 16221 [linux-focal-cuda11.8-py3.10-gcc9 / test (distributed](https://github.com/pytorch/pytorch/actions/runs/14114171729/job/39545500161#logs) `jq '[.layers[].size, .config.size] | add / 1024 / 1024'` Example https://hud.pytorch.org/pytorch/pytorch/commit/6eb3c2e2822c50d8a87b43938a9cf7ef0561ede2#39520169577-box ![image](https://github.com/user-attachments/assets/d44ef415-6e48-41ef-ac83-f19bab47560c) TODO: * Figure out a way to verify that nccl was built + works properly when it is expected (this time i just checked torch.distributed.is_nccl_available) * Merge the cusparse installation scripts * Merge the cuda installation scripts * Either split the nccl, cuda, and cusparse installations always, or make the always together in one bash script distributed/test_distributed_spawn Pull Request resolved: https://github.com/pytorch/pytorch/pull/150226 Approved by: https://github.com/seemethere, https://github.com/atalman --- .ci/docker/almalinux/Dockerfile | 2 + .ci/docker/common/install_cuda.sh | 46 ++++---------------- .ci/docker/common/install_cuda_aarch64.sh | 12 +---- .ci/docker/common/install_nccl.sh | 26 +++++++++++ .ci/docker/libtorch/Dockerfile | 2 + .ci/docker/linter-cuda/Dockerfile | 4 +- .ci/docker/manywheel/Dockerfile | 4 +- .ci/docker/manywheel/Dockerfile_2_28 | 4 +- .ci/docker/manywheel/Dockerfile_cuda_aarch64 | 4 +- .ci/docker/ubuntu-cuda/Dockerfile | 10 +++++ .ci/docker/ubuntu/Dockerfile | 9 +++- 11 files changed, 70 insertions(+), 53 deletions(-) create mode 100644 .ci/docker/common/install_nccl.sh diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index 5f17a6332dd1..7548bd28bcc0 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -44,6 +44,8 @@ FROM base as cuda ARG CUDA_VERSION=12.4 RUN rm -rf /usr/local/cuda-* ADD ./common/install_cuda.sh install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION} # Preserve CUDA_VERSION for the builds ENV CUDA_VERSION=${CUDA_VERSION} diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index 10f3c7733f4f..3959880b53c5 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -2,7 +2,6 @@ set -ex -NCCL_VERSION=v2.26.2-1 CUDNN_VERSION=9.5.1.17 function install_cusparselt_040 { @@ -40,8 +39,7 @@ function install_cusparselt_063 { function install_118 { CUDNN_VERSION=9.1.0.70 - NCCL_VERSION=v2.21.5-1 - echo "Installing CUDA 11.8 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.4.0" + echo "Installing CUDA 11.8 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.4.0" rm -rf /usr/local/cuda-11.8 /usr/local/cuda # install CUDA 11.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run @@ -59,14 +57,7 @@ function install_118 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=11.8 bash install_nccl.sh install_cusparselt_040 @@ -75,7 +66,7 @@ function install_118 { function install_124 { CUDNN_VERSION=9.1.0.70 - echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.2" + echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" rm -rf /usr/local/cuda-12.4 /usr/local/cuda # install CUDA 12.4.1 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda_12.4.1_550.54.15_linux.run @@ -93,14 +84,7 @@ function install_124 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.4 bash install_nccl.sh install_cusparselt_062 @@ -108,7 +92,7 @@ function install_124 { } function install_126 { - echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.6 /usr/local/cuda # install CUDA 12.6.3 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.6.3/local_installers/cuda_12.6.3_560.35.05_linux.run @@ -126,14 +110,7 @@ function install_126 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.6 bash install_nccl.sh install_cusparselt_063 @@ -241,7 +218,7 @@ function prune_126 { function install_128 { CUDNN_VERSION=9.8.0.87 - echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.8 /usr/local/cuda # install CUDA 12.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run @@ -259,14 +236,7 @@ function install_128 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.8 bash install_nccl.sh install_cusparselt_063 diff --git a/.ci/docker/common/install_cuda_aarch64.sh b/.ci/docker/common/install_cuda_aarch64.sh index 3f154a103aa7..ae4983712989 100644 --- a/.ci/docker/common/install_cuda_aarch64.sh +++ b/.ci/docker/common/install_cuda_aarch64.sh @@ -3,7 +3,6 @@ set -ex -NCCL_VERSION=v2.26.2-1 CUDNN_VERSION=9.8.0.87 function install_cusparselt_063 { @@ -18,7 +17,7 @@ function install_cusparselt_063 { } function install_128 { - echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.3" + echo "Installing CUDA 12.8.0 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.3" rm -rf /usr/local/cuda-12.8 /usr/local/cuda # install CUDA 12.8.0 in the same container wget -q https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux_sbsa.run @@ -36,14 +35,7 @@ function install_128 { cd .. rm -rf tmp_cudnn - # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses - # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build - git clone -b ${NCCL_VERSION} --depth 1 https://github.com/NVIDIA/nccl.git - cd nccl && make -j src.build - cp -a build/include/* /usr/local/cuda/include/ - cp -a build/lib/* /usr/local/cuda/lib64/ - cd .. - rm -rf nccl + CUDA_VERSION=12.8 bash install_nccl.sh install_cusparselt_063 diff --git a/.ci/docker/common/install_nccl.sh b/.ci/docker/common/install_nccl.sh new file mode 100644 index 000000000000..17d80ebe7d27 --- /dev/null +++ b/.ci/docker/common/install_nccl.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +set -ex + +NCCL_VERSION="" +if [[ ${CUDA_VERSION:0:2} == "11" ]]; then + NCCL_VERSION=$(cat ci_commit_pins/nccl-cu11.txt) +elif [[ ${CUDA_VERSION:0:2} == "12" ]]; then + NCCL_VERSION=$(cat ci_commit_pins/nccl-cu12.txt) +else + echo "Unexpected CUDA_VERSION ${CUDA_VERSION}" + exit 1 +fi + +if [[ -n "${NCCL_VERSION}" ]]; then + # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses + # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build + git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git + pushd nccl + make -j src.build + cp -a build/include/* /usr/local/cuda/include/ + cp -a build/lib/* /usr/local/cuda/lib64/ + popd + rm -rf nccl + ldconfig +fi diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile index f9ae32ad7f8e..e90306767ff6 100644 --- a/.ci/docker/libtorch/Dockerfile +++ b/.ci/docker/libtorch/Dockerfile @@ -49,6 +49,8 @@ RUN bash ./install_mkl.sh && rm install_mkl.sh FROM cpu as cuda ADD ./common/install_cuda.sh install_cuda.sh ADD ./common/install_magma.sh install_magma.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ ENV CUDA_HOME /usr/local/cuda FROM cuda as cuda11.8 diff --git a/.ci/docker/linter-cuda/Dockerfile b/.ci/docker/linter-cuda/Dockerfile index d93f69a149f2..ed8fc7eabba5 100644 --- a/.ci/docker/linter-cuda/Dockerfile +++ b/.ci/docker/linter-cuda/Dockerfile @@ -30,7 +30,9 @@ RUN bash ./install_python.sh && rm install_python.sh /opt/requirements-ci.txt # Install cuda and cudnn ARG CUDA_VERSION COPY ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* ENV DESIRED_CUDA ${CUDA_VERSION} ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/manywheel/Dockerfile b/.ci/docker/manywheel/Dockerfile index d7daf989b496..75f2ab9a5ce0 100644 --- a/.ci/docker/manywheel/Dockerfile +++ b/.ci/docker/manywheel/Dockerfile @@ -64,7 +64,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION=10.2 # Install CUDA ADD ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* FROM base as intel # MKL diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index e3ac65f5ca21..fbf74fb81c01 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -36,7 +36,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION=11.8 # Install CUDA ADD ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${BASE_CUDA_VERSION} && rm install_cuda.sh install_nccl.sh ci_commit_pins/nccl-cu* FROM base as intel # MKL diff --git a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 index dfd766b4dd5a..fe2a04fd92db 100644 --- a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 @@ -67,7 +67,9 @@ FROM base as cuda ARG BASE_CUDA_VERSION # Install CUDA ADD ./common/install_cuda_aarch64.sh install_cuda_aarch64.sh -RUN bash ./install_cuda_aarch64.sh ${BASE_CUDA_VERSION} && rm install_cuda_aarch64.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda_aarch64.sh ${BASE_CUDA_VERSION} && rm install_cuda_aarch64.sh install_nccl.sh ci_commit_pins/nccl-cu* FROM base as magma ARG BASE_CUDA_VERSION diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index c9579950e0ac..4739271899c3 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -158,6 +158,16 @@ COPY ./common/install_cusparselt.sh install_cusparselt.sh RUN bash install_cusparselt.sh RUN rm install_cusparselt.sh +# Install NCCL +ARG CUDA_VERSION +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash install_nccl.sh +RUN rm install_nccl.sh /ci_commit_pins/nccl-cu* +ENV USE_SYSTEM_NCCL=1 +ENV NCCL_INCLUDE_DIR="/usr/local/cuda/include/" +ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" + # Install CUDSS ARG CUDA_VERSION COPY ./common/install_cudss.sh install_cudss.sh diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 11888f37bff2..c33abda4aaa7 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -52,9 +52,16 @@ RUN bash ./install_lcov.sh && rm install_lcov.sh # Install cuda and cudnn ARG CUDA_VERSION COPY ./common/install_cuda.sh install_cuda.sh -RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh +COPY ./common/install_nccl.sh install_nccl.sh +COPY ./ci_commit_pins/nccl-cu* /ci_commit_pins/ +RUN bash ./install_cuda.sh ${CUDA_VERSION} && rm install_cuda.sh install_nccl.sh /ci_commit_pins/nccl-cu* ENV DESIRED_CUDA ${CUDA_VERSION} ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH +# No effect if cuda not installed +ENV USE_SYSTEM_NCCL=1 +ENV NCCL_INCLUDE_DIR="/usr/local/cuda/include/" +ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" + # (optional) Install UCC ARG UCX_COMMIT From 22030efb6415887979d0ae856c34795f732c0fd7 Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Wed, 2 Apr 2025 19:56:50 +0000 Subject: [PATCH 112/332] expect fail scan test in sigmoid (#150475) Summary: as titled. Test Plan: see modified test. Differential Revision: D72271976 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150475 Approved by: https://github.com/zhxchen17 --- test/export/test_export.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index 4fc5515c7665..d92ef65fb743 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6988,6 +6988,8 @@ def forward(self, x): len([node for node in gm.graph.nodes if node.op == "placeholder"]), 1 ) + # scan is not supported in sigmoid yet + @testing.expectedFailureCppRuntime def test_export_scan_pytree_output(self): def add(carry, accum): return carry + carry, (accum[0]["moo"] + 1, accum[0]["moo2"] + 1) From b03c42109c4e7dd52228f0a2bd65963a1d86523c Mon Sep 17 00:00:00 2001 From: James Wu Date: Tue, 1 Apr 2025 12:17:14 -0700 Subject: [PATCH 113/332] Proactively remove CompiledTritonKernels before loading from cache/starting inductor compile (#150453) We'll still running into this issue intermittently and it's hard to debug; so I thought a more aggressive cache clear strategy may fix it as a stopgap until we can Statically launch cuda kernels and avoid some of this stuff Differential Revision: [D72257973](https://our.internmc.facebook.com/intern/diff/D72257973/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150453 Approved by: https://github.com/oulgen --- torch/_functorch/_aot_autograd/autograd_cache.py | 2 ++ torch/_inductor/compile_fx.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 4533d3f12cae..f23fbc84bad2 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -373,6 +373,8 @@ def load(self, example_inputs) -> CompiledFxGraph: # so we can call it only after we're sure both forward and backward have # TODO: We don't cache debug lines for now, but we should for improved debugging + # Clear CompiledTritonKernels before loading from FXGraphCache + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() remote_cache = None constants = CompiledFxGraphConstants() if should_use_remote_fx_graph_cache(): diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 0e86da9e94d8..86c79ed4a3bc 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -651,6 +651,10 @@ def _compile_fx_inner( """ aot_mode: bool = V.aot_compilation + # Clean up Compiled Triton Kernels per inductor compile, as the future objects + # may not be valid for use after they are run/autotuned + torch._inductor.async_compile.CompiledTritonKernels.cache_clear() + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: # trigger the real recompilation for _LazyGraphModule before returning # the forward method. @@ -879,8 +883,8 @@ def _compile_fx_inner( log.info("{:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 log.info("-" * 100) - # Clear Compiled Triton Kernels per inductor compile, as the future objects - # may not be valid for use after they are run/autotuned + # Not strictly necessary, but good to clean up straggling futures + # that are unused to reclaim memory. torch._inductor.async_compile.CompiledTritonKernels.cache_clear() _step_logger()( From af5c1b96e251422ad5fb05f98c1f0095f9c9d1cf Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 2 Apr 2025 18:32:12 +0000 Subject: [PATCH 114/332] ci: Set minimum cmake version for halide build (#150560) This was failing due to pybind being strict about their cmake version requirements. This resolves errors like: ``` 652.1 Compatibility with CMake < 3.5 has been removed from CMake. 652.1 652.1 Update the VERSION argument value. Or, use the ... syntax 652.1 to tell CMake that the project requires at least but has been updated 652.1 to work with policies introduced by or earlier. 652.1 652.1 Or, add -DCMAKE_POLICY_VERSION_MINIMUM=3.5 to try configuring anyway. 652.1 652.1 652.1 -- Configuring incomplete, errors occurred! ``` Tested this locally with the following command: ``` ./build.sh pytorch-linux-jammy-py3.12-halide -t 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-jammy-py3.12-halide:8a8989876ff1aa1d5b0e465177afebbc7a9da921 ``` Closes https://github.com/pytorch/pytorch/issues/150420 Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/150560 Approved by: https://github.com/clee2000, https://github.com/ZainRizvi, https://github.com/atalman, https://github.com/malfet --- .ci/docker/common/install_halide.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.ci/docker/common/install_halide.sh b/.ci/docker/common/install_halide.sh index 0cfcfbce107b..ed1d7d33649d 100644 --- a/.ci/docker/common/install_halide.sh +++ b/.ci/docker/common/install_halide.sh @@ -35,7 +35,9 @@ git clone https://github.com/halide/Halide.git pushd Halide git checkout ${COMMIT} && git submodule update --init --recursive pip_install -r requirements.txt -cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -S . -B build +# NOTE: pybind has a requirement for cmake > 3.5 so set the minimum cmake version here with a flag +# Context: https://github.com/pytorch/pytorch/issues/150420 +cmake -G Ninja -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DCMAKE_BUILD_TYPE=Release -S . -B build cmake --build build test -e ${CONDA_PREFIX}/lib/python3 || ln -s python${ANACONDA_PYTHON_VERSION} ${CONDA_PREFIX}/lib/python3 cmake --install build --prefix ${CONDA_PREFIX} From e54556734084952449ebd0022f2a621906bd698e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 20:30:32 +0000 Subject: [PATCH 115/332] Revert "[dynamo] Always trace into tensor subclass `__torch_function__` (#149792)" This reverts commit 238109ad3245c5485f9e83b4b02d258b09329042. Reverted https://github.com/pytorch/pytorch/pull/149792 on behalf of https://github.com/malfet due to Broke trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/b03c42109c4e7dd52228f0a2bd65963a1d86523c/1?per_page=50&name_filter=clang10%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522)) --- test/dynamo/test_subclasses.py | 28 ------ ..._preserve_torch_function_when_return_as_is | 10 -- .../TestGradNewOnesOverride.test_newones | 1 - .../TestIterator.test_iterator | 1 - .../TestLazyModules.test_lazy_module_buffer | 1 - ...estLazyModules.test_lazy_module_jit_buffer | 1 - .../TestTorchFunctionMode.test_subclass_hash | 10 -- ..._on_invalid_torch_function_tensor_subclass | 3 - test/profiler/test_profiler_tree.py | 1 - torch/_dynamo/config.py | 26 +++-- torch/_dynamo/source.py | 6 -- torch/_dynamo/utils.py | 1 - torch/_dynamo/variables/builder.py | 99 +++++++------------ torch/_dynamo/variables/tensor.py | 11 ++- torch/_dynamo/variables/torch.py | 5 +- torch/_dynamo/variables/torch_function.py | 3 + torch/_guards.py | 8 +- torch/nn/attention/bias.py | 8 +- 18 files changed, 72 insertions(+), 151 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is delete mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones delete mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator delete mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer delete mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer delete mode 100644 test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash delete mode 100644 test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 0e7d54c28448..df6397df8257 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -40,10 +40,6 @@ def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) -def nontraceable_subclass(c): - return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c}) - - def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) @@ -1192,30 +1188,6 @@ def f(t): self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) - def test_nontraceable_tensor_subclass(self): - # This will error if Dynamo tries to wrap it as a tensor variable, - # because that involves calling certain methods to inspect the tensor - # property, which will blow up in the overriden `__torch_function__`. - class MySubclass(torch.Tensor): - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - raise RuntimeError("one shall not pass") - - def f(t): - return t.foo + torch.ones(10) - - opt_f = torch.compile(f, backend="eager", fullgraph=False) - - t = MySubclass(torch.ones(2)) - t.foo = 42 - # Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo - # from wrapping `t`. - with nontraceable_subclass(MySubclass): - res = f(t) - ref = opt_f(t) - - self.assertEqual(res, ref) - def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is deleted file mode 100644 index f243ff1904b0..000000000000 --- a/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is +++ /dev/null @@ -1,10 +0,0 @@ -- Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) -or properly graph break on it rather than skipping the frame altogether. -https://github.com/pytorch/pytorch/issues/128942 - -Fundamental issue is Dynamo tries to probe tensor object properties, but that -could trigger user-defined `__torch_function__` for tensor subclass objects. - -In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the -`__init__` of the following class, but `self._data = data` hasn't fired yet, and -its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones deleted file mode 100644 index 24f34ca8e8e6..000000000000 --- a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones +++ /dev/null @@ -1 +0,0 @@ -https://github.com/pytorch/pytorch/issues/149975 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator deleted file mode 100644 index 880a24b122bb..000000000000 --- a/test/dynamo_expected_failures/TestIterator.test_iterator +++ /dev/null @@ -1 +0,0 @@ -https://github.com/pytorch/pytorch/issues/150005 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer deleted file mode 100644 index 89dda61098d2..000000000000 --- a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer +++ /dev/null @@ -1 +0,0 @@ -Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer deleted file mode 100644 index 89dda61098d2..000000000000 --- a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer +++ /dev/null @@ -1 +0,0 @@ -Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash deleted file mode 100644 index beb4bf5d003a..000000000000 --- a/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash +++ /dev/null @@ -1,10 +0,0 @@ -Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) -or properly graph break on it rather than skipping the frame altogether. -https://github.com/pytorch/pytorch/issues/128942 - -Fundamental issue is Dynamo tries to probe tensor object properties, but that -could trigger user-defined `__torch_function__` for tensor subclass objects. - -In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the -`__init__` of the following class, but `self._diag = _diag` hasn't fired yet, and -its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass deleted file mode 100644 index c2ddc08d1e40..000000000000 --- a/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass +++ /dev/null @@ -1,3 +0,0 @@ -Dynamo cannot query properties of the tensor subclass object when wrapping it -into a VT, because it has a `__torch_function__` that only allows limited -torch ops. diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 48bbbf01727f..7dac5fb70905 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -690,7 +690,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ...""", ) - @skipIfTorchDynamo("segfaults in 3.13+") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b59e1c49e607..5d58efdeed09 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -152,15 +152,25 @@ # Non-Inductor backends can use this list for graph freezing. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1" -# NOTE this has been deprecated, it does nothing now. -traceable_tensor_subclasses: set[type[Any]] = set() -# If a tensor subclass is put into this set, Dynamo will model its instasnces in -# a very conservative and limited way (most likely causing lots of graph breaks -# if one apply tensor ops on these instances). This is useful if you encounter -# internal compiler errors from Dynamo which are caused by tensor subclasses, -# and you are willing to tolerate potential graph breaks rather than hard error. -nontraceable_tensor_subclasses: set[type[Any]] = set() +# This feature doesn't really work. We offer this flag for experimental +# purposes / if you want to help us build out support. +# +# torchdynamo has limited support for tensor subclasses that implement +# __torch_function__ see [Note: __torch_function__] in torch_function.py. +# Our current support is limited to tensor subclasses +# that DO NOT store metadata on the tensor (in general, dynamo does not +# support Python code that stores extra attributes on tensors at present). +# If your tensor subclass purely changes function call behavior via +# __torch_function__, you can allow torchdynamo to trace into it by +# adding it to traceable_tensor_subclasses. We don't do any safety checks, +# so it is up to you to ensure that your subclass is well behaved. See also +# https://github.com/pytorch/torchdynamo/issues/1948 +# +# We do NOT currently support __torch_dispatch__. The implementation is +# currently buggy, the main show stopper for nontrivial use is +# https://github.com/pytorch/torchdynamo/issues/1952 +traceable_tensor_subclasses: set[type[Any]] = set() # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. # This is a good way to get your model to work one way or another, but you may diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 4116f110b21d..e01c166c97d2 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -842,12 +842,6 @@ def is_from_local_source(source: Source, *, only_allow_input=False): return True -def is_from_source(source: Source, target: Source): - if isinstance(source, ChainedSource): - return is_from_source(source.base, target) - return source == target - - def is_from_unspecialized_param_buffer_source(source: Source): if isinstance(source, UnspecializedParamBufferSource): return True diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8fa038ce7116..8ee9289633b1 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1408,7 +1408,6 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: "reorderable_logging_functions", "ignore_logger_methods", "traceable_tensor_subclasses", - "nontraceable_tensor_subclasses", "_custom_ops_profile", } diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d5cea823b7f6..1cd8001e4c34 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -68,11 +68,7 @@ SymbolicContext, ) from torch.fx.immutable_collections import immutable_dict, immutable_list -from torch.nn.utils._expanded_weights import ExpandedWeight -from torch.utils._python_dispatch import ( - is_traceable_wrapper_subclass, - is_traceable_wrapper_subclass_type, -) +from torch.utils._python_dispatch import is_traceable_wrapper_subclass from torch.utils._sympy.value_ranges import ValueRanges from torch.utils.weak import TensorWeakRef @@ -616,30 +612,11 @@ def create_2d_tma_descriptor(): return id_dispatch(self, value) # Everything else (NB: order matters!) - if ( - isinstance(value, torch.Tensor) - and type(value) - not in ( - # These torch-native subclasses have overly restrictive - # `__torch_function__` which prevents Dynamo from reading their - # tensor attributes like `is_nested` or calling methods like - # `_is_view`. - torch.nn.parameter.UninitializedBuffer, - torch.nn.parameter.UninitializedParameter, - ExpandedWeight, - ) - and type(value) not in config.nontraceable_tensor_subclasses + if is_traceable_wrapper_subclass(value) or istype( + value, config.traceable_tensor_subclasses ): - if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: - # This case it's either tensor or subclass with default - # torch_dispatch (they might override torch_function or not), - # and we can always trace into them. - return self.wrap_tensor(value) - elif is_traceable_wrapper_subclass(value): - # For non-default torch_dispatch, we have more requirements. - return self.wrap_tensor(value) - - if is_namedtuple(value): + return self.wrap_tensor(value) + elif is_namedtuple(value): self.install_guards(GuardBuilder.SEQUENCE_LENGTH) output = [ LazyVariableTracker.create( @@ -953,6 +930,11 @@ def build_key_value(i, k, v): value, source=self.source, ) + elif ( + isinstance(value, torch._C._TensorMeta) + and value in config.traceable_tensor_subclasses + ): + return TensorSubclassVariable(value, source=self.source) elif ( istype(value, contextlib.nullcontext) and inspect.getattr_static(value, "enter_result", None) is None @@ -1205,20 +1187,6 @@ def build_key_value(i, k, v): if value is torch.autograd._unsafe_preserve_version_counter: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PreserveVersionContextVariable.constructor(self.tx) - if ( - # `value` must be a strict subclass of `torch.Tensor` - issubclass(value, torch.Tensor) - and value is not torch.Tensor - # `TensorSubclassVariable` is not for subclass that overrides - # `torch_dispatch`. - and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ - # `TensorSubclassVariable` would lead to construction of - # `TensorWithTFOverrideVariable`, but we don't want that for - # traceable wrapper subclasses (we wrap those subclass instances - # into `TensorVariable`). - and not is_traceable_wrapper_subclass_type(value) - ): - return TensorSubclassVariable(value, source=self.source) # This is a userdefined class, so install an ID_MATCH even if its a # global variable. self.install_guards(GuardBuilder.ID_MATCH) @@ -1761,22 +1729,7 @@ def wrap_tensor(self, value: torch.Tensor): # Guards are added inside register_attr_or_module ) - # NB: this just says we accessed a tensor from the same source again - # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). - # This is distinct from two distinct sources mapping to the same - # Tensor (per id())! No guard is necessary here. See below for the - # other case. - is_duplicate_tensor = source in self.tx.output.input_source_to_var - if is_duplicate_tensor: - return self.tx.output.input_source_to_var[source] - - options = {} - if type(value) in ( - torch.Tensor, - torch.nn.Parameter, - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ) or is_traceable_wrapper_subclass(value): + if type(value) in config.traceable_tensor_subclasses: # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary @@ -1794,13 +1747,24 @@ def wrap_tensor(self, value: torch.Tensor): # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. - subclass_type = None - else: subclass_type = type(value) - options["torch_function_fn"] = build_torch_function_fn( - self.tx, value, self.source - ) - self.install_guards(GuardBuilder.TYPE_MATCH) + else: + assert type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value), type(value) + subclass_type = None + + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] if get_static_address_type(value) == "guarded": self.install_guards(GuardBuilder.ID_MATCH) @@ -1808,6 +1772,13 @@ def wrap_tensor(self, value: torch.Tensor): # By this point, we should have deduplicated all tensors self.assert_not_wrapped_by_this_graph(value) + options = {} + if type(value) in config.traceable_tensor_subclasses: + options["torch_function_fn"] = build_torch_function_fn( + self.tx, value, self.source + ) + self.install_guards(GuardBuilder.TYPE_MATCH) + if ( isinstance(value, torch.Tensor) and value.is_nested diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index c477979fa9e3..44b3ffc27689 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -70,7 +70,6 @@ from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable -from .user_defined import UserDefinedClassVariable try: @@ -411,6 +410,8 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): return ConstantVariable(ret_val) def var_getattr(self, tx: "InstructionTranslator", name): + from . import UserDefinedClassVariable + if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): unimplemented( @@ -613,7 +614,7 @@ def call_method( """ # This is seen in inspect signature where we check if the value is a default value - if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): + if name == "__eq__" and isinstance(args[0], variables.UserDefinedClassVariable): return variables.ConstantVariable(False) try: @@ -1445,7 +1446,11 @@ def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) -class TensorSubclassVariable(UserDefinedClassVariable): +class TensorSubclassVariable(VariableTracker): + def __init__(self, value, *args, **kwargs) -> None: + self.value = value + super().__init__(*args, **kwargs) + def call_function( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 40821a16e5e5..1e7a9baf9494 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -76,7 +76,6 @@ from .torch_function import ( can_dispatch_torch_function, dispatch_torch_function, - TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) @@ -1351,9 +1350,7 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if isinstance( - data, TensorWithTFOverrideVariable - ) or is_traceable_wrapper_subclass_type(data.class_type): + if is_traceable_wrapper_subclass_type(data.class_type): unimplemented("Parameter constructor with tensor subclass NYI") if not can_convert_to_tracable_parameter(): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 330faf9bf902..9f24f669e398 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -24,6 +24,9 @@ See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w for more information on the design. + +To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses +in torch/_dynamo/config.py """ import collections diff --git a/torch/_guards.py b/torch/_guards.py index b6b36f637101..ad5f4a7b130a 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -631,12 +631,8 @@ def update(self, *others: set[Guard]): self.add(g, skip=1) def remove_guards_with_source(self, source): - """Delete all guards that contains a given source""" - from ._dynamo.source import is_from_source - - self.inner = { - g for g in self.inner if not is_from_source(g.originating_source, source) - } + """Delete all guards with a given source""" + self.inner = {g for g in self.inner if g.originating_source != source} class GuardsContext(Checkpointable[GuardsCheckpointState]): diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 36c0a18cdd12..da7acb957d96 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -283,9 +283,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" if kwargs is None: kwargs = {} - if func is torch.nn.functional.scaled_dot_product_attention: - return cls._dispatch(*args, **kwargs) - return super().__torch_function__(func, types, args, kwargs) + if func != torch.nn.functional.scaled_dot_product_attention: + raise NotImplementedError( + "CausalBias only supports scaled_dot_product_attention" + ) + return cls._dispatch(*args, **kwargs) def __repr__(self): # type:ignore[override] return self._materialize().__repr__() From 01411c739f136581f34e592318c9f8b95b3dd487 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 20:30:33 +0000 Subject: [PATCH 116/332] Revert "[dynamo] Support tensor subclass with overriden tensor methods and properties (#149484)" This reverts commit 7e53c58687482d58461e1dd8e09f59a9daf8f7b3. Reverted https://github.com/pytorch/pytorch/pull/149484 on behalf of https://github.com/malfet due to Broke trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/b03c42109c4e7dd52228f0a2bd65963a1d86523c/1?per_page=50&name_filter=clang10%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522)) --- test/dynamo/test_subclasses.py | 131 +++++----------------- torch/_dynamo/variables/misc.py | 15 +-- torch/_dynamo/variables/torch_function.py | 33 ++---- torch/_dynamo/variables/user_defined.py | 5 - 4 files changed, 47 insertions(+), 137 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index df6397df8257..99b7ab9784ae 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -757,22 +757,26 @@ def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None + @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn_opt = compile_full_eager(fn) - - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - res_exp = fn(x) - res_act = fn_opt(x) - - self.assertEqual(res_exp, res_act) + msg = ( + "Accessing overridden method/attribute sigmoid on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @@ -788,7 +792,10 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def fn(x): return x.ndim - msg = "Currently only support accessing overridden attributes that are functions or properties, but got " + msg = ( + "Accessing overridden method/attribute ndim on a tensor" + " subclass with a __torch_function__ override is not supported" + ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): @@ -797,11 +804,13 @@ def fn(x): def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): - def __init__(self, *args, **kwargs) -> None: + def __init__(self) -> None: self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property @@ -812,17 +821,19 @@ def ndim(self): def ndim(self, value): self._ndim = value + @torch.compile(backend="eager", fullgraph=True) def fn(x): - return x + x.ndim - - x = LocalSubclass(torch.ones(2, 2)) - fn_opt = compile_full_eager(fn) - - with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): - res_exp = fn(x) - res_act = fn_opt(x) + return x.ndim - self.assertEqual(res_exp, res_act) + msg = ( + "Accessing overridden method/attribute ndim on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + with torch._dynamo.config.patch( + "traceable_tensor_subclasses", {LocalSubclass} + ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn(x) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @@ -971,88 +982,6 @@ def fn(x): self.assertEqual(res_exp, res_act) self.assertEqual(x0, x1) - def test_subclass_override_shape_and_to(self): - # This is a slight variabtion of - # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 - class MySubclass(torch.Tensor): - def to(self, *args, **kwargs): - new = super().to(*args, **kwargs) - new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) - return new - - @property - def shape(self): - if not hasattr(self, "tensor_shape"): - self.tensor_shape = self.size() - return self.tensor_shape - - def fn(x): - x_shape = x.shape - y = x.to("cpu") - return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape - - with traceable_subclass(MySubclass): - x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) - x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass)) - - fn_opt = compile_full_eager(fn) - - res_exp = fn(x0) - res_act = fn_opt(x1) - self.assertEqual(res_exp, res_act) - self.assertEqual(x0, x1) - self.assertEqual(x0.tensor_shape, x1.tensor_shape) - - def test_subclass_dont_invoke_torch_function_on_overriden_method(self): - # We shouldn't fire `__torch_function__` for overriden tensor methods. - class MySubclass(torch.Tensor): - def to(self, device): - return self * len(device) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if func is torch.Tensor.to: - torch._dynamo.graph_break() - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return x.to("cpu") - - with traceable_subclass(MySubclass): - x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) - - fn_opt = compile_full_eager(fn) - - res_exp = fn(x) - res_act = fn_opt(x) - self.assertEqual(res_exp, res_act) - - def test_subclass_dont_invoke_torch_function_on_overriden_attr(self): - from types import MethodWrapperType - - # We shouldn't fire `__torch_function__` for overriden tensor attrs. - class MySubclass(torch.Tensor): - def ndim(self): - return 42 - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if type(func) is MethodWrapperType and func.__name__ == "ndim": - torch._dynamo.graph_break() - return super().__torch_function__(func, types, args, kwargs) - - def fn(x): - return x + x.ndim() - - with traceable_subclass(MySubclass): - x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) - - fn_opt = compile_full_eager(fn) - - res_exp = fn(x) - res_act = fn_opt(x) - self.assertEqual(res_exp, res_act) - def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2c92599a8b28..c6a2124c871a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -32,7 +32,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, trace_rules, variables +from .. import config, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 @@ -297,14 +297,6 @@ def call_method( tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( tx_old ) - elif ( - isinstance(inner_fn, types.MethodDescriptorType) - and inner_fn in trace_rules.get_tensor_method() - ): - # FunctionType but implementation is in C, we support some of these, - # e.g., tensor ops like `torch.Tensor.to`. - fn_var = VariableTracker.build(tx, inner_fn, source) - return fn_var.call_function(tx, [self.objvar] + args, kwargs) unimplemented(f"non-function or method super: {inner_fn}") @@ -677,10 +669,11 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ): + from ..trace_rules import is_callable_allowed from .builder import wrap_fx_proxy if name == "apply": - if trace_rules.is_callable_allowed(self.fn_cls): + if is_callable_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) @@ -698,6 +691,8 @@ def call_method( elif name == "backward": return self.call_backward(tx, args, kwargs) else: + from .. import trace_rules + source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.fn_cls, name) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 9f24f669e398..3946dccd8dc7 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -597,9 +597,8 @@ def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) - input_tensor_type = kwargs.pop("class_type") - assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( - f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" + assert kwargs.pop("class_type") is torch.Tensor, ( + "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" ) torch_fn_var = build_torch_function_fn(tx, class_type, cls_source) var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs) @@ -639,9 +638,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - # Handle non-overriden attributes inherited from `torch.Tensor`. - attr_is_overriden = _is_attr_overidden(tx, self, name) - if hasattr(torch.Tensor, name) and not attr_is_overriden: + if hasattr(torch.Tensor, name): + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overridden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) + if tx.output.torch_function_enabled: if self.source: install_guard( @@ -671,23 +674,11 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: import types - cls_source = GlobalSource(self.global_mangled_class_name(tx)) - attr_source = AttrSource(cls_source, name) if isinstance(attr, types.FunctionType): - install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + func_source = AttrSource(cls_source, name) + install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) return UserMethodVariable(attr, self) - - elif isinstance(attr, property): - getter_source = AttrSource(attr_source, "fget") - getter = attr.fget - getter_var = UserMethodVariable(getter, self, source=getter_source) - return getter_var.call_function(tx, [], {}) - - elif attr_is_overriden: - unimplemented( - f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950 - ) - return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2d22e0d35805..b842a552649f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,7 +82,6 @@ ) from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker from .dicts import DefaultDictVariable -from .lists import SizeVariable try: @@ -580,10 +579,6 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) - elif self.value is torch.Size: - # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. - tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) - return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) items = list(args) From 18908c8cedd2be1185aa345dc392d7019faacea7 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 20:30:33 +0000 Subject: [PATCH 117/332] Revert "[dynamo] Support `torch.Tensor._make_subclass` and tracing through tensor subclass `__new__` (#149483)" This reverts commit 203e1d681d1a4eb7794dfaeaebfa497242dde17d. Reverted https://github.com/pytorch/pytorch/pull/149483 on behalf of https://github.com/malfet due to Broke trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/b03c42109c4e7dd52228f0a2bd65963a1d86523c/1?per_page=50&name_filter=clang10%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522)) --- test/dynamo/test_subclasses.py | 37 ++-------- .../TestGradNewOnesOverride.test_newones | 0 .../TestIterator.test_iterator | 0 .../TestNamedTuple.test_max | 0 .../TestPickle.test_pickle | 0 torch/_dynamo/polyfills/loader.py | 1 - torch/_dynamo/polyfills/tensor.py | 37 ---------- torch/_dynamo/variables/tensor.py | 67 ++++++------------- 8 files changed, 26 insertions(+), 116 deletions(-) create mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones create mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator create mode 100644 test/dynamo_expected_failures/TestNamedTuple.test_max create mode 100644 test/dynamo_expected_failures/TestPickle.test_pickle delete mode 100644 torch/_dynamo/polyfills/tensor.py diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 99b7ab9784ae..7fefc281089b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,34 +954,6 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) - def test_make_subclass(self): - # Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo - # models its aliasing relationships correctly. - class MySubclass(torch.Tensor): - pass - - def fn(x): - # Downcast then upcast - y = torch.Tensor._make_subclass(MySubclass, x) - z = torch.Tensor._make_subclass(torch.Tensor, x) - # Now `x, y, z` should have the same underlying data. - x += 1 - y += 2 - z += 3 - res = x * y + z - return res - - with traceable_subclass(MySubclass): - x0 = torch.randn(2, 2) - x1 = x0.clone() - - fn_opt = compile_full_eager(fn) - - res_exp = fn(x0) - res_act = fn_opt(x1) - self.assertEqual(res_exp, res_act) - self.assertEqual(x0, x1) - def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 @@ -1002,9 +974,7 @@ def __init__(self, *args, quant_type=None, **kwargs): self.quant_type = quant_type def as_tensor(self): - return torch.Tensor._make_subclass( - torch.Tensor, self, self.requires_grad - ) + return torch.Tensor(self.data) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1113,8 +1083,9 @@ def f(t): res = f(t) ref = opt_f(t) - self.assertEqual(res, ref) - self.assertEqual(res.elem, ref.elem) + # TODO uncomment once we trace into `__new__`. + # self.assertEqual(res, ref) + # self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) def test_compile_with_fake_tensor_dynamic_dim(self): diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestNamedTuple.test_max b/test/dynamo_expected_failures/TestNamedTuple.test_max new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestPickle.test_pickle b/test/dynamo_expected_failures/TestPickle.test_pickle new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index f60aa57a5d40..d9be4e9febc9 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -21,7 +21,6 @@ "pytree", "sys", "fx", - "tensor", ) POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/tensor.py b/torch/_dynamo/polyfills/tensor.py deleted file mode 100644 index 002ccf5d1d4f..000000000000 --- a/torch/_dynamo/polyfills/tensor.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Any - -import torch - -from ..decorators import substitute_in_graph - - -@substitute_in_graph( # type: ignore[arg-type] - torch.Tensor._make_subclass -) -def make_subclass( - cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any -) -> Any: - # This is a rough approximation of `THPVariable_make_subclass`. It should - # suffice for most of Dynamo tracing purposes. - # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 - assert len(kwargs) == 0, "_make_subclass only supports requires_grad as keyword arg" - data = data.detach() - - # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. - if data.requires_grad != requires_grad: - data.requires_grad = requires_grad - - # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. - if cls is torch.Tensor: - return torch.Tensor(data) - - # Calling `as_subclass` because - # 1. Dynamo knows how to handle it - # 2. the C impls match at this point -- both `THPVariable_make_subclass` and - # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. - return data.as_subclass(cls) - - -__all__ = [ - "make_subclass", -] diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 44b3ffc27689..99bd1f3eb552 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -797,15 +797,6 @@ def method_as_subclass(self, cls): object(), var, mutation_type_cls=AttributeMutationNew ) return var - unimplemented_v2( - gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", - context=f"{self}.as_subclass({cls})", - explanation="Currently not supported", - hints=[ - "Avoid this call or move it outside `torch.compile` regione", - *graph_break_hints.SUPPORTABLE, - ], - ) def method_get_device(self): if isinstance(self.device, torch.device): @@ -1457,46 +1448,32 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - # Handle `Subclass(existing_tensor, ...)` calls. - from .torch_function import TensorWithTFOverrideVariable - - new_func = self.value.__new__ - if new_func is torch.Tensor.__new__: - if ( - len(args) == 1 - and isinstance(args[0], TensorVariable) - and len(kwargs) == 0 - ): - data = args[0] - # Simulate `torch.Tensor.__new__` as shallow-copying the input - # tensor data with a new type. TODO polyfill? + # Handle `Subclass(existing_tensor)` calls. + def impl(): + if len(args) == 1 and isinstance(args[0], TensorVariable): + from .torch_function import TensorWithTFOverrideVariable + + # This simulates `__new__` and _assumes_ it doesn't have + # side-effects that matters to Dynamo tracing. TODO trace through + # `__new__`. var = TensorWithTFOverrideVariable.from_tensor_var( - tx, data, self.value, self.source - ) - else: - unimplemented_v2( - gb_type="Calling subclass default constructor with more than tensor argument", - context=f"{self.value}(args={args}, kwargs={kwargs})", - explanation="Currently not supported", - hints=[ - "Avoid this constructor call or move it outside " - "`torch.compile` regione", - *graph_break_hints.SUPPORTABLE, - ], + tx, args[0], self.value, self.source ) - else: - # Let Dynamo trace through custom `__new__` - var = VariableTracker.build(tx, new_func).call_function( - tx, [self] + args, kwargs - ) - # Let Dynamo trace through custom `__init__` - init_func = self.value.__init__ - # TODO builder should be able to handle `torch.Tensor.__init__`, - # which is `object.__init__`, so that we can remove this check. - if init_func is not torch.Tensor.__init__: - VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + cls_kwargs = kwargs or {} + VariableTracker.build(tx, init_func).call_function( + tx, [var], cls_kwargs + ) + return var + + return super().call_function(tx, args, kwargs) + var = impl() # See NOTE [Side effect tracking for newly constructed tensor] tx.output.side_effects._track_obj( object(), var, mutation_type_cls=AttributeMutationNew From 03c879d59bf2561d4a4390b1c3d996a8795d3f4f Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 20:30:33 +0000 Subject: [PATCH 118/332] Revert "[dynamo] Support Tensor subclass that has dynamic attributes or calls `Parameter.__torch_function__` (#149482)" This reverts commit 98453c135a7778d12ff881d8b0a717257be9fc38. Reverted https://github.com/pytorch/pytorch/pull/149482 on behalf of https://github.com/malfet due to Broke trunk, see https://hud.pytorch.org/hud/pytorch/pytorch/b03c42109c4e7dd52228f0a2bd65963a1d86523c/1?per_page=50&name_filter=clang10%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/149482#issuecomment-2773650522)) --- test/dynamo/test_misc.py | 16 --- test/dynamo/test_subclasses.py | 134 ------------------ .../TestTorch.test_tensor_ressurecting_clear | 1 - ...edding_swap_True_set_grad_True_cpu_float32 | 0 ..._PReLU_swap_True_set_grad_True_cpu_float32 | 0 ...MSNorm_swap_True_set_grad_True_cpu_float32 | 0 ...dding_swap_True_set_grad_True_cuda_float32 | 0 ...PReLU_swap_True_set_grad_True_cuda_float32 | 0 ...SNorm_swap_True_set_grad_True_cuda_float32 | 0 torch/_dynamo/side_effects.py | 66 +++------ torch/_dynamo/trace_rules.py | 1 + torch/_dynamo/variables/builder.py | 89 ++++-------- torch/_dynamo/variables/builtin.py | 14 -- torch/_dynamo/variables/misc.py | 31 ---- torch/_dynamo/variables/tensor.py | 44 ++---- torch/_dynamo/variables/torch_function.py | 64 ++++----- 16 files changed, 87 insertions(+), 373 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear create mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 create mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 create mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 create mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 create mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 create mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8579ee8e1b2e..a6b3a29eb4d3 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -586,22 +586,6 @@ def f(x): ref = f(x) self.assertEqual(res, ref) - def test_newly_constructed_tensor_attr_mutation(self): - def f(x): - y = x + 10 - y.grad = x - y.foo = 42 - return y - - opt_f = torch.compile(f, backend="eager", fullgraph=True) - x = torch.ones(5) - - res = opt_f(x) - ref = f(x) - self.assertEqual(res, ref) - self.assertEqual(res.grad, ref.grad) - self.assertEqual(res.foo, ref.foo) - def test_closure_recompiles(self): cnt = CompileCounter() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7fefc281089b..ef2acadac89d 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,140 +954,6 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) - def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): - # This is a slight variation of - # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 - # which basically - # 1. uses tensor subclass to attach quantization metadata onto tensors - # 2. preserve them across torch ops - # 3. use the metadata to dequantize the tensor - # 4. convert it to a regular tensor. - # - # The test is meant to make sure Dynamo won't graph break over it. - class GGUFParameter(torch.nn.Parameter): - def __new__(cls, data, requires_grad=False, quant_type=None): - data = data if data is not None else torch.empty(0) - self = torch.Tensor._make_subclass(cls, data, requires_grad) - return self - - def __init__(self, *args, quant_type=None, **kwargs): - self.quant_type = quant_type - - def as_tensor(self): - return torch.Tensor(self.data) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} - - result = super().__torch_function__(func, types, args, kwargs) - - quant_type = None - for arg in args: - if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): - quant_type = arg[0].quant_type - break - if isinstance(arg, GGUFParameter): - quant_type = arg.quant_type - break - if isinstance(result, torch.Tensor): - return cls(result, quant_type=quant_type) - # Handle tuples and lists - elif isinstance(result, (tuple, list)): - # Preserve the original type (tuple or list) - wrapped = [ - cls(x, quant_type=quant_type) - if isinstance(x, torch.Tensor) - else x - for x in result - ] - return type(result)(wrapped) - else: - return result - - def f(x): - tmp = x * 2 - tmp = tmp + tmp.quant_type - tmp = tmp.as_tensor() - return tmp * 3 - - opt_f = torch.compile(f, backend="eager", fullgraph=True) - - x = GGUFParameter(torch.ones(2), quant_type=42) - with traceable_subclass(GGUFParameter): - res = f(x) - ref = opt_f(x) - self.assertEqual(res, ref) - - def test_newly_constructed_tensor_subclass_attr_mutation(self): - # Make sure the attribute mutation for newly constructed tensor subclass - # object (from constructor call) is handled both during Dynamo tracing - # and codegen-ed to be visible outside `torch.compile`. - class MySubclass(torch.Tensor): - pass - - def f(): - x = MySubclass(torch.ones(2)) - x.bar = 42 - return x, x * x.bar - - opt_f = compile_full_eager(f) - - with traceable_subclass(MySubclass): - res = f() - ref = opt_f() - - self.assertEqual(res, ref) - self.assertEqual(res[0].bar, ref[0].bar) - - def test_as_subclass_attr_mutation(self): - # Make sure the attribute mutation for newly constructed tensor subclass - # object (from as_subclass call) is handled both during Dynamo tracing - # and codegen-ed to be visible outside `torch.compile`. - class MySubclass(torch.Tensor): - pass - - def f(): - x = torch.ones(2).as_subclass(MySubclass) - x.bar = 42 - return x, x * x.bar - - opt_f = compile_full_eager(f) - - with traceable_subclass(MySubclass): - res = f() - ref = opt_f() - - self.assertEqual(res, ref) - self.assertEqual(res[0].bar, ref[0].bar) - - def test_tensor_subclass_attr_codegen_tos(self): - # This repros a very subtle interaction between - # `TensorWithTFOverrideVariable` attribute mutation codegen and - # `PyCodegen.top_of_stack`. It was uncovered from - # `test_tensor_subclass_deepcopy`. - class MySubclass(torch.Tensor): - def __new__(cls, elem, *args, **kwargs): - r = torch.Tensor._make_subclass(cls, torch.ones(0)) - r.elem = elem - return r - - def f(t): - return MySubclass(t.elem.clone()) - - opt_f = compile_full_eager(f) - - t = MySubclass(torch.ones(2)) - with traceable_subclass(MySubclass): - res = f(t) - ref = opt_f(t) - - # TODO uncomment once we trace into `__new__`. - # self.assertEqual(res, ref) - # self.assertEqual(res.elem, ref.elem) - self.assertEqual(type(res), type(ref)) - def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear deleted file mode 100644 index 276a4f74bbca..000000000000 --- a/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear +++ /dev/null @@ -1 +0,0 @@ -https://github.com/pytorch/pytorch/issues/149881 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 1deb09e2cc1e..4c85d98cfd16 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -295,7 +295,9 @@ def _track_obj( variable: VariableTracker, mutation_type_cls=ValueMutationExisting, ): - """Start tracking an existing or new variable for mutation""" + """Start tracking a new variable for mutation""" + assert variable.source is not None + if id(item) in self.id_to_variable: raise AssertionError( f"{variable} is already tracked for mutation. This could be " @@ -574,18 +576,12 @@ def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): - # We must codegen modified VT to their source by default, so that - # mutation and aliasing are properly accounted for. - # - # Since newly constructed objects don't have a source, we manually - # codegen their construction and store them to a newly assigned local - # source. Note that `ValueMutationNew` isn't tracked by SideEffects. + # Make sure we codegen these modified VT to their source by default, so + # that mutation and aliasing are properly accounted for. for var in self._get_modified_vars(): - if not isinstance(var.mutation_type, AttributeMutationNew): - assert var.source is not None - continue - - if isinstance(var, variables.CellVariable): + if isinstance(var.mutation_type, AttributeMutationNew) and isinstance( + var, variables.CellVariable + ): # Cells created in the root frame are created either by # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit # `make_cell` for the non-root-frame cells here. @@ -599,38 +595,18 @@ def codegen_save_tempvars(self, cg: PyCodegen): var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: var.source = LocalCellSource(var.local_name) - elif isinstance(var, variables.TensorVariable): - # NOTE: for historical reasons we never assigned local sources - # to newly constructed tensor object, so we keep it that way. - # They are always loaded from output of the fx graph, so one can - # think of it as having a "OutputGraphSource" for codegen - # purposes. - # - # However, tensor subclass objects are different, because the - # reconstruction logic in `PyCodegen` loads the data tensor from - # graph output and then calls `as_subclass`, meaning we must - # assign a source to it to ensure we only reconstruct one - # subclass instance. - if isinstance( - var, variables.torch_function.TensorWithTFOverrideVariable - ): - # Don't codegen from temp source assigned from the 1st pass. - cg(var, allow_cache=False) - cg.add_cache(var) - # `add_cache` generates STORE and consumes TOS, but we never - # cleared it. TODO move this call into `add_cache` - cg.clear_tos() - var.source = LocalSource(cg.tempvars[var]) - elif isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented_v2( - gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", - context="", - explanation="We cannot reconstruct a torch.autograd.Function's context object.", - hints=[], - ) - else: + elif isinstance(var.mutation_type, AttributeMutationNew): + if isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) + # Reconstruct the bytecode for # base_cls.__new__(user_cls, *args) + if isinstance(var, variables.UserDefinedObjectVariable): def load_new_method(): @@ -654,6 +630,10 @@ def load_new_method(): cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) + else: + # The remaning cases here are `AttributeMutationExisting` and + # `MutableSideEffects`, which have sources already. + assert var.source is not None for ctx, args in self.save_for_backward: cg(ctx.source) @@ -1013,7 +993,7 @@ def codegen_update_mutated(self, cg: PyCodegen): else: cg.tx.output.update_co_names(name) cg(value) - cg(var) + cg(var.source) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.ListIteratorVariable): for _ in range(var.index): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4b3eb10d09e7..05739259dc5b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -510,6 +510,7 @@ "torch._C._debug_set_fusion_group_inlining", "torch._C._demangle", "torch._C._disabled_torch_dispatch_impl", + "torch._C._disabled_torch_function_impl", "torch._C._dispatch_call_boxed", "torch._C._dispatch_check_all_invariants", "torch._C._dispatch_check_invariants", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 1cd8001e4c34..49d0c162d68a 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -140,7 +140,6 @@ wrap_fake_exception, ) from .base import ( - AttributeMutationNew, typestr, ValueMutationExisting, ValueMutationNew, @@ -2471,9 +2470,7 @@ def _wrap_fx_preexisting_tensor( f"wrapped by this instance of Dynamo. Found: {tensor}" ) - return construct_tensor_variable( - target_cls, tx, proxy, tensor, subclass_type, options - ) + return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls) # This is 2 in the above comment (wrapping the output of a traced op) @@ -2507,23 +2504,36 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe import torch._utils if isinstance(example_value, torch.Tensor): - var = construct_tensor_variable( - target_cls, tx, proxy, example_value, subclass_type, options - ) - # NOTE: [Side effect tracking for newly constructed tensor] - # For newly constructed objects that have mutable attributes, we usually - # construct their VariableTracker via `track_object_new`, but since - # tensor variable construction is a bit different, we handle them - # speically here. This ensures that codegen will actually generate the - # attribute mutations on this tensor. - # - # NOTE we pass a dummy object as the `item` argument to avoid - # constructing a dummy _tensor_ object. The object isn't used for - # newly constructed VTs anyways. - tx.output.side_effects._track_obj( - proxy, var, mutation_type_cls=AttributeMutationNew - ) - return var + is_parameter = isinstance(example_value, torch.nn.Parameter) + is_buffer = isinstance(example_value, torch.nn.Buffer) + + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + tensor_type = subclass_type if subclass_type else torch.Tensor + specialized_props["class_type"] = ( + torch.nn.Parameter + if is_parameter + else torch.nn.Buffer + if is_buffer + else tensor_type + ) + + options.update(specialized_props) + return target_cls(proxy, **options) elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" @@ -2692,43 +2702,6 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ) -def construct_tensor_variable( - target_cls, tx, proxy, example_value, subclass_type, options -): - """ - Actually construct a tensor variable after all the pre-processing from - wrapping a pre-existing or newly created tensor value. - """ - # NB: In most (all?) cases, this does not actually do a clone. - # (WARNING: this means that if we mutate metadata on the fake - # tensor, the stored example value will update too!) - example_value = _clone_input(example_value, tx.fake_mode) - set_example_value(proxy.node, example_value) - # We bind the unbacked symints in sizes/trdies of tensor lazily. - # So that subgraphs can access the unbacked symbol's proxy in parent graph - # when lifting unbacked symbols of input tensors to subgraph inputs. - # We do it lazily because the tensor may not be used in subgraphs. - tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) - specialized_props = target_cls.specialize(example_value) - # TODO: not sure about this fake mode test - if ( - isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) - and example_value.fake_mode is tx.fake_mode - ): - if subclass_type: - tensor_type = subclass_type - elif isinstance(example_value, torch.nn.Parameter): - tensor_type = torch.nn.Parameter - elif isinstance(example_value, torch.nn.Buffer): - tensor_type = torch.nn.Buffer - else: - tensor_type = torch.Tensor - specialized_props["class_type"] = tensor_type - - options.update(specialized_props) - return target_cls(proxy, **options) - - def get_automatic_dynamic_shapes_mark_as(): if config.automatic_dynamic_shapes_mark_as == "dynamic": return DimDynamic.DYNAMIC diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9c11423162d3..c66c369876b9 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1933,20 +1933,6 @@ def call_setattr( "the middle of the graph, which aot_autograd does not currently know how to handle. " ) elif name == "data": - # See comments on `test_set_data_on_scoped_tensor` for plans - # to support this. - if obj.source is None: - unimplemented_v2( - gb_type="Failed to mutate tensor data attribute", - context=f"setattr({obj}, {name}, {val})", - explanation="Dyanmo only supports mutating `.data`" - " of tensor created outside `torch.compile` region", - hints=[ - "Don't mutate `.data` on this tensor, or move " - "the mutation out of `torch.compile` region", - ], - ) - # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause # tracked fakes to produce incorrect guards. This is sound because the TensorVariable diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c6a2124c871a..7eaa01c2a5da 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -161,14 +161,6 @@ def call_method( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) - # This essentially simulates CPython's `super_getattro`: - # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 - # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. - # - # However, `res`'s type needs to be checked for `tp_descr_get`, and - # applied if it has one. We currently don't have polyfills for all the - # relevant `tp_descr_get`, so we explicitly handle the cases we care - # about here (e.g., note the staticmethod, classmethod cases). if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -274,29 +266,6 @@ def call_method( source = self.source and AttrSource(self.source, attr_name) return VariableTracker.build(tx, attr_value, source) - elif inner_fn is torch._C._disabled_torch_function_impl: - # See `THPModule_disable_torch_function` for the C impl. - # The signature of _disabled_torch_function_impl is similar to - # `__torch_function__`, just without the first `cls` argument: - # * (func, types, args, kwargs) - func = args[0] - tf_kwargs = {} - tf_args = args[2].items - for hash_key_vt, value_vt in args[3].items.items(): - key_str = hash_key_vt.vt.as_python_constant() - tf_kwargs[key_str] = value_vt - - output_old = tx.output.torch_function_enabled - tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled - tx.output.torch_function_enabled = False - tx.symbolic_torch_function_state.torch_function_subclass_enabled = False - try: - return func.call_function(tx, tf_args, tf_kwargs) - finally: - tx.output.torch_function_enabled = output_old - tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( - tx_old - ) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 99bd1f3eb552..5b10a643ad94 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -67,7 +67,7 @@ set_example_value, tensortype_to_dtype, ) -from .base import AttributeMutationNew, VariableTracker +from .base import VariableTracker from .constant import ConstantVariable from .lists import SizeVariable @@ -789,14 +789,9 @@ def method_as_subclass(self, cls): tx = InstructionTranslator.current_tx() py_cls = cls.as_python_constant() - var = TensorWithTFOverrideVariable.from_tensor_var( + return TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, cls.source ) - # See NOTE [Side effect tracking for newly constructed tensor] - tx.output.side_effects._track_obj( - object(), var, mutation_type_cls=AttributeMutationNew - ) - return var def method_get_device(self): if isinstance(self.device, torch.device): @@ -1448,37 +1443,14 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - # Handle `Subclass(existing_tensor)` calls. - def impl(): - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable - - # This simulates `__new__` and _assumes_ it doesn't have - # side-effects that matters to Dynamo tracing. TODO trace through - # `__new__`. - var = TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source - ) - - # Let Dynamo trace through custom `__init__` - init_func = self.value.__init__ - # TODO builder should be able to handle `torch.Tensor.__init__`, - # which is `object.__init__`, so that we can remove this check. - if init_func is not torch.Tensor.__init__: - cls_kwargs = kwargs or {} - VariableTracker.build(tx, init_func).call_function( - tx, [var], cls_kwargs - ) - return var + if len(args) == 1 and isinstance(args[0], TensorVariable): + from .torch_function import TensorWithTFOverrideVariable - return super().call_function(tx, args, kwargs) + return TensorWithTFOverrideVariable.from_tensor_var( + tx, args[0], self.value, self.source + ) - var = impl() - # See NOTE [Side effect tracking for newly constructed tensor] - tx.output.side_effects._track_obj( - object(), var, mutation_type_cls=AttributeMutationNew - ) - return var + return super().call_function(tx, args, kwargs) def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 3946dccd8dc7..e51f6ccd6c9d 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -62,7 +62,6 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable -from .functions import UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -593,9 +592,12 @@ def __init__(self, *args, **kwargs) -> None: def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. + # This isn't sound if a __torch_function__ tensor subclass defines a + # constructor, but if only a __torch_function__ impl is defined, this is + # okay to call. It is up to the user whether this is correct behavior + # or not. import torch - # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) assert kwargs.pop("class_type") is torch.Tensor, ( "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" @@ -638,48 +640,30 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if hasattr(torch.Tensor, name): - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overridden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) - if tx.output.torch_function_enabled: - if self.source: - install_guard( - AttrSource( - AttrSource(self.source, "__class__"), name - ).make_guard(GuardBuilder.FUNCTION_MATCH) + if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): + if self.source: + install_guard( + AttrSource(AttrSource(self.source, "__class__"), name).make_guard( + GuardBuilder.FUNCTION_MATCH ) - get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) - - return self.call_torch_function( - tx, - get_fn, - TupleVariable([self.class_type_var(tx)]), - [self], - {}, ) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + [self], + {}, + ) else: - # `TensorVariable.var_getattr` doesn't handle user-defined - # function/attribute well, so we explicitly handle them here. - # - # TODO move this logic into `TensorVariable`, or try to merge it - # with similar logic in `UserDefinedObjectVariable`. - try: - attr = inspect.getattr_static(self.class_type, name) - except AttributeError: - pass - else: - import types - - if isinstance(attr, types.FunctionType): - cls_source = GlobalSource(self.global_mangled_class_name(tx)) - func_source = AttrSource(cls_source, name) - install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) - return UserMethodVariable(attr, self) - return super().var_getattr(tx, name) + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( From a8f6b40e36bc4afe4e58568620a008c9a8a8704e Mon Sep 17 00:00:00 2001 From: Colin Peppler Date: Tue, 1 Apr 2025 12:53:29 -0700 Subject: [PATCH 119/332] [inductor] skip non-trivial tiling if unbacked symints are present (#150225) Take two of https://github.com/pytorch/pytorch/pull/149994. This time we just skip `convert_tiling_to_3d` and `candidate_tilings` if there exists unbacked symints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150225 Approved by: https://github.com/eellison --- .../test_torchinductor_strided_blocks.py | 34 +++++++++++++++++++ torch/_inductor/codegen/simd.py | 13 +++++-- torch/_inductor/lowering.py | 8 ++--- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 895c536ed326..ec6c3dc8a578 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -931,6 +931,40 @@ def foo(x, y, z): # Check for 3D tiling self.assertIn("ZBLOCK", code) + @torch._dynamo.config.patch({"capture_scalar_outputs": True}) + @parametrize("num_tile_candidates", (1, 2)) + def test_unbacked_size_on_non_contig_dim(self, num_tile_candidates: int): + # NUM_REPEAT should determine # of candidate_tilings. + NUM_REPEAT = 2 if num_tile_candidates == 2 else 8 + + def foo(x, length): + unbacked = length.item() + torch._check_is_size(unbacked) + + repeated = x.repeat(1, unbacked, NUM_REPEAT) + # permute creates split in middle with unbacked symint is the first range + # ranges: [33*unbacked, NUM_REPEAT, 64] + permute120 = repeated.permute([1, 2, 0]) + return permute120.cos() + + inps = ( + torch.rand((64, 33, 1), device=self.device, dtype=torch.float32), + torch.scalar_tensor(16, device=self.device, dtype=torch.int32), + ) + + with torch._dynamo.config.patch({"capture_scalar_outputs": True}): + run_and_compare( + self, + foo, + *inps, + expected_num_triton_kernels=1, + expected_num_block_pointers=0, + config_patches={ + "triton.max_tiles": 3, + "triton.prefer_nd_tiling": True, + }, + ) + # block_ptr advancements should also be deferrered conditional # on the associated buffer not being removed # in this case the bernoulli operation is fused with the following sum diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index f33c39623acb..db8091c78648 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -18,6 +18,7 @@ import torch import torch._logging +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.fx.immutable_collections import immutable_dict from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, Identity, ModularIndexing @@ -1764,7 +1765,11 @@ def collapse_ranges(ranges: Sequence[sympy.Expr]) -> sympy.Expr: return tilings pointwise_ranges, reduction_ranges = node.get_ranges() - if len(pointwise_ranges) <= 1 and len(reduction_ranges) <= 1: + if ( + len(pointwise_ranges) <= 1 + and len(reduction_ranges) <= 1 + or free_unbacked_symbols(pointwise_ranges + reduction_ranges) + ): return [] # Tile either pointwise or reduction dims. @@ -2013,7 +2018,11 @@ def convert_tiling_to_3d( ) -> Optional[dict[str, sympy.Expr]]: a0, a1 = tiling0["x"], tiling0.get("y", 1) b0, b1 = tiling1["x"], tiling1.get("y", 1) - if V.graph.sizevars.size_hint(a1 - b1) == 0: + + if ( + free_unbacked_symbols([a1, b1]) + or V.graph.sizevars.size_hint(a1 - b1) == 0 + ): return None if V.graph.sizevars.size_hint(a1 - b1) < 0: # swap so a0 is bigger diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 9996857b29d2..7fcf79041851 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -40,6 +40,7 @@ Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator +from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing @@ -1088,8 +1089,6 @@ def trunc(x): @register_lowering(aten.expand, type_promotion_kind=None) def expand(x, sizes): - from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols - (x,) = promote_constants([x]) if isinstance(x, ir.BaseConstant): return ExpandView.create(x, tuple(sizes)) @@ -1166,8 +1165,9 @@ def inner_fn(index): return x_loader(index) old_size_product = V.graph.sizevars.size_hint(sympy_product(old_size)) - if old_size_product > 0: - # maybe realize the input + if old_size_product > 0 and not free_unbacked_symbols(new_size): + # maybe realize the input but skip for unbacked symints since it'll + # choke on the size hint. x.mark_reuse( V.graph.sizevars.size_hint(sympy_product(new_size)) // old_size_product ) From 85df0dc2460603116392001ae63b33ac0ee8fc54 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 1 Apr 2025 15:43:29 -0700 Subject: [PATCH 120/332] [dynamo] emit only 1 graph break message on unrecoverable data-dependent assert fail (#150471) Addresses https://fb.workplace.com/groups/1075192433118967/permalink/1625299684774903/ Pull Request resolved: https://github.com/pytorch/pytorch/pull/150471 Approved by: https://github.com/jansel --- test/dynamo/test_error_messages.py | 62 ++++++++++++++++++++++++------ torch/_dynamo/symbolic_convert.py | 28 +++++++++----- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 3793ade26738..1098b1bfbb2f 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -36,6 +36,14 @@ """ +class GenericCtxMgr: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + class GraphBreakMessagesTest(LoggingTestCase): def test_dynamic_shape_operator(self): def fn(): @@ -569,19 +577,12 @@ def fn(mod, x): ) def test_generic_ctx_mgr_graph_break(self): - class CtxMgr: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - def fn(): - with CtxMgr(): - with CtxMgr(): + with GenericCtxMgr(): + with GenericCtxMgr(): pass - with CtxMgr(): - with CtxMgr(): + with GenericCtxMgr(): + with GenericCtxMgr(): pass torch._dynamo.graph_break() @@ -596,7 +597,7 @@ def fn(): Hint: Move the offending context manager(s) to outside the compiled region. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Developer debug context: Active generic context managers: [GenericContextWrappingVariable(CtxMgr), GenericContextWrappingVariable(CtxMgr)] + Developer debug context: Active generic context managers: [GenericContextWrappingVariable(GenericCtxMgr), GenericContextWrappingVariable(GenericCtxMgr)] from user code: @@ -834,6 +835,43 @@ def fn(x): """, ) + @make_logging_test(graph_breaks=True) + def test_assert_failure_in_generic_ctx_mgr(self, records): + def fn(x): + with GenericCtxMgr(): + assert x is None + + with self.assertRaises(AssertionError): + torch.compile(fn, backend="eager")(torch.randn(3)) + + # only 1 graph break message + self.assertEqual(len(records), 1) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + """\ +Graph break: skip: from user code at: + File "test_error_messages.py", line N, in fn + assert x is None +""", + ) + self.assertExpectedInline( + munge_exc(records[0].exc_info[1], suppress_suffix=True, skip=0), + """\ +Data-dependent assertion failed (cannot compile partial graph) + Explanation: Dynamo has determined when encountering a data-dependent assert failure that it should not compile the partial graph. + Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. + Hint: Use `torch._assert()` to raise a hard AssertionError when the check fails. This error will propagate back the user code that called the compiled function (i.e. Dynamo wil not trace any exception handling). + Hint: Remove the assert statement. + Hint: Move the assert statement outside of any context managers in order to graph break with partial graph compilation (if fullgraph=False). + + Developer debug context: value: ConstantVariable(bool: False) + + +from user code: + File "test_error_messages.py", line N, in fn + assert x is None""", + ) + def test_no_internal_compiler_stacktrace(self): def fn(): gn() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 80960d6eb94a..fb38b9e1b664 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -590,15 +590,7 @@ def jump_graph_break(self, inst, value, extra_msg=""): hints=_hints, ), ) - if not self.should_compile_partial_graph(): - unimplemented_v2( - gb_type="Should not compile partial graph (data-dependent branching)", - context="", - explanation="Dynamo has determined when encountering data-dependent " - "branching (e.g. `if my_tensor.item() > 0:`) that it should not " - "compile the partial graph.", - hints=[], - ) + assert self.should_compile_partial_graph() # compile a partial subgraph prefix then jump into user code if self.maybe_has_backedge(): msg = ( @@ -642,8 +634,24 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if value.is_python_constant(): if bool(value.as_python_constant()): return self.jump(inst) - else: + elif self.should_compile_partial_graph(): jump_graph_break(self, inst, value) + else: + unimplemented_v2( + gb_type="Data-dependent assertion failed (cannot compile partial graph)", + context=f"value: {value}", + explanation="Dynamo has determined when encountering a data-dependent assert failure " + "that it should not compile the partial graph.", + hints=[ + *graph_break_hints.FUNDAMENTAL, + "Use `torch._assert()` to raise a hard AssertionError when the check fails. " + "This error will propagate back the user code " + "that called the compiled function (i.e. Dynamo wil not trace any exception handling).", + "Remove the assert statement.", + "Move the assert statement outside of any context managers in order to graph break with " + "partial graph compilation (if fullgraph=False).", + ], + ) # TODO maybe should respect DtoH sync intention of users later?? # Manually insert torch._assert_async instead of python assert and jump over From 33535b3eee7b4ccede1a7d6913b9807430de8303 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:39 -0700 Subject: [PATCH 121/332] [dynamo] Support Tensor subclass that has dynamic attributes or calls `Parameter.__torch_function__` (#149482) This fixes most of https://github.com/huggingface/diffusers/issues/10795, except for `torch.Tensor._make_subclass`, which will be fixed in a subsequent patch. The relevant tensor subclass from the aforementioned issue is defined here: https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435. There are two things to note about the tensor subclass: 1. it calls `super().__torch_function__`, which is `torch._C._disabled_torch_function_impl`, so this patch updates `SuperVariable.call_method` to handle it (we can't do a simpler polyfill due to some bug with `var_getattr` raising `NotImplementedError`, which forgot to restore symbolic context). 2. it sets and reads attributes (`quant_type`), and defines new methods (`as_data`), so this patch adds support for those. 3. it has a `__init__`, which Dynamo needs to trace through in `TensorSubclassVariable.call_function`. Differential Revision: [D71906140](https://our.internmc.facebook.com/intern/diff/D71906140) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149482 Approved by: https://github.com/jansel, https://github.com/mlazos --- test/dynamo/test_misc.py | 16 +++ test/dynamo/test_subclasses.py | 134 ++++++++++++++++++ .../TestTorch.test_tensor_ressurecting_clear | 1 + ...edding_swap_True_set_grad_True_cpu_float32 | 0 ..._PReLU_swap_True_set_grad_True_cpu_float32 | 0 ...MSNorm_swap_True_set_grad_True_cpu_float32 | 0 ...dding_swap_True_set_grad_True_cuda_float32 | 0 ...PReLU_swap_True_set_grad_True_cuda_float32 | 0 ...SNorm_swap_True_set_grad_True_cuda_float32 | 0 torch/_dynamo/side_effects.py | 66 ++++++--- torch/_dynamo/trace_rules.py | 1 - torch/_dynamo/variables/builder.py | 89 ++++++++---- torch/_dynamo/variables/builtin.py | 14 ++ torch/_dynamo/variables/misc.py | 31 ++++ torch/_dynamo/variables/tensor.py | 44 ++++-- torch/_dynamo/variables/torch_function.py | 64 +++++---- 16 files changed, 373 insertions(+), 87 deletions(-) create mode 100644 test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 delete mode 100644 test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a6b3a29eb4d3..8579ee8e1b2e 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -586,6 +586,22 @@ def f(x): ref = f(x) self.assertEqual(res, ref) + def test_newly_constructed_tensor_attr_mutation(self): + def f(x): + y = x + 10 + y.grad = x + y.foo = 42 + return y + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.ones(5) + + res = opt_f(x) + ref = f(x) + self.assertEqual(res, ref) + self.assertEqual(res.grad, ref.grad) + self.assertEqual(res.foo, ref.foo) + def test_closure_recompiles(self): cnt = CompileCounter() diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ef2acadac89d..7fefc281089b 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,6 +954,140 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): + # This is a slight variation of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + # which basically + # 1. uses tensor subclass to attach quantization metadata onto tensors + # 2. preserve them across torch ops + # 3. use the metadata to dequantize the tensor + # 4. convert it to a regular tensor. + # + # The test is meant to make sure Dynamo won't graph break over it. + class GGUFParameter(torch.nn.Parameter): + def __new__(cls, data, requires_grad=False, quant_type=None): + data = data if data is not None else torch.empty(0) + self = torch.Tensor._make_subclass(cls, data, requires_grad) + return self + + def __init__(self, *args, quant_type=None, **kwargs): + self.quant_type = quant_type + + def as_tensor(self): + return torch.Tensor(self.data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + result = super().__torch_function__(func, types, args, kwargs) + + quant_type = None + for arg in args: + if isinstance(arg, list) and isinstance(arg[0], GGUFParameter): + quant_type = arg[0].quant_type + break + if isinstance(arg, GGUFParameter): + quant_type = arg.quant_type + break + if isinstance(result, torch.Tensor): + return cls(result, quant_type=quant_type) + # Handle tuples and lists + elif isinstance(result, (tuple, list)): + # Preserve the original type (tuple or list) + wrapped = [ + cls(x, quant_type=quant_type) + if isinstance(x, torch.Tensor) + else x + for x in result + ] + return type(result)(wrapped) + else: + return result + + def f(x): + tmp = x * 2 + tmp = tmp + tmp.quant_type + tmp = tmp.as_tensor() + return tmp * 3 + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + + x = GGUFParameter(torch.ones(2), quant_type=42) + with traceable_subclass(GGUFParameter): + res = f(x) + ref = opt_f(x) + self.assertEqual(res, ref) + + def test_newly_constructed_tensor_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from constructor call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = MySubclass(torch.ones(2)) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_as_subclass_attr_mutation(self): + # Make sure the attribute mutation for newly constructed tensor subclass + # object (from as_subclass call) is handled both during Dynamo tracing + # and codegen-ed to be visible outside `torch.compile`. + class MySubclass(torch.Tensor): + pass + + def f(): + x = torch.ones(2).as_subclass(MySubclass) + x.bar = 42 + return x, x * x.bar + + opt_f = compile_full_eager(f) + + with traceable_subclass(MySubclass): + res = f() + ref = opt_f() + + self.assertEqual(res, ref) + self.assertEqual(res[0].bar, ref[0].bar) + + def test_tensor_subclass_attr_codegen_tos(self): + # This repros a very subtle interaction between + # `TensorWithTFOverrideVariable` attribute mutation codegen and + # `PyCodegen.top_of_stack`. It was uncovered from + # `test_tensor_subclass_deepcopy`. + class MySubclass(torch.Tensor): + def __new__(cls, elem, *args, **kwargs): + r = torch.Tensor._make_subclass(cls, torch.ones(0)) + r.elem = elem + return r + + def f(t): + return MySubclass(t.elem.clone()) + + opt_f = compile_full_eager(f) + + t = MySubclass(torch.ones(2)) + with traceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + # TODO uncomment once we trace into `__new__`. + # self.assertEqual(res, ref) + # self.assertEqual(res.elem, ref.elem) + self.assertEqual(type(res), type(ref)) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear new file mode 100644 index 000000000000..276a4f74bbca --- /dev/null +++ b/test/dynamo_expected_failures/TestTorch.test_tensor_ressurecting_clear @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149881 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_Embedding_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_PReLU_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 b/test/inductor_expected_failures/TestModuleCPU.test_to_nn_RMSNorm_swap_True_set_grad_True_cpu_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_Embedding_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_PReLU_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 b/test/inductor_expected_failures/TestModuleCUDA.test_to_nn_RMSNorm_swap_True_set_grad_True_cuda_float32 deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 4c85d98cfd16..1deb09e2cc1e 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -295,9 +295,7 @@ def _track_obj( variable: VariableTracker, mutation_type_cls=ValueMutationExisting, ): - """Start tracking a new variable for mutation""" - assert variable.source is not None - + """Start tracking an existing or new variable for mutation""" if id(item) in self.id_to_variable: raise AssertionError( f"{variable} is already tracked for mutation. This could be " @@ -576,12 +574,18 @@ def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): - # Make sure we codegen these modified VT to their source by default, so - # that mutation and aliasing are properly accounted for. + # We must codegen modified VT to their source by default, so that + # mutation and aliasing are properly accounted for. + # + # Since newly constructed objects don't have a source, we manually + # codegen their construction and store them to a newly assigned local + # source. Note that `ValueMutationNew` isn't tracked by SideEffects. for var in self._get_modified_vars(): - if isinstance(var.mutation_type, AttributeMutationNew) and isinstance( - var, variables.CellVariable - ): + if not isinstance(var.mutation_type, AttributeMutationNew): + assert var.source is not None + continue + + if isinstance(var, variables.CellVariable): # Cells created in the root frame are created either by # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit # `make_cell` for the non-root-frame cells here. @@ -595,18 +599,38 @@ def codegen_save_tempvars(self, cg: PyCodegen): var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: var.source = LocalCellSource(var.local_name) - elif isinstance(var.mutation_type, AttributeMutationNew): - if isinstance(var, variables.AutogradFunctionContextVariable): - unimplemented_v2( - gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", - context="", - explanation="We cannot reconstruct a torch.autograd.Function's context object.", - hints=[], - ) - + elif isinstance(var, variables.TensorVariable): + # NOTE: for historical reasons we never assigned local sources + # to newly constructed tensor object, so we keep it that way. + # They are always loaded from output of the fx graph, so one can + # think of it as having a "OutputGraphSource" for codegen + # purposes. + # + # However, tensor subclass objects are different, because the + # reconstruction logic in `PyCodegen` loads the data tensor from + # graph output and then calls `as_subclass`, meaning we must + # assign a source to it to ensure we only reconstruct one + # subclass instance. + if isinstance( + var, variables.torch_function.TensorWithTFOverrideVariable + ): + # Don't codegen from temp source assigned from the 1st pass. + cg(var, allow_cache=False) + cg.add_cache(var) + # `add_cache` generates STORE and consumes TOS, but we never + # cleared it. TODO move this call into `add_cache` + cg.clear_tos() + var.source = LocalSource(cg.tempvars[var]) + elif isinstance(var, variables.AutogradFunctionContextVariable): + unimplemented_v2( + gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", + context="", + explanation="We cannot reconstruct a torch.autograd.Function's context object.", + hints=[], + ) + else: # Reconstruct the bytecode for # base_cls.__new__(user_cls, *args) - if isinstance(var, variables.UserDefinedObjectVariable): def load_new_method(): @@ -630,10 +654,6 @@ def load_new_method(): cg.add_cache(var) var.source = LocalSource(cg.tempvars[var]) - else: - # The remaning cases here are `AttributeMutationExisting` and - # `MutableSideEffects`, which have sources already. - assert var.source is not None for ctx, args in self.save_for_backward: cg(ctx.source) @@ -993,7 +1013,7 @@ def codegen_update_mutated(self, cg: PyCodegen): else: cg.tx.output.update_co_names(name) cg(value) - cg(var.source) + cg(var) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.ListIteratorVariable): for _ in range(var.index): diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 05739259dc5b..4b3eb10d09e7 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -510,7 +510,6 @@ "torch._C._debug_set_fusion_group_inlining", "torch._C._demangle", "torch._C._disabled_torch_dispatch_impl", - "torch._C._disabled_torch_function_impl", "torch._C._dispatch_call_boxed", "torch._C._dispatch_check_all_invariants", "torch._C._dispatch_check_invariants", diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 49d0c162d68a..1cd8001e4c34 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -140,6 +140,7 @@ wrap_fake_exception, ) from .base import ( + AttributeMutationNew, typestr, ValueMutationExisting, ValueMutationNew, @@ -2470,7 +2471,9 @@ def _wrap_fx_preexisting_tensor( f"wrapped by this instance of Dynamo. Found: {tensor}" ) - return handle_traced_output(tensor, tx, proxy, options, subclass_type, target_cls) + return construct_tensor_variable( + target_cls, tx, proxy, tensor, subclass_type, options + ) # This is 2 in the above comment (wrapping the output of a traced op) @@ -2504,36 +2507,23 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe import torch._utils if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - is_buffer = isinstance(example_value, torch.nn.Buffer) - - # NB: In most (all?) cases, this does not actually do a clone. - # (WARNING: this means that if we mutate metadata on the fake - # tensor, the stored example value will update too!) - example_value = _clone_input(example_value, tx.fake_mode) - set_example_value(proxy.node, example_value) - # We bind the unbacked symints in sizes/trdies of tensor lazily. - # So that subgraphs can access the unbacked symbol's proxy in parent graph - # when lifting unbacked symbols of input tensors to subgraph inputs. - # We do it lazily because the tensor may not be used in subgraphs. - tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) - specialized_props = target_cls.specialize(example_value) - # TODO: not sure about this fake mode test - if ( - isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) - and example_value.fake_mode is tx.fake_mode - ): - tensor_type = subclass_type if subclass_type else torch.Tensor - specialized_props["class_type"] = ( - torch.nn.Parameter - if is_parameter - else torch.nn.Buffer - if is_buffer - else tensor_type - ) - - options.update(specialized_props) - return target_cls(proxy, **options) + var = construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options + ) + # NOTE: [Side effect tracking for newly constructed tensor] + # For newly constructed objects that have mutable attributes, we usually + # construct their VariableTracker via `track_object_new`, but since + # tensor variable construction is a bit different, we handle them + # speically here. This ensures that codegen will actually generate the + # attribute mutations on this tensor. + # + # NOTE we pass a dummy object as the `item` argument to avoid + # constructing a dummy _tensor_ object. The object isn't used for + # newly constructed VTs anyways. + tx.output.side_effects._track_obj( + proxy, var, mutation_type_cls=AttributeMutationNew + ) + return var elif ( hasattr(proxy.node.target, "__name__") and proxy.node.target.__name__ == "set_state" @@ -2702,6 +2692,43 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe ) +def construct_tensor_variable( + target_cls, tx, proxy, example_value, subclass_type, options +): + """ + Actually construct a tensor variable after all the pre-processing from + wrapping a pre-existing or newly created tensor value. + """ + # NB: In most (all?) cases, this does not actually do a clone. + # (WARNING: this means that if we mutate metadata on the fake + # tensor, the stored example value will update too!) + example_value = _clone_input(example_value, tx.fake_mode) + set_example_value(proxy.node, example_value) + # We bind the unbacked symints in sizes/trdies of tensor lazily. + # So that subgraphs can access the unbacked symbol's proxy in parent graph + # when lifting unbacked symbols of input tensors to subgraph inputs. + # We do it lazily because the tensor may not be used in subgraphs. + tx.output.current_tracer.track_unbacked_symbols(example_value, proxy) + specialized_props = target_cls.specialize(example_value) + # TODO: not sure about this fake mode test + if ( + isinstance(example_value, torch._subclasses.fake_tensor.FakeTensor) + and example_value.fake_mode is tx.fake_mode + ): + if subclass_type: + tensor_type = subclass_type + elif isinstance(example_value, torch.nn.Parameter): + tensor_type = torch.nn.Parameter + elif isinstance(example_value, torch.nn.Buffer): + tensor_type = torch.nn.Buffer + else: + tensor_type = torch.Tensor + specialized_props["class_type"] = tensor_type + + options.update(specialized_props) + return target_cls(proxy, **options) + + def get_automatic_dynamic_shapes_mark_as(): if config.automatic_dynamic_shapes_mark_as == "dynamic": return DimDynamic.DYNAMIC diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c66c369876b9..9c11423162d3 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1933,6 +1933,20 @@ def call_setattr( "the middle of the graph, which aot_autograd does not currently know how to handle. " ) elif name == "data": + # See comments on `test_set_data_on_scoped_tensor` for plans + # to support this. + if obj.source is None: + unimplemented_v2( + gb_type="Failed to mutate tensor data attribute", + context=f"setattr({obj}, {name}, {val})", + explanation="Dyanmo only supports mutating `.data`" + " of tensor created outside `torch.compile` region", + hints=[ + "Don't mutate `.data` on this tensor, or move " + "the mutation out of `torch.compile` region", + ], + ) + # Remove the old reference in tracked fakes - if we don't do this # new .data value size and shape differences will cause # tracked fakes to produce incorrect guards. This is sound because the TensorVariable diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 7eaa01c2a5da..c6a2124c871a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -161,6 +161,14 @@ def call_method( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": inner_fn, source = self._resolved_getattr_and_source(self, name) + # This essentially simulates CPython's `super_getattro`: + # https://github.com/python/cpython/blob/a1c52d1265c65bcf0d9edf87e143843ad54f9b8f/Objects/typeobject.c#L11138-L11168 + # where `inner_fn` is the VT for `res = _super_lookup_descr(...)`. + # + # However, `res`'s type needs to be checked for `tp_descr_get`, and + # applied if it has one. We currently don't have polyfills for all the + # relevant `tp_descr_get`, so we explicitly handle the cases we care + # about here (e.g., note the staticmethod, classmethod cases). if inner_fn is object.__init__: return LambdaVariable(identity) elif inner_fn is torch.nn.Module.__init__: @@ -266,6 +274,29 @@ def call_method( source = self.source and AttrSource(self.source, attr_name) return VariableTracker.build(tx, attr_value, source) + elif inner_fn is torch._C._disabled_torch_function_impl: + # See `THPModule_disable_torch_function` for the C impl. + # The signature of _disabled_torch_function_impl is similar to + # `__torch_function__`, just without the first `cls` argument: + # * (func, types, args, kwargs) + func = args[0] + tf_kwargs = {} + tf_args = args[2].items + for hash_key_vt, value_vt in args[3].items.items(): + key_str = hash_key_vt.vt.as_python_constant() + tf_kwargs[key_str] = value_vt + + output_old = tx.output.torch_function_enabled + tx_old = tx.symbolic_torch_function_state.torch_function_subclass_enabled + tx.output.torch_function_enabled = False + tx.symbolic_torch_function_state.torch_function_subclass_enabled = False + try: + return func.call_function(tx, tf_args, tf_kwargs) + finally: + tx.output.torch_function_enabled = output_old + tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( + tx_old + ) unimplemented(f"non-function or method super: {inner_fn}") diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5b10a643ad94..99bd1f3eb552 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -67,7 +67,7 @@ set_example_value, tensortype_to_dtype, ) -from .base import VariableTracker +from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable @@ -789,9 +789,14 @@ def method_as_subclass(self, cls): tx = InstructionTranslator.current_tx() py_cls = cls.as_python_constant() - return TensorWithTFOverrideVariable.from_tensor_var( + var = TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, cls.source ) + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var def method_get_device(self): if isinstance(self.device, torch.device): @@ -1443,14 +1448,37 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable + # Handle `Subclass(existing_tensor)` calls. + def impl(): + if len(args) == 1 and isinstance(args[0], TensorVariable): + from .torch_function import TensorWithTFOverrideVariable + + # This simulates `__new__` and _assumes_ it doesn't have + # side-effects that matters to Dynamo tracing. TODO trace through + # `__new__`. + var = TensorWithTFOverrideVariable.from_tensor_var( + tx, args[0], self.value, self.source + ) - return TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source - ) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + cls_kwargs = kwargs or {} + VariableTracker.build(tx, init_func).call_function( + tx, [var], cls_kwargs + ) + return var + + return super().call_function(tx, args, kwargs) - return super().call_function(tx, args, kwargs) + var = impl() + # See NOTE [Side effect tracking for newly constructed tensor] + tx.output.side_effects._track_obj( + object(), var, mutation_type_cls=AttributeMutationNew + ) + return var def as_python_constant(self): return self.value diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index e51f6ccd6c9d..3946dccd8dc7 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -62,6 +62,7 @@ from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable +from .functions import UserMethodVariable from .lazy import LazyVariableTracker from .lists import TupleVariable from .tensor import TensorSubclassVariable, TensorVariable @@ -592,12 +593,9 @@ def __init__(self, *args, **kwargs) -> None: def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # [Note: __torch_function__] coerce `tensor_var` into a # TensorWithTFOverrideVariable. In eager, this is just a type change. - # This isn't sound if a __torch_function__ tensor subclass defines a - # constructor, but if only a __torch_function__ impl is defined, this is - # okay to call. It is up to the user whether this is correct behavior - # or not. import torch + # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) assert kwargs.pop("class_type") is torch.Tensor, ( "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" @@ -640,30 +638,48 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + if hasattr(torch.Tensor, name): + if _is_attr_overidden(tx, self, name): + unimplemented( + f"Accessing overridden method/attribute {name} on a tensor" + " subclass with a __torch_function__ override is not supported" + ) - if tx.output.torch_function_enabled and hasattr(torch.Tensor, name): - if self.source: - install_guard( - AttrSource(AttrSource(self.source, "__class__"), name).make_guard( - GuardBuilder.FUNCTION_MATCH + if tx.output.torch_function_enabled: + if self.source: + install_guard( + AttrSource( + AttrSource(self.source, "__class__"), name + ).make_guard(GuardBuilder.FUNCTION_MATCH) ) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) + + return self.call_torch_function( + tx, + get_fn, + TupleVariable([self.class_type_var(tx)]), + [self], + {}, ) - get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) - - return self.call_torch_function( - tx, - get_fn, - TupleVariable([self.class_type_var(tx)]), - [self], - {}, - ) else: - return super().var_getattr(tx, name) + # `TensorVariable.var_getattr` doesn't handle user-defined + # function/attribute well, so we explicitly handle them here. + # + # TODO move this logic into `TensorVariable`, or try to merge it + # with similar logic in `UserDefinedObjectVariable`. + try: + attr = inspect.getattr_static(self.class_type, name) + except AttributeError: + pass + else: + import types + + if isinstance(attr, types.FunctionType): + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + func_source = AttrSource(cls_source, name) + install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + return UserMethodVariable(attr, self) + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): return call_torch_function( From 0d4dbfd9edee1403a08a6784cc1566f9fd00f184 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:40 -0700 Subject: [PATCH 122/332] [dynamo] Support `torch.Tensor._make_subclass` and tracing through tensor subclass `__new__` (#149483) This builds off the previous patch in the stack, and fully fixes https://github.com/huggingface/diffusers/issues/10795. Essentially, tensor subclass in the issue uses `torch.Tensor._make_subclass`, which has a pretty simple shallow-copy plus type change semantics, as far as Dynamo is concerned. So this patch adds a polyfill for it. As a result, this allows us to trace through many user-defined `__new__` in tensor subclass (it's similar to how we trace through user-defined `__new__` for `UserDefinedClassVariable`), so this patch also faithfully trace through these `__new__` methods. Differential Revision: [D71906139](https://our.internmc.facebook.com/intern/diff/D71906139) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149483 Approved by: https://github.com/zou3519, https://github.com/mlazos ghstack dependencies: #149482 --- test/dynamo/test_subclasses.py | 37 ++++++++-- .../TestGradNewOnesOverride.test_newones | 0 .../TestIterator.test_iterator | 0 .../TestNamedTuple.test_max | 0 .../TestPickle.test_pickle | 0 torch/_dynamo/polyfills/loader.py | 1 + torch/_dynamo/polyfills/tensor.py | 37 ++++++++++ torch/_dynamo/variables/tensor.py | 67 +++++++++++++------ 8 files changed, 116 insertions(+), 26 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones delete mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator delete mode 100644 test/dynamo_expected_failures/TestNamedTuple.test_max delete mode 100644 test/dynamo_expected_failures/TestPickle.test_pickle create mode 100644 torch/_dynamo/polyfills/tensor.py diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 7fefc281089b..99b7ab9784ae 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -954,6 +954,34 @@ def fn(x): res_act = fn_opt(input) self.assertEqual(res_exp, res_act) + def test_make_subclass(self): + # Make sure `torch.Tensor._make_subclass` is traceable, and Dynamo + # models its aliasing relationships correctly. + class MySubclass(torch.Tensor): + pass + + def fn(x): + # Downcast then upcast + y = torch.Tensor._make_subclass(MySubclass, x) + z = torch.Tensor._make_subclass(torch.Tensor, x) + # Now `x, y, z` should have the same underlying data. + x += 1 + y += 2 + z += 3 + res = x * y + z + return res + + with traceable_subclass(MySubclass): + x0 = torch.randn(2, 2) + x1 = x0.clone() + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 @@ -974,7 +1002,9 @@ def __init__(self, *args, quant_type=None, **kwargs): self.quant_type = quant_type def as_tensor(self): - return torch.Tensor(self.data) + return torch.Tensor._make_subclass( + torch.Tensor, self, self.requires_grad + ) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1083,9 +1113,8 @@ def f(t): res = f(t) ref = opt_f(t) - # TODO uncomment once we trace into `__new__`. - # self.assertEqual(res, ref) - # self.assertEqual(res.elem, ref.elem) + self.assertEqual(res, ref) + self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) def test_compile_with_fake_tensor_dynamic_dim(self): diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestNamedTuple.test_max b/test/dynamo_expected_failures/TestNamedTuple.test_max deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestPickle.test_pickle b/test/dynamo_expected_failures/TestPickle.test_pickle deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index d9be4e9febc9..f60aa57a5d40 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -21,6 +21,7 @@ "pytree", "sys", "fx", + "tensor", ) POLYFILLED_MODULES: tuple["ModuleType", ...] = tuple( importlib.import_module(f".{submodule}", package=polyfills.__name__) diff --git a/torch/_dynamo/polyfills/tensor.py b/torch/_dynamo/polyfills/tensor.py new file mode 100644 index 000000000000..002ccf5d1d4f --- /dev/null +++ b/torch/_dynamo/polyfills/tensor.py @@ -0,0 +1,37 @@ +from typing import Any + +import torch + +from ..decorators import substitute_in_graph + + +@substitute_in_graph( # type: ignore[arg-type] + torch.Tensor._make_subclass +) +def make_subclass( + cls: type[Any], data: torch.Tensor, requires_grad: bool = False, **kwargs: Any +) -> Any: + # This is a rough approximation of `THPVariable_make_subclass`. It should + # suffice for most of Dynamo tracing purposes. + # https://github.com/pytorch/pytorch/blob/ccfde4dadfa3c342076a1ee387017f84dd4ad2f7/torch/csrc/autograd/python_variable.cpp#L597-L650 + assert len(kwargs) == 0, "_make_subclass only supports requires_grad as keyword arg" + data = data.detach() + + # Avoid unnecessary `requires_grad` mutation, which isn't supported in Dynamo. + if data.requires_grad != requires_grad: + data.requires_grad = requires_grad + + # Dynamo can't yet handle upcasting to base tensor type via `as_subclass`. + if cls is torch.Tensor: + return torch.Tensor(data) + + # Calling `as_subclass` because + # 1. Dynamo knows how to handle it + # 2. the C impls match at this point -- both `THPVariable_make_subclass` and + # `THPVariable_as_subclass` calls `THPVariable_NewWithVar`. + return data.as_subclass(cls) + + +__all__ = [ + "make_subclass", +] diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 99bd1f3eb552..44b3ffc27689 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -797,6 +797,15 @@ def method_as_subclass(self, cls): object(), var, mutation_type_cls=AttributeMutationNew ) return var + unimplemented_v2( + gb_type="Argument of `as_subclass` must be a non-dispatcher-style tensor subclass", + context=f"{self}.as_subclass({cls})", + explanation="Currently not supported", + hints=[ + "Avoid this call or move it outside `torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) def method_get_device(self): if isinstance(self.device, torch.device): @@ -1448,32 +1457,46 @@ def call_function( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - # Handle `Subclass(existing_tensor)` calls. - def impl(): - if len(args) == 1 and isinstance(args[0], TensorVariable): - from .torch_function import TensorWithTFOverrideVariable - - # This simulates `__new__` and _assumes_ it doesn't have - # side-effects that matters to Dynamo tracing. TODO trace through - # `__new__`. + # Handle `Subclass(existing_tensor, ...)` calls. + from .torch_function import TensorWithTFOverrideVariable + + new_func = self.value.__new__ + if new_func is torch.Tensor.__new__: + if ( + len(args) == 1 + and isinstance(args[0], TensorVariable) + and len(kwargs) == 0 + ): + data = args[0] + # Simulate `torch.Tensor.__new__` as shallow-copying the input + # tensor data with a new type. TODO polyfill? var = TensorWithTFOverrideVariable.from_tensor_var( - tx, args[0], self.value, self.source + tx, data, self.value, self.source ) + else: + unimplemented_v2( + gb_type="Calling subclass default constructor with more than tensor argument", + context=f"{self.value}(args={args}, kwargs={kwargs})", + explanation="Currently not supported", + hints=[ + "Avoid this constructor call or move it outside " + "`torch.compile` regione", + *graph_break_hints.SUPPORTABLE, + ], + ) + else: + # Let Dynamo trace through custom `__new__` + var = VariableTracker.build(tx, new_func).call_function( + tx, [self] + args, kwargs + ) - # Let Dynamo trace through custom `__init__` - init_func = self.value.__init__ - # TODO builder should be able to handle `torch.Tensor.__init__`, - # which is `object.__init__`, so that we can remove this check. - if init_func is not torch.Tensor.__init__: - cls_kwargs = kwargs or {} - VariableTracker.build(tx, init_func).call_function( - tx, [var], cls_kwargs - ) - return var - - return super().call_function(tx, args, kwargs) + # Let Dynamo trace through custom `__init__` + init_func = self.value.__init__ + # TODO builder should be able to handle `torch.Tensor.__init__`, + # which is `object.__init__`, so that we can remove this check. + if init_func is not torch.Tensor.__init__: + VariableTracker.build(tx, init_func).call_function(tx, [var], kwargs) - var = impl() # See NOTE [Side effect tracking for newly constructed tensor] tx.output.side_effects._track_obj( object(), var, mutation_type_cls=AttributeMutationNew From 3463ea1059ee7e7fa1d4d10c3eb372275bec8403 Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Tue, 1 Apr 2025 17:29:40 -0700 Subject: [PATCH 123/332] [dynamo] Support tensor subclass with overriden tensor methods and properties (#149484) This fixes most of the "torch.compile X tensor-subclass" issues encountered in https://github.com/city96/ComfyUI-GGUF/issues/118. The relevant tensor subclass definition is here: https://github.com/city96/ComfyUI-GGUF/blob/298192ed60f8ca821c6fe5f8030cae23424cada5/ops.py#L18-L65. A few things to note about the tensor subclass: 1. it overrides a lot of the `torch.Tensor` methods (e.g., `to`, `clone`), so this patch updates `TensorWithTFOverrideVariable.var_getattr` to support that. 2. it overrides the `shape` property, so this patch updates `TensorWithTFOverrideVariable.var_getattr` to support property as well. 3. it has calls to `torch.Tensor.size`, which returns `torch.Size`, which gets reconstructed in `torch.Tensor.__torch_function__`, so this patch adds support for calling `torch.Size(...)` on non-constant inputs. Differential Revision: [D71906137](https://our.internmc.facebook.com/intern/diff/D71906137) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149484 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: #149482, #149483 --- test/dynamo/test_subclasses.py | 131 +++++++++++++++++----- torch/_dynamo/variables/misc.py | 15 ++- torch/_dynamo/variables/torch_function.py | 33 ++++-- torch/_dynamo/variables/user_defined.py | 5 + 4 files changed, 137 insertions(+), 47 deletions(-) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 99b7ab9784ae..df6397df8257 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -757,26 +757,22 @@ def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None - @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() - msg = ( - "Accessing overridden method/attribute sigmoid on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = torch.ones(2, 2).as_subclass(LocalSubclass) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @@ -792,10 +788,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def fn(x): return x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) + msg = "Currently only support accessing overridden attributes that are functions or properties, but got " with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): @@ -804,13 +797,11 @@ def fn(x): def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): - def __init__(self) -> None: + def __init__(self, *args, **kwargs) -> None: self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): - if kwargs is None: - kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property @@ -821,19 +812,17 @@ def ndim(self): def ndim(self, value): self._ndim = value - @torch.compile(backend="eager", fullgraph=True) def fn(x): - return x.ndim + return x + x.ndim - msg = ( - "Accessing overridden method/attribute ndim on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - with torch._dynamo.config.patch( - "traceable_tensor_subclasses", {LocalSubclass} - ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): - x = torch.ones(2, 2).as_subclass(LocalSubclass) - fn(x) + x = LocalSubclass(torch.ones(2, 2)) + fn_opt = compile_full_eager(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x) + res_act = fn_opt(x) + + self.assertEqual(res_exp, res_act) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @@ -982,6 +971,88 @@ def fn(x): self.assertEqual(res_exp, res_act) self.assertEqual(x0, x1) + def test_subclass_override_shape_and_to(self): + # This is a slight variabtion of + # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 + class MySubclass(torch.Tensor): + def to(self, *args, **kwargs): + new = super().to(*args, **kwargs) + new.tensor_shape = getattr(self, "tensor_shape", new.data.shape) + return new + + @property + def shape(self): + if not hasattr(self, "tensor_shape"): + self.tensor_shape = self.size() + return self.tensor_shape + + def fn(x): + x_shape = x.shape + y = x.to("cpu") + return x + 1, y + 2, x_shape, x.tensor_shape, y.tensor_shape + + with traceable_subclass(MySubclass): + x0 = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + x1 = torch.nn.Parameter(x0.clone().as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x0) + res_act = fn_opt(x1) + self.assertEqual(res_exp, res_act) + self.assertEqual(x0, x1) + self.assertEqual(x0.tensor_shape, x1.tensor_shape) + + def test_subclass_dont_invoke_torch_function_on_overriden_method(self): + # We shouldn't fire `__torch_function__` for overriden tensor methods. + class MySubclass(torch.Tensor): + def to(self, device): + return self * len(device) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func is torch.Tensor.to: + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x.to("cpu") + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + + def test_subclass_dont_invoke_torch_function_on_overriden_attr(self): + from types import MethodWrapperType + + # We shouldn't fire `__torch_function__` for overriden tensor attrs. + class MySubclass(torch.Tensor): + def ndim(self): + return 42 + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if type(func) is MethodWrapperType and func.__name__ == "ndim": + torch._dynamo.graph_break() + return super().__torch_function__(func, types, args, kwargs) + + def fn(x): + return x + x.ndim() + + with traceable_subclass(MySubclass): + x = torch.nn.Parameter(torch.randn(2, 2).as_subclass(MySubclass)) + + fn_opt = compile_full_eager(fn) + + res_exp = fn(x) + res_act = fn_opt(x) + self.assertEqual(res_exp, res_act) + def test_parameter_subclass_custom_torch_func_and_dynamic_attr(self): # This is a slight variation of # https://github.com/huggingface/diffusers/blob/fbf6b856cc61fd22ad8635547bff4aafe05723f3/src/diffusers/quantizers/gguf/utils.py#L398-L435 diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index c6a2124c871a..2c92599a8b28 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -32,7 +32,7 @@ import torch._numpy as tnp import torch.utils._pytree as pytree -from .. import config, variables +from .. import config, trace_rules, variables from ..bytecode_transformation import create_call_function, create_instruction from ..create_parameter_op import do_not_convert_to_tracable_parameter from ..exc import raise_observed_exception, unimplemented, unimplemented_v2 @@ -297,6 +297,14 @@ def call_method( tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( tx_old ) + elif ( + isinstance(inner_fn, types.MethodDescriptorType) + and inner_fn in trace_rules.get_tensor_method() + ): + # FunctionType but implementation is in C, we support some of these, + # e.g., tensor ops like `torch.Tensor.to`. + fn_var = VariableTracker.build(tx, inner_fn, source) + return fn_var.call_function(tx, [self.objvar] + args, kwargs) unimplemented(f"non-function or method super: {inner_fn}") @@ -669,11 +677,10 @@ def call_method( args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", ): - from ..trace_rules import is_callable_allowed from .builder import wrap_fx_proxy if name == "apply": - if is_callable_allowed(self.fn_cls): + if trace_rules.is_callable_allowed(self.fn_cls): trampoline_autograd_apply = produce_trampoline_autograd_apply( self.fn_cls ) @@ -691,8 +698,6 @@ def call_method( elif name == "backward": return self.call_backward(tx, args, kwargs) else: - from .. import trace_rules - source = AttrSource(self.source, name) if self.source is not None else None try: obj = inspect.getattr_static(self.fn_cls, name) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 3946dccd8dc7..9f24f669e398 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -597,8 +597,9 @@ def from_tensor_var(cls, tx, tensor_var, class_type, cls_source): # This simulates shallow-copying the tensor object. kwargs = dict(tensor_var.__dict__) - assert kwargs.pop("class_type") is torch.Tensor, ( - "invalid class type in TensorWithTFOverrideVariable.from_tensor_var" + input_tensor_type = kwargs.pop("class_type") + assert input_tensor_type in (torch.Tensor, torch.nn.Parameter), ( + f"invalid class type {input_tensor_type} in TensorWithTFOverrideVariable.from_tensor_var" ) torch_fn_var = build_torch_function_fn(tx, class_type, cls_source) var = cls(torch_function_fn=torch_fn_var, class_type=class_type, **kwargs) @@ -638,13 +639,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" ) - if hasattr(torch.Tensor, name): - if _is_attr_overidden(tx, self, name): - unimplemented( - f"Accessing overridden method/attribute {name} on a tensor" - " subclass with a __torch_function__ override is not supported" - ) - + # Handle non-overriden attributes inherited from `torch.Tensor`. + attr_is_overriden = _is_attr_overidden(tx, self, name) + if hasattr(torch.Tensor, name) and not attr_is_overriden: if tx.output.torch_function_enabled: if self.source: install_guard( @@ -674,11 +671,23 @@ def var_getattr(self, tx: "InstructionTranslator", name): else: import types + cls_source = GlobalSource(self.global_mangled_class_name(tx)) + attr_source = AttrSource(cls_source, name) if isinstance(attr, types.FunctionType): - cls_source = GlobalSource(self.global_mangled_class_name(tx)) - func_source = AttrSource(cls_source, name) - install_guard(func_source.make_guard(GuardBuilder.FUNCTION_MATCH)) + install_guard(attr_source.make_guard(GuardBuilder.FUNCTION_MATCH)) return UserMethodVariable(attr, self) + + elif isinstance(attr, property): + getter_source = AttrSource(attr_source, "fget") + getter = attr.fget + getter_var = UserMethodVariable(getter, self, source=getter_source) + return getter_var.call_function(tx, [], {}) + + elif attr_is_overriden: + unimplemented( + f"Currently only support accessing overridden attributes that are functions or properties, but got {type(attr)}" # noqa: B950 + ) + return super().var_getattr(tx, name) def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index b842a552649f..2d22e0d35805 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -82,6 +82,7 @@ ) from .base import AttributeMutationExisting, ValueMutationNew, VariableTracker from .dicts import DefaultDictVariable +from .lists import SizeVariable try: @@ -579,6 +580,10 @@ def call_function( assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) + elif self.value is torch.Size: + # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. + tup = variables.BuiltinVariable(tuple).call_function(tx, args, kwargs) + return SizeVariable(tup.items) elif is_frozen_dataclass(self.value) and self.is_standard_new(): fields = dataclasses.fields(self.value) items = list(args) From bb987492302e18ed5a8324a6e18d99553947367b Mon Sep 17 00:00:00 2001 From: Ryan Guo Date: Wed, 2 Apr 2025 13:36:46 -0700 Subject: [PATCH 124/332] [dynamo] Always trace into tensor subclass `__torch_function__` (#149792) This patch effectively ignores traceable_tensor_subclasses, allowing Dynamo to always try tracing into the `__torch_function__` of tensor subclass. This helps us with 2 things: 1. allowing users to directly benefit from better compilation of tensor subclass, by just upgrading pytorch, without having to change legacy library code (see earlier patches in the stack for examples). 2. potentially exposing more issues in compiling tensor subclass, so we can get signals and improve them. As a consequence, it exposed and fixes 2 subtle bugs: 1. In `build_torch_function_fn`, we could get `torch._C._disabled_torch_function_impl` because we have a `Parameter` subclass without `__torch_function__` override or if we have a tensor subclass with `__torch_dispatch__` override. We graph break on this for now, and plan to add support -- the logic for simulating `torch._C._disabled_torch_function_impl` is already in `SuperVariable`, we just need to reuse it. 2. Sometimes we create `SyntheticLocalSource` and need to remove all the guards installed on it, but we only removed the ones whose source _is_ the created synthetic source `s`, but forgot about chained source like `s.foo`, this showed up as `SYNTHETIC_LOCAL['tmp_0'].__torch_function__.__func__`. Differential Revision: [D71906141](https://our.internmc.facebook.com/intern/diff/D71906141) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149792 Approved by: https://github.com/jansel, https://github.com/mlazos ghstack dependencies: #149482, #149483, #149484 --- test/dynamo/test_subclasses.py | 28 ++++++ ..._preserve_torch_function_when_return_as_is | 10 ++ .../TestGradNewOnesOverride.test_newones | 1 + .../TestIterator.test_iterator | 1 + .../TestLazyModules.test_lazy_module_buffer | 1 + ...estLazyModules.test_lazy_module_jit_buffer | 1 + .../TestTorchFunctionMode.test_subclass_hash | 10 ++ ..._on_invalid_torch_function_tensor_subclass | 3 + test/profiler/test_profiler_tree.py | 1 + torch/_dynamo/config.py | 26 ++--- torch/_dynamo/source.py | 6 ++ torch/_dynamo/utils.py | 1 + torch/_dynamo/variables/builder.py | 99 ++++++++++++------- torch/_dynamo/variables/tensor.py | 11 +-- torch/_dynamo/variables/torch.py | 5 +- torch/_dynamo/variables/torch_function.py | 3 - torch/_guards.py | 8 +- torch/_inductor/fuzzer.py | 2 + torch/nn/attention/bias.py | 8 +- 19 files changed, 153 insertions(+), 72 deletions(-) create mode 100644 test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is create mode 100644 test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones create mode 100644 test/dynamo_expected_failures/TestIterator.test_iterator create mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer create mode 100644 test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer create mode 100644 test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash create mode 100644 test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index df6397df8257..0e7d54c28448 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -40,6 +40,10 @@ def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) +def nontraceable_subclass(c): + return torch._dynamo.config.patch("nontraceable_tensor_subclasses", {c}) + + def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) @@ -1188,6 +1192,30 @@ def f(t): self.assertEqual(res.elem, ref.elem) self.assertEqual(type(res), type(ref)) + def test_nontraceable_tensor_subclass(self): + # This will error if Dynamo tries to wrap it as a tensor variable, + # because that involves calling certain methods to inspect the tensor + # property, which will blow up in the overriden `__torch_function__`. + class MySubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("one shall not pass") + + def f(t): + return t.foo + torch.ones(10) + + opt_f = torch.compile(f, backend="eager", fullgraph=False) + + t = MySubclass(torch.ones(2)) + t.foo = 42 + # Make sure the `nontraceable_tensor_subclasses` config prevents Dynamo + # from wrapping `t`. + with nontraceable_subclass(MySubclass): + res = f(t) + ref = opt_f(t) + + self.assertEqual(res, ref) + def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) diff --git a/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is new file mode 100644 index 000000000000..f243ff1904b0 --- /dev/null +++ b/test/dynamo_expected_failures/TestAutograd.test_custom_function_preserve_torch_function_when_return_as_is @@ -0,0 +1,10 @@ +- Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._data = data` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones new file mode 100644 index 000000000000..24f34ca8e8e6 --- /dev/null +++ b/test/dynamo_expected_failures/TestGradNewOnesOverride.test_newones @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/149975 diff --git a/test/dynamo_expected_failures/TestIterator.test_iterator b/test/dynamo_expected_failures/TestIterator.test_iterator new file mode 100644 index 000000000000..880a24b122bb --- /dev/null +++ b/test/dynamo_expected_failures/TestIterator.test_iterator @@ -0,0 +1 @@ +https://github.com/pytorch/pytorch/issues/150005 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer new file mode 100644 index 000000000000..89dda61098d2 --- /dev/null +++ b/test/dynamo_expected_failures/TestLazyModules.test_lazy_module_jit_buffer @@ -0,0 +1 @@ +Related to `_BufferMeta.__instancecheck__`: https://github.com/pytorch/pytorch/issues/149991 diff --git a/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash new file mode 100644 index 000000000000..beb4bf5d003a --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionMode.test_subclass_hash @@ -0,0 +1,10 @@ +Need to handle `class` block inside `torch.compile` region (`LOAD_BUILD_CLASS`) +or properly graph break on it rather than skipping the frame altogether. +https://github.com/pytorch/pytorch/issues/128942 + +Fundamental issue is Dynamo tries to probe tensor object properties, but that +could trigger user-defined `__torch_function__` for tensor subclass objects. + +In this case the `LOAD_BUILD_CLASS` error caused Dynamo to start tracing in the +`__init__` of the following class, but `self._diag = _diag` hasn't fired yet, and +its `__torch_function__` errors when Dynamo is probing tensor property diff --git a/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass new file mode 100644 index 000000000000..c2ddc08d1e40 --- /dev/null +++ b/test/dynamo_expected_failures/TestTorchFunctionWarning.test_warn_on_invalid_torch_function_tensor_subclass @@ -0,0 +1,3 @@ +Dynamo cannot query properties of the tensor subclass object when wrapping it +into a VT, because it has a `__torch_function__` that only allows limited +torch ops. diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 7dac5fb70905..48bbbf01727f 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -690,6 +690,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ...""", ) + @skipIfTorchDynamo("segfaults in 3.13+") @unittest.skipIf( TEST_WITH_CROSSREF, "crossref intercepts calls and changes the callsite." ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5d58efdeed09..b59e1c49e607 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -152,26 +152,16 @@ # Non-Inductor backends can use this list for graph freezing. prepare_freezing = os.environ.get("TORCHDYNAMO_PREPARE_FREEZING", "0") == "1" - -# This feature doesn't really work. We offer this flag for experimental -# purposes / if you want to help us build out support. -# -# torchdynamo has limited support for tensor subclasses that implement -# __torch_function__ see [Note: __torch_function__] in torch_function.py. -# Our current support is limited to tensor subclasses -# that DO NOT store metadata on the tensor (in general, dynamo does not -# support Python code that stores extra attributes on tensors at present). -# If your tensor subclass purely changes function call behavior via -# __torch_function__, you can allow torchdynamo to trace into it by -# adding it to traceable_tensor_subclasses. We don't do any safety checks, -# so it is up to you to ensure that your subclass is well behaved. See also -# https://github.com/pytorch/torchdynamo/issues/1948 -# -# We do NOT currently support __torch_dispatch__. The implementation is -# currently buggy, the main show stopper for nontrivial use is -# https://github.com/pytorch/torchdynamo/issues/1952 +# NOTE this has been deprecated, it does nothing now. traceable_tensor_subclasses: set[type[Any]] = set() +# If a tensor subclass is put into this set, Dynamo will model its instasnces in +# a very conservative and limited way (most likely causing lots of graph breaks +# if one apply tensor ops on these instances). This is useful if you encounter +# internal compiler errors from Dynamo which are caused by tensor subclasses, +# and you are willing to tolerate potential graph breaks rather than hard error. +nontraceable_tensor_subclasses: set[type[Any]] = set() + # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager. # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index e01c166c97d2..4116f110b21d 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -842,6 +842,12 @@ def is_from_local_source(source: Source, *, only_allow_input=False): return True +def is_from_source(source: Source, target: Source): + if isinstance(source, ChainedSource): + return is_from_source(source.base, target) + return source == target + + def is_from_unspecialized_param_buffer_source(source: Source): if isinstance(source, UnspecializedParamBufferSource): return True diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8ee9289633b1..8fa038ce7116 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1408,6 +1408,7 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: "reorderable_logging_functions", "ignore_logger_methods", "traceable_tensor_subclasses", + "nontraceable_tensor_subclasses", "_custom_ops_profile", } diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 1cd8001e4c34..d5cea823b7f6 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -68,7 +68,11 @@ SymbolicContext, ) from torch.fx.immutable_collections import immutable_dict, immutable_list -from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.nn.utils._expanded_weights import ExpandedWeight +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, +) from torch.utils._sympy.value_ranges import ValueRanges from torch.utils.weak import TensorWeakRef @@ -612,11 +616,30 @@ def create_2d_tma_descriptor(): return id_dispatch(self, value) # Everything else (NB: order matters!) - if is_traceable_wrapper_subclass(value) or istype( - value, config.traceable_tensor_subclasses + if ( + isinstance(value, torch.Tensor) + and type(value) + not in ( + # These torch-native subclasses have overly restrictive + # `__torch_function__` which prevents Dynamo from reading their + # tensor attributes like `is_nested` or calling methods like + # `_is_view`. + torch.nn.parameter.UninitializedBuffer, + torch.nn.parameter.UninitializedParameter, + ExpandedWeight, + ) + and type(value) not in config.nontraceable_tensor_subclasses ): - return self.wrap_tensor(value) - elif is_namedtuple(value): + if type(value).__torch_dispatch__ is torch.Tensor.__torch_dispatch__: + # This case it's either tensor or subclass with default + # torch_dispatch (they might override torch_function or not), + # and we can always trace into them. + return self.wrap_tensor(value) + elif is_traceable_wrapper_subclass(value): + # For non-default torch_dispatch, we have more requirements. + return self.wrap_tensor(value) + + if is_namedtuple(value): self.install_guards(GuardBuilder.SEQUENCE_LENGTH) output = [ LazyVariableTracker.create( @@ -930,11 +953,6 @@ def build_key_value(i, k, v): value, source=self.source, ) - elif ( - isinstance(value, torch._C._TensorMeta) - and value in config.traceable_tensor_subclasses - ): - return TensorSubclassVariable(value, source=self.source) elif ( istype(value, contextlib.nullcontext) and inspect.getattr_static(value, "enter_result", None) is None @@ -1187,6 +1205,20 @@ def build_key_value(i, k, v): if value is torch.autograd._unsafe_preserve_version_counter: self.install_guards(GuardBuilder.FUNCTION_MATCH) return PreserveVersionContextVariable.constructor(self.tx) + if ( + # `value` must be a strict subclass of `torch.Tensor` + issubclass(value, torch.Tensor) + and value is not torch.Tensor + # `TensorSubclassVariable` is not for subclass that overrides + # `torch_dispatch`. + and value.__torch_dispatch__ is torch.Tensor.__torch_dispatch__ + # `TensorSubclassVariable` would lead to construction of + # `TensorWithTFOverrideVariable`, but we don't want that for + # traceable wrapper subclasses (we wrap those subclass instances + # into `TensorVariable`). + and not is_traceable_wrapper_subclass_type(value) + ): + return TensorSubclassVariable(value, source=self.source) # This is a userdefined class, so install an ID_MATCH even if its a # global variable. self.install_guards(GuardBuilder.ID_MATCH) @@ -1729,7 +1761,22 @@ def wrap_tensor(self, value: torch.Tensor): # Guards are added inside register_attr_or_module ) - if type(value) in config.traceable_tensor_subclasses: + # NB: this just says we accessed a tensor from the same source again + # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). + # This is distinct from two distinct sources mapping to the same + # Tensor (per id())! No guard is necessary here. See below for the + # other case. + is_duplicate_tensor = source in self.tx.output.input_source_to_var + if is_duplicate_tensor: + return self.tx.output.input_source_to_var[source] + + options = {} + if type(value) in ( + torch.Tensor, + torch.nn.Parameter, + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ) or is_traceable_wrapper_subclass(value): # Ordinarily, we would fakeify a tensor so that it can get dynamic # shapes and be computed on without triggering actual operations. # However, how can we fakeify a tensor subclass? Ordinary @@ -1747,24 +1794,13 @@ def wrap_tensor(self, value: torch.Tensor): # To simplify things for now, the __dict__ tracking bits haven't # been implemented yet, but they can be added into this design at # a later point in time. - subclass_type = type(value) - else: - assert type(value) in ( - torch.Tensor, - torch.nn.Parameter, - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ) or is_traceable_wrapper_subclass(value), type(value) subclass_type = None - - # NB: this just says we accessed a tensor from the same source again - # (e.g., a tensor lives in a global foo, and we LOAD_GLOBAL it twice). - # This is distinct from two distinct sources mapping to the same - # Tensor (per id())! No guard is necessary here. See below for the - # other case. - is_duplicate_tensor = source in self.tx.output.input_source_to_var - if is_duplicate_tensor: - return self.tx.output.input_source_to_var[source] + else: + subclass_type = type(value) + options["torch_function_fn"] = build_torch_function_fn( + self.tx, value, self.source + ) + self.install_guards(GuardBuilder.TYPE_MATCH) if get_static_address_type(value) == "guarded": self.install_guards(GuardBuilder.ID_MATCH) @@ -1772,13 +1808,6 @@ def wrap_tensor(self, value: torch.Tensor): # By this point, we should have deduplicated all tensors self.assert_not_wrapped_by_this_graph(value) - options = {} - if type(value) in config.traceable_tensor_subclasses: - options["torch_function_fn"] = build_torch_function_fn( - self.tx, value, self.source - ) - self.install_guards(GuardBuilder.TYPE_MATCH) - if ( isinstance(value, torch.Tensor) and value.is_nested diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 44b3ffc27689..c477979fa9e3 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -70,6 +70,7 @@ from .base import AttributeMutationNew, VariableTracker from .constant import ConstantVariable from .lists import SizeVariable +from .user_defined import UserDefinedClassVariable try: @@ -410,8 +411,6 @@ def call_obj_hasattr(self, tx: "InstructionTranslator", name): return ConstantVariable(ret_val) def var_getattr(self, tx: "InstructionTranslator", name): - from . import UserDefinedClassVariable - if self.is_strict_mode(tx): if name in self._strict_mode_banned_ops(): unimplemented( @@ -614,7 +613,7 @@ def call_method( """ # This is seen in inspect signature where we check if the value is a default value - if name == "__eq__" and isinstance(args[0], variables.UserDefinedClassVariable): + if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): return variables.ConstantVariable(False) try: @@ -1446,11 +1445,7 @@ def from_tensor_variable(cls, tensor_variable): return FakeItemVariable(**dict(tensor_variable.__dict__)) -class TensorSubclassVariable(VariableTracker): - def __init__(self, value, *args, **kwargs) -> None: - self.value = value - super().__init__(*args, **kwargs) - +class TensorSubclassVariable(UserDefinedClassVariable): def call_function( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1e7a9baf9494..40821a16e5e5 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -76,6 +76,7 @@ from .torch_function import ( can_dispatch_torch_function, dispatch_torch_function, + TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) @@ -1350,7 +1351,9 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): if data.source: return cls._nn_param_via_prefix_insert(tx, data, requires_grad) - if is_traceable_wrapper_subclass_type(data.class_type): + if isinstance( + data, TensorWithTFOverrideVariable + ) or is_traceable_wrapper_subclass_type(data.class_type): unimplemented("Parameter constructor with tensor subclass NYI") if not can_convert_to_tracable_parameter(): diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 9f24f669e398..330faf9bf902 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -24,9 +24,6 @@ See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w for more information on the design. - -To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses -in torch/_dynamo/config.py """ import collections diff --git a/torch/_guards.py b/torch/_guards.py index ad5f4a7b130a..b6b36f637101 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -631,8 +631,12 @@ def update(self, *others: set[Guard]): self.add(g, skip=1) def remove_guards_with_source(self, source): - """Delete all guards with a given source""" - self.inner = {g for g in self.inner if g.originating_source != source} + """Delete all guards that contains a given source""" + from ._dynamo.source import is_from_source + + self.inner = { + g for g in self.inner if not is_from_source(g.originating_source, source) + } class GuardsContext(Checkpointable[GuardsCheckpointState]): diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index 4a42b71559c5..a95d8419d662 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -174,6 +174,7 @@ def failing(self) -> bool: "autoheuristic_collect": ["pad_mm", "mixed_mm"], "autoheuristic_use": ["pad_mm", "mixed_mm"], "traceable_tensor_subclasses": [OrderedSet()], + "nontraceable_tensor_subclasses": [OrderedSet()], } SamplingType = Callable[[str, type[Any], Any], Any] @@ -499,6 +500,7 @@ def keys(self) -> KeysView[ComboType]: }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing + "nontraceable_tensor_subclasses": DEFAULT, # Typing "compiled_autograd_kwargs_override": DEFAULT, # Typing "fail_on_recompile_limit_hit": DEFAULT, # fails in combo with suppress_errors "suppress_errors": DEFAULT, diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index da7acb957d96..36c0a18cdd12 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -283,11 +283,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): """Defines the behavior of torch.nn.functional.scaled_dot_product_attention when the attn_bias is an AttnBias""" if kwargs is None: kwargs = {} - if func != torch.nn.functional.scaled_dot_product_attention: - raise NotImplementedError( - "CausalBias only supports scaled_dot_product_attention" - ) - return cls._dispatch(*args, **kwargs) + if func is torch.nn.functional.scaled_dot_product_attention: + return cls._dispatch(*args, **kwargs) + return super().__torch_function__(func, types, args, kwargs) def __repr__(self): # type:ignore[override] return self._materialize().__repr__() From 1017927c83dd95a4be6074c48e0fb38f0a1bd8f3 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Wed, 2 Apr 2025 20:57:12 +0000 Subject: [PATCH 125/332] multidimensional slicing (#150104) Differential Revision: D71962884 Fixes #150057 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150104 Approved by: https://github.com/angelayi --- test/export/test_export.py | 18 ++++++++++++++++ torch/_export/non_strict_utils.py | 35 ++++++++++++++++++++++++++++--- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index d92ef65fb743..c56f7af7d47f 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -4795,6 +4795,24 @@ def forward(self, scores, score_thr, topk: torch.Tensor, results=None): self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) + def test_multidimensional_slicing(self): + class M(torch.nn.Module): + def forward(self, x, y): + b = x.item() + torch._check(b >= 0) + torch._check(b < y.shape[0]) + return y[0, b] + + if is_non_strict_test(self._testMethodName): + m = M() + inp = (torch.tensor(4), torch.ones(10, 10)) + r = m(*inp) + + epm = export(m, inp).module() + er = epm(*inp) + + self.assertTrue(torch.allclose(er, r)) + def test_sequential_slicing(self): # See https://github.com/pytorch/pytorch/issues/137455 diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 63c0cf5d30ec..6e65141acfac 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -699,9 +699,38 @@ def _override(self, func, args, kwargs): ): return torch._refs.tensor, args, kwargs if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor): - # Redirect to torch.select for indexing with symint. - if isinstance(args[1], torch.SymInt): - return torch.select, [args[0], 0, args[1]], {} + + def rewrite(dim, item): + # Redirect to torch.select for indexing. + if isinstance(item, (int, torch.SymInt)): + return dim, (torch.select, [dim, item]) + # Redirect to torch.ops.aten.slice for slicing. + if isinstance(item, slice): + return dim + 1, ( + torch.ops.aten.slice, + [dim, item.start, item.stop, item.step or 1], + ) + # Otherwise do nothing. + + items = args[1] if isinstance(args[1], tuple) else (args[1],) + dim = 0 + # Sequence rewrites. + sequence = [] + for item in items: + if (r := rewrite(dim, item)) is None: + return func, args, kwargs + dim, call_spec = r + sequence.append(call_spec) + + def run(): + # Run sequence. + t = args[0] + for _method, _args in sequence: + t = _method(t, *_args) + return t + + return run, [], {} + return func, args, kwargs def __torch_function__(self, func, types, args=(), kwargs=None): From 74aa9f571c23dfdb047997d82e6eb5a0a92f0148 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 2 Apr 2025 18:31:20 +0000 Subject: [PATCH 126/332] ci: Use cache / progress when local docker build (#150551) It's a bit annoying to try and work on these locally when the cache / progress isn't being used so let's just set it so that those flags are only valid when in CI directly. `${CI}` is a default environment variable that's defined by actions itself. See https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/150551 Approved by: https://github.com/clee2000, https://github.com/ZainRizvi, https://github.com/atalman --- .ci/docker/build.sh | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 07e991658b7a..1e1ec8b491ae 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -460,10 +460,18 @@ if [[ "$image" == *cuda* && ${OS} == "ubuntu" ]]; then fi fi +no_cache_flag="" +progress_flag="" +# Do not use cache and progress=plain when in CI +if [[ -n "${CI:-}" ]]; then + no_cache_flag="--no-cache" + progress_flag="--progress=plain" +fi + # Build image docker build \ - --no-cache \ - --progress=plain \ + ${no_cache_flag} \ + ${progress_flag} \ --build-arg "BUILD_ENVIRONMENT=${image}" \ --build-arg "PROTOBUF=${PROTOBUF:-}" \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ From a677b491c9459913e0bba9e43d9e7191ba9d72b8 Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Wed, 2 Apr 2025 22:25:46 +0000 Subject: [PATCH 127/332] [Profiler] Fix Empty C Call Queue (#150370) Summary: My commandeer of https://github.com/pytorch/pytorch/pull/150102 Based on description of PR it seems that we need to add C calls for each starting python event with a callable such that when the tracing exits we will have a matching enter for any given exit. It adds some unnecessary events at worst but prevents segfaults/failures. My PR just cleans up some refcount impl and logging. Contributors: @arjun-choudhry Test Plan: Ran resnet test internally. Will check CI and ask reviewers to make sure it resolves their issues. Differential Revision: D72207570 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150370 Approved by: https://github.com/aaronenyeshi --- torch/csrc/autograd/profiler_python.cpp | 54 ++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 6 deletions(-) diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a98d1a8b7934..17da6cf3d70b 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -705,10 +705,13 @@ class PythonTracer final : public python_tracer::PythonTracerBase { void recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, - PyObject* arg); + PyObject* arg, + bool start_frame = false); const std::vector interpreterThreads() const; + PyObject* get_callable_from_frame(PyFrameObject* frame); + std::atomic active_lock_{false}; bool active_{false}; @@ -787,6 +790,16 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { recordPyCall(thread_local_results_.back(), it->get(), true); + PyFrameObject* frame = it->get(); + PyObject* callable = get_callable_from_frame(frame); + if (callable) { + // If the frame has a callable, record it as a C call since + // PyEval_GetFrame only gets the python frame. We need to record this C + // call so that when exiting the profiler we don't have a mismatched C + // call. + recordCCall(thread_local_results_.back(), it->get(), callable, true); + } + auto frame_refcount = Py_REFCNT(it->get()); // We hold one reference in `current_stack`, and the interpreter holds @@ -890,8 +903,13 @@ void PythonTracer::recordPyCall( void PythonTracer::recordCCall( ThreadLocalResults& tls, PyFrameObject* frame, - PyObject* arg) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg)); + PyObject* arg, + bool start_frame) { + // for starting frames we duplicate callable python functions to avoid having + // empty C frames in trace when exiting + if (!start_frame) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(PyCFunction_Check(arg)); + } auto fn = reinterpret_cast(arg); // NB: For C calls a new frame is not created, so we use `frame` rather than @@ -901,6 +919,26 @@ void PythonTracer::recordCCall( queue_->getSubqueue()->emplace_py_call(key, c10::getApproximateTime()); } +PyObject* PythonTracer::get_callable_from_frame(PyFrameObject* frame) { + if (frame == nullptr) { + return nullptr; + } + // Get the code object associated with the frame + auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); + if (code == nullptr) { + return nullptr; + } + // Get the function name (if needed) + auto name = THPUtils_unpackStringView(code->co_name).data(); + // To get the function object, you will need to look in the globals or the + // frame's f_globals + PyObject* func = PyDict_GetItemString(PyFrame_GetGlobals(frame), name); + if (func) { + Py_INCREF(func); // Make sure the returned function has a reference + } + return func; // Returns a PyObject* (the function) +} + // ============================================================================ // == Post processing ========================================================= // ============================================================================ @@ -983,9 +1021,13 @@ class PostProcess { using stack_t = std::vector>; const auto initial_size = out.size(); auto pop = [](stack_t& stack, c10::time_t t) { - TORCH_INTERNAL_ASSERT(!stack.empty(), "Python replay stack is empty."); - std::get>(stack.back()->extra_fields_).end_time_ns_ = t; - stack.pop_back(); + if (!stack.empty()) { + std::get>(stack.back()->extra_fields_).end_time_ns_ = t; + stack.pop_back(); + } else { + TORCH_WARN_ONCE( + "Python replay stack is empty during pop operation! May result in incorrect stack tracing."); + } }; ska::flat_hash_map stacks; From 0bacb90a9c95d35388baaaf357a86c148c4c5add Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 2 Apr 2025 11:35:04 -0700 Subject: [PATCH 128/332] [invoke_subgraph][min-cut partitioner] Fix bug to use the correct root module (#150556) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150556 Approved by: https://github.com/bdhirsh, https://github.com/zou3519 ghstack dependencies: #150082, #150450, #150486 --- test/higher_order_ops/test_invoke_subgraph.py | 16 ++++++++++++++++ .../jit_compile_runtime_wrappers.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index cd6c97fcf878..db585afaafd7 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -1061,6 +1061,22 @@ def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: """, ) + def test_const_tensor(self): + @mark_compile_region + def gn(x): + return torch.tensor(64, dtype=torch.float32) * x + + def fn(x): + return gn(x) + + x = torch.randn(64, requires_grad=True) + + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + @parameterized_class( [ diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 0ac2144cd77a..aab77b80d40b 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -486,7 +486,7 @@ def prepare_for_partitioner(mod, num_primals, num_fw_outputs): new_graph.lint() - out = torch.fx.GraphModule(joint_gm, new_graph) + out = torch.fx.GraphModule(mod, new_graph) return out new_hop_graphs: dict[str, InvokeSubgraphHopGraphs] = defaultdict( From 8667a00979ab4dc9513951c630b59bc21c0fefa7 Mon Sep 17 00:00:00 2001 From: PaulZhang12 Date: Tue, 1 Apr 2025 09:17:45 -0700 Subject: [PATCH 129/332] Add stride + dtype to autotune results (#150419) Add stride/dtype info to autotune gemm results. New output header: `AUTOTUNE mm(1024x1024, 1024x7680)` `strides: [1, 1024], [7680, 1]` `dtypes: torch.bfloat16, torch.bfloat16` Differential Revision: [D72253313](https://our.internmc.facebook.com/intern/diff/D72253313) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150419 Approved by: https://github.com/eellison --- torch/_inductor/select_algorithm.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 558e59af81c7..fed0f9ebebd7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2205,6 +2205,9 @@ def log_results( for n in input_nodes ] ) + + strides = ", ".join([str(n.get_stride()) for n in input_nodes]) + dtypes = ", ".join([str(n.get_dtype()) for n in input_nodes]) if config.autotune_num_choices_displayed == 0: return # when autotune_num_choices_displayed is None, [:None] means all @@ -2252,6 +2255,9 @@ def get_choice_info(choice): best_time = timings[best] sys.stderr.write(f"AUTOTUNE {name}({sizes})\n") + sys.stderr.write(f"strides: {strides}\n") + sys.stderr.write(f"dtypes: {dtypes}\n") + for choice in top_k: result = timings[choice] if result: From 0198e44f3741a8228e776ee126e82554d47d0667 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Wed, 2 Apr 2025 22:42:18 +0000 Subject: [PATCH 130/332] Update torch-xpu-ops commit pin to 98c808d (#150554) Update the torch-xpu-ops commit to [98c808dea6de7330c415aa777d6921944cf79887](https://github.com/intel/torch-xpu-ops/commit/98c808dea6de7330c415aa777d6921944cf79887), include - Fixes #150001 by removing pre-CXX11 ABI logic from build script for XPU - Fixes #150430 - Fixes XCCL build issue caused by PR #150398 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150554 Approved by: https://github.com/EikanWang, https://github.com/malfet --- third_party/xpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 53b3ef7e4560..5bdc7353dfe7 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -3ee2bd2f13e1ed17a685986ff667a58bed5f2aa5 +98c808dea6de7330c415aa777d6921944cf79887 From de15ef0ee82dfc881ca27eb961ef464753bca02a Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 2 Apr 2025 11:42:21 -0700 Subject: [PATCH 131/332] [invoke_subgraph] Force grad_outs to be contiguous at tracing time (#150561) I am unable to come up with a testcase. It passes many end-to-end tests that fail with ReshapeError at https://ossci-raw-job-status.s3.amazonaws.com/log/39717218372 ![image](https://github.com/user-attachments/assets/8509b485-3897-4538-968b-bbe05af63a59) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150561 Approved by: https://github.com/zou3519, https://github.com/bdhirsh ghstack dependencies: #150082, #150450, #150486, #150556 --- torch/_higher_order_ops/invoke_subgraph.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 840e32e29dfc..d508e8ffc0da 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -228,6 +228,12 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): grad_outputs = [grad for grad in grad_outputs if grad is not None] grad_outputs = [grad for grad in grad_outputs if grad.requires_grad] + # Force grad_out to be contiguous. This is because at runtime, + # grad_out could have different strides than fw_outs. So, we + # force the grad_outs to be contiguous for both tracing and + # runtime. + grad_outputs = [grad.contiguous() for grad in grad_outputs] + if any( not isinstance(out, torch.Tensor) for out in grad_outputs From 61a1f09b5b160a088cbbe180df0477dab21f4d41 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 2 Apr 2025 23:14:11 +0000 Subject: [PATCH 132/332] Revert "[cuda] Add new faster gammabeta backward kernel (#148605)" This reverts commit 114d404b0720e8073748690faeb96449e5c0b229. Reverted https://github.com/pytorch/pytorch/pull/148605 on behalf of https://github.com/drisspg due to See https://github.com/pytorch/pytorch/issues/150266#issuecomment-2773907902 for more details ([comment](https://github.com/pytorch/pytorch/pull/148605#issuecomment-2773928838)) --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 526 +++++++----------- test/test_nn.py | 20 - 2 files changed, 195 insertions(+), 351 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 0d63a2f979c9..9feb30c21941 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( } } + template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, @@ -539,364 +540,191 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } } -template -__device__ -__forceinline__ -void -blockReduceGammaBetaBackwardsHelper( - int64_t M_start, - int64_t M, - int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db, - T_ACC &dg_sum, - T_ACC &db_sum -) { - constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; - int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; - - int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); - int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; - T_ACC warp_mean = 0, warp_rstd = 0; - if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; - warp_rstd = rstd[mean_index + lane_id]; - } - // We do a WARP_SYNC() here because we use WARP_SHFL below to access - // warp_mean and warp_rstd. - WARP_SYNC(); - - T_ACC dY_regs[rows_per_thread_y] = {0}; - T_ACC X_regs[rows_per_thread_y] = {0}; - #pragma unroll - for (int i = 0; i < rows_per_thread_y; ++i) { - int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i; - bool active = true; - if (check_x && thread_x >= N) { - active = false; - } - if (check_y && current_y >= M) { - active = false; - } - if (active) { - dY_regs[i] = dY[current_y * N + thread_x]; - X_regs[i] = X[current_y * N + thread_x]; - } - } - - #pragma unroll - for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); - T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; - } -} - -template -__device__ -__forceinline__ -void -blockReduceGammaBetaBackwardsWithChecks( - int64_t M, - int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db, - T_ACC &dg_sum, - T_ACC &db_sum -) { - for (int64_t M_start = blockIdx.y * rows_per_block_y; - M_start < M; - M_start += rows_per_block_y * gridDim.y) { - int64_t M_end = M_start + rows_per_block_y - 1; - if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper - (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - blockReduceGammaBetaBackwardsHelper - (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } - } -} +// This implementation gets called if M and N divide with 32. This case should +// be the most common. We can then make better use of warp level intrinsics +// to improve performance. -// block_dim_x is the number of threads in the x dimension per block. -// block_dim_y is the number of threads in the y dimension per block. -// rows_per_block_y is the size of the tile (number of data elements) -// in the y dimension per block. -// partial_reduction indicates whether we need to reduce across threads -// or not. If set to true, we will not reduce across threads. This can -// be faster in the M >> N case but requires another kernel to do a full -// final reduction. -// aligned_grid means the data size is a multiple of tile size. In that -// case we don't need to check for boundary conditions which can provide -// a further speedup by not needing instructions to check for edge cases -// and not needing predicate registers. -template -__global__ -void - GammaBetaBackwardCUDAKernelTemplate( +template +__global__ void GammaBetaBackwardCUDAKernel_32x32( int64_t M, int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db) { - // This assert is a compile-time check only. - constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; - static_assert(rows_per_thread_y <= kWarpSize); + const T* dY, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + T* dg, + T* db) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; T_ACC dg_sum = 0; T_ACC db_sum = 0; - if (aligned_grid) { - // When N and M align perfectly with block_dim_x and block_dim_y, we - // can skip boundary condition checks that waste instruction issue slots. - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - // In the general case we need to check boundary conditions in the M - // dimension. However, we can still avoid boundary checks in the N dimension - // for the inner blocks. So try to avoid those checks when possible. - if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } - } + const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x; + if (j < N) { + constexpr int unroll_factor = 8; + int laneId = threadIdx.x & (C10_WARP_SIZE - 1); + + T_ACC mean_reg, mean_reg_tmp; + T_ACC rstd_reg, rstd_reg_tmp; + T dY_reg; + T X_reg; + + // Main loop + int bcounter; + for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); + bcounter++) { + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + + if (laneId < unroll_factor) { + mean_reg_tmp = mean[offset + laneId]; + rstd_reg_tmp = rstd[offset + laneId]; + } + WARP_SYNC(); - // When partial_reduction is requested, we don't reduce within a block. - // We also don't reduce if we are only a single block in the y dimension. - if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { - if (aligned_grid || thread_x < N) { - int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y; - if (dg) { - dg[thread_y * N + thread_x] = dg_sum; + #pragma unroll + for (int ii = 0; ii < unroll_factor; ++ii) { + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); + rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; } - if (db) { - db[thread_y * N + thread_x] = db_sum; + } + + // Remainder loop + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (int ii = 0; ii < unroll_factor; ii++) { + if ((offset + ii) < M) { + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; } } - } else { - // The caller requested a full reduction so we must reduce across - // warps using shared memory and warp shuffles. - static_assert(rows_per_thread_y <= C10_WARP_SIZE); - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - int padded_bx = (block_dim_x + 1); - // Transpose dg and db. + + // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and + // gets called when M; N divide by 32. We can use warp shuffles + // for the final reduction step. This removes 4 shmem loads and + // stores with their corresponding __syncthreads() + + // This greatly reduces bank conflicts at the expense of a little + // extra shared memory. It does not impact occupancy + int padded_bx = (1 + blockDim.x); + s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * block_dim_y); + s_db = s_data_typed + (padded_bx * blockDim.y); s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; __syncthreads(); // Load transposed so that a warp holds an entire column - // Because block_dim_x != block_dim_y in the general case, we need - // some code to handle the general case. - static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); - constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; - int thread_id = threadIdx.y * block_dim_x + threadIdx.x; - int warp_id = thread_id / C10_WARP_SIZE; - int lane_id = thread_id & (C10_WARP_SIZE - 1); - #pragma unroll - for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { - T_ACC reg_db, reg_dg; - if (lane_id < block_dim_y) { - reg_dg = s_dg[lane_id * padded_bx + i]; - reg_db = s_db[lane_id * padded_bx + i]; - } - #pragma unroll - for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; + T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; + for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + + if (threadIdx.x == 0) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + if (dg) { + dg[j] = reg_dg; } - // Reduce is done. Now write it out to global memory. - int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; - if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { - if (dg) { - dg[out_index] = reg_dg; - } - if (db) { - db[out_index] = reg_db; - } + if (db) { + db[j] = reg_db; } } } } -template -void LaunchAndCheckGammaBetaBackwardKernel( - bool aligned_grid, - dim3 blocks, - dim3 threads, - size_t shmem_sz, - cudaStream_t cuda_stream, - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, - int64_t M, - int64_t N, - T* dgamma_data, - T* dbeta_data) { -if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - } else { - GammaBetaBackwardCUDAKernelTemplate - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void ConfigureAndLaunchGammaBetaBackwardKernel( - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, +template +__global__ void GammaBetaBackwardCUDAKernel( int64_t M, int64_t N, - Tensor* dgamma, - Tensor* dbeta, - cudaStream_t cuda_stream) { - T* dgamma_data = - dgamma->defined() ? dgamma->template data_ptr() : nullptr; - T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; - bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); - dim3 threads{block_dim_x, block_dim_y}; - dim3 blocks; - blocks.x = (N + block_dim_x - 1) / block_dim_x; - blocks.y = 1; - size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; - if (blocks.y == 1 && threads.y == 1) { - // Optimization: since there is just one thread doing all the summation, we don't need a reduction - // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); - } else { - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); - } + const T* dY, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + T* dg, + T* db) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; -} + const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; -template -void LaunchGammaBetaBackwardCUDAKernel( - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, - int64_t M, - int64_t N, - Tensor* dgamma, - Tensor* dbeta, - cudaStream_t cuda_stream) { - constexpr int block_dim_x = 32; - const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { - // We have a situation where M >> N and N is small. - // In this case we can speed up the computation by parallelizing in the M dimension. - // We launch multiple blocks in the y-dimension, and compute partial sums for the - // gradient in the first pass. Then we do a .sum(0) to do a final reduction. - // Although we launch 2 kernels, we can get up to a 10x speedup for large M. - constexpr int block_dim_y = 1; - constexpr int rows_per_block_y = 32; - bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); - dim3 threads{block_dim_x, block_dim_y}; - dim3 blocks; - blocks.x = (N + block_dim_x - 1) / block_dim_x; - // int rows_per_block = my_gamma_beta_unroll_factor * - blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; - constexpr int max_grid_size = 64 * 1024 / 2; - blocks.y = std::min(max_grid_size / blocks.x, blocks.y); - Tensor dgamma_blocks; - Tensor dbeta_blocks; - T * dgamma_blocks_ptr = nullptr; - T * dbeta_blocks_ptr = nullptr; - if (dgamma->defined()) { - auto options = dgamma->options(); - dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); - dgamma_blocks_ptr = dgamma_blocks.data_ptr(); + T_ACC dg_sum = 0; + T_ACC db_sum = 0; + + if (j < N) { + constexpr int unroll_factor = 8; + + T_ACC mean_reg; + T_ACC rstd_reg; + T dY_reg; + T X_reg; + + // Main Loop + int bcounter; + for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + + #pragma unroll + for (int ii = 0; ii < unroll_factor; ++ii) { + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } } - if (dbeta->defined()) { - auto options = dbeta->options(); - dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); - dbeta_blocks_ptr = dbeta_blocks.data_ptr(); + + // Remainder loop + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (int ii = 0; ii < unroll_factor; ii++ ){ + if ((offset + ii) < M) { + dY_reg = dY[(offset + ii) * N + j ]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } } - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - *dgamma = dgamma_blocks.sum(0); - *dbeta = dbeta_blocks.sum(0); - } else { - // We are in the normal case where M is not that large. - // We can change the tile shape (which is the last template parameter) in accordance with M. - // For small M it is faster to have a smaller tile, otherwise we could have idle threads. - // For larger M we use a bigger tile size. - if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + // Do the final reduction in shared memory + s_dg = s_data_typed; + s_db = s_data_typed + blockDim.x * blockDim.y; + s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; + s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum; + __syncthreads(); + + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + if (threadIdx.y < offset) { + s_dg[threadIdx.y * blockDim.x + threadIdx.x] += + s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + s_db[threadIdx.y * blockDim.x + threadIdx.x] += + s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + } + __syncthreads(); + } + + if (threadIdx.y == 0) { + if (dg) { + dg[j] = s_dg[threadIdx.x]; + } + if (db) { + db[j] = s_db[threadIdx.x]; + } } } } @@ -1422,7 +1250,6 @@ void LayerNormBackwardKernelImplInternal( dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; -#if defined(USE_ROCM) if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; @@ -1438,6 +1265,7 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { +#if defined(USE_ROCM) // For small batch size, do colwise reduce directly. const int part_size = warp_size; const dim3 threads2(warp_size, 4, 1); @@ -1472,11 +1300,47 @@ void LayerNormBackwardKernelImplInternal( dgamma_data, dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - } #else - LaunchGammaBetaBackwardCUDAKernel( - dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { + // This implementation relies on warp primitives and requires that M and N divide + // exactly to warp size. + dim3 threads{kWarpSize, kWarpSize}; + int blocks = (N + threads.x - 1) / threads.x; + + // If M and N divide by warp_size, we can use warp shuffles for the final reduction. + // That requires transposing values in shared memory, so we apply a padding to + // reduce bank conflicts. + + size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; + GammaBetaBackwardCUDAKernel_32x32 + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dim3 threads{16, 32}; + int blocks = (N + threads.x - 1) / threads.x; + size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; + GammaBetaBackwardCUDAKernel + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } #endif + } } } diff --git a/test/test_nn.py b/test/test_nn.py index 72c440ca5ec5..30fe71b4162e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7195,26 +7195,6 @@ def test_layer_norm_eps(self): ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) - @unittest.skipIf(not TEST_CUDA, "CUDA not available") - def test_layer_norm_backwards_eps(self): - dtype = torch.float - m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), - (32, 32), (1024, 32), (1024, 1024), - (33, 33), (1025, 33), (1025, 1025)] - for m, n in m_x_n_list: - x = torch.randn((m, n), dtype=dtype, requires_grad=True) - grad_output = torch.rand_like(x) - x_cuda = x.clone().detach().to("cuda").requires_grad_() - grad_output_cuda = grad_output.clone().detach().to("cuda") - ln = nn.LayerNorm(n, dtype=dtype) - ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) - ln_out = ln(x) - ln_out_cuda = ln_cuda(x_cuda) - ln_out.backward(grad_output) - ln_out_cuda.backward(grad_output_cuda) - self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) - self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) - @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291 From 24f50653c8a5544ada6488df76ffb26fb77ef64a Mon Sep 17 00:00:00 2001 From: Gabriel Ferns Date: Wed, 2 Apr 2025 23:39:06 +0000 Subject: [PATCH 133/332] fix bug in logging code (#150518) Fixes https://github.com/pytorch/pytorch/issues/150379 ```python >>> key = "aten._int_mm_1_2_3" >>> m, n, k = key.split("_")[-3:] >>> m, n, k ('1', '2', '3') >>> name = "_".join(key.split("_")[:-3]) >>> name 'aten._int_mm' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150518 Approved by: https://github.com/xmfan --- torch/_inductor/compile_fx.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 86c79ed4a3bc..ba77e78240a2 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -870,7 +870,8 @@ def _compile_fx_inner( if log.isEnabledFor(logging.INFO): mm_table_data = [] for key, value in counters["aten_mm_info"].items(): - name, m, n, k = key.split("_") + m, n, k = key.split("_")[-3:] + name = "_".join(key.split("_")[:-3]) mm_table_data.append([name, m, n, k, value]) log.info("Overview info of inductor aten mms: ") log.info( From f363fe616d808d82f01faa2a4b2c8dae87907dd9 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Thu, 3 Apr 2025 00:08:19 +0000 Subject: [PATCH 134/332] [AOTInductor] Fix autotuning code's codegen (#150522) Summary: Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller. We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args. Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse` Differential Revision: D72297084 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522 Approved by: https://github.com/desertfire, https://github.com/22quinn --- test/inductor/test_aot_inductor.py | 37 +++++++++++++++++++++++++ torch/_inductor/codegen/wrapper.py | 10 ++++--- torch/testing/_internal/triton_utils.py | 26 +++++++++++++++++ 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index ce661145d9c7..d1fe8b7cfbc1 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -70,6 +70,7 @@ add_kernel_with_tma_2d, mul2_inplace_kernel, strange_config_matmul_kernel, + sub_kernel_autotuned, ) if IS_WINDOWS and IS_CI: @@ -4662,6 +4663,42 @@ def forward(self, x): model, example_inputs, "aoti_torch_clone_preserve_strides", 0 ) + def test_autotuning_args_reuse(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires GPU") + + class Model(torch.nn.Module): + def forward(self, x, y): + x_out = torch.empty_strided( + (x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE + ) + x_out = torch.permute(x_out, [0, 1]) + add_kernel_autotuned[(4,)](x, x, x_out, 16) + + y_out = torch.empty_strided( + (y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE + ) + y_out = torch.permute(y_out, [0, 1]) + add_kernel_autotuned[(64,)](y, y, y_out, 64) + + sub_kernel_autotuned[(4,)](x, x, x_out, 16) + + return x_out, y_out + + example_inputs = ( + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(8, 8, device=GPU_TYPE), + ) + dim0_x = Dim("dim0_x", min=1, max=2048) + dim0_y = Dim("dim0_y", min=1, max=2048) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}} + self.check_model( + Model(), + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"max_autotune": True}, + ) + @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_stft(self): N_FFT = 400 diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index bfb78977dba4..d7de4b4f24a6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -620,6 +620,7 @@ def __init__(self): # Map key is the kernel argument name; value is a tuple of the resulting example # tensor name with the kernel where that tensor was most recently used. self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {} + self.kernel_autotune_tmp_arg_idx: int = 0 # If the generated source code is exactly the same, reuse the # pre-existing kernel for it self.src_to_kernel: dict[str, str] = {} @@ -1991,7 +1992,7 @@ def wrap_arg(arg): return [wrap_arg(arg) for arg in call_args] - def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): + def generate_example_arg_value(self, arg, arg_type, raw_arg=None): if isinstance(arg_type, torch_dtype): if isinstance(raw_arg, ir.TMADescriptor): # first we generate the underlying buffer @@ -2004,8 +2005,9 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): assert raw_arg is not None, ( "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" ) - buf_name = f"tmp_arg_{index}" + buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}" buf = raw_arg + self.kernel_autotune_tmp_arg_idx += 1 size = tuple( V.graph.sizevars.atomically_apply_size_hint( @@ -2182,13 +2184,13 @@ def get_autotune_deletion_call() -> str: arg_str = arg elif arg not in self.kernel_autotune_example_args: arg_str = self.generate_example_arg_value( - arg, arg_type, raw_arg, i + arg, arg_type, raw_arg ) else: arg_str = self.kernel_autotune_example_args[arg][0] self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) else: - arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i) + arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) all_args.append(arg_str if key is None else f"{key}={arg_str}") self.kernel_autotune_calls.writeline( diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index 433a518feb15..608a6f14389b 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -117,6 +117,32 @@ def add_kernel_autotuned( output = x + y tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=[], + ) + @triton.jit + def sub_kernel_autotuned( + in_ptr0, + in_ptr1, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + y = tl.load(in_ptr1 + offsets, mask=mask) + output = x - y + tl.store(out_ptr + offsets, output, mask=mask) + @triton.autotune( configs=[ triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2), From 13f48197d2acce3f8f43da3687d422633acd99e5 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 2 Apr 2025 15:44:16 -0700 Subject: [PATCH 135/332] Add Chillee as core reviewer (#150579) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150579 Approved by: https://github.com/albanD, https://github.com/drisspg, https://github.com/malfet --- .github/merge_rules.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/merge_rules.yaml b/.github/merge_rules.yaml index 0a091ecadbe5..bae188d2a335 100644 --- a/.github/merge_rules.yaml +++ b/.github/merge_rules.yaml @@ -540,6 +540,7 @@ - bdhirsh - zou3519 - isuruf + - Chillee mandatory_checks_name: - EasyCLA - Lint From 77dca3947ebf0d60bec4e1966da77b12ad3f7798 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Apr 2025 00:55:35 +0000 Subject: [PATCH 136/332] [aoti] make a check function for each input (#150553) Summary: make a check function for each input to avoid too large to optimize error on `__check_inputs_outputs` Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r runtime_checks ``` Differential Revision: D72286280 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150553 Approved by: https://github.com/desertfire --- test/inductor/test_aot_inductor.py | 22 ++++++++++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 22 ++++++++++++++++------ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index d1fe8b7cfbc1..905a3d2850c9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3518,6 +3518,28 @@ def forward(self, x0, x1): dynamic_shapes=dynamic_shapes, ) + def test_runtime_checks_large(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, *inputs): + result = inputs[0] + for i in range(1, len(inputs)): + result = result + inputs[i] + return result + + inputs = [] + for i in range(1000): + inputs.append(torch.ones(8, 8, 8, dtype=torch.float16, device=self.device)) + inputs = tuple(inputs) + model = Model() + with torch.no_grad(): + AOTIRunnerUtil.compile( + model, + inputs, + ) + def test_runtime_checks_complex(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index fc30b0f3e437..9f163256b311 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -403,6 +403,19 @@ def gen_check(handle_kind, idx, name, tensor): """ ) + # Create a separate function for each input check to avoid "too big to optimize" error + for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): + self.prefix.splice( + f""" + AOTI_NOINLINE static void check_input_{idx}( + AtenTensorHandle* input_handles + ) {{ + """ + ) + with self.prefix.indent(): + gen_check("input_handles", idx, name, tensor) + self.prefix.writeline("}") + # force noinline to avoid any potential compilation slowdown due to aggressive # inline done by the host compiler self.prefix.splice( @@ -422,8 +435,8 @@ def gen_check(handle_kind, idx, name, tensor): """ ) with self.prefix.indent(): - for idx, (name, tensor) in enumerate(V.graph.graph_inputs.items()): - gen_check("input_handles", idx, name, tensor) + for idx in range(len(V.graph.graph_inputs)): + self.prefix.writeline(f"check_input_{idx}(input_handles);") self.prefix.writeline("}") def write_wrapper_decl(self): @@ -475,13 +488,10 @@ def write_wrapper_decl(self): DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor ) { + __check_inputs_outputs(input_handles, output_handles); """ self.generate_input_output_runtime_checks() - run_impl_proto += """ - __check_inputs_outputs(input_handles, output_handles); - """ - self.prefix.splice(run_impl_proto) else: # cpp entry function for JIT with cpp wrapper From 2e5d95a0828060f816251671e8e59f2680f9f9be Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 2 Apr 2025 15:06:18 -0700 Subject: [PATCH 137/332] [FlexAttention] Remove dead code (#150575) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150575 Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng --- torch/nn/attention/_utils.py | 10 +--------- torch/nn/attention/flex_attention.py | 12 +----------- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index 7ec94e8189f7..5b09a2c14c24 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs """Defines utilities for interacting with scaled_dot_product_attention""" import math -from typing import Optional, Union +from typing import Optional import torch @@ -31,14 +31,6 @@ def _calculate_scale(head_dim_size: int, scale: Optional[float]) -> float: return 1.0 / math.sqrt(head_dim_size) -_SUPPORTED_HEAD_DIMS = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] - - -def _supported_head_dim(n: Union[int, torch.SymInt]) -> bool: - """Returns true if the head dim is supported by FlexAttention""" - return n in _SUPPORTED_HEAD_DIMS - - def _validate_sdpa_input( query: torch.Tensor, key: torch.Tensor, diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 6bf74aab2029..8bf87c60d411 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -19,7 +19,7 @@ _temp_remove_metadata_torch_function_mode, _temp_remove_pre_dispatch_torch_function_mode, ) -from torch.nn.attention._utils import _supported_head_dim, _validate_sdpa_input +from torch.nn.attention._utils import _validate_sdpa_input from torch.utils._pytree import tree_map_only @@ -1118,16 +1118,6 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): f"Expect query and key/value to have the same embedding dimension " f"but got E={query.size(-1)} and E={key.size(-1)}." ) - return - # TODO this config segfaults with Triton without: - # https://github.com/triton-lang/triton/pull/4540 - if not ( - _supported_head_dim(query.size(-1)) and _supported_head_dim(value.size(-1)) - ): - raise ValueError( - f"NYI: Currently non power of 2 embedding dimension are not supported. " - f"Got E={query.size(-1)} and Ev={value.size(-1)}." - ) def _validate_device(query: Tensor, key: Tensor, value: Tensor): From 90ddb33141b8aecbe0da979d284fff7fa9f93bca Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Thu, 3 Apr 2025 05:20:10 +0000 Subject: [PATCH 138/332] [export] specialize for aten.to (#149235) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes decomposition behavior of `aten.to` to respect the aliasing/non-aliasing behavior in eager, and to specialize to the input/conversion dtype & device. Before change: we always decompose `aten.to` into `_to_copy`, regardless of aliasing behavior. This leads us to ban mutations on the result of `_to_copy` when aliased, since we can't guarantee correct program semantics. This meant users had to explicitly call `.clone()` before mutating. In the special cases where we don’t ban mutations (e.g. dtype conversion), we add runtime assertions on the input & conversion dtype/devices in the decomposed program (see https://github.com/pytorch/pytorch/pull/142420). After change: we decompose to the aliasing/non-aliasing behavior that matches eager, allowing mutations in all cases. We also add dtype/device assertions for all `aten.to` ops, starting in the pre-dispatch graph, basically specializing the program to the dtype/devices. Differential Revision: D71229547 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149235 Approved by: https://github.com/tugsbayasgalan --- aten/src/ATen/native/ComparisonUtils.cpp | 8 +- .../tensor/test_dtensor_compile.py | 4 +- test/export/test_export.py | 294 ++++++++++++++++-- test/functorch/test_control_flow.py | 1 + torch/_export/utils.py | 59 ++++ torch/_subclasses/functional_tensor.py | 58 +--- .../pt2e/representation/rewrite.py | 36 ++- torch/export/_trace.py | 11 +- torch/export/exported_program.py | 41 +-- 9 files changed, 374 insertions(+), 138 deletions(-) diff --git a/aten/src/ATen/native/ComparisonUtils.cpp b/aten/src/ATen/native/ComparisonUtils.cpp index 4019cf2ff9b1..415b8cab1364 100644 --- a/aten/src/ATen/native/ComparisonUtils.cpp +++ b/aten/src/ATen/native/ComparisonUtils.cpp @@ -30,7 +30,9 @@ void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalS _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); - _assert_match(tensor.device(), device, "device"); + if (tensor.device().type() != DeviceType::Meta) { + _assert_match(tensor.device(), device, "device"); + } _assert_match(tensor.layout(), layout, "layout"); } @@ -38,7 +40,9 @@ void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef s _assert_match(tensor.sizes(), sizes, "sizes"); _assert_match(tensor.strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); - _assert_match(tensor.device(), device, "device"); + if (tensor.device().type() != DeviceType::Meta) { + _assert_match(tensor.device(), device, "device"); + } _assert_match(tensor.layout(), layout, "layout"); } diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index 8de5a4db0a98..162acbd000e9 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -157,6 +157,7 @@ def forward(self, x): str(ep.graph_module.code).strip(), """\ def forward(self, b_buffer, x): + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(x, dtype = torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None to = torch.ops.aten.to.dtype_layout(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None view_as = torch.ops.aten.view_as.default(to, to); to = None dtensor___init__0 = self.dtensor___init__0 @@ -172,7 +173,8 @@ def forward(self, b_buffer, x): str(ep.run_decompositions({}).graph_module.code).strip(), """\ def forward(self, b_parametrizations_buffer_original0, x): - _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda')); x = None + _assert_tensor_metadata = torch.ops.aten._assert_tensor_metadata.default(x, None, None, torch.float64, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata = None + _to_copy = torch.ops.aten._to_copy.default(x, dtype = torch.float64, layout = torch.strided, device = device(type='cuda', index=0)); x = None view = torch.ops.aten.view.default(_to_copy, [4, 4]); _to_copy = None add = torch.ops.aten.add.Tensor(b_parametrizations_buffer_original0, view); b_parametrizations_buffer_original0 = view = None view_1 = torch.ops.aten.view.default(add, [4, 4]); add = None diff --git a/test/export/test_export.py b/test/export/test_export.py index c56f7af7d47f..988e2fae81c6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -227,6 +227,12 @@ def is_non_strict_legacy_test(test_name): return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) +def is_legacy_test(test_name): + return test_name.endswith(LEGACY_EXPORT_NONSTRICT_SUFFIX) or test_name.endswith( + LEGACY_EXPORT_STRICT_SUFFIX + ) + + def is_retracebility_test(test_name): return test_name.endswith(RETRACEABILITY_STRICT_SUFFIX) or test_name.endswith( RETRACEABILITY_NON_STRICT_SUFFIX @@ -5855,14 +5861,37 @@ class Module(torch.nn.Module): def forward(self, x): return x.to("cpu") - ep = export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({}) + ep = export(Module(), (torch.tensor(1, device="cpu"),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to will just specialize by decomposing to a no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) def test_device_to_dynamic(self): class Module(torch.nn.Module): @@ -5873,14 +5902,37 @@ def forward(self, x): Module(), (torch.tensor([1, 2], device="cpu"),), dynamic_shapes={"x": {0: Dim("i")}}, - ).run_decompositions({}) + ) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to will just specialize by decomposing to a no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) def test_device_to_mutation(self): class Module(torch.nn.Module): @@ -5889,10 +5941,102 @@ def forward(self, x): y.add_(1) return y, x - with self.assertRaisesRegex( - RuntimeError, "cannot mutate tensors with frozen storage" + ep = export(Module(), (torch.tensor(1, device="cpu"),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # aten.to decomposes to no-op, add_ decomposes to functional variant + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.add.Tensor, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + torch.ops.aten.add_.Tensor, + ], + ) + + # test mutation + x = torch.tensor(2, device="cpu") + y, _ = ep.module()(x) + self.assertEqual(x.item(), 3) + self.assertEqual(id(y), id(x)) + + # test decomp ep + ep = ep.run_decompositions({}) + for node in ep.graph.nodes: + if node.op == "call_function": + self.assertNotEqual(node.target, torch.ops.aten.to.dtype_layout) + + # test mutation for decomposed program + y, _ = ep.module()(x) + self.assertEqual(x.item(), 4) + self.assertEqual(id(y), id(x)) + + @requires_gpu + @testing.expectedFailureCppRuntime + def test_device_to_gpu(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x.to("cpu") + + ep = export(Foo(), (torch.randn(64).cuda(),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName ): - export(Module(), (torch.tensor(1, device="cpu"),)).run_decompositions({}) + # aten.to decomposes to _to_copy + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype_layout, + ], + ) + + # Check device assertion + with self.assertRaisesRegex(RuntimeError, "Tensor device mismatch!"): + ep.module()(torch.randn(64)) + + ep = ep.run_decompositions() + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 2) + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + + # Check device assertion again after decomp + with self.assertRaisesRegex(RuntimeError, "Tensor device mismatch!"): + ep.module()(torch.randn(64)) def test_tensor_constant_aten_to(self): class Module(torch.nn.Module): @@ -5920,40 +6064,96 @@ class Module(torch.nn.Module): def forward(self, x): return x.float() - ep = export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions( - {} - ) + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - for op in ops: - self.assertIn(op, (torch.ops.aten._to_copy.default,)) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # .float() decomposes to no-op + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + ], + ) + + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual(len(ops), 1) + + # test aliasing + x = torch.tensor(1, dtype=torch.float) + out = ep.module()(x) + self.assertEqual(id(x), id(out)) def test_float_conversion_from_int(self): class Module(torch.nn.Module): def forward(self, x): return x.float() - ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)).run_decompositions( - {} - ) + ep = export(Module(), (torch.tensor(1, dtype=torch.int32),)) ops = [] for node in ep.graph.nodes: if node.op == "call_function": ops.append(node.target) - self.assertGreater(len(ops), 0) - self.assertIn(torch.ops.aten._to_copy.default, ops) - self.assertIn(torch.ops.aten._assert_tensor_metadata.default, ops) - - self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName + ): + # .float() decomposes to _to_copy() + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + ], + ) # Raises error because the input dtype is not the same as the input # tensor when exporting. with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"): ep.module()(torch.tensor(1, dtype=torch.float32)) + ep = ep.run_decompositions({}) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten._to_copy.default, + ], + ) + + # Check dtype assertion again after decomp + with self.assertRaisesRegex(RuntimeError, "Tensor dtype mismatch!"): + ep.module()(torch.tensor(1, dtype=torch.float32)) + + self.assertEqual(ep.module()(torch.tensor(1, dtype=torch.int32)), 1) + def test_device_to_mutation_float(self): class Module(torch.nn.Module): def forward(self, x): @@ -5961,12 +6161,48 @@ def forward(self, x): y.add_(1) return y, x - with self.assertRaisesRegex( - RuntimeError, "cannot mutate tensors with frozen storage" + ep = export(Module(), (torch.tensor(1, dtype=torch.float),)) + ops = [] + for node in ep.graph.nodes: + if node.op == "call_function": + ops.append(node.target) + if is_legacy_test(self._testMethodName) or is_training_ir_test( + self._testMethodName ): - export(Module(), (torch.tensor(1, dtype=torch.float),)).run_decompositions( - {} + # aten.to decomposes to no-op, add_ decomposes to functional variant + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.add.Tensor, + ], ) + else: + self.assertEqual( + ops, + [ + torch.ops.aten._assert_tensor_metadata.default, + torch.ops.aten.to.dtype, + torch.ops.aten.add_.Tensor, + ], + ) + + # test mutation + x = torch.tensor(2, dtype=torch.float) + y, _ = ep.module()(x) + self.assertEqual(x.item(), 3.0) + self.assertEqual(id(y), id(x)) + + # test decomp ep + ep = ep.run_decompositions({}) + for node in ep.graph.nodes: + if node.op == "call_function": + self.assertNotEqual(node.target, torch.ops.aten.to.dtype) + + # test mutation for decomposed program + y, _ = ep.module()(x) + self.assertEqual(x.item(), 4.0) + self.assertEqual(id(y), id(x)) def test_module(self): class MyLinear(torch.nn.Module): diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 0e3f39eb2266..9349e9c103f2 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -6903,6 +6903,7 @@ def forward(self, t): t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec) sum_1: "f32[]" = torch.ops.aten.sum.default(t) + _assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None item: "Sym(u0)" = torch.ops.aten.item.default(to); to = None sin: "f32[2, 3]" = torch.ops.aten.sin.default(t) diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 833fec60cb69..c3562c470f1b 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -58,6 +58,8 @@ InputKind.TOKEN: "token", } +_DISABLE_ATEN_TO_ASSERTION_PASS = False + def _collect_and_set_constant_attrs( graph_signature, constants, mod @@ -577,6 +579,59 @@ def nodes_filter(nodes: list[torch.fx.Node], node_call_back) -> list[torch.fx.No return [node for node in nodes if node_call_back(node)] +@contextmanager +def _disable_aten_to_metadata_assertions(): + global _DISABLE_ATEN_TO_ASSERTION_PASS + orig_val = _DISABLE_ATEN_TO_ASSERTION_PASS + _DISABLE_ATEN_TO_ASSERTION_PASS = True + try: + yield + finally: + _DISABLE_ATEN_TO_ASSERTION_PASS = orig_val + + +def _insert_aten_to_metadata_assert_pass(gm: torch.fx.GraphModule) -> None: + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + + if _DISABLE_ATEN_TO_ASSERTION_PASS: + return + + aten_to_variants = [ + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + torch.ops.aten.to.dtype_layout, + ] + for node in gm.graph.nodes: + if node.target in aten_to_variants: + if ( + node.prev.target == torch.ops.aten._assert_tensor_metadata.default + and node.args[0] == node.prev.args[0] + ): + # skip if already guarded + continue + + if (tensor_val := node.args[0].meta.get("val")) is not None: + with gm.graph.inserting_before(node), _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + ), + ): + gm.graph.call_function( + torch.ops.aten._assert_tensor_metadata.default, + args=(node.args[0],), + kwargs={ + "dtype": tensor_val.dtype, + "device": tensor_val.device, + "layout": tensor_val.layout, + }, + ) + + def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): from torch._export.passes._node_metadata_hook import ( _node_metadata_hook, @@ -600,6 +655,10 @@ def apply_runtime_assertion_pass(gm: torch.fx.GraphModule, graph_signature): f"exported program: {first_call_function_nn_module_stack(gm.graph)}", export=True, ) + + # insert runtime assertions for aten.to nodes + _insert_aten_to_metadata_assert_pass(gm) + # update output specs gm.recompile() graph_signature.user_outputs = _graph_output_names(gm) diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index fb272adc7ea3..368e30246091 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -260,9 +260,12 @@ def tolist(self) -> Any: def to(self, *args, **kwargs): if _detect_infra_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL).export: - # If copy is specified as pos arg, it's always the second one. - if len([arg for arg in args if isinstance(arg, bool)]) <= 1: - return super().to(*args, **{**kwargs, "copy": True}) + torch.ops.aten._assert_tensor_metadata( + self, + dtype=self.dtype, + device=self.device, + layout=self.layout, + ) return super().to(*args, **kwargs) def cuda(self, device=None, *args, **kwargs): @@ -354,23 +357,6 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - if self.export: - # We need to make sure that we don't decompose to() as usual in export mode, - # because it can get optimized away. Instead we always replace it with _to_copy(). - if func == torch.ops.aten.to.dtype_layout: - kwargs.pop("copy", None) - return self.__torch_dispatch__( - torch.ops.aten._to_copy.default, types, args, kwargs - ) - if func == torch.ops.aten.to.dtype: - schema = tuple(arg.name for arg in func._schema.arguments) - for arg, name in zip(args[1:], schema[1:]): - kwargs[name] = arg - kwargs.pop("copy", None) - return self.__torch_dispatch__( - torch.ops.aten._to_copy.default, types, args[:1], kwargs - ) - unrecognized_types = [ t for t in types @@ -527,36 +513,10 @@ def unwrap(x): *args_unwrapped, **kwargs_unwrapped, ) - # We don't allow any mutation on result of dropout or _to_copy + if self.export: - if func in ( - torch.ops.aten.dropout.default, - torch.ops.aten._to_copy.default, - ): - - def must_copy(): - """ - Return True if the output of the op must be copied, not an alias - """ - # output dtype is different from input - return ( - func == torch.ops.aten._to_copy.default - and "dtype" in kwargs - and kwargs["dtype"] != args_unwrapped[0].dtype - ) - - # `args_unwrapped` might be a tensor constant, not a functional tensor. - if must_copy() and torch._is_functional_tensor( - args_unwrapped[0] - ): - # We can further relax to args_unwrapped[0] != kwargs["dtype"], but I don't think - # we have an aten op for that. - torch.ops.aten._assert_tensor_metadata.default( - torch._from_functional_tensor(args_unwrapped[0]), - dtype=args_unwrapped[0].dtype, - ) - else: - torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] + if func == torch.ops.aten.dropout.default: + torch._freeze_functional_tensor(outs_unwrapped) # type: ignore[attr-defined] outs_wrapped = pytree.tree_map_only( torch.Tensor, wrap, outs_unwrapped ) diff --git a/torch/ao/quantization/pt2e/representation/rewrite.py b/torch/ao/quantization/pt2e/representation/rewrite.py index ed3b30552a1f..ae23b43b9cb0 100644 --- a/torch/ao/quantization/pt2e/representation/rewrite.py +++ b/torch/ao/quantization/pt2e/representation/rewrite.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Optional import torch +from torch._export.utils import _disable_aten_to_metadata_assertions from torch._higher_order_ops.out_dtype import out_dtype from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 from torch.ao.quantization.pt2e.export_utils import _WrapperModule @@ -798,22 +799,23 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule: remove_tensor_overload_for_qdq_ops(model) - for rewrite_info in _REWRITE_INFO_LIST: - example_inputs = rewrite_info.example_inputs - pattern = rewrite_info.pattern - replacement = rewrite_info.replacement - pattern_post_trans = rewrite_info.pattern_post_trans - replacement_post_trans = rewrite_info.replacement_post_trans - pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] - remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] - replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] - remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] - if pattern_post_trans: - pattern = pattern_post_trans(pattern) - if replacement_post_trans: - replacement = replacement_post_trans(replacement) - pattern.recompile() # type: ignore[attr-defined] - replacement.recompile() # type: ignore[attr-defined] - replace_pattern(model, pattern, replacement) + with _disable_aten_to_metadata_assertions(): + for rewrite_info in _REWRITE_INFO_LIST: + example_inputs = rewrite_info.example_inputs + pattern = rewrite_info.pattern + replacement = rewrite_info.replacement + pattern_post_trans = rewrite_info.pattern_post_trans + replacement_post_trans = rewrite_info.replacement_post_trans + pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(pattern) # type: ignore[arg-type] + replacement = _get_aten_graph_module_for_pattern(replacement, example_inputs) # type: ignore[arg-type, assignment] + remove_tensor_overload_for_qdq_ops(replacement) # type: ignore[arg-type] + if pattern_post_trans: + pattern = pattern_post_trans(pattern) + if replacement_post_trans: + replacement = replacement_post_trans(replacement) + pattern.recompile() # type: ignore[attr-defined] + replacement.recompile() # type: ignore[attr-defined] + replace_pattern(model, pattern, replacement) return model diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 63c7472f0a7c..72269acc2625 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -500,11 +500,12 @@ def _produce_aten_artifact( It does: 1. Applies runtime assertion pass - 2. Populate meta val when missing - 3. Lift constants as placeholders - 4. Replace raw autograd and autocast ops with HOPs - 5. Prettify names for placeholders - 6. Preserve requires_grad value on node meta val + 2. Recompute unbacked_bindings pass + 3. Populate meta val when missing + 4. Lift constants as placeholders + 5. Replace raw autograd and autocast ops with HOPs + 6. Prettify names for placeholders + 7. Preserve requires_grad value on node meta val """ # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. # Overwrite output specs afterwards. diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index ee8640c3ade7..bcaf8645f795 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -187,7 +187,7 @@ def _fx_collection_equivalence_fn( @contextmanager -def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): +def _override_composite_implicit_decomp(cia_ops_to_callable): # This function overrides CompositeImplicitAutograd decomp for # functional composite ops that user specified. Ideally we want to not-decompose # ALL composite ops but today's C++ functinalization relies on @@ -195,13 +195,6 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # Hence we can only do it for functional ops. One caveat is that # there are some composite ops that lie about their schema (claimed to be # functional but not really aka dropout), for these cases, we just decompose. - - # When safe=False, we will assume that ops_to_preserve can be mutating/aliasing - # and their usual decompositions need to be shadowed rather than overridden. - # Thus we will avoid asserting that they are valid to preserve, and will not - # replace their CompositeImplicitAutograd kernels with NotImplemented. - # The only current users of this mode are variants of aten::to that we will - # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. saved_tables = {} patched_ops = set() for op_overload, decomp_callable in cia_ops_to_callable.items(): @@ -219,10 +212,9 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] - if safe: - op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( - decomp_callable - ) + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) # [NOTE] Directly registering fake tensor rule to CIA ops # The problem we are facing here is if your CIA custom rule @@ -278,21 +270,6 @@ def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): _deregister_op_impl(op) -@contextmanager -def _override_decomp_aten_to_variants(): - # Preserve variants of aten::to understanding that they are mutating/aliasing - # and their CompositeImplicitAutograd kernels will not become NotImplemented. - # We will later replace them with aten._to_copy when functionalizing. - with _override_composite_implicit_decomp( - { - torch.ops.aten.to.dtype_layout: _special_op_to_preserve_cia, - torch.ops.aten.to.dtype: _special_op_to_preserve_cia, - }, - safe=False, - ): - yield - - def _split_decomp_table_to_cia_and_python_decomp( decomp_table: dict[torch._ops.OperatorBase, Callable] ) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: @@ -465,15 +442,9 @@ def _is_joint_ir_decomp(ep, joint_loss_index): tx = TracingContext(fake_mode) - with ( - fake_mode - ), _override_decomp_aten_to_variants(), _override_composite_implicit_decomp( + with fake_mode, _override_composite_implicit_decomp( cia_to_decomp, - ), _enable_graph_inputs_of_type_nn_module( - ep.example_inputs - ), tracing( - tx - ): + ), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx): retracing_args_unwrapped = pytree.tree_unflatten( retracing_args, mod._in_spec ) From fc674b45d4d8edfd4c630d89f71ea9f85a2f61f2 Mon Sep 17 00:00:00 2001 From: "Junjie Wang (PyTorch)" Date: Thu, 3 Apr 2025 06:42:06 +0000 Subject: [PATCH 139/332] [c10d] Add logging for desync debug report (#150513) Summary: We want to add a logging to first understand what is the distribution of desync debug report. Test Plan: Test with logger staging Differential Revision: D72249281 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150513 Approved by: https://github.com/kwen2501 --- test/distributed/test_c10d_nccl.py | 2 +- .../csrc/distributed/c10d/FlightRecorder.hpp | 2 +- .../distributed/c10d/ProcessGroupNCCL.cpp | 23 ++++++++++++++++++- .../distributed/c10d/ProcessGroupNCCL.hpp | 9 +++++++- 4 files changed, 32 insertions(+), 4 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index dadb3b0804b3..c8032a89d523 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -4385,7 +4385,7 @@ def started_or_scheduled(self, timing_enabled): class NCCLTraceTest(NCCLTraceTestBase): def _verify_trace(self, t, include_collectives, timing_enabled, is_json): ver = t["version"] - self.assertEqual(ver, "2.4") + self.assertEqual(ver, "2.5") pg_config = t["pg_config"] self.assertEqual(len(pg_config), 1) default_pg_info = pg_config["0"] diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index e15f153d70c7..e134e39cab78 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -24,7 +24,7 @@ namespace c10d { // (minor when adding fields, major when changing existing fields) // Also update both JSON and Pickle dumps to make use of the newly defined // field(s). -DEFINE_CONSTANT(version_val, "2.4") +DEFINE_CONSTANT(version_val, "2.5") DEFINE_CONSTANT(entries_key, "entries") DEFINE_CONSTANT(nccl_comm_key, "nccl_comm_state") DEFINE_CONSTANT(version_key, "version") diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 734705a93cc9..3dc7f23860f3 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1058,7 +1058,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( // Enable Desync Debugger per user setting if (desyncDebug_) { - desyncDebugger_.init(rank, size, store_); + desyncDebugger_.init(rank, size, globalRank(), getUid(), store_); } } @@ -1943,9 +1943,13 @@ void ProcessGroupNCCL::ncclCommWatchdog() { void ProcessGroupNCCL::DesyncDebugger::init( int rank, int size, + int globalRank, + int pgId, c10::intrusive_ptr store) { rank_ = rank; size_ = size; + globalRank_ = globalRank; + pgId_ = pgId; store_ = std::move(store); enabled_ = true; traceKeyStart_ = getTraceStartKey("NCCL", rank); @@ -1957,21 +1961,38 @@ void ProcessGroupNCCL::DesyncDebugger::run() { if (!enabled_) return; auto logPrefix = c10::str("Rank ", rank_); + ::c10d::C10dLoggingData log; + log.integers["pg_id"] = pgId_; + log.integers["rank"] = rank_; + log.integers["global_rank"] = globalRank_; + log.integers["world_size"] = size_; + // Use this to differentiate between flight recorder and desync debug report. + log.strings["flight_recorder_version"] = "-1"; + try { std::string desyncMsg = retrieveDesyncReport(store_, "NCCL", rank_, size_); + log.strings["status"] = "SUCCESS"; LOG(ERROR) << logPrefix << desyncMsg; } catch (const std::exception& e) { + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = e.what(); enabled_ = false; LOG(ERROR) << logPrefix << " Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " << " Please file an issue. Error: " << e.what(); } catch (...) { enabled_ = false; + log.strings["status"] = "EXCEPTION"; + log.strings["exception_msg"] = "Unknown exception"; LOG(ERROR) << logPrefix << " Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." << " Please file an issue."; } + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + logger->log(log); + } } // Log work start to store. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 9f3ad484e55d..6f8b192a1a51 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -555,7 +555,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { class DesyncDebugger { public: // Initialize and enable DesyncDebugger - void init(int rank, int size, c10::intrusive_ptr store); + void init( + int rank, + int size, + int globalRank, + int pgId, + c10::intrusive_ptr store); // Run desync debug. This function is called by watchdog at time of timeout. void run(); @@ -574,6 +579,8 @@ class TORCH_API ProcessGroupNCCL : public Backend { // From ProcessGroupNCCL int rank_; int size_; + int globalRank_; + int pgId_; // Reference to the store so that we can log start/end event. c10::intrusive_ptr store_; From c067127d47fcf0254f38d95e9990f51092fb4fab Mon Sep 17 00:00:00 2001 From: Saagar Jha Date: Thu, 3 Apr 2025 06:50:22 +0000 Subject: [PATCH 140/332] Ensure cuda_dlink_post_cflags are quoted as well (#150151) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150151 Approved by: https://github.com/janeyx99 --- torch/utils/cpp_extension.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 197eba777930..23f1ba6ed559 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -798,6 +798,7 @@ def unix_wrap_ninja_compile(sources, if isinstance(extra_postargs, dict) and 'nvcc_dlink' in extra_postargs: cuda_dlink_post_cflags = unix_cuda_flags(extra_postargs['nvcc_dlink']) + cuda_dlink_post_cflags = [shlex.quote(f) for f in cuda_dlink_post_cflags] else: cuda_dlink_post_cflags = None From 9e106019f64d668f17f0b50dc46192cff7a37dce Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Thu, 3 Apr 2025 08:12:38 +0000 Subject: [PATCH 141/332] [XPU] Add an implict conversion from XPUStream to sycl::queue* (#148646) # Motivation Currently, in Pytorch XPU, `cudaStream_t` is mapped to `sycl::queue&`, so an implicit cast from `XPUStream` to `sycl::queue&` is provided just like `CUDAStream` has an implicit cast to `cudaStream_t`. But on the SYCLomatic side, we migrate `cudaStream_t` to `sycl::queue*` but not `sycl::queue&` (One reason is that `cudaStream_t` is actually a pointer so users can do anything with that integer. Another reason is that the early `sycl::queue` was not impl-ed by a pointer, so copy by value is not desirable.) Without this PR: ``` cudaStream_t a = getCurrentCUDAStream(); cudaStream_t b = getCurrentCUDAStream().stream(); ``` need be migrated to: ``` queue_ptr a = &(sycl::queue&)getCurrentXPUStream(); queue_ptr b = &(getCurrentXPUStream().queue()); ``` With this PR: ``` queue_ptr a = getCurrentXPUStream(); queue_ptr b = &(getCurrentXPUStream().queue()); ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/148646 Approved by: https://github.com/guangyey, https://github.com/EikanWang --- c10/xpu/XPUStream.h | 5 +++++ c10/xpu/test/impl/XPUStreamTest.cpp | 3 +++ 2 files changed, 8 insertions(+) diff --git a/c10/xpu/XPUStream.h b/c10/xpu/XPUStream.h index 903986253d23..fea64d7c109e 100644 --- a/c10/xpu/XPUStream.h +++ b/c10/xpu/XPUStream.h @@ -59,6 +59,11 @@ class C10_XPU_API XPUStream { return queue(); } + /// Implicit conversion to sycl::queue*. + operator sycl::queue*() const { + return &queue(); + } + /// Implicit conversion to Stream (a.k.a., forget that the stream is a /// XPU stream). operator Stream() const { diff --git a/c10/xpu/test/impl/XPUStreamTest.cpp b/c10/xpu/test/impl/XPUStreamTest.cpp index 581e7e69c6fa..661022dbe18e 100644 --- a/c10/xpu/test/impl/XPUStreamTest.cpp +++ b/c10/xpu/test/impl/XPUStreamTest.cpp @@ -223,6 +223,9 @@ TEST(XPUStreamTest, ExternalTest) { ASSERT_TRUE(curStream == myStream); ASSERT_TRUE(&(curStream.queue()) == stream); + sycl::queue* q_ptr = curStream; + ASSERT_TRUE(q_ptr == stream); + delete stream; } From e6e07ec1cf0b2770c24a12f57b4654fd7686c541 Mon Sep 17 00:00:00 2001 From: Arash Pakbin Date: Thu, 3 Apr 2025 09:51:06 +0000 Subject: [PATCH 142/332] [ROCm] code cleanup of architecture checks (#150473) This PR replaces several calls to `at::cuda::getCurrentDeviceProperties()->gcnArchName` and `at::cuda::getDeviceProperties(device_index)->gcnArchName` when checking to see if the GPU architecture is in a certain list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150473 Approved by: https://github.com/jeffdaily, https://github.com/cyyever --- aten/src/ATen/Context.cpp | 6 ++--- aten/src/ATen/cuda/CUDABlas.cpp | 4 +-- aten/src/ATen/cuda/CublasHandlePool.cpp | 4 +-- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 10 +++++-- aten/src/ATen/cuda/detail/CUDAHooks.h | 2 +- aten/src/ATen/detail/CUDAHooksInterface.h | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 32 +++-------------------- aten/src/ATen/native/cuda/int4mm.cu | 11 +------- aten/src/ATen/native/hip/ck_gemm_half.hip | 4 +-- 9 files changed, 21 insertions(+), 54 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 2b6cbfa6e7bf..2fdc98318c2e 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -340,7 +340,7 @@ at::BlasBackend Context::blasPreferredBackend() { #endif }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { return false; } } @@ -366,7 +366,7 @@ at::BlasBackend Context::blasPreferredBackend() { #endif }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use hipBLASLt on an unsupported architecture! " "Overriding blas backend to hipblas"); @@ -419,7 +419,7 @@ void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) { "gfx90a", "gfx942" }; for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(index, archs)) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { TORCH_WARN_ONCE( "Attempting to use CK on an unsupported architecture! Cannot set backend to CK"); return true; diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 52aee1378c0e..c5dd44dc1edf 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1085,9 +1085,7 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) } #if defined(USE_ROCM) && !defined(_MSC_VER) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { //no CK GEMM version for gfx1100 + if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100 gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } else{ at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(float)); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 9b183848503e..e88c0bd5dab2 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -124,9 +124,7 @@ size_t parseChosenWorkspaceSize() { val = getenv("ROCBLAS_WORKSPACE_CONFIG"); } /* 32MiB default, 128MiB for MI300 */ - cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = properties->gcnArchName; - const bool gfx94 = device_arch.find("gfx94") != std::string::npos; + const bool gfx94 = at::detail::getCUDAHooks().isGPUArch({"gfx94"}); const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 9847386c3394..ac5c833070c1 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -448,8 +448,14 @@ DeviceIndex CUDAHooks::getCurrentDevice() const { } #ifdef USE_ROCM -bool CUDAHooks::isGPUArch(DeviceIndex device_index, const std::vector& archs) const { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(device_index); +bool CUDAHooks::isGPUArch(const std::vector& archs, DeviceIndex device_index) const { + hipDeviceProp_t* prop; + if (device_index == -1){ + prop = at::cuda::getCurrentDeviceProperties(); + } else { + prop = at::cuda::getDeviceProperties(device_index); + } + std::string device_arch = prop->gcnArchName; for (std::string arch : archs) { size_t substring = device_arch.find(arch); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index d0be9d5f535c..2b4c11136321 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -57,7 +57,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { DeviceIndex getCurrentDevice() const override; #ifdef USE_ROCM - bool isGPUArch(DeviceIndex device_index, const std::vector& archs) const override; + bool isGPUArch(const std::vector& archs, DeviceIndex device_index = -1) const override; #endif void deviceSynchronize(DeviceIndex device_index) const override; }; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 9b54a84dd68d..9bc30ba84ea5 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -196,7 +196,7 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { } #ifdef USE_ROCM - virtual bool isGPUArch(DeviceIndex /*device_index*/, const std::vector& /*archs*/) const { + virtual bool isGPUArch(const std::vector& /*archs*/, DeviceIndex = -1 /*device_index*/) const { TORCH_CHECK(false, "Cannot check GPU arch without ATen_cuda library. ", CUDA_HELP); } #endif diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index eaa90de69570..ef863a219238 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -265,8 +265,6 @@ static bool getDisableAddmmCudaLt() { #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 @@ -276,13 +274,7 @@ static bool isSupportedHipLtROCmArch(int index) { "gfx950" #endif }; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch(archs, index); } #endif @@ -939,9 +931,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { } static bool _scaled_mm_allowed_device() { - auto dprops = at::cuda::getCurrentDeviceProperties(); #ifdef USE_ROCM - std::string device_arch = dprops->gcnArchName; static const std::vector archs = { "gfx942", #if ROCM_VERSION >= 60300 @@ -951,30 +941,16 @@ static bool _scaled_mm_allowed_device() { "gfx950" #endif }; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch(archs); #else + auto dprops = at::cuda::getCurrentDeviceProperties(); return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); #endif } #ifdef USE_ROCM static bool _scaled_mm_is_fnuz() { - auto dprops = at::cuda::getCurrentDeviceProperties(); - std::string device_arch = dprops->gcnArchName; - static const std::vector archs = {"gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx942"}); } #endif diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index dcc9237d737e..7fc3947879f4 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -135,16 +135,7 @@ template using VecT = T __attribute__((ext_vector_type(Rank))); static bool isCDNA2orLater(int index) { - hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); - std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx942"}; - for (std::string arch : archs) { - size_t substring = device_arch.find(arch); - if (substring != std::string::npos) { - return true; - } - } - return false; + return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942"}, index); } #else diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip index 14756167b142..552f0de84541 100644 --- a/aten/src/ATen/native/hip/ck_gemm_half.hip +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -598,9 +598,7 @@ void dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGTYPES(at::Half)) { template <> void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - c10::string_view arch(dprops->gcnArchName); - if (arch == "gfx1100") { + if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { dispatch_half_gemm_wmma(CUDABLAS_GEMM_ARGS(at::Half)); } else{ dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half)); From 6fa1b171955716002129b2155c79e56e8d9bdf08 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Thu, 3 Apr 2025 10:58:45 +0000 Subject: [PATCH 143/332] ROCm: Add trailing comma for consistency in gfx architecture list (#150250) Adding trailing comma for consistency. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150250 Approved by: https://github.com/petrex, https://github.com/jeffdaily, https://github.com/cyyever --- aten/src/ATen/Context.cpp | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 2fdc98318c2e..b5ce540b52ab 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -359,7 +359,7 @@ at::BlasBackend Context::blasPreferredBackend() { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201" + "gfx1100", "gfx1101", "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index ef863a219238..906b24652e49 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -268,7 +268,7 @@ static bool isSupportedHipLtROCmArch(int index) { static const std::vector archs = { "gfx90a", "gfx942", #if ROCM_VERSION >= 60300 - "gfx1100", "gfx1101", "gfx1200", "gfx1201" + "gfx1100", "gfx1101", "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" @@ -935,7 +935,7 @@ static bool _scaled_mm_allowed_device() { static const std::vector archs = { "gfx942", #if ROCM_VERSION >= 60300 - "gfx1200", "gfx1201" + "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 60500 "gfx950" From d4c30b4599f5d4541e39afaf62485c088d8772b0 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 2 Apr 2025 06:00:29 -0700 Subject: [PATCH 144/332] [AOTI][dashboard] Update how peak memory is measured (#150534) Summary: In the dashboard measurement script, AOTI needs to run Eager first to register the output pytree, so the peak memory compression ratio on the dashboard is always close to 1. Update AOTI run to use an extra warmup run, so the peak memory compression ratio measures the result at the run time instead of the compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150534 Approved by: https://github.com/yushangdi --- benchmarks/dynamo/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index d23c528c9de9..7c8a91de5202 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -3735,6 +3735,10 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) runner.skip_models.update(runner.skip_models_due_to_export_not_supported) + + # For AOTI, we only measure the memory compression ratio at the run time + # instead of the compile time, so use a warmup run to trigger AOTI compilation. + args.use_warm_peak_memory = True elif args.backend == "torchao": assert "cuda" in args.devices, "Quantization requires CUDA device." assert args.bfloat16, "Quantization requires dtype bfloat16." From 5d9c7f78e75910c1a515cfdb2e4f08fccdb74468 Mon Sep 17 00:00:00 2001 From: Danfeng Wang Date: Thu, 3 Apr 2025 12:01:57 +0000 Subject: [PATCH 145/332] [fbcode]Removing `@NoIntBaseDeprecated` annotation in `evaluation.thrift` file (#150271) Summary: #buildall Test Plan: ``` buck test 'fbcode//mode/opt' fbcode//caffe2/torch/fb/training_toolkit/applications/bulk_eval/tests:evaluator_test -- --exact 'caffe2/torch/fb/training_toolkit/applications/bulk_eval/tests:evaluator_test - test_setup_evaluation_utils (caffe2.torch.fb.training_toolkit.applications.bulk_eval.tests.evaluator_test.EvaluatorTest)' ``` Differential Revision: D72028940 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150271 Approved by: https://github.com/huydhn --- torch/testing/_internal/common_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index afbd569b34ba..01232af5d0d5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -2813,7 +2813,10 @@ def _to_number(self, number_like, *, id): elif isinstance(number_like, Enum): return int(number_like) # type: ignore[call-overload] else: - return super()._to_number(number_like, id=id) + number = super()._to_number(number_like, id=id) + if type(number) not in self._TYPE_TO_DTYPE.keys(): + self._inputs_not_supported() + return number class TensorOrArrayPair(TensorLikePair): From e0d19cf6ccb698e1c6081f5f18f555c972fbd9b4 Mon Sep 17 00:00:00 2001 From: LifengWang Date: Thu, 3 Apr 2025 12:17:16 +0000 Subject: [PATCH 146/332] Enable weekly test for operator benchmark (#150502) To regularly track the performance of the operator benchmark, enable the weekly test. Hi, @huydhn, as you mentioned in https://github.com/pytorch/pytorch/pull/143733#issuecomment-2578317520, we could integrate the performance data from the weekly test into the OSS benchmark database for the dashboard. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150502 Approved by: https://github.com/huydhn --- .github/workflows/operator_benchmark.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 7da1b438c7e9..805a7d328575 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -11,6 +11,9 @@ on: type: string default: 'short' description: tag filter for operator benchmarks, options from long, short, all + schedule: + # Run at 07:00 UTC every Sunday + - cron: 0 7 * * 0 concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} From cbc901fac335d2c9e1c6c3e75f541669870c79d7 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 2 Apr 2025 17:02:34 -0300 Subject: [PATCH 147/332] Implement `raise ... from ...` (#148766) Pull Request resolved: https://github.com/pytorch/pytorch/pull/148766 Approved by: https://github.com/zou3519 --- test/dynamo/test_exceptions.py | 28 ++ test/dynamo/test_generator.py | 11 +- test/dynamo/test_raise.py | 612 +++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 46 +- torch/_dynamo/variables/builtin.py | 6 + torch/_dynamo/variables/functions.py | 1 + 6 files changed, 679 insertions(+), 25 deletions(-) create mode 100644 test/dynamo/test_raise.py diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index c2390e8db449..6c82593e6ec3 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -545,6 +545,34 @@ def fn(x, d, key): self.assertEqual(ref[0], res[0]) self.assertEqual(ref[1], res[1]) + @make_dynamo_test + def test_raise_from_None_2(self): + def fn(): + try: + raise ValueError + except Exception: + raise TypeError from None + + try: + fn() + except TypeError as e: + assert e.__cause__ is None + assert e.__suppress_context__ is True + + @make_dynamo_test + def test_raise_from_other(self): + def fn(): + try: + raise ValueError + except Exception as e: + raise TypeError from e + + try: + fn() + except TypeError as e: + assert isinstance(e.__cause__, ValueError) + assert e.__suppress_context__ is True + @make_dynamo_test def test_reraise_first_exc(self): def fn(): diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index d1f4289a5793..03b1cf3e5268 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -12,6 +12,7 @@ from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + make_dynamo_test, parametrize, ) @@ -1069,6 +1070,7 @@ def fn(t): self.assertEqual(L, [1, -123, -1, 456]) @parametrize("exc", [RuntimeError, AttributeError]) + @make_dynamo_test def test_close_capture_and_reraise_exc(self, exc): def whoo(t): try: @@ -1079,7 +1081,6 @@ def whoo(t): finally: pass - @torch.compile(backend="eager", fullgraph=True) def fn(t): gen = whoo(t) i = next(gen) @@ -1087,8 +1088,14 @@ def fn(t): return i t = torch.randn(2) - with self.assertRaises(exc): + + z = 0 + try: fn(t) + except exc: + z = 1 + finally: + assert z == 1 def test_close_with_subgen(self): L = [] diff --git a/test/dynamo/test_raise.py b/test/dynamo/test_raise.py new file mode 100644 index 000000000000..133ebc142fe4 --- /dev/null +++ b/test/dynamo/test_raise.py @@ -0,0 +1,612 @@ +# Owner(s): ["module: dynamo"] + +# ruff: noqa +# flake8: noqa + +import sys +import types +import unittest + +import torch +import torch._dynamo.config +import torch._dynamo.test_case +import torch._functorch.config +import torch.nn +import torch.utils.checkpoint +from torch.testing._internal.common_utils import make_dynamo_test + + +def get_tb(): + try: + raise OSError() + except: + return sys.exc_info()[2] + + +class Context: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + return True + + +class MyException(Exception): + def __init__(self): + raise RuntimeError() + + +class ContextManager: + def __enter__(self): + pass + + def __exit__(self, t, v, tb): + raise NameError + + +class TestRaise(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + + def assertIn(self, member, container, msg=None): + assert member in container, msg + + def assertIs(self, expr1, expr2, msg=None): + assert expr1 is expr2, msg + + def assertRaises(self, expected_exception, *args, **kwargs): + z = 0 + try: + yield + except expected_exception: + z = 1 + except Exception: + z = 2 + assert z == 1 + + def assertIsInstance(self, obj, cls, msg=None): + assert isinstance(obj, cls), msg + + def assertIsNone(self, obj, msg=None): + assert obj is None, msg + + @make_dynamo_test + def test_invalid_reraise(self): + try: + raise + except RuntimeError as e: + self.assertIn("No active exception", str(e)) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_reraise(self): + try: + try: + raise IndexError + except IndexError as e: + exc1 = e + raise + except IndexError as exc2: + self.assertIs(exc1, exc2) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_except_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + except KeyError: + pass + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_finally_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + try: + raise KeyError("caught") + finally: + raise + + self.assertRaises(KeyError, reraise) + + @make_dynamo_test + def test_nested_reraise(self): + def nested_reraise(): + raise + + def reraise(): + try: + raise TypeError("foo") + except: + nested_reraise() + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_raise_from_None(self): + try: + try: + raise TypeError("foo") + except: + raise ValueError() from None + except ValueError as e: + self.assertIsInstance(e.__context__, TypeError) + self.assertIsNone(e.__cause__) + + @make_dynamo_test + def test_with_reraise1(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + pass + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_with_reraise2(self): + def reraise(): + try: + raise TypeError("foo") + except: + with Context(): + raise KeyError("caught") + raise + + self.assertRaises(TypeError, reraise) + + @make_dynamo_test + def test_yield_reraise(self): + def reraise(): + try: + raise TypeError("foo") + except: + yield 1 + raise + + g = reraise() + next(g) + self.assertRaises(TypeError, lambda: next(g)) + self.assertRaises(StopIteration, lambda: next(g)) + + @make_dynamo_test + def test_erroneous_exception(self): + try: + raise MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + @unittest.expectedFailure # object + @make_dynamo_test + def test_new_returns_invalid_instance(self): + # See issue #11627. + class MyException2(Exception): + def __new__(cls, *args): + return object() + + with self.assertRaises(TypeError): + raise MyException2 + + @unittest.expectedFailure # Assertion with non-string message + @make_dynamo_test + def test_assert_with_tuple_arg(self): + try: + assert False, (3,) + except AssertionError as e: + self.assertEqual(str(e), "(3,)") + + +class TestCause(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + + def assertIn(self, member, container, msg=None): + assert member in container, msg + + def assertIs(self, expr1, expr2, msg=None): + assert expr1 is expr2, msg + + def assertRaises(self, expected_exception, *args, **kwargs): + z = 0 + try: + yield + except expected_exception: + z = 1 + except Exception: + z = 2 + assert z == 1 + + def assertIsInstance(self, obj, cls, msg=None): + assert isinstance(obj, cls), msg + + def assertIsNone(self, obj, msg=None): + assert obj is None, msg + + def assertTrue(self, expr, msg=None): + assert bool(expr) is True, msg + + def assertFalse(self, expr, msg=None): + assert bool(expr) is False, msg + + @make_dynamo_test + def testCauseSyntax(self): + try: + try: + try: + raise TypeError + except Exception: + raise ValueError from None + except ValueError as exc: + self.assertIsNone(exc.__cause__) + self.assertTrue(exc.__suppress_context__) + exc.__suppress_context__ = False + raise exc + except ValueError as exc: + e = exc + + self.assertIsNone(e.__cause__) + self.assertFalse(e.__suppress_context__) + self.assertIsInstance(e.__context__, TypeError) + + @make_dynamo_test + def test_invalid_cause(self): + try: + raise IndexError from 5 + except TypeError as e: + self.assertIn("exception cause", str(e)) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_class_cause(self): + try: + raise IndexError from KeyError + except IndexError as e: + self.assertIsInstance(e.__cause__, KeyError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_instance_cause(self): + cause = KeyError() + try: + raise IndexError from cause + except IndexError as e: + self.assertIs(e.__cause__, cause) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_erroneous_cause(self): + try: + raise IndexError from MyException + except RuntimeError: + pass + else: + self.fail("No exception raised") + + +class TestTraceback(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_sets_traceback(self): + try: + raise IndexError() + except IndexError as e: + self.assertIsInstance(e.__traceback__, types.TracebackType) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_accepts_traceback(self): + tb = get_tb() + try: + raise IndexError().with_traceback(tb) + except IndexError as e: + self.assertNotEqual(e.__traceback__, tb) + self.assertEqual(e.__traceback__.tb_next, tb) + else: + self.fail("No exception raised") + + +class TestTracebackType(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + + def raiser(self): + raise ValueError + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_attrs(self): + try: + self.raiser() + except Exception as exc: + tb = exc.__traceback__ + + self.assertIsInstance(tb.tb_next, types.TracebackType) + self.assertIs(tb.tb_frame, sys._getframe()) + self.assertIsInstance(tb.tb_lasti, int) + self.assertIsInstance(tb.tb_lineno, int) + + self.assertIs(tb.tb_next.tb_next, None) + + # Invalid assignments + with self.assertRaises(TypeError): + del tb.tb_next + + with self.assertRaises(TypeError): + tb.tb_next = "asdf" + + # Loops + with self.assertRaises(ValueError): + tb.tb_next = tb + + with self.assertRaises(ValueError): + tb.tb_next.tb_next = tb + + # Valid assignments + tb.tb_next = None + self.assertIs(tb.tb_next, None) + + new_tb = get_tb() + tb.tb_next = new_tb + self.assertIs(tb.tb_next, new_tb) + + @unittest.expectedFailure # Dynamo doesn't track traceback + @make_dynamo_test + def test_constructor(self): + other_tb = get_tb() + frame = sys._getframe() + + tb = types.TracebackType(other_tb, frame, 1, 2) + self.assertEqual(tb.tb_next, other_tb) + self.assertEqual(tb.tb_frame, frame) + self.assertEqual(tb.tb_lasti, 1) + self.assertEqual(tb.tb_lineno, 2) + + tb = types.TracebackType(None, frame, 1, 2) + self.assertEqual(tb.tb_next, None) + + with self.assertRaises(TypeError): + types.TracebackType("no", frame, 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, "no", 1, 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, "no", 2) + + with self.assertRaises(TypeError): + types.TracebackType(other_tb, frame, 1, "nuh-uh") + + +class TestContext(torch._dynamo.test_case.TestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_raise.py + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + + def assertIn(self, member, container, msg=None): + assert member in container, msg + + def assertIs(self, expr1, expr2, msg=None): + assert expr1 is expr2, msg + + def assertRaises(self, expected_exception, *args, **kwargs): + z = 0 + try: + yield + except expected_exception: + z = 1 + except Exception: + z = 2 + assert z == 1 + + def assertIsInstance(self, obj, cls, msg=None): + assert isinstance(obj, cls), msg + + def assertIsNone(self, obj, msg=None): + assert obj is None, msg + + @unittest.expectedFailure # missing Exception.__eq__ + @make_dynamo_test + def test_instance_context_instance_raise(self): + context = IndexError() + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertEqual(e.__context__, context) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ + @make_dynamo_test + def test_class_context_instance_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError() + except OSError as e: + self.assertNotEqual(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + @unittest.expectedFailure # missing Exception.__eq__ and Exception.__repr__ + @make_dynamo_test + def test_class_context_class_raise(self): + context = IndexError + try: + try: + raise context + except: + raise OSError + except OSError as e: + self.assertNotEqual(e.__context__, context) + self.assertIsInstance(e.__context__, context) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_c_exception_context(self): + try: + try: + raise ZeroDivisionError + except: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_c_exception_raise(self): + try: + try: + raise ZeroDivisionError + except: + raise NameError + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_noraise_finally(self): + try: + try: + pass + finally: + raise OSError + except OSError as e: + self.assertIsNone(e.__context__) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_raise_finally(self): + try: + try: + raise ZeroDivisionError + finally: + raise OSError + except OSError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_context_manager(self): + try: + with ContextManager(): + raise ZeroDivisionError + except NameError as e: + self.assertIsInstance(e.__context__, ZeroDivisionError) + else: + self.fail("No exception raised") + + @make_dynamo_test + def test_cycle_broken(self): + # Self-cycles (when re-raising a caught exception) are broken + try: + try: + raise ZeroDivisionError + except ZeroDivisionError as e: + raise e + except ZeroDivisionError as e: + self.assertIsNone(e.__context__) + + @make_dynamo_test + def test_reraise_cycle_broken(self): + # Non-trivial context cycles (through re-raising a previous exception) + # are broken too. + try: + try: + raise NameError + except NameError as a: + try: + raise ZeroDivisionError + except ZeroDivisionError: + raise a + except NameError as e: + self.assertIsNone(e.__context__.__context__) + + @make_dynamo_test + def test_3118(self): + # deleting the generator caused the __context__ to be cleared + def gen(): + try: + yield 1 + finally: + pass + + def f(): + g = gen() + next(g) + try: + try: + raise ValueError + except: + del g + raise KeyError + except Exception as e: + self.assertIsInstance(e.__context__, ValueError) + + f() + + @unittest.expectedFailure # too CPython specific(?) + @make_dynamo_test + def test_3611(self): + # A re-raised exception in a __del__ caused the __context__ + # to be cleared + class C: + def __del__(self): + try: + raise ZeroDivisionError + except: + raise + + def f(): + x = C() + try: + try: + x.x + except AttributeError: + del x + raise TypeError + except Exception as e: + self.assertNotEqual(e.__context__, None) + self.assertIsInstance(e.__context__, AttributeError) + + with support.catch_unraisable_exception() as cm: + f() + + self.assertEqual(ZeroDivisionError, cm.unraisable.exc_type) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index fb38b9e1b664..6b9067a91830 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1757,18 +1757,22 @@ def FOR_ITER(self, inst): self.push(ConstantVariable.create(None)) self.jump(inst) - def _raise_exception_variable(self, val) -> NoReturn: - # User can raise exception in 2 ways - # 1) raise exception type - raise NotImplementedError - # 2) raise execption instance - raise NotImplemetedError("foo") - - # 1) when user raises exception type + def _create_exception_type(self, val): if isinstance( val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable) ): # Create the instance of the exception type # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549 val = val.call_function(self, [], {}) # type: ignore[arg-type] + return val + + def _raise_exception_variable(self, val) -> NoReturn: + # User can raise exception in 2 ways + # 1) raise exception type - raise NotImplementedError + # 2) raise execption instance - raise NotImplemetedError("foo") + + # 1) when user raises exception type + val = self._create_exception_type(val) # Handle https://peps.python.org/pep-0479/ # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this @@ -1795,6 +1799,10 @@ def _raise_exception_variable(self, val) -> NoReturn: def RAISE_VARARGS(self, inst): if inst.arg == 0: + if not len(self.exn_vt_stack): + msg = ConstantVariable("No active exception to reraise") + exc.raise_observed_exception(RuntimeError, self, args=[msg]) + # re-raise the previous exception. Here CPython refers to the exception # on top of the exception stack assert len(self.exn_vt_stack) @@ -1806,24 +1814,16 @@ def RAISE_VARARGS(self, inst): val = self.stack[-1] self._raise_exception_variable(val) else: - # raise .. from None + # raise .. from ... from_vt = self.pop() - if isinstance(from_vt, ConstantVariable) and from_vt.value is None: - val = self.pop() - try: - self._raise_exception_variable(val) - finally: - # Update __cause__/__supppress_context__ in the raised exception - curr_exc = self.exn_vt_stack.get_current_exception() - curr_exc.call_setattr( - self, ConstantVariable("__cause__"), ConstantVariable(None) - ) - unimplemented_v2( - gb_type="Re-raise with 2 arguments", - context=str(from_vt), - explanation="Dynamo does not support `raise ... from [not-None]`", - hints=[], - ) + val = self.pop() + try: + self._raise_exception_variable(val) + finally: + # Update __cause__/__supppress_context__ in the raised exception + curr_exc = self.exn_vt_stack.get_current_exception() + cause = self._create_exception_type(from_vt) + curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) def CLEANUP_THROW(self, inst): # https://github.com/python/cpython/pull/96010 diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 9c11423162d3..383ca5b4c343 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1267,6 +1267,12 @@ def call_str(self, tx: "InstructionTranslator", arg): # Inline the user function return tx.inline_user_function_return(user_func_variable, [arg], {}) + elif isinstance(arg, (variables.ExceptionVariable,)): + if len(arg.args) == 0: + value = f"{arg.exc_type}" + else: + value = ", ".join(a.as_python_constant() for a in arg.args) + return variables.ConstantVariable.create(value=value) def _call_min_max(self, tx: "InstructionTranslator", *args): if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index fc20350dc943..257ccac4d37b 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -520,6 +520,7 @@ def next_variable(self, tx): with patch.dict(counters, {"unimplemented": counters["inline_call"]}): return tracer.inline_call_() except ObservedException as e: + tracer.generator_exhausted = True raise e except InfiniteGeneratorError: # test/dynamo/test_misc.py::test_iterator_limit From 781d28e2655f88ae2fef827ed110f22ed553a0ab Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 3 Apr 2025 13:27:50 +0000 Subject: [PATCH 148/332] add unit test for preferred_blas_library settings (#150581) Follow up to #150212 that was committed without a unit test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150581 Approved by: https://github.com/atalman --- test/test_cuda.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/test/test_cuda.py b/test/test_cuda.py index a3cc62c5e1d4..4f4fb5148a7a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,6 +595,64 @@ def test_serialization_array_with_storage(self): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) + @setBlasBackendsToDefaultFinally + def test_preferred_blas_library_settings(self): + def _check_default(): + default = torch.backends.cuda.preferred_blas_library() + if torch.version.cuda: + # CUDA logic is easy, it's always cublas + self.assertTrue(default == torch._C._BlasBackend.Cublas) + else: + # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else + gcn_arch = str( + torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] + ) + if gcn_arch in ["gfx90a", "gfx942", "gfx950"]: + self.assertTrue(default == torch._C._BlasBackend.Cublaslt) + else: + self.assertTrue(default == torch._C._BlasBackend.Cublas) + + _check_default() + # "Default" can be set but is immediately reset internally to the actual default value. + self.assertTrue( + torch.backends.cuda.preferred_blas_library("default") + != torch._C._BlasBackend.Default + ) + _check_default() + self.assertTrue( + torch.backends.cuda.preferred_blas_library("cublas") + == torch._C._BlasBackend.Cublas + ) + self.assertTrue( + torch.backends.cuda.preferred_blas_library("hipblas") + == torch._C._BlasBackend.Cublas + ) + # check bad strings + with self.assertRaisesRegex( + RuntimeError, + "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", + ): + torch.backends.cuda.preferred_blas_library("unknown") + # check bad input type + with self.assertRaisesRegex(RuntimeError, "Unknown input value type."): + torch.backends.cuda.preferred_blas_library(1.0) + # check env var override + custom_envs = [ + {"TORCH_BLAS_PREFER_CUBLASLT": "1"}, + {"TORCH_BLAS_PREFER_HIPBLASLT": "1"}, + ] + test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())" + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("_BlasBackend.Cublaslt", r) + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @setBlasBackendsToDefaultFinally def test_cublas_workspace_explicit_allocation(self): From 70b34a42c17cecd316487dc574dce3b8121270cc Mon Sep 17 00:00:00 2001 From: FFFrog Date: Thu, 3 Apr 2025 16:11:32 +0800 Subject: [PATCH 149/332] Add new dependences for gen_pyi.py (#150391) As the title stated. When we update some functions in _torch_docs.py or _tensor_docs.py, and execute some commands (like ``python setup.py evolve``) to install the latest version, the description about the function we just changed is not updated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150391 Approved by: https://github.com/Skylion007, https://github.com/peterbell10 --- torch/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 8b8ebdc6e976..67fe1df8ca87 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -251,6 +251,8 @@ add_custom_command( "${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml" "${TORCH_ROOT}/aten/src/ATen/native/tags.yaml" "${TORCH_ROOT}/tools/autograd/deprecated.yaml" + "${TORCH_ROOT}/torch/_torch_docs.py" + "${TORCH_ROOT}/torch/_tensor_docs.py" ${pyi_python} ${autograd_python} ${torchgen_python} From ff783f062a4ca889dbe1eae1e72e6d20dd3839db Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 2 Apr 2025 22:07:37 +0000 Subject: [PATCH 150/332] Fix shape guard failure to be valid python (#149149) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149149 Approved by: https://github.com/anijain2305 --- torch/csrc/dynamo/guards.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 0b8dec86f98d..9df910662742 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -2047,7 +2047,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { bool result = check_nopybind(value); if (!result) { - std::string msg = "Shape guard failed with values: "; + std::string msg = "\"Shape guard failed with values: "; for (auto v : _args_int) { msg += std::to_string(v) + ","; } @@ -2055,6 +2055,7 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { msg += std::to_string(v) + ","; } msg.pop_back(); + msg += "\""; auto msgs = py::list(); for (auto code_part : verbose_code_parts()) { msgs.append(code_part); From f9a7eac718a5788ed0be23e88772ca833947fba5 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 2 Apr 2025 22:07:38 +0000 Subject: [PATCH 151/332] use python fallback if there are overflows (#149197) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149197 Approved by: https://github.com/anijain2305 ghstack dependencies: #149149 --- torch/_dynamo/guards.py | 148 ++++++++++++++++++++++------------------ 1 file changed, 81 insertions(+), 67 deletions(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 48631a7021f9..6fcaa875c144 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -1984,11 +1984,20 @@ def _get_code_parts(langs): ) if config.enable_cpp_symbolic_shape_guards: - # For exporting we need the python code parts - python_code_parts, verbose_code_parts, cpp_code_parts = _get_code_parts( - ("python", "verbose_python", "cpp") - ) + try: + # For exporting we need the python code parts + python_code_parts, verbose_code_parts, cpp_code_parts = _get_code_parts( + ("python", "verbose_python", "cpp") + ) + python_fallback = False + except OverflowError: + # Cannot use int64_t + python_fallback = True + python_code_parts, verbose_code_parts = _get_code_parts( + ("python", "verbose_python") + ) else: + python_fallback = True python_code_parts, verbose_code_parts = _get_code_parts( ("python", "verbose_python") ) @@ -2004,11 +2013,10 @@ def _get_code_parts(langs): if compile_context := CompileContext.try_get(): compile_context.shape_env_guards.extend(verbose_code_parts.exprs) - if config.enable_cpp_symbolic_shape_guards: - import ctypes - - from torch._inductor.codecache import CppCodeCache + int_source_to_symbol = [] + float_source_to_symbol = [] + if not python_fallback: assert cpp_code_parts # type: ignore[possibly-undefined] code_parts, source_to_symbol = ( cpp_code_parts.exprs, @@ -2018,10 +2026,6 @@ def _get_code_parts(langs): if not code_parts: return - int_source_to_symbol = [] - float_source_to_symbol = [] - - python_fallback = False for source, symbol in source_to_symbol.items(): if isinstance(source, ConstantSource): python_fallback = True @@ -2039,62 +2043,72 @@ def _get_code_parts(langs): # int64_t/double in C++ guards for now. python_fallback = True - if not python_fallback: - source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) - try: - guard_managers = [ - self.get_guard_manager_from_source(IndexedSource(source, i)) - for i, source in enumerate(source_to_symbol) - ] - - int_symbols_str = ", ".join( - f"{symbol} = int_values[{i}]" - for i, (_, symbol) in enumerate(int_source_to_symbol) - ) - float_symbols_str = ", ".join( - f"{symbol} = float_values[{i}]" - for i, (_, symbol) in enumerate(float_source_to_symbol) - ) + if not python_fallback: + import ctypes - if int_symbols_str: - int_symbols_str = f"int64_t {int_symbols_str};" - if float_symbols_str: - float_symbols_str = f"double {float_symbols_str};" - - func_str = textwrap.dedent( - f""" - #include - #include - #include - - extern "C" int8_t guard(int64_t *int_values, double *float_values) {{ - {int_symbols_str} - {float_symbols_str} - return ({") && (".join(code_parts)}); - }} - """ - ) - guards_log.debug( - "C++ shape guard function: %s %s", - func_str, - verbose_code_parts.exprs, - ) - clib = CppCodeCache.load(func_str) - cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value - assert cguard - except torch._inductor.exc.InvalidCxxCompiler: - # No valid C++ compiler to compile the shape guard - pass - else: - install_symbolic_shape_guard( - guard_managers, - len(int_source_to_symbol), - len(float_source_to_symbol), - cguard, - clib, - verbose_code_parts.exprs, - ) - return + from torch._inductor.codecache import CppCodeCache + + assert cpp_code_parts # type: ignore[possibly-undefined] + code_parts, source_to_symbol = ( + cpp_code_parts.exprs, + cpp_code_parts.source_to_symbol, + ) + + source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol) + try: + guard_managers = [ + self.get_guard_manager_from_source(IndexedSource(source, i)) + for i, source in enumerate(source_to_symbol) + ] + + int_symbols_str = ", ".join( + f"{symbol} = int_values[{i}]" + for i, (_, symbol) in enumerate(int_source_to_symbol) + ) + float_symbols_str = ", ".join( + f"{symbol} = float_values[{i}]" + for i, (_, symbol) in enumerate(float_source_to_symbol) + ) + + if int_symbols_str: + int_symbols_str = f"int64_t {int_symbols_str};" + if float_symbols_str: + float_symbols_str = f"double {float_symbols_str};" + + func_str = textwrap.dedent( + f""" + #include + #include + #include + + extern "C" int8_t guard(int64_t *int_values, double *float_values) {{ + {int_symbols_str} + {float_symbols_str} + return ({") && (".join(code_parts)}); + }} + """ + ) + guards_log.debug( + "C++ shape guard function: %s %s", + func_str, + verbose_code_parts.exprs, + ) + clib = CppCodeCache.load(func_str) + cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value + assert cguard + except torch._inductor.exc.InvalidCxxCompiler: + # No valid C++ compiler to compile the shape guard + pass + else: + install_symbolic_shape_guard( + guard_managers, + len(int_source_to_symbol), + len(float_source_to_symbol), + cguard, + clib, + verbose_code_parts.exprs, + ) + return # Install all the symbolic guards in one python lambda guard. These are run # at the very end of the RootGuardManager via epilogue guards. From a72b4eb80604f5f7997c7695cc8a63ca3f3c8ff1 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 2 Apr 2025 22:07:38 +0000 Subject: [PATCH 152/332] Support windows in C++ shape guards (#149211) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149211 Approved by: https://github.com/anijain2305 ghstack dependencies: #149149, #149197 --- torch/_dynamo/guards.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 6fcaa875c144..fe3a93be8644 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -2081,7 +2081,13 @@ def _get_code_parts(langs): #include #include - extern "C" int8_t guard(int64_t *int_values, double *float_values) {{ + #if defined(_MSC_VER) + # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport) + #else + # define EXTERN_DLL_EXPORT extern "C" + #endif + + EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{ {int_symbols_str} {float_symbols_str} return ({") && (".join(code_parts)}); From 5314a6fe82f7905e1617de93f08de99df26678dd Mon Sep 17 00:00:00 2001 From: angelayi Date: Thu, 3 Apr 2025 15:27:45 +0000 Subject: [PATCH 153/332] [export] Fix deserialization issue (#150515) An internal model was serialized in 2023, and is now breaking while loading with the following error: ``` File ".1675", line 4 def forward(self, arg1163_1, arg1164_1, , arg1166_1, , arg1168_1, arg1169_1, arg1170_1, , arg1172_1, arg1173_1, arg1174_1, arg1175_1, arg1176_1, arg1177_1, arg1178_1, arg1179_1, arg1180_1, arg1181_1, arg1182_1, arg1183_1, arg1184_1, arg1185_1, arg1186_1, arg1187_1, arg1188_1, arg1189_1, arg1190_1, arg1191_1, arg1192_1, arg1193_1, arg1194_1, arg1195_1, arg1196_1, arg1197_1, arg1198_1, arg1199_1, arg1200_1, arg1201_1, arg1202_1, arg1203_1, arg1204_1, arg1205_1, arg1206_1, arg1207_1, arg1208_1, arg1209_1, arg1210_1, arg1211_1, arg1212_1, arg1213_1, arg1214_1, arg1215_1, arg1216_1, , arg1218_1, arg1219_1, arg1220_1, arg1221_1, arg1222_1, arg1223_1, arg1224_1, , arg1226_1, arg1227_1, arg1228_1, , arg1230_1, , , , , , , , , , , , , , , ): ^ SyntaxError: invalid syntax ``` The syntax errors are due to inputs that are `None` when exporting. Prior to changes in https://github.com/pytorch/pytorch/pull/123590 (landed 4/2024), input specs for none inputs look like `InputSpec(userInput=UserInputSpec(arg=Argument(asNone=True)))`, and during deserialization when creating a node, we would just use a dummy name `arg`. After to those changes, the input specs for none inputs look like `InputSpec(constantInput=InputToConstantInputSpec(name='y', value=ConstantValue(asNone=True)))`, and when creating a node we would use the name `y` as the name. However the PR didn't handle the case if it's loading an old package which doesn't have this name, so ended up putting empty names in the placeholder nodes. This error was uncovered after https://github.com/pytorch/pytorch/pull/149717, where we now use the GraphModule's python codegen to run the UnflattenedModule instead of going through the interpreter path. The placeholder nodes having empty names caused the python codegen to fail. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150515 Approved by: https://github.com/yushangdi --- test/export/test_serialize.py | 27 +++++++++++++++++++++++++++ torch/_export/serde/serialize.py | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f5a324c7afdb..ae9f45cbeb21 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -16,6 +16,7 @@ import torch import torch._dynamo as torchdynamo +import torch._export.serde.schema as schema import torch.export._trace import torch.utils._pytree as pytree from torch._export.db.case import ExportCase, SupportLevel @@ -918,6 +919,32 @@ def forward(self, a, b, c): inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(2)) self.check_graph(Mod(), inp, use_pre_dispatch=False) + def test_none_input(self): + """ + Testing a backwards-compatibility breakage where old models do not have + an input spec with the node name. + """ + + class M(torch.nn.Module): + def forward(self, x, y, z): + return x + z + + ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) + + serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) + serialized_program.exported_program.graph_module.signature.input_specs[ + 1 + ] = schema.InputSpec.create( + user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) + ) + ep = ExportedProgramDeserializer(None).deserialize( + serialized_program.exported_program, {}, {}, {} + ) + ep.graph_module.recompile() + unflattened = torch.export.unflatten(ep) + inp = (torch.rand(3, 3), None, torch.rand(3, 3)) + self.assertEqual(unflattened(*inp), M()(*inp)) + def test_multi_return(self) -> None: """ Test multiple return from a single node (ex. layer_norm has 2 outputs) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 26ae80af1c6a..d630896f69c6 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1861,7 +1861,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: "as_none", "as_string", ): - node_name = self.signature.input_specs[i].arg.name + node_name = self.signature.input_specs[i].arg.name or f"arg{i}" placeholder_node = self.graph.placeholder(node_name) placeholder_node.meta["val"] = self.deserialize_input(input_) else: From 440c07e56aa31764825fa2b81f1eaa1e1466aa65 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Thu, 3 Apr 2025 12:20:13 +0000 Subject: [PATCH 154/332] Fix detection of GPU multicast (#150563) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150563 Approved by: https://github.com/kwen2501 --- c10/cuda/driver_api.h | 2 ++ caffe2/CMakeLists.txt | 1 + torch/csrc/distributed/c10d/cuda/utils.cpp | 2 ++ 3 files changed, 5 insertions(+) diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 65cbdfe878dc..d2eb495e8833 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -3,6 +3,8 @@ #define NVML_NO_UNVERSIONED_FUNC_DEFS #include +#include + #define C10_CUDA_DRIVER_CHECK(EXPR) \ do { \ CUresult __err = EXPR; \ diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index b850644fe977..71cc4b31a995 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -568,6 +568,7 @@ if(USE_CUDA) append_filelist("libtorch_cuda_distributed_extra_sources" Caffe2_GPU_SRCS) set_source_files_properties( ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp + ${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu diff --git a/torch/csrc/distributed/c10d/cuda/utils.cpp b/torch/csrc/distributed/c10d/cuda/utils.cpp index 7884be53a1a7..0072fab983f6 100644 --- a/torch/csrc/distributed/c10d/cuda/utils.cpp +++ b/torch/csrc/distributed/c10d/cuda/utils.cpp @@ -1,3 +1,5 @@ +#include + #include #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) From 5be5cfe4cba63265a8a286f274167cef97ba17ab Mon Sep 17 00:00:00 2001 From: David Berard Date: Thu, 3 Apr 2025 16:01:57 +0000 Subject: [PATCH 155/332] [inductor][autotune cache] add torch_key() to configs hash (#150494) Summary: **Context**: https://github.com/pytorch/pytorch/pull/150122 (D71982587 - let's call this "the WS diff") introduces "bc/fc-breaking" cache changes. In particular, it introduces `num_consumer_groups` and adds it to the cached config. In versions of torch that include the WS diff, `num_consumer_groups` is treated as a class variable on a triton.Config object (i.e. `triton.Config({..kwargs..}, num_consumer_groups=num_consumer_groups, ...`). And in versions of torch that don't include the WS diff, you generally don't expect to see this kwarg. But if a program is run WS-torch (i.e. torch w/ the WS diff), and then later you run the same program with non-WS-torch, then non-WS-torch is going to find this autotune cache entry, and interpret `num_consumer_groups` as a kwarg, because there's no special handling for for num_consumer_groups in this version of torch. Then the program crashes with a triton failure message. **The fix**: add the torch version / torch key into the hash, so that any changes to inductor will invalidate the cache (ensuring that other changes to triton_heuristics won't cause these bc/fc issues). Test Plan: D72285868 (or https://gist.github.com/davidberard98/2ea697eb550c94d0d1948fedb5c5c7d8, but this doesn't repro in OSS because this version of warp specialization is not available in oss triton) can repro the failure, and the failure is fixed after this PR is patched. Also, added a test in test/inductor/test_codecache.py which verifies that there's no cache hit if the torch_key changes (and verified that without the functional changes in this PR, the test fails). Differential Revision: D72285303 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150494 Approved by: https://github.com/oulgen --- test/inductor/test_codecache.py | 64 +++++++++++++++++++++++ torch/_inductor/runtime/autotune_cache.py | 21 +++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 1cb4b4f96dfc..bb86d143621e 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import functools import os import pickle import shutil @@ -1770,6 +1771,69 @@ def f(a, b, c, d, e, f): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+") + @requires_triton() + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not SM80OrLater, "Requires SM80+") + @config.patch({"fx_graph_cache": False}) + @config.patch({"fx_graph_remote_cache": False}) + @config.patch({"bundled_autotune_remote_cache": False}) + @config.patch({"max_autotune": True}) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly + @parametrize("remote_cache", (True, False)) + def test_modified_autotune_cache(self, remote_cache): + """ + If a developer changes the way the autotune cache is handled, + there's a chance it'll break the cache. This happened with + #150122. This test ensures that if torch code changes, then + old cache entries will be invalidated. + """ + + def mock_torch_key(value: str) -> bytes: + return value.encode("utf-8") + + def get_autotune_stats(): + if remote_cache: + return global_stats.autotune_remote + return global_stats.autotune_local + + def fn(x, y): + return (x + y).relu() + + x = torch.randn(100, 100).cuda() + y = torch.randn(100, 100).cuda() + + with config.patch( + { + "autotune_local_cache": not remote_cache, + "autotune_remote_cache": remote_cache, + } + ): + with PatchCaches(): + with mock.patch( + "torch._inductor.codecache.torch_key", + functools.partial(mock_torch_key, "torchkey1"), + ): + f_compiled = torch.compile(fn, fullgraph=True) + res1 = f_compiled(x, y) + + self.assertEqual(get_autotune_stats(), Stats(1, 0, 1)) + + torch._dynamo.reset() + PyCodeCache.cache_clear() + + with mock.patch( + "torch._inductor.codecache.torch_key", + functools.partial(mock_torch_key, "torchkey2"), + ): + f_compiled = torch.compile(fn, fullgraph=True) + res2 = f_compiled(x, y) + + self.assertEqual(get_autotune_stats(), Stats(2, 0, 2)) + + self.assertEqual(res1, res2) + class TestRemoteAOTAutogradCache(TestCase): @unittest.skipIf(not HAS_CUDA, "Requires CUDA") diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index d19a96a85604..4988f3780812 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -121,7 +121,21 @@ def _setup_local_cache( if not inductor_meta.get("autotune_local_cache", True): return - cache_filename = f"{dirname}/{cache_key}.best_config" + from ..codecache import torch_key + + """ + [Note: torch_key in autotune cache key] + Include torch_key() in the cache key so that different versions + of torch result in cache invalidation. This is important in case + of changes to the best_config format or other code changes that + are not backward compatible w.r.t. the cache. + """ + hasher = hashlib.sha256() + hasher.update(cache_key.encode("utf-8")) + hasher.update(torch_key()) + updated_cache_key = hasher.hexdigest() + + cache_filename = f"{dirname}/{updated_cache_key}.best_config" local_cache = LocalAutotuneCache() self.local_cache = (local_cache, cache_filename) @@ -139,10 +153,13 @@ def _setup_remote_autotune_cache( return assert isinstance(backend_hash, str) + from ..codecache import torch_key + is_fbcode = bool(inductor_meta.get("is_fbcode", False)) salt = "autotune-best-config-v2" - key = backend_hash + self.configs_hash + salt + # re: torch_key - see [Note: torch_key in autotune cache key] + key = torch_key().hex() + backend_hash + self.configs_hash + salt key = hashlib.sha256(key.encode("utf-8")).hexdigest() remote_cache = create_cache( From fa0fdc0ccac69919f552adb01fdba9d8eb9494de Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 3 Apr 2025 16:18:59 +0000 Subject: [PATCH 156/332] if blaslt fails, fall back to blas (#150147) Fixes #150016. This is implemented for both cublaslt and hipblaslt. gemm_and_bias on failure will fall back to unfused path. lt gemm on failure falls back to gemm even if gemm preference is set to lt. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150147 Approved by: https://github.com/malfet --- aten/src/ATen/cuda/CUDABlas.cpp | 91 +++++++++++++++++++----------- aten/src/ATen/cuda/CUDABlas.h | 2 +- aten/src/ATen/native/cuda/Blas.cpp | 18 +++++- 3 files changed, 74 insertions(+), 37 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index c5dd44dc1edf..4f5e511c33bc 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -335,7 +335,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< template -static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { +static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { cudaDataType_t abcType = CUDA_R_32F; cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; cudaDataType_t scaleType = CUDA_R_32F; @@ -426,6 +426,7 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( @@ -440,10 +441,10 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { &heuristicResult, &returnedResult)); if (returnedResult == 0) { - TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); + cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; } - - cublasStatus_t cublasStatus = cublasLtMatmul( + else { + cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), alpha_ptr, @@ -460,9 +461,10 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { workspace.mutable_data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - cublasStatus == CUBLAS_STATUS_SUCCESS, - "CUDA error: ", + } + if (cublasStatus != CUBLAS_STATUS_SUCCESS) { + TORCH_WARN( + "bgemm_internal_cublaslt error: ", at::cuda::blas::_cublasGetErrorEnum(cublasStatus), " when calling cublasLtMatmul with transpose_mat1 ", (opa == CUBLAS_OP_T), @@ -485,7 +487,11 @@ static inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { " computeType ", computeType, " scaleType ", - scaleType); + scaleType, + ". Will attempt to recover by calling cublas instead."); + return false; + } + return true; } @@ -646,7 +652,9 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) // hipblaslt does not support double gemm yet bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(double)); #else - bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(double)); + if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(double))) { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(double)); + } #endif } else { @@ -658,7 +666,9 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float)); + if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float))) { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); + } } else { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); @@ -673,7 +683,9 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex gemm yet bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); #else - bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex)); + if (!bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex))) { + bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } #endif } else { @@ -689,7 +701,9 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex gemm yet bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); #else - bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex)); + if (!bgemm_internal_cublaslt>(CUDABLAS_BGEMM_ARGS(c10::complex))) { + bgemm_internal_cublas>(CUDABLAS_BGEMM_ARGS(c10::complex)); + } #endif } else { @@ -701,7 +715,9 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half)); + if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half))) { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); + } } else { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); @@ -712,7 +728,9 @@ template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { - bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + if (!bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16))) { + bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); + } } #if defined(USE_ROCM) && !defined(_MSC_VER) else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { @@ -835,18 +853,11 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { } } -template -inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { - // forward to bgemm implementation but set strides and batches to 0 - bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); -} - template inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented"); } - template <> void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] @@ -1056,6 +1067,14 @@ void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); } +template +inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // forward to bgemm implementation but set strides and batches to 0 + if (!bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0)) { + gemm_internal_cublas(CUDABLAS_GEMM_ARGS(Dtype)); + } +} + template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) { @@ -1270,7 +1289,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { template -void gemm_and_bias( +bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1387,11 +1406,12 @@ void gemm_and_bias( 1, &heuristicResult, &returnedResult)); + cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; if (returnedResult == 0) { - TORCH_CUDABLAS_CHECK(CUBLAS_STATUS_NOT_SUPPORTED); + cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; } - - cublasStatus_t cublasStatus = cublasLtMatmul( + else { + cublasStatus = cublasLtMatmul( ltHandle, computeDesc.descriptor(), alpha_ptr, @@ -1408,9 +1428,10 @@ void gemm_and_bias( workspace.mutable_data_ptr(), workspaceSize, at::cuda::getCurrentCUDAStream()); - TORCH_CHECK( - cublasStatus == CUBLAS_STATUS_SUCCESS, - "CUDA error: ", + } + if (cublasStatus != CUBLAS_STATUS_SUCCESS) { + TORCH_WARN( + "gemm_and_bias error: ", at::cuda::blas::_cublasGetErrorEnum(cublasStatus), " when calling cublasLtMatmul with transpose_mat1 ", transpose_mat1, @@ -1433,10 +1454,14 @@ void gemm_and_bias( " computeType ", computeType, " scaleType ", - scaleType); + scaleType, + ". Will attempt to recover by calling unfused cublas path."); + return false; + } + return true; } -template void gemm_and_bias( +template bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1452,7 +1477,7 @@ template void gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template void gemm_and_bias( +template bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1468,7 +1493,7 @@ template void gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template void gemm_and_bias( +template bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, @@ -1484,7 +1509,7 @@ template void gemm_and_bias( int64_t result_ld, GEMMAndBiasActivationEpilogue activation); -template void gemm_and_bias( +template bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 637b48c797fa..b65a7c79ee10 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -91,7 +91,7 @@ enum GEMMAndBiasActivationEpilogue { // NOTE: GELU activation is not supported prior to CUDA 11.4 and will // do nothing if passed in that case. template -void gemm_and_bias( +bool gemm_and_bias( bool transpose_mat1, bool transpose_mat2, int64_t m, diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 906b24652e49..dd04e58c3721 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -318,7 +318,7 @@ static void launchTunableGemmAndBias(cublasCommonArgs &args, const Scalar& alpha } } -Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None) { +Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Activation activation=Activation::None, bool disable_addmm_cuda_lt_override=false) { // Make sure to keep addmm_cuda below in sync with this code; it // preflights a check to try to avoid actually needing to call // expand(). @@ -344,6 +344,8 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma #else static bool disable_addmm_cuda_lt = getDisableAddmmCudaLt(); #endif + // if lt path fails, we recurse back into this function here and force the lt path to off + disable_addmm_cuda_lt |= disable_addmm_cuda_lt_override; at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned self_; if (&result != &self) { @@ -438,6 +440,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma if (useLtInterface) { #if defined(USE_ROCM) + bool okay = true; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -453,7 +456,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_to_gemm_and_blas_arg(activation)); } else { - at::cuda::blas::gemm_and_bias( + okay = at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', args.m, @@ -472,6 +475,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_to_gemm_and_blas_arg(activation) ); }}); + if (!okay) { + // lt path failed; recurse but disable lt path + return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); + } #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); #if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) @@ -483,6 +490,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None; #endif + bool okay = true; AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -498,7 +506,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue); } else { - at::cuda::blas::gemm_and_bias( + okay = at::cuda::blas::gemm_and_bias( args.transa == 't', args.transb == 't', args.m, @@ -515,6 +523,10 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue ); }}); + if (!okay) { + // lt path failed; recurse but disable lt path + return addmm_out_cuda_impl(result, self, mat1, mat2, beta, alpha, activation, true); + } #endif } else { From 5d36253a7dd654ad71102a8867b22d37fddc9f19 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Thu, 3 Apr 2025 15:30:15 +0800 Subject: [PATCH 157/332] Refactoring: fix the python constant check (#150608) As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150608 Approved by: https://github.com/Skylion007 --- torch/_dynamo/variables/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 383ca5b4c343..5360868dd7e7 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1804,11 +1804,11 @@ def call_getattr( name_var: VariableTracker, default=None, ): - name = name_var.as_python_constant() - if not name_var.is_python_constant(): unimplemented("non-const getattr() name") + name = name_var.as_python_constant() + if tx.output.side_effects.is_attribute_mutation(obj): if isinstance(obj, variables.UnspecializedNNModuleVariable): if ( From 78d1165d7605db7b01962ee8bbc69d6bd5945580 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:02:16 -0700 Subject: [PATCH 158/332] [DTensor][tp] fix errors in FSDP+TP checkpointing test (#150354) ## Summary remove the `tp_parallelize_plan` assignment that accidentally rewrites the previous assignments in `test_fsdp_dsd.py`. ## Test `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150354 Approved by: https://github.com/wconstab --- test/distributed/checkpoint/fsdp/test_fsdp_dsd.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py index fac49dd2786f..f8d90d3677e1 100644 --- a/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py +++ b/test/distributed/checkpoint/fsdp/test_fsdp_dsd.py @@ -482,15 +482,6 @@ def _get_base_model(mlp_dim: int = 2): tp_parallelize_plan.pop("0.out_proj") with cm: - tp_parallelize_plan = { - "0.in_proj": ColwiseParallel(), - "0.out_proj": RowwiseParallel(), - "1.in_proj": ColwiseParallel(), - "1.out_proj": RowwiseParallel(), - "2.in_proj": ColwiseParallel(), - "2.out_proj": RowwiseParallel(), - } - # init device mesh dp_size = 2 global_mesh_1d = init_device_mesh( From 96f35f55e2676cfa76c28fb8f88e9f3cde08c59c Mon Sep 17 00:00:00 2001 From: ZhaoqiongZ <106125927+ZhaoqiongZ@users.noreply.github.com> Date: Thu, 3 Apr 2025 18:17:02 +0000 Subject: [PATCH 159/332] update get start xpu document for v2.7 (#150397) update get start xpu document for v2.7 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150397 Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/atalman Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- docs/source/notes/get_start_xpu.rst | 51 ++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index dce6d126dce3..d5f140a3db0b 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -4,27 +4,46 @@ Getting Started on Intel GPU Hardware Prerequisite --------------------- +For Intel Data Center GPU + .. list-table:: - :widths: 50 50 + :widths: 50 50 50 50 :header-rows: 1 - * - Supported OS - - Validated Hardware - * - Linux - - Intel® Client GPUs / Intel® Data Center GPU Max Series - * - Windows - - Intel® Client GPUs - * - WSL2 (experimental feature) - - Intel® Client GPUs - -Intel GPUs support (Prototype) is ready in PyTorch* 2.6 for Intel® Client GPUs and Intel® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. + * - Device + - Red Hat* Enterprise Linux* 9.2 + - SUSE Linux Enterprise Server* 15 SP5 + - Ubuntu* Server 22.04 (>= 5.15 LTS kernel) + * - Intel® Data Center GPU Max Series (CodeName: Ponte Vecchio) + - yes + - yes + - yes + +For Intel Client GPU + ++-------------------------------------+----------------------------------------------------------------------------------------------+ +| Supported OS | Validated Hardware | ++=====================================+==============================================================================================+ +|| Windows 10/11 & Ubuntu 24.10 || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Arc B-Series Graphics (CodeName: Battlemage) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | +|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | ++-------------------------------------+----------------------------------------------------------------------------------------------+ +|| Ubuntu 24.04 & WSL2 (Ubuntu 24.04) || Intel® Arc A-Series Graphics (CodeName: Alchemist) | +|| || Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake) | +|| || Intel® Core™ Ultra 200V Series with Intel® Arc™ Graphics (CodeName: Lunar Lake) | +|| || Intel® Core™ Ultra Series 2 Processors with Intel® Arc™ Graphics (CodeName: Arrow Lake) | ++-------------------------------------+----------------------------------------------------------------------------------------------+ + +Intel GPUs support (Prototype) is ready from PyTorch* 2.5 for Intel® Client GPUs and Intel® Data Center GPU Max Series on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. Software Prerequisite --------------------- -To use PyTorch on Intel GPUs, you need to install the Intel GPUs driver first. For installation guide, visit `Intel GPUs Driver Installation `_. +To use PyTorch on Intel GPUs, you need to install the Intel GPUs driver first. For installation guide, visit `Intel GPUs Driver Installation `_. -Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. +Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. Installation @@ -33,7 +52,7 @@ Installation Binaries ^^^^^^^^ -Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. +Now that we have `Intel GPU Driver `_ installed, use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. For release wheels @@ -52,7 +71,7 @@ For nightly wheels From Source ^^^^^^^^^^^ -Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed. Follow guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. +Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed. Follow guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. Build from source for ``torch`` refer to `PyTorch Installation Build from source `_. @@ -88,7 +107,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. +#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to Use Inductor on Windows with CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples From 3b02f795c5ad2339794b15b370c0e4a235d36adf Mon Sep 17 00:00:00 2001 From: "Jiang, Yanbing" Date: Thu, 3 Apr 2025 19:43:45 +0000 Subject: [PATCH 160/332] Add torch._scaled_mm for CPU (#150410) This PR is the duplicated one for https://github.com/pytorch/pytorch/pull/139975. This PR is to add torch._scaled_mm for CPU backend. _scaled_mm_out_cpu and _scaled_mm_cpu are new added and included in torch._scaled_mm CPU dispatch. We also add _scaled_mm_out_cpu_emulated as a fallback function if the current platform cannot run FP8 matmul using oneDNN. And this PR also updates the various UTs related to FP8 to support CPU tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150410 Approved by: https://github.com/atalman --- aten/src/ATen/native/Blas.cpp | 96 +++++++++++++ aten/src/ATen/native/mkldnn/Linear.cpp | 126 +++++++++++++++++- aten/src/ATen/native/mkldnn/Linear.h | 12 ++ aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp | 22 ++- aten/src/ATen/native/native_functions.yaml | 2 + test/inductor/test_fp8.py | 113 ++++++++++------ test/test_matmul_cuda.py | 23 ++-- torch/_inductor/codegen/cpp_prefix.h | 4 + .../aoti_torch/generated/c_shim_cpu.h | 2 + torch/testing/_internal/common_device_type.py | 2 + .../_internal/common_methods_invocations.py | 20 ++- 11 files changed, 364 insertions(+), 58 deletions(-) diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index f62c31777822..560a8f7657a8 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -7,6 +7,11 @@ #include #include +#include +#include +#if !defined(__s390x__) && !defined(__powerpc__) +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -24,6 +29,9 @@ #include #include #include +#include +#include +#include #endif namespace at::meta { @@ -222,4 +230,92 @@ Tensor vdot(const Tensor &self, const Tensor &other){ } +static Tensor& +_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum, + Tensor& out) { + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + + TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); + TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], + " but got ", bias->numel()); + + // Check types + TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); + TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + + auto mat1_c = mat1.contiguous(); + auto mat2_c = mat2.contiguous(); + IntArrayRef mat1_sizes = mat1_c.sizes(); + IntArrayRef mat2_sizes = mat2_c.sizes(); + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + + float input_scale = scale_a.item(); + float weight_scale = scale_b.item(); + auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale); + auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale); + auto out_tmp = at::matmul(fp32_mat1, fp32_mat2); + if (bias) { + out_tmp.add_(bias.value()); + } + out_tmp = out_tmp.to(out.scalar_type()); + out.copy_(out_tmp); + return out; +} + +Tensor& +_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum, + Tensor& out) { +#if AT_MKLDNN_ENABLED() + if (at::globalContext().userEnabledMkldnn()) { + bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type(); + if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) || + (mixed_dtype && cpuinfo_has_x86_amx_fp16())) { + return mkldnn_scaled_mm( + mat1, + mat2, + scale_a, + scale_b, + bias, + scale_result, + out_dtype, + use_fast_accum, + out); + } + } +#endif + { + return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + } +} + +Tensor +_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); + return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); +} + } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 8153ae8a4d8e..b1175b796224 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -46,8 +47,19 @@ std::tuple mkldnn_linear_backward( TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support"); } -} // namespace at::native +Tensor& +mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum, + Tensor& out) { + TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support"); +} +} // namespace at::native #else // AT_MKLDNN_ENABLED @@ -459,6 +471,118 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { TORCH_FN(mkldnn_linear_pointwise_binary)); } +Tensor& +mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum, + Tensor& out) { + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + + TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); + TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], + " but got ", bias->numel()); + + // Check types + TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); + TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + + // Validation checks have passed lets resize the output to actual size + auto mat1_c = mat1.contiguous(); + auto mat2_c = mat2.contiguous(); + IntArrayRef mat1_sizes = mat1_c.sizes(); + IntArrayRef mat2_sizes = mat2_c.sizes(); + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + + float input_scale = scale_a.item(); + float weight_scale = scale_b.item(); + auto src = at::native::itensor_view_from_dense(mat1_c); + auto weight_t = at::native::itensor_view_from_dense(mat2_c); + bool with_bias = bias.has_value(); + int64_t K = mat1_sizes[1], M = mat1_sizes[0], + N = mat2_sizes[1]; + + std::vector src_dims = {M, K}; + std::vector weight_dims = {K, N}; + std::vector dst_dims = {M, N}; + + ideep::tensor dst = at::native::itensor_view_from_dense(out); + auto src_desc = ideep::tensor::desc( + src_dims, + get_mkldnn_dtype(mat1.scalar_type()), + ideep::format_tag::any); + auto weights_desc = ideep::tensor::desc( + weight_dims, + get_mkldnn_dtype(mat2.scalar_type()), + ideep::format_tag::any); + auto dst_desc = ideep::tensor::desc( + dst_dims, + get_mkldnn_dtype(out.scalar_type()), + ideep::format_tag::any); + ideep::tensor onednn_bias; + if (with_bias) { + auto bias_value = bias.value(); + if (bias_value.dim() == 1) { + auto b_reshape = bias_value.reshape({1, bias_value.size(0)}); + onednn_bias = at::native::itensor_view_from_dense(b_reshape); + } else { + onednn_bias = at::native::itensor_view_from_dense(bias_value); + } + } + auto bias_desc = ideep::tensor::desc(); + if (with_bias) { + bias_desc = ideep::tensor::desc(onednn_bias.get_dims(), + get_mkldnn_dtype(bias.value().scalar_type()), + ideep::format_tag::any); + } + auto op_attr = ideep::attr_t(); + if (input_scale != 1.0f) { + op_attr.set_scales_mask(DNNL_ARG_SRC, 0); + } + if (weight_scale != 1.0f) { + op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } + + op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto engine = ideep::engine::cpu_engine(); + dnnl::matmul::primitive_desc primitive_desc = with_bias + ? dnnl::matmul::primitive_desc( + engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) + : dnnl::matmul::primitive_desc( + engine, src_desc, weights_desc, dst_desc, op_attr); + auto expected_weight = weight_t.reorder_if_differ_in(primitive_desc.weights_desc()); + auto primitive = dnnl::matmul(primitive_desc); + + // Prepare args and execute primitive + ideep::tensor scratchpad(primitive_desc.scratchpad_desc()); + ideep::exec_args args; + args.insert({DNNL_ARG_SRC, src}); + args.insert({DNNL_ARG_WEIGHTS, expected_weight}); + args.insert({DNNL_ARG_DST, dst}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + if (with_bias) { + args.insert({DNNL_ARG_BIAS, onednn_bias}); + } + ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale)); + ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale)); + + if (input_scale != 1.0f) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t}); + } + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t}); + + primitive.execute(ideep::stream::default_stream(), args); + return out; +} + } // namespace at #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/Linear.h b/aten/src/ATen/native/mkldnn/Linear.h index 6a7fcd60b0e6..1dc50c7c5416 100644 --- a/aten/src/ATen/native/mkldnn/Linear.h +++ b/aten/src/ATen/native/mkldnn/Linear.h @@ -35,3 +35,15 @@ C10_API Tensor mkl_linear( } // namespace at #endif // AT_MKLDNN_ENABLED() + +namespace at::native { +Tensor& +mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + bool use_fast_accum, + Tensor& out); +} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index 32daef37a563..f26427a981f7 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -57,6 +57,10 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) { return ideep::tensor::data_type::bf16; case ScalarType::Half: return ideep::tensor::data_type::f16; + case ScalarType::Float8_e4m3fn: + return ideep::tensor::data_type::f8_e4m3; + case ScalarType::Float8_e5m2: + return ideep::tensor::data_type::f8_e5m2; default: TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type"); } @@ -161,8 +165,24 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data const_cast(tensor.const_data_ptr()) : tensor.data_ptr()}; } + else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) { + return {{tensor.sizes().vec(), + ideep::tensor::data_type::f8_e4m3, + tensor.strides().vec()}, + from_const_data_ptr ? + const_cast(tensor.const_data_ptr()) : + tensor.data_ptr()}; + } + else if (tensor.scalar_type() == ScalarType::Float8_e5m2) { + return {{tensor.sizes().vec(), + ideep::tensor::data_type::f8_e5m2, + tensor.strides().vec()}, + from_const_data_ptr ? + const_cast(tensor.const_data_ptr()) : + tensor.data_ptr()}; + } else { - TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input"); + TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input"); } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e3a1cd175c86..c574130ac43d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7063,11 +7063,13 @@ - func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: + CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: + CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index e208565081a1..8f36b2930f00 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -13,7 +13,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA from torch.utils._triton import has_triton_tma_device @@ -116,9 +116,9 @@ def _fix_fp8_dtype_for_rocm( @instantiate_parametrized_tests class TestFP8Types(TestCase): - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) - def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): + @parametrize("device", ("cuda", "cpu")) + def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): """ TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 depends on the variant of fp8 type. @@ -128,29 +128,33 @@ def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): We should not pick a XBLOCK larger than xnumel """ float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) def f(x): return x.to(dtype=float8_dtype) - x = torch.randn(1, device="cuda") + x = torch.randn(1, device=device) expected = f(x) actual = torch.compile(f)(x) torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16)) - def test_eager_fallback(self, dtype: torch.dtype): + @parametrize("device", ("cuda", "cpu")) + def test_eager_fallback(self, dtype: torch.dtype, device: torch.device): + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) weight_shape = (32, 16) e4m3_type = torch.float8_e4m3fn e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda") def fp8_matmul_unwrapped(x): - a_scale = torch.Tensor([1.0]).to(device="cuda") - b_scale = torch.Tensor([1.0]).to(device="cuda") + a_scale = torch.Tensor([1.0]).to(device=device) + b_scale = torch.Tensor([1.0]).to(device=device) output_scale = None - input_bias = torch.rand(32, device="cuda", dtype=dtype) - weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to( + input_bias = torch.rand(32, device=device, dtype=dtype) + weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to( e4m3_type ) a_inverse_scale = 1 / a_scale @@ -171,19 +175,24 @@ def fp8_matmul_unwrapped(x): ) x_shape = (16, 16) - x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 x_shape = (15, 16) - x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("shape", ("15,3,13", "4,2048,4096")) @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) - def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): - dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") + @parametrize("device", ("cuda", "cpu")) + def test_valid_cast( + self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device + ): + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) + if device == "cuda": + dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") e4m3, e5m2 = dst_types def fp8_cast(x): @@ -194,7 +203,7 @@ def fp8_cast(x): compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device="cuda", dtype=dtype) + x = torch.rand(*shape, device=device, dtype=dtype) y0_fp8, y1_fp8 = compiled_fp8_cast(x) torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) @@ -223,14 +232,21 @@ def fp8_cast(x, dtype): x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) compiled_fp8_cast(x, torch.float8_e4m3fn) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("16,16,16", "4,2048,4096")) + @parametrize("device", ("cuda", "cpu")) def test_to_fp8_saturated( - self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str + self, + src_dtype: torch.dtype, + dst_dtype: torch.dtype, + shape: str, + device: torch.device, ): - dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) + if device == "cuda": + dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") def fp8_saturated(x, dtype): return _to_fp8_saturated(x, dtype) @@ -239,17 +255,22 @@ def fp8_saturated(x, dtype): fp8_saturated, backend="inductor", dynamic=True ) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device="cuda", dtype=src_dtype) + x = torch.rand(*shape, device=device, dtype=src_dtype) y_compiled = compiled_fp8_cast(x, dst_dtype) y = fp8_saturated(x, dst_dtype) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + @parametrize("device", ("cuda", "cpu")) + def test_amax_fp8_quant( + self, float8_dtype: torch.dtype, shape: str, device: torch.device + ): + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest( + "FP8 is only supported on H100+ and sm_89 and MI300+ devices" + ) shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -262,19 +283,24 @@ def amax_fp8(x: Tensor, scale: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device="cuda", dtype=torch.half) - scale = torch.tensor(0.2, device="cuda", dtype=torch.float) + x = torch.rand(*x_shape, device=device, dtype=torch.half) + scale = torch.tensor(0.2, device=device, dtype=torch.float) y_compiled = compiled_amax_fp8_quant(x, scale) y = amax_fp8(x, scale) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + @parametrize("device", ("cuda", "cpu")) + def test_amax_along_with_fp8_quant( + self, float8_dtype: torch.dtype, shape: str, device: torch.device + ): + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) + if device == "cuda": + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -287,12 +313,12 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device="cuda", dtype=torch.half) - scale = torch.tensor(1.0, device="cuda", dtype=torch.float) + x = torch.rand(*x_shape, device=device, dtype=torch.half) + scale = torch.tensor(1.0, device=device, dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) + amax_buffer = torch.zeros((1), device=device, dtype=torch.half) y = amax_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -300,14 +326,21 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) + @parametrize("device", ("cuda", "cpu")) def test_layernorm_fp8_quant( - self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str + self, + float8_dtype: torch.dtype, + amax_keep_dim: bool, + shape: str, + device: torch.device, ): - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + if device == "cuda" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest( + "FP8 is only supported on H100+ and sm_89 and MI300+ devices" + ) shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -329,12 +362,12 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device="cuda", dtype=torch.half) - scale = torch.tensor(0.2, device="cuda", dtype=torch.float) + x = torch.rand(*x_shape, device=device, dtype=torch.half) + scale = torch.tensor(0.2, device=device, dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) + amax_buffer = torch.zeros((1), device=device, dtype=torch.half) y = ln_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -750,5 +783,5 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": - if HAS_CUDA: + if HAS_CUDA or HAS_CPU: run_tests() diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 49da165ca20e..17ece41af239 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -484,15 +484,15 @@ def _bfloat16_to_float4_e2m1fn_x2(x): return x -@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") -class TestFP8MatmulCuda(TestCase): +class TestFP8Matmul(TestCase): - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, y_dtype: torch.dtype = e4m3_type, out_dtype: Optional[torch.dtype] = None, size: int = 16) -> None: + if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) x_fp8 = torch.rand(size, size, device=device).to(x_dtype) y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) @@ -503,12 +503,13 @@ def _test_tautological_mm(self, device: str = "cuda", self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_basics(self, device) -> None: + if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # supported on ROCm but fails on CUDA - ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() + ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext() with ctx: self._test_tautological_mm(device, e5m2_type, e5m2_type) @@ -519,11 +520,12 @@ def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) - with self.assertRaises(AssertionError if torch.version.hip else RuntimeError): + with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError): self._test_tautological_mm(device, out_dtype=e5m2_type) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_scale(self, device) -> None: + if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) size = (16, 16) x = torch.full(size, .5, device=device, dtype=e4m3_type) # hipblaslt does not yet support mixed e4m3_type input @@ -638,8 +640,9 @@ def test_scaled_mm_change_stride(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias(self, device) -> None: + if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest(f8_msg) (k, l, m) = (16, 48, 32) x = torch.ones((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() @@ -692,7 +695,7 @@ def test_float32_output_errors_with_bias(self, device) -> None: lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), ) - @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) + @unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg) def test_error_message_fp8_pre_sm89(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) @@ -1548,8 +1551,8 @@ def run_test( ) instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") -instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") +instantiate_device_type_tests(TestFP8Matmul, globals()) if __name__ == '__main__': TestCase._default_dtype_check_enabled = True diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 8254363cbdcb..3a00ce1e3015 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -21,6 +21,8 @@ #include #include +#include +#include #include #include #include @@ -48,6 +50,8 @@ typedef at::BFloat16 bfloat16; typedef at::Float8_e4m3fn float8_e4m3fn; typedef at::Float8_e5m2 float8_e5m2; +typedef at::Float8_e4m3fnuz float8_e4m3fnuz; +typedef at::Float8_e5m2fnuz float8_e5m2fnuz; template struct Welford { diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 682364e950c4..55085ee1be7b 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -37,6 +37,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attent AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 9cd0661cac15..4ec7eb34a5dc 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1003,6 +1003,8 @@ class OpDTypes(Enum): torch.int8, torch.uint8, torch.bool, + torch.float8_e4m3fn, + torch.float8_e5m2, ) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24f651020d75..d16d31d42684 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -22,7 +22,7 @@ from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -16217,7 +16217,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo( 'torch._scaled_mm', sample_inputs_func=sample_inputs_scaled_mm, - dtypes=empty_types(), + dtypes=float8_types(), dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), supports_out=True, supports_forward_ad=False, @@ -16225,12 +16225,20 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], skips=( # Sample inputs isn't really parametrized on dtype - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', - device_type='cuda'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), + # "add_stub" not implemented for 'Float8_e4m3fn' + # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' + # https://github.com/pytorch/pytorch/issues/107256 + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), # "mul_cuda" not implemented for float8_e4m3fn + # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' # https://github.com/pytorch/pytorch/issues/107256 - DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', - dtypes=(torch.float8_e4m3fn,)), + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), + # aten::_scaled_mm hit the vmap fallback which is currently disabled + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), + DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), + DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', + dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), ) ), OpInfo( From 1843ad458d7a511fa4e6aa9520ece133d52fefd2 Mon Sep 17 00:00:00 2001 From: Kai Londenberg Date: Thu, 3 Apr 2025 19:47:23 +0000 Subject: [PATCH 161/332] [Inductor] Cache CUDA compilation errors (#149716) Summary: Add support for caching of CUDA (nvcc) compilation errors to codecache.py Test Plan: CI ( for example Cutlass backend unit tests ) Reviewed By: ColinPeppler Differential Revision: D71562040 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149716 Approved by: https://github.com/ColinPeppler --- torch/_inductor/codecache.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index aaad5a53486a..177b53e3e999 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3164,6 +3164,7 @@ class CUDACodeCache: class CacheEntry: input_path: str output_path: str + error_json: Optional[str] = None cache: dict[str, CacheEntry] = {} cache_clear = staticmethod(cache.clear) @@ -3200,6 +3201,14 @@ def compile( lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT) with lock: output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext + if os.path.exists(output_path + ".error"): + with open(output_path + ".error", encoding="utf-8") as fh: + error_json = fh.read() + cmd_parts, error_output = json.loads(error_json) + cls.cache[key] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + raise exc.CUDACompileError(cmd_parts, error_output) if not os.path.exists(output_path): cmd = cuda_compile_command( [input_path], output_path, dst_file_ext, extra_args @@ -3215,6 +3224,14 @@ def compile( cmd_parts, stderr=subprocess.STDOUT, env=os.environ ) except subprocess.CalledProcessError as error: + error_json = json.dumps( + [cmd_parts, error.output.decode("utf-8")] + ) + cls.cache[key] = CUDACodeCache.CacheEntry( + input_path, output_path, error_json + ) + with open(output_path + ".error", "w", encoding="utf-8") as fh: + fh.write(error_json) raise exc.CUDACompileError(cmd_parts, error.output) from error end_time = time() log_duration_msg = f"CUDA Compilation took {end_time - start_time} seconds. Compile command: {cmd}" @@ -3224,8 +3241,12 @@ def compile( "CUDA Compilation skipped: %s since output already exists", input_path, ) - cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path) - + cls.cache[key] = CUDACodeCache.CacheEntry(input_path, output_path, None) + cache_entry: CUDACodeCache.CacheEntry = cls.cache[key] + if cache_entry.error_json is not None: + # Restore cached Exception and raise it as if we had compiled + cmd_parts, error_output = json.loads(cache_entry.error_json) + raise exc.CUDACompileError(cmd_parts, error_output.encode("utf-8")) return (cls.cache[key].output_path, key, input_path) @classmethod From c1d503529d23f33bc0819286df8d0ecbe31b559f Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Wed, 2 Apr 2025 22:07:39 +0000 Subject: [PATCH 162/332] Enable C++ dynamic shape guards by default (#140756) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140756 Approved by: https://github.com/anijain2305 ghstack dependencies: #149149, #149197, #149211 --- torch/_dynamo/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b59e1c49e607..1829eaef63ff 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -396,7 +396,7 @@ enable_cpp_guard_manager = True # Use C++ guard manger for symbolic shapes -enable_cpp_symbolic_shape_guards = False +enable_cpp_symbolic_shape_guards = not is_fbcode() # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True From 51da241c0a478ae3324ff074517dd7d54305d62b Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Apr 2025 20:06:12 +0000 Subject: [PATCH 163/332] [aoti] Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering (#150570) Summary: Fix cannot determine truth value of Relation error when propagating unbacked symint in lowering Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts ``` Differential Revision: D72331070 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150570 Approved by: https://github.com/angelayi, https://github.com/henryoier --- test/inductor/test_aot_inductor.py | 50 ++++++++++++++++++++++++++++++ torch/_subclasses/fake_tensor.py | 3 +- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 905a3d2850c9..501fbb49c2b4 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3369,6 +3369,56 @@ def forward(self, q, k, v, attn_bias): ) self.check_model(Model(), example_inputs) + def test_aoti_runtime_asserts(self): + from torch._dispatch.python import enable_python_dispatcher + from torch.export._draft_export import draft_export + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor a, Tensor b) -> Tensor", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a[: b.item()] + + @torch.library.impl_abstract("mylib::foo", lib=lib) + def foo_fake_impl(a, b): + ctx = torch.library.get_ctx() + u = ctx.new_dynamic_size() + return torch.empty(u) + + class M(torch.nn.Module): + def forward(self, a, b): + res = torch.ops.mylib.foo(a, b) + s = res.shape[0] + torch._check(s > 3) + torch._check(s < a.shape[0]) + return a[s - 3] + + example_inputs = (torch.randn(100), torch.tensor(10)) + ep = draft_export(M(), example_inputs) + m = ep.module() + from torch.fx.passes.fake_tensor_prop import FakeTensorProp + + example_inputs = [ + node.meta["val"] for node in m.graph.nodes if node.op == "placeholder" + ] + fake_mode = example_inputs[0].fake_mode + with enable_python_dispatcher(), fake_mode: + FakeTensorProp(m, mode=fake_mode).propagate_dont_convert_inputs( + *example_inputs + ) + + # TODO: change to the tests below after MetadataMismatchError is fixed + # pt2_file = torch._inductor.aoti_compile_and_package(ep) + # optimized = torch._inductor.aoti_load_package(pt2_file) + + # self.assertTrue(same(optimized(example_inputs), m(example_inputs))) + def test_index_put_with_none_index(self): # index_put falls back in the deterministic mode with DeterministicGuard(True): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 1328d5233d36..f8cb248e0a60 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -2308,10 +2308,9 @@ def maybe_to_real_tensor( if ( self.propagate_real_tensors and all(e.real_tensor is not None for e in flat_arg_fake_tensors) - # TODO: Handle SymFloat/SymBool and not any( ( - isinstance(a, SymInt) + isinstance(a, py_sym_types) and (syms := free_unbacked_symbols(a)) and self.shape_env is not None and any(s not in self.shape_env.unbacked_var_to_val for s in syms) From a3f9e04656867c9e20b5c088bd66b913d1d8cde6 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Thu, 3 Apr 2025 20:44:31 +0000 Subject: [PATCH 164/332] [export] Make aoti_call_delegate hop traceable (#148804) Summary: The `aoti_call_delegate` hop now uses a stateless `original_gm` for tracing with fake tensors and the OSS AOTI Runner for running with real tensors Differential Revision: D70738393 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148804 Approved by: https://github.com/SherlockNoMad --- torch/_export/passes/lift_constants_pass.py | 3 +- torch/_export/verifier.py | 2 + torch/_higher_order_ops/aoti_call_delegate.py | 104 +++++++++++++----- 3 files changed, 82 insertions(+), 27 deletions(-) diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 734f8cd33786..8ecb84b7adf4 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -178,6 +178,8 @@ def lift_constants_pass( continue if "LoweredBackendModule" in type(constant_val).__name__: continue + if "AOTInductorRunnerWrapper" in type(constant_val).__name__: + continue if isinstance(constant_val, torch.utils._pytree.TreeSpec): continue @@ -237,7 +239,6 @@ def lift_constants_pass( constant_name = f"lifted_tensor_{num_tensor_constants}" constant_fqn = get_constant_fqn(node, constant_name) num_tensor_constants += 1 - else: raise SpecViolationError( f"getattr node {node} referencing unsupported type {type(constant_val)}" diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 4940973c5f0d..8ba1132ca668 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -271,6 +271,8 @@ def _is_type(name, ty): elif type(attr).__name__ == "AOTInductorEPModule": continue + elif type(attr).__name__ == "AOTInductorRunnerWrapper": + continue if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)): raise SpecViolationError( diff --git a/torch/_higher_order_ops/aoti_call_delegate.py b/torch/_higher_order_ops/aoti_call_delegate.py index 286575726dc2..d90586f8950d 100644 --- a/torch/_higher_order_ops/aoti_call_delegate.py +++ b/torch/_higher_order_ops/aoti_call_delegate.py @@ -1,20 +1,25 @@ +# mypy: allow-untyped-defs + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict - from __future__ import annotations import torch import torch.utils._pytree as pytree from torch._ops import HigherOrderOperator from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) -AOTI_LOWERED_MODULE = "AOTInductorEPModule" +AOTI_LOWERED_MODULE = "AOTInductorEPModule/AOTInductorRunnerWrapper" class AOTICallDelegate(HigherOrderOperator): @@ -22,7 +27,7 @@ class AOTICallDelegate(HigherOrderOperator): It has the following signature: aoti_call_delegate( - lowered_module: AOTInductorEPModule, + lowered_module: Union[AOTInductorEPModule, AOTInductorRunnerWrapper] original_gm:fx.GraphModule, weight_args: List[Tensor], input_args: List[Tensor], @@ -30,15 +35,9 @@ class AOTICallDelegate(HigherOrderOperator): where, - lowered_module is the AOTInductor lowered submodule, backed by compiled .so file, supporting real tensor inputs - - original_gm is the original GraphModule before lowering, allowing FakeTensor propagation + - original_gm is the stateless version of the original GraphModule before lowering, allowing FakeTensor propagation - weight_args is the list of weights in original GraphModule, including parameters and buffers - input_args is the list of flatten inputs - - NOTE: aoti_call_delegate doesn't support retracing yet, as original_gm is currently stateful with weight as get_attr nodes. - This will fail functionalization during retrace. When we move AOTI to accept stateless GraphModule, we can enable retracing. - - When serialization, we have special hanlding for aoti_call_delegate, as AOTInductorEPModule is not serializable - and stateful original_gm is failing the verifier. """ def __init__(self) -> None: @@ -62,7 +61,6 @@ def __call__( @aoti_call_delegate.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) -# pyre-ignore def call_delegate_cpu( lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] original_gm: torch.fx.GraphModule, @@ -77,27 +75,60 @@ def call_delegate_cpu( new_args = pytree.tree_map_only( tuple(map_types.keys()), lambda a: map_types[type(a)](a), - input_args, + weight_args + input_args, lambda a: isinstance(a, tuple(map_types.keys())), ) - - has_fake_input_args = any(isinstance(arg, FakeTensor) for arg in new_args) - has_fake_params = any( - isinstance(param, FakeTensor) for param in original_gm.parameters() - ) - has_fake_buffers = any( - isinstance(buffer, FakeTensor) for buffer in original_gm.buffers() + has_fake_args = any(isinstance(arg, FakeTensor) for arg in new_args) + if has_fake_args: + # use stateless original_gm for tracing with fake tensors + fake_out = original_gm(*new_args) + return fake_out + else: + # use AOTI Runner for real tensors + new_input_args = new_args[len(weight_args) :] + if type(lowered_module).__name__ == "AOTInductorRunnerWrapper": + return lowered_module(*new_input_args) # type: ignore[misc] + elif type(lowered_module).__name__ == "AOTInductorEPModule": + return lowered_module(new_input_args) # type: ignore[misc] + else: + raise RuntimeError( + f"Unexpected lowered_module type: {type(lowered_module)}." + ) + + +def trace_aoti_call_delegate( + proxy_mode, func_overload, lowered_module, original_gm, weight_args, input_args +): + proxy_mode.tracer.root.register_module("lowered_module", lowered_module) + proxy_mode.tracer.root.register_module("original_gm", original_gm) + + node_args = (lowered_module, original_gm, weight_args, input_args) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", func_overload, proxy_args, {}, name="aoti_call_delegate" ) + with disable_proxy_modes_tracing(): + out = call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) - if has_fake_input_args or has_fake_params or has_fake_buffers: - # aoti lowered module doesn't support fake tensor - return original_gm(*new_args) - else: - return lowered_module(new_args) # type: ignore[misc] + return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) + + +@aoti_call_delegate.py_impl(ProxyTorchDispatchMode) +def call_delegate_proxy_torch_dispatch_mode( + mode: ProxyTorchDispatchMode, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + res = trace_aoti_call_delegate( + mode, aoti_call_delegate, lowered_module, original_gm, weight_args, input_args + ) + return res @aoti_call_delegate.py_impl(FakeTensorMode) -# pyre-ignore def call_delegate_fake_tensor_mode( mode: FakeTensorMode, lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] @@ -107,3 +138,24 @@ def call_delegate_fake_tensor_mode( ) -> list[torch.Tensor]: with mode: return call_delegate_cpu(lowered_module, original_gm, weight_args, input_args) + + +@aoti_call_delegate.py_functionalize_impl +def call_delegate_functionalize( + ctx, + lowered_module: AOTI_LOWERED_MODULE, # type: ignore[valid-type] + original_gm: torch.fx.GraphModule, + weight_args: list[torch.Tensor], + input_args: list[torch.Tensor], +): + unwrapped_weight_args = tuple( + ctx.unwrap_tensors(weight_arg) for weight_arg in weight_args + ) + unwrapped_input_args = tuple( + ctx.unwrap_tensors(input_arg) for input_arg in input_args + ) + with ctx.redispatch_to_next(): + res = aoti_call_delegate( + lowered_module, original_gm, unwrapped_weight_args, unwrapped_input_args # type: ignore[arg-type] + ) + return ctx.wrap_tensors(res) From 277369ac1690a9c83dca2bc827d9b61efd694c31 Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Thu, 3 Apr 2025 20:47:35 +0000 Subject: [PATCH 165/332] Move formulas on separate line in loss.py (#150565) Move formulas on separate line in loss.py for better readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150565 Approved by: https://github.com/mikaylagawarecki --- torch/nn/modules/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index de8b55575fb3..75d5c91756df 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -154,8 +154,8 @@ class NLLLoss(_WeightedLoss): The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as: .. math:: - \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad - l_n = - w_{y_n} x_{n,y_n}, \quad + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \\ + l_n = - w_{y_n} x_{n,y_n}, \\ w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\}, where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and From d41c22b5781bbaa4a97b7759ec1451348fd26aef Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Thu, 3 Apr 2025 21:15:38 +0000 Subject: [PATCH 166/332] Revert "[fx] Move Node._prepend/Node._remove_from_list to C++ (#148261)" (#150542) Reverts #148261 due to possible memory leak This reverts commit 5d4e7d58b42623a9024a84f0050967ff0318dcdb. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150542 Approved by: https://github.com/clee2000 --- .../pr_time_benchmarks/expected_results.csv | 39 ++-- torch/_C/__init__.pyi.in | 6 - torch/csrc/fx/node.cpp | 205 +----------------- torch/fx/node.py | 42 +++- 4 files changed, 66 insertions(+), 226 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 47fe5eafcd0c..934e10e5c275 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,2866000000,0.015 +add_loop_eager,compile_time_instruction_count,2926000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5460000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5637000000,0.025 -add_loop_inductor,compile_time_instruction_count,27660000000,0.015 +add_loop_inductor,compile_time_instruction_count,28680000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,40640000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42170000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,23970000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,24980000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,953800000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,969300000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17190000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17840000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15410000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15990000000,0.015 @@ -34,43 +34,44 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,97140000 -update_hint_regression,compile_time_instruction_count,1523000000,0.02 +update_hint_regression,compile_time_instruction_count,1593000000,0.02 -float_args,compile_time_instruction_count,413700000,0.015 +float_args,compile_time_instruction_count,416400000,0.015 -sum_floordiv_regression,compile_time_instruction_count,970100000,0.015 +sum_floordiv_regression,compile_time_instruction_count,989900000,0.015 -symint_sum,compile_time_instruction_count,3080000000,0.015 +symint_sum,compile_time_instruction_count,3164000000,0.015 -symint_sum_loop,compile_time_instruction_count,3988000000,0.015 +symint_sum_loop,compile_time_instruction_count,4142000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1989000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2034000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5759000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5880000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,7873000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8419000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1746000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1838000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3579000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3742000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,9830000000,0.015 + +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10190000000,0.015 diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 893df5db74f8..c6003fe63fcc 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2539,12 +2539,6 @@ class _NodeBase: return_type: Any, ) -> None: ... def _update_args_kwargs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): ... - def _prepend(self, n: FxNode) -> None: ... - def _remove_from_list(self) -> None: ... - def __lt__(self, n: Self) -> _bool: ... - def __gt__(self, n: Self) -> _bool: ... - def __le__(self, n: Self) -> _bool: ... - def __ge__(self, n: Self) -> _bool: ... class _NodeIter(Iterator): def __init__(self, root: FxNode, reversed: _bool) -> None: ... diff --git a/torch/csrc/fx/node.cpp b/torch/csrc/fx/node.cpp index 425a28393113..d3244441da16 100644 --- a/torch/csrc/fx/node.cpp +++ b/torch/csrc/fx/node.cpp @@ -1,14 +1,11 @@ #include -#include #include #include #include -#include namespace { -using NodeSortKey = c10::SmallVector; struct NodeBase; // Thrown to exit out of a C++ function and return an error to Python. @@ -166,22 +163,7 @@ struct NodeBase { PyObject* users; PyObject* _repr_fn; PyObject* meta; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - alignas(NodeSortKey) char sort_key_buf[sizeof(NodeSortKey)]; - - inline NodeSortKey& sort_key() { - return *reinterpret_cast(sort_key_buf); - } - - // Equivalent to: - // p, n = self._prev, self._next - // p._next, n._prev = n, p - inline void remove_from_list() { - NodeBase* p = this->_prev; - NodeBase* n = this->_next; - p->_next = n; - n->_prev = p; - } + PyObject* _sort_key; }; static PyObject* NodeBase_new( @@ -191,8 +173,6 @@ static PyObject* NodeBase_new( PyObject* self = type->tp_alloc(type, 0); if (!self) return nullptr; - new (reinterpret_cast(self)->sort_key_buf) - NodeSortKey(); // placement new does not allocate return self; } @@ -221,6 +201,7 @@ static int NodeBase_init_fn(NodeBase* self, PyObject* args, PyObject* kwds) { self->users = PyDict_New(); self->_repr_fn = Py_NewRef(Py_None); self->meta = PyDict_New(); + self->_sort_key = PyTuple_New(0); return 0; } @@ -240,6 +221,7 @@ static struct PyMemberDef NodeBase_members[] = { {"users", T_OBJECT_EX, offsetof(NodeBase, users), 0, nullptr}, {"_repr_fn", T_OBJECT_EX, offsetof(NodeBase, _repr_fn), 0, nullptr}, {"meta", T_OBJECT_EX, offsetof(NodeBase, meta), 0, nullptr}, + {"_sort_key", T_OBJECT_EX, offsetof(NodeBase, _sort_key), 0, nullptr}, {nullptr} /* Sentinel */ }; @@ -257,6 +239,7 @@ static int NodeBase_traverse(NodeBase* self, visitproc visit, void* arg) { Py_VISIT(self->users); Py_VISIT(self->_repr_fn); Py_VISIT(self->meta); + Py_VISIT(self->_sort_key); return 0; } @@ -274,12 +257,12 @@ static int NodeBase_clear(NodeBase* self) { Py_CLEAR(self->users); Py_CLEAR(self->_repr_fn); Py_CLEAR(self->meta); + Py_CLEAR(self->_sort_key); return 0; } static void NodeBase_dealloc(PyObject* self) { PyObject_GC_UnTrack(self); - reinterpret_cast(self)->sort_key().~NodeSortKey(); (void)NodeBase_clear((NodeBase*)self); Py_TYPE(self)->tp_free(self); } @@ -338,191 +321,15 @@ static PyObject* NodeBase__update_args_kwargs( } } -static PyObject* NodeBase__remove_from_list( - PyObject* self, - PyObject* _ignored) { - reinterpret_cast(self)->remove_from_list(); - Py_RETURN_NONE; -} - -static PyObject* NodeBase__prepend(PyObject* self_, PyObject* arg) { - if (self_ == arg) { - Py_RETURN_NONE; - } - if (!is_node(arg)) { - PyErr_SetString(PyExc_TypeError, "_prepend() argument must be a Node"); - return nullptr; - } - NodeBase* self = reinterpret_cast(self_); - NodeBase* x = reinterpret_cast(arg); - if (self->graph != x->graph) { - PyErr_SetString( - PyExc_AssertionError, - "Attempting to move a Node into a different Graph"); - return nullptr; - } - - x->remove_from_list(); - NodeBase* p = self->_prev; - p->_next = x; - x->_prev = p; - x->_next = self; - self->_prev = x; - - // Now compute x.sort_key() - const NodeSortKey& psk = x->_prev->sort_key(); - const NodeSortKey& nsk = x->_next->sort_key(); - if (psk.size() > nsk.size()) { - // prefix = psk[: len(nsk)+1] - size_t slice_len = nsk.size() + 1; - NodeSortKey prefix(psk.begin(), psk.begin() + slice_len); - // last element is idx => increment by 1 - prefix.back()++; - x->sort_key() = std::move(prefix); - } else if (psk.size() < nsk.size()) { - // prefix = nsk[: len(psk)+1] - size_t slice_len = psk.size() + 1; - NodeSortKey prefix(nsk.begin(), nsk.begin() + slice_len); - // last element is idx => decrement by 1 - prefix.back()--; - x->sort_key() = std::move(prefix); - } else { - // same length => add a 0 - x->sort_key() = psk; - x->sort_key().emplace_back(0); - } - Py_RETURN_NONE; -} - -// __lt__(self, other): Return self.sort_key < other.sort_key -static PyObject* NodeBase___lt__(PyObject* self, PyObject* other) { - // METH_O => one argument: 'other' - if (!is_node(other)) { - Py_RETURN_NOTIMPLEMENTED; - } - const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); - const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); - bool less = std::lexicographical_compare( - lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); - if (less) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} - -// __gt__(self, other): Return self.sort_key() > other.sort_key -static PyObject* NodeBase___gt__(PyObject* self, PyObject* other) { - if (!is_node(other)) { - Py_RETURN_NOTIMPLEMENTED; - } - const NodeSortKey& lhs = reinterpret_cast(self)->sort_key(); - const NodeSortKey& rhs = reinterpret_cast(other)->sort_key(); - // "a > b" is equivalent to "b < a" - bool greater = std::lexicographical_compare( - rhs.begin(), rhs.end(), lhs.begin(), lhs.end()); - if (greater) - Py_RETURN_TRUE; - Py_RETURN_FALSE; -} - -static PyObject* NodeBase___ge__(PyObject* self, PyObject* other) { - if (self == other) { - Py_RETURN_TRUE; - } - return NodeBase___gt__(self, other); -} - -// __le__(self, other): Return not (self > other) -static PyObject* NodeBase___le__(PyObject* self, PyObject* other) { - if (self == other) { - Py_RETURN_TRUE; - } - return NodeBase___lt__(self, other); -} - -// Convert the NodeBase::sort_key vector into a Python tuple of ints -// Only used by pickle/__getstate__ -static PyObject* NodeBase_get_sort_key(PyObject* self, void* /*closure*/) { - NodeBase* node = reinterpret_cast(self); - const NodeSortKey& vec = node->sort_key(); - Py_ssize_t n = static_cast(vec.size()); - THPObjectPtr tuple(PyTuple_New(n)); - if (!tuple) { - return nullptr; // Out of memory - } - for (Py_ssize_t i = 0; i < n; i++) { - PyTuple_SET_ITEM(tuple.get(), i, PyLong_FromSsize_t(vec[i])); - } - return tuple.release(); -} - -// Setter for NodeBase::sort_key: expects a Python tuple of ints, e.g. -// node._sort_key = (1,2,3) Only used by pickle/__setstate__ -static int NodeBase_set_sort_key( - PyObject* self, - PyObject* value, - void* /*closure*/) { - NodeBase* node = reinterpret_cast(self); - if (!PyTuple_Check(value)) { - PyErr_SetString(PyExc_TypeError, "_sort_key must be an tuple of ints"); - return -1; - } - Py_ssize_t size = PyTuple_GET_SIZE(value); - NodeSortKey new_vec; - new_vec.reserve(size); - for (Py_ssize_t i = 0; i < size; i++) { - int64_t val = PyLong_AsSsize_t(PyTuple_GET_ITEM(value, i)); - if (val == -1 && PyErr_Occurred()) { - return -1; - } - new_vec.emplace_back(val); - } - node->sort_key() = std::move(new_vec); - return 0; -} - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) static PyMethodDef NodeBase_methods[] = { {"_update_args_kwargs", (PyCFunction)(void*)(NodeBase__update_args_kwargs), METH_FASTCALL, "Internal method: do not call directly."}, - {"_remove_from_list", - (PyCFunction)(void*)(NodeBase__remove_from_list), - METH_NOARGS, - "Internal method: do not call directly."}, - {"_prepend", - (PyCFunction)(void*)(NodeBase__prepend), - METH_O, - "Internal method: do not call directly."}, - {"__lt__", - (PyCFunction)(void*)NodeBase___lt__, - METH_O, - "Return True if self.sort_key < other.sort_key"}, - {"__gt__", - (PyCFunction)(void*)NodeBase___gt__, - METH_O, - "Return True if self.sort_key > other.sort_key"}, - {"__ge__", - (PyCFunction)(void*)NodeBase___ge__, - METH_O, - "Return True if self.sort_key >= other.sort_key"}, - {"__le__", - (PyCFunction)(void*)NodeBase___le__, - METH_O, - "Return True if self.sort_key <= other.sort_key"}, {nullptr, nullptr, 0, nullptr} // Sentinel }; -// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) -static PyGetSetDef NodeBase_getset[] = { - {"_sort_key", // attribute name in Python - (getter)NodeBase_get_sort_key, // C getter function - (setter)NodeBase_set_sort_key, // C setter function - (char*)"The sort key as a tuple of ints", // docstring - nullptr}, - {nullptr, nullptr, nullptr, nullptr, nullptr} // Sentinel -}; - PyTypeObject NodeBaseType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._NodeBase", /* tp_name */ @@ -554,7 +361,7 @@ PyTypeObject NodeBaseType = { nullptr, /* tp_iternext */ NodeBase_methods, /* tp_methods */ NodeBase_members, /* tp_members */ - NodeBase_getset, /* tp_getset */ + nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ diff --git a/torch/fx/node.py b/torch/fx/node.py index 722de170bfd5..8433e9ea651b 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -375,7 +375,41 @@ def prepend(self, x: "Node") -> None: Args: x (Node): The node to put before this node. Must be a member of the same graph. """ - self._prepend(x) + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + log.debug( + "Trying to prepend a node to itself. This behavior has no effect on the graph." + ) + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + # compute x._sort_key + psk = x._prev._sort_key + nsk = x._next._sort_key + if len(psk) > len(nsk): + idx: int + *prefix, idx = psk[: len(nsk) + 1] + x._sort_key = (*prefix, idx + 1) + elif len(psk) < len(nsk): + *prefix, idx = nsk[: len(psk) + 1] + x._sort_key = (*prefix, idx - 1) + else: # same length, increase length by 1 + x._sort_key = (*psk, 0) + + def __gt__(self, other: "Node") -> bool: + return self._sort_key > other._sort_key + + def __lt__(self, other: "Node") -> bool: + return self._sort_key < other._sort_key + + def __ge__(self, other: "Node") -> bool: + return self > other or self == other + + def __le__(self, other: "Node") -> bool: + return self < other or self == other @compatibility(is_backward_compatible=True) def append(self, x: "Node") -> None: @@ -386,7 +420,11 @@ def append(self, x: "Node") -> None: Args: x (Node): The node to put after this node. Must be a member of the same graph. """ - self._next._prepend(x) + self._next.prepend(x) + + def _remove_from_list(self) -> None: + p, n = self._prev, self._next + p._next, n._prev = n, p @property def args(self) -> tuple[Argument, ...]: From 5a654deb408d48e7ec244bb560724ff0a62bfcb6 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 3 Apr 2025 21:44:41 +0000 Subject: [PATCH 167/332] Revert "Enable C++ dynamic shape guards by default (#140756)" This reverts commit c1d503529d23f33bc0819286df8d0ecbe31b559f. Reverted https://github.com/pytorch/pytorch/pull/140756 on behalf of https://github.com/isuruf due to new test test_runtime_checks_large hangs on CI ([comment](https://github.com/pytorch/pytorch/pull/140756#issuecomment-2776979814)) --- torch/_dynamo/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 1829eaef63ff..b59e1c49e607 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -396,7 +396,7 @@ enable_cpp_guard_manager = True # Use C++ guard manger for symbolic shapes -enable_cpp_symbolic_shape_guards = not is_fbcode() +enable_cpp_symbolic_shape_guards = False # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True From 941090a791f39bb1bbd0e4741e10718c011e050f Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 3 Apr 2025 22:02:25 +0000 Subject: [PATCH 168/332] Make sure torch.compiler._is_compiling_flag=True in aoti (#150588) Summary: See internal Diff summary Differential Revision: D72355449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150588 Approved by: https://github.com/angelayi --- torch/_inductor/__init__.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index a2acd6570a20..f9d05e24fff7 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -283,12 +283,14 @@ def aot_compile( flat_example_inputs, options = _aoti_flatten_inputs( gm, args, kwargs, options=options ) + from torch._export.utils import _compiling_state_context - return compile_fx_aot( - gm, - flat_example_inputs, # type: ignore[arg-type] - config_patches=options, - ) + with _compiling_state_context(): + return compile_fx_aot( + gm, + flat_example_inputs, # type: ignore[arg-type] + config_patches=options, + ) def list_mode_options( From 2abd81402fa28341dd16b1de55bf95e05d0727f4 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Thu, 3 Apr 2025 22:06:58 +0000 Subject: [PATCH 169/332] [validations] Run nccl version check on Linux only (#150635) Followup https://github.com/pytorch/pytorch/pull/150194 to disable nccl version print on OS's other then Linux Pull Request resolved: https://github.com/pytorch/pytorch/pull/150635 Approved by: https://github.com/clee2000 --- .ci/pytorch/smoke_test/smoke_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index acc69e36a5a5..24d1d64dd205 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -271,8 +271,10 @@ def smoke_test_cuda( print(f"cuDNN enabled? {torch.backends.cudnn.enabled}") torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") - torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) - print(f"Torch nccl; version: {torch_nccl_version}") + + if sys.platform in ["linux", "linux2"]: + torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) + print(f"Torch nccl; version: {torch_nccl_version}") # Pypi dependencies are installed on linux ony and nccl is availbale only on Linux. if pypi_pkg_check == "enabled" and sys.platform in ["linux", "linux2"]: From c6defa9443d241dd7a0baac4e708b6e906bd012c Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Thu, 3 Apr 2025 22:07:39 +0000 Subject: [PATCH 170/332] [cuda] Add new faster gammabeta backward kernel (#148605) (Reapply with launch bounds) (#150625) # Changes over the previous PR This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel. Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266. This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added: ``` git diff HEAD^ diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 0d63a2f979c..3ce2c24c18e 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -657,6 +657,7 @@ bool aligned_grid > __global__ void +__launch_bounds__(block_dim_x * block_dim_y) GammaBetaBackwardCUDAKernelTemplate( int64_t M, int64_t N, ``` I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg
Repro script that fails on Blackwell ``` import torch from torch.nn import init # from transformer_nuggets import init_logging # from transformer_nuggets.utils.benchmark import profiler # from pathlib import Path # init_logging() class PermuteModule(torch.nn.Module): def __init__(self, permutation): super(PermuteModule, self).__init__() self.permutation = permutation def forward(self, x:torch.Tensor) -> torch.Tensor: assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!" return x.permute(*self.permutation) def test(n_layers:int, conv_stride:int): _sequence = [] for _ in range(n_layers): # Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules. _sequence += [ PermuteModule((0,2,1)), torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False), PermuteModule((0,2,1)), torch.nn.LayerNorm(512), torch.nn.ReLU() ] model = torch.nn.Sequential(*_sequence).to(device="cuda") data = torch.randn((100,2048,512), device="cuda") out = model(data) loss = torch.nn.functional.mse_loss(out, torch.rand_like(out)) loss.backward() torch.autograd.set_detect_anomaly(True) print(f"Torch version: {torch.__version__}") # with profiler(Path("conv")): # # print(f"layers=1, stride=1") # # test(n_layers=1, conv_stride=1) # # print(f"layers=2, stride=1") # # test(n_layers=2, conv_stride=1) # # print(f"layers=1, stride=2") # # test(n_layers=1, conv_stride=2) # print(f"layers=2, stride=2") # test(n_layers=2, conv_stride=2) print(f"layers=2, stride=2") test(n_layers=2, conv_stride=2) # we will not reach this print statement. print("DONE.") ```
I also re-ran my performance benchmark and found no regressions over the previous PR. # Full description of the old PR Original PR: https://github.com/pytorch/pytorch/pull/148605 This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way. To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables: 1. dtype in {half, float} 2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 4. Whether we flush the L2 cache before running the backward pass Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster). In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*. Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old: M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.35 Elapsed Cycles cycle 27,526 Memory Throughput % 2.21 DRAM Throughput % 0.54 Duration us 20.42 L1/TEX Cache Throughput % 4.31 L2 Cache Throughput % 2.62 SM Active Cycles cycle 1,475.02 Compute (SM) Throughput % 0.29 ----------------------- ----------- ------------ ``` M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.34 Elapsed Cycles cycle 10,920 Memory Throughput % 5.64 DRAM Throughput % 1.35 Duration us 8.13 L1/TEX Cache Throughput % 1.92 L2 Cache Throughput % 6.89 SM Active Cycles cycle 3,554.41 Compute (SM) Throughput % 0.67 ----------------------- ----------- ------------ ``` Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following: image There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made: ![image](https://github.com/user-attachments/assets/90c26f7c-e3ad-46d2-a6ce-fe4b5fb3d738) For dtype=float32, we get a similar chart: image The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension). The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough. I am including the regressions here for completeness' sake: image To see this better: 1. Click the image 2. Right click the expanded image and open in a new tab 3. Go to that tab and left click once to zoom in If you want to see the full data, here it is: ![image](https://github.com/user-attachments/assets/54fb60c9-8c0c-4530-a1dd-79ecda1a69a1) I also measured binary size and compile time since those are important for developers: Binary size comparison ![image](https://github.com/user-attachments/assets/ceef5073-1036-47f6-b9dc-cea088beda51) ``` # Original -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so # This PR -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so ``` The diff in bytes is 302kB which is about a 0.1% increase. Compile time difference: ``` # Original real 0m10.931s user 0m9.676s sys 0m1.004s # this PR real 0m16.720s user 0m15.514s sys 0m1.066s # Command I ran time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o ``` So the new PR is 6 seconds longer compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625 Approved by: https://github.com/ngimel --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 527 +++++++++++------- test/test_nn.py | 20 + 2 files changed, 352 insertions(+), 195 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 9feb30c21941..3ce2c24c18e6 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -508,7 +508,6 @@ __global__ void layer_norm_grad_input_kernel_vectorized( } } - template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, @@ -540,191 +539,365 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } } -// This implementation gets called if M and N divide with 32. This case should -// be the most common. We can then make better use of warp level intrinsics -// to improve performance. +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsHelper( + int64_t M_start, + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; + + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; + T_ACC warp_mean = 0, warp_rstd = 0; + if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { + warp_mean = mean[mean_index + lane_id]; + warp_rstd = rstd[mean_index + lane_id]; + } + // We do a WARP_SYNC() here because we use WARP_SHFL below to access + // warp_mean and warp_rstd. + WARP_SYNC(); + + T_ACC dY_regs[rows_per_thread_y] = {0}; + T_ACC X_regs[rows_per_thread_y] = {0}; + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i; + bool active = true; + if (check_x && thread_x >= N) { + active = false; + } + if (check_y && current_y >= M) { + active = false; + } + if (active) { + dY_regs[i] = dY[current_y * N + thread_x]; + X_regs[i] = X[current_y * N + thread_x]; + } + } -template -__global__ void GammaBetaBackwardCUDAKernel_32x32( + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } +} + +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsWithChecks( int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + for (int64_t M_start = blockIdx.y * rows_per_block_y; + M_start < M; + M_start += rows_per_block_y * gridDim.y) { + int64_t M_end = M_start + rows_per_block_y - 1; + if (!check_y || M_end < M) { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } +} + +// block_dim_x is the number of threads in the x dimension per block. +// block_dim_y is the number of threads in the y dimension per block. +// rows_per_block_y is the size of the tile (number of data elements) +// in the y dimension per block. +// partial_reduction indicates whether we need to reduce across threads +// or not. If set to true, we will not reduce across threads. This can +// be faster in the M >> N case but requires another kernel to do a full +// final reduction. +// aligned_grid means the data size is a multiple of tile size. In that +// case we don't need to check for boundary conditions which can provide +// a further speedup by not needing instructions to check for edge cases +// and not needing predicate registers. +template +__global__ +void +__launch_bounds__(block_dim_x * block_dim_y) + GammaBetaBackwardCUDAKernelTemplate( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db) { + // This assert is a compile-time check only. + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + static_assert(rows_per_thread_y <= kWarpSize); T_ACC dg_sum = 0; T_ACC db_sum = 0; - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + if (aligned_grid) { + // When N and M align perfectly with block_dim_x and block_dim_y, we + // can skip boundary condition checks that waste instruction issue slots. + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + // In the general case we need to check boundary conditions in the M + // dimension. However, we can still avoid boundary checks in the N dimension + // for the inner blocks. So try to avoid those checks when possible. + if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } - if (j < N) { - constexpr int unroll_factor = 8; - int laneId = threadIdx.x & (C10_WARP_SIZE - 1); - - T_ACC mean_reg, mean_reg_tmp; - T_ACC rstd_reg, rstd_reg_tmp; - T dY_reg; - T X_reg; - - // Main loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); - bcounter++) { - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - - if (laneId < unroll_factor) { - mean_reg_tmp = mean[offset + laneId]; - rstd_reg_tmp = rstd[offset + laneId]; - } - WARP_SYNC(); + int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x; - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); - rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; + // When partial_reduction is requested, we don't reduce within a block. + // We also don't reduce if we are only a single block in the y dimension. + if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { + if (aligned_grid || thread_x < N) { + int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y; + if (dg) { + dg[thread_y * N + thread_x] = dg_sum; } - } - - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++) { - if ((offset + ii) < M) { - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; + if (db) { + db[thread_y * N + thread_x] = db_sum; } } - - // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and - // gets called when M; N divide by 32. We can use warp shuffles - // for the final reduction step. This removes 4 shmem loads and - // stores with their corresponding __syncthreads() - - // This greatly reduces bank conflicts at the expense of a little - // extra shared memory. It does not impact occupancy - int padded_bx = (1 + blockDim.x); - + } else { + // The caller requested a full reduction so we must reduce across + // warps using shared memory and warp shuffles. + static_assert(rows_per_thread_y <= C10_WARP_SIZE); + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; + int padded_bx = (block_dim_x + 1); + // Transpose dg and db. s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * blockDim.y); + s_db = s_data_typed + (padded_bx * block_dim_y); s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; __syncthreads(); // Load transposed so that a warp holds an entire column - T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; - T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; - for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); - } - - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; - if (dg) { - dg[j] = reg_dg; + // Because block_dim_x != block_dim_y in the general case, we need + // some code to handle the general case. + static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); + constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; + int thread_id = threadIdx.y * block_dim_x + threadIdx.x; + int warp_id = thread_id / C10_WARP_SIZE; + int lane_id = thread_id & (C10_WARP_SIZE - 1); + #pragma unroll + for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { + T_ACC reg_db, reg_dg; + if (lane_id < block_dim_y) { + reg_dg = s_dg[lane_id * padded_bx + i]; + reg_db = s_db[lane_id * padded_bx + i]; } - if (db) { - db[j] = reg_db; + #pragma unroll + for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + // Reduce is done. Now write it out to global memory. + int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; + if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { + if (dg) { + dg[out_index] = reg_dg; + } + if (db) { + db[out_index] = reg_db; + } } } } } -template -__global__ void GammaBetaBackwardCUDAKernel( +template +void LaunchAndCheckGammaBetaBackwardKernel( + bool aligned_grid, + dim3 blocks, + dim3 threads, + size_t shmem_sz, + cudaStream_t cuda_stream, + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + T* dgamma_data, + T* dbeta_data) { +if (aligned_grid) { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } else { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void ConfigureAndLaunchGammaBetaBackwardKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - - T_ACC dg_sum = 0; - T_ACC db_sum = 0; - - if (j < N) { - constexpr int unroll_factor = 8; - - T_ACC mean_reg; - T_ACC rstd_reg; - T dY_reg; - T X_reg; - - // Main Loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + T* dgamma_data = + dgamma->defined() ? dgamma->template data_ptr() : nullptr; + T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = 1; + size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; + if (blocks.y == 1 && threads.y == 1) { + // Optimization: since there is just one thread doing all the summation, we don't need a reduction + // across threads. So we set partial_reduction to true. + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } else { + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } - } +} - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++ ){ - if ((offset + ii) < M) { - dY_reg = dY[(offset + ii) * N + j ]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } +template +void LaunchGammaBetaBackwardCUDAKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + constexpr int block_dim_x = 32; + const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { + // We have a situation where M >> N and N is small. + // In this case we can speed up the computation by parallelizing in the M dimension. + // We launch multiple blocks in the y-dimension, and compute partial sums for the + // gradient in the first pass. Then we do a .sum(0) to do a final reduction. + // Although we launch 2 kernels, we can get up to a 10x speedup for large M. + constexpr int block_dim_y = 1; + constexpr int rows_per_block_y = 32; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + // int rows_per_block = my_gamma_beta_unroll_factor * + blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; + constexpr int max_grid_size = 64 * 1024 / 2; + blocks.y = std::min(max_grid_size / blocks.x, blocks.y); + Tensor dgamma_blocks; + Tensor dbeta_blocks; + T * dgamma_blocks_ptr = nullptr; + T * dbeta_blocks_ptr = nullptr; + if (dgamma->defined()) { + auto options = dgamma->options(); + dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - - // Do the final reduction in shared memory - s_dg = s_data_typed; - s_db = s_data_typed + blockDim.x * blockDim.y; - s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; - s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum; - __syncthreads(); - - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - if (threadIdx.y < offset) { - s_dg[threadIdx.y * blockDim.x + threadIdx.x] += - s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_db[threadIdx.y * blockDim.x + threadIdx.x] += - s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } - __syncthreads(); + if (dbeta->defined()) { + auto options = dbeta->options(); + dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - if (threadIdx.y == 0) { - if (dg) { - dg[j] = s_dg[threadIdx.x]; - } - if (db) { - db[j] = s_db[threadIdx.x]; - } + *dgamma = dgamma_blocks.sum(0); + *dbeta = dbeta_blocks.sum(0); + } else { + // We are in the normal case where M is not that large. + // We can change the tile shape (which is the last template parameter) in accordance with M. + // For small M it is faster to have a smaller tile, otherwise we could have idle threads. + // For larger M we use a bigger tile size. + if (M < 64) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 128) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 256) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } @@ -1250,6 +1423,7 @@ void LayerNormBackwardKernelImplInternal( dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; +#if defined(USE_ROCM) if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; @@ -1265,7 +1439,6 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { -#if defined(USE_ROCM) // For small batch size, do colwise reduce directly. const int part_size = warp_size; const dim3 threads2(warp_size, 4, 1); @@ -1300,47 +1473,11 @@ void LayerNormBackwardKernelImplInternal( dgamma_data, dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); + } #else - if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { - // This implementation relies on warp primitives and requires that M and N divide - // exactly to warp size. - dim3 threads{kWarpSize, kWarpSize}; - int blocks = (N + threads.x - 1) / threads.x; - - // If M and N divide by warp_size, we can use warp shuffles for the final reduction. - // That requires transposing values in shared memory, so we apply a padding to - // reduce bank conflicts. - - size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; - GammaBetaBackwardCUDAKernel_32x32 - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - dim3 threads{16, 32}; - int blocks = (N + threads.x - 1) / threads.x; - size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; - GammaBetaBackwardCUDAKernel - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } + LaunchGammaBetaBackwardCUDAKernel( + dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif - } } } diff --git a/test/test_nn.py b/test/test_nn.py index 30fe71b4162e..72c440ca5ec5 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7195,6 +7195,26 @@ def test_layer_norm_eps(self): ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_layer_norm_backwards_eps(self): + dtype = torch.float + m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), + (32, 32), (1024, 32), (1024, 1024), + (33, 33), (1025, 33), (1025, 1025)] + for m, n in m_x_n_list: + x = torch.randn((m, n), dtype=dtype, requires_grad=True) + grad_output = torch.rand_like(x) + x_cuda = x.clone().detach().to("cuda").requires_grad_() + grad_output_cuda = grad_output.clone().detach().to("cuda") + ln = nn.LayerNorm(n, dtype=dtype) + ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) + ln_out = ln(x) + ln_out_cuda = ln_cuda(x_cuda) + ln_out.backward(grad_output) + ln_out_cuda.backward(grad_output_cuda) + self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291 From 9e55dae2a69b32accb1d64986364ee6504d99d3e Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 3 Apr 2025 22:33:45 +0000 Subject: [PATCH 171/332] CUDA CachingHostAllocator tracks registrations to call correct free (#146520) Allocations using cudaHostRegister should use corresponding cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost. In test_cuda.py, the allocator config will change from test to test but the cache is not emptied prior to changing the config. This results in the wrong free being called later. Unit test sharding is avoiding this issue, but running the test_cuda.py with a single shard will fail. The following reproducer demonstrates the problem. ```C++ int main(int argc, char **argv) { void *ptr; assert(cudaSuccess == cudaHostAlloc(&ptr, 1024, cudaHostAllocDefault)); assert(cudaSuccess == cudaHostUnregister(ptr)); std::free(ptr); return 0; } ``` The above code results in the following failure because the ptr is an invalid argument to cudaHostUnregister. ``` a.out: test.cpp:53: int main(int, char**): Assertion `cudaSuccess == cudaHostUnregister(ptr)' failed. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/146520 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/CachingHostAllocator.cpp | 22 +++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index 8e084aec2a0c..ce1ef86d5091 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -9,6 +9,7 @@ #include #include +#include namespace at::cuda { namespace { @@ -71,6 +72,8 @@ using Block = HostBlock; struct CUDACachingHostAllocatorImpl : public CachingHostAllocatorImpl { private: + std::unordered_map use_host_register; + void allocate_host_memory(size_t size, void** ptr) override { // Pinned memory pointers allocated by any device can be directly used by // any other device, regardless of the current device at the time of @@ -89,13 +92,16 @@ struct CUDACachingHostAllocatorImpl } auto start = std::chrono::system_clock::now(); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - pinned_use_cuda_host_register()) { + bool use_register = c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::pinned_use_cuda_host_register(); + if (use_register) { allocWithCudaHostRegister(ptr, size); } else { // Use cudaHostAlloc for allocating pinned memory (global lock in driver) C10_CUDA_CHECK(cudaHostAlloc(ptr, size, cudaHostAllocDefault)); } + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(use_host_register.count(*ptr) == 0); + use_host_register[*ptr] = use_register; + auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); @@ -108,15 +114,19 @@ struct CUDACachingHostAllocatorImpl void free_block(Block* block) override { auto start = std::chrono::system_clock::now(); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - pinned_use_cuda_host_register()) { - void* ptr = block->ptr_; + // Users may change the allocator config at will. torch unit tests do this. + // However, allocations using cudaHostRegister should use corresonding + // cudaHostUnregister and similarly for cudaHostAlloc / cudaFreeHost. + void* ptr = block->ptr_; + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(use_host_register.count(ptr) == 1); + if (use_host_register[ptr]) { AT_CUDA_CHECK(cudaHostUnregister(ptr)); // NOLINTNEXTLINE(cppcoreguidelines-no-malloc) std::free(ptr); } else { - AT_CUDA_CHECK(cudaFreeHost(block->ptr_)); + AT_CUDA_CHECK(cudaFreeHost(ptr)); } + use_host_register.erase(ptr); auto end = std::chrono::system_clock::now(); auto duration = std::chrono::duration_cast(end - start); From 76994d48f4d138fc5f88247a37d9021033bbe4de Mon Sep 17 00:00:00 2001 From: Richard Howell Date: Thu, 3 Apr 2025 22:36:14 +0000 Subject: [PATCH 172/332] [pytorch] add experimental TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT (#150537) Summary: Add an experimental feature to defer pytorch library initialization cost to post startup. As noted this feature is not thread safe, it requires the client to maintain thread safety at library load time. Reviewed By: zou3519 Differential Revision: D71917841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150537 Approved by: https://github.com/zou3519 --- aten/src/ATen/core/library.cpp | 12 +++++++++++ torch/csrc/jit/mobile/import.cpp | 4 ++++ torch/library.h | 35 ++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+) diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp index b8a5b418bbc0..bdc525dca08c 100644 --- a/aten/src/ATen/core/library.cpp +++ b/aten/src/ATen/core/library.cpp @@ -58,6 +58,18 @@ void Library::reset() { #define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")" +#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +namespace detail { + std::vector torch_library_initializers; +} // namespace detail +void initialize_torch_libraries() { + for (auto* initializer : detail::torch_library_initializers) { + initializer->initialize(); + } + detail::torch_library_initializers.clear(); +} +#endif + Library::Library(Kind kind, std::string ns, std::optional k, const char* file, uint32_t line) : kind_(kind) , ns_(ns == "_" ? std::nullopt : std::make_optional(std::move(ns))) diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 6c1bfd0ec3ec..94f49ac67dc2 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -646,6 +647,9 @@ mobile::Module _load_for_mobile( std::optional device, ExtraFilesMap& extra_files, uint64_t module_load_options) { +#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT + torch::initialize_torch_libraries(); +#endif auto observer = torch::observerConfig().getModuleObserver(); if (observer) { extra_files.insert(std::make_pair("model_path", filename)); diff --git a/torch/library.h b/torch/library.h index ef92bee6c93b..653a45361a1b 100644 --- a/torch/library.h +++ b/torch/library.h @@ -884,8 +884,42 @@ class TORCH_API Library final { at::OperatorName _parseNameForLib(const char* name_str) const; }; +#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +void initialize_torch_libraries(); +#endif + namespace detail { +#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +extern std::vector torch_library_initializers; +class TorchLibraryInit final { + private: + using InitFn = void(Library&); + Library::Kind kind; + InitFn* init_function; + const char* ns; + std::optional key; + const char* file; + uint32_t line; + std::unique_ptr lib = nullptr; + + public: + TorchLibraryInit( + Library::Kind kind, + InitFn* fn, + const char* ns, + std::optional k, + const char* file, + uint32_t line) : kind(kind), init_function(fn), ns(ns), key(k), file(file), line(line) { + torch_library_initializers.push_back(this); + } + + void initialize() { + lib = std::unique_ptr(new Library(kind, ns, key, file, line)); + init_function(*lib); + } +}; +#else class TorchLibraryInit final { private: using InitFn = void(Library&); @@ -903,6 +937,7 @@ class TorchLibraryInit final { fn(lib_); } }; +#endif } // namespace detail From c0618a3957f890112dbf64c814080d131793ab52 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 3 Apr 2025 11:45:02 -0700 Subject: [PATCH 173/332] Update commitlist.py instructions for the GitHub repo regime (#149535) Pull Request resolved: https://github.com/pytorch/pytorch/pull/149535 Approved by: https://github.com/albanD --- scripts/release_notes/commitlist.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/scripts/release_notes/commitlist.py b/scripts/release_notes/commitlist.py index bff5b3ec3e50..916d87ce0da0 100644 --- a/scripts/release_notes/commitlist.py +++ b/scripts/release_notes/commitlist.py @@ -32,6 +32,9 @@ """ +# Increase the allowed size of a CSV field to 1mil bytes for long files changed +csv.field_size_limit(1000000) + @dataclasses.dataclass(frozen=False) class Commit: @@ -490,15 +493,15 @@ def get_markdown_header(category): header = f""" # Release Notes worksheet {category} -The main goal of this process is to rephrase all the commit messages below to make them clear and easy to read by the end user. You should follow the following instructions to do so: +The main goal of this process is to rephrase all the commit messages below to make them **clear and easy to read** by the end user. You should follow the following instructions to do so: -* **Please cleanup, and format commit titles to be readable by the general pytorch user.** [Detailed instructions here](https://docs.google.com/document/d/14OmgGBr1w6gl1VO47GGGdwrIaUNr92DFhQbY_NEk8mQ/edit) +* **Please clean up and format commit titles to be readable by the general PyTorch user.** Make sure you're [following the guidance here](https://docs.google.com/document/d/14OmgGBr1w6gl1VO47GGGdwrIaUNr92DFhQbY_NEk8mQ/edit)! Your resulting notes must be consistent and easy to read. * Please sort commits into the following categories (you should not rename the categories!), I tried to pre-sort these to ease your work, feel free to move commits around if the current categorization is not good. -* Please drop any commits that are not user-facing. -* If anything is from another domain, leave it in the UNTOPICED section at the end and I'll come and take care of it. -* Please use markdown format -* Please use #PR_NUM to link to the PR, instead of `[#PR_NUM](https://github.com/pytorch/pytorch/pull/#PR_NUM)` to reduce the length of the release notes -* We place a lot of emphasis on the “BC-breaking” and “deprecation” sections. Those should be where the most effort goes in. The “improvements” and “bug fixes” for Python API should be nice as well. Everything else doesn’t matter too much so feel free to cut corners if time is short. +* Anything that is not public facing needs to be removed. +* If anything is miscategorized/belongs to another domain, move it to `miscategorized.md`. +* Please scan through `miscategorized.md` and handle any commits that belong within your domain according to these instructions. +* We place a lot of emphasis on the “BC-breaking” and “deprecation” sections. Those should be where the most effort goes in. The “improvements” and “bug fixes” for Python API should be nice as well. +* Once you are finished, move this very file from `todo/` to `done/` and submit a pull request. The categories below are as follows: @@ -510,6 +513,7 @@ def get_markdown_header(category): * performance: All commits that are added mainly for performance (we separate this from improvements above to make it easier for users to look for it) * documentation: All commits that add/update documentation * Developers: All commits that are not end-user facing but still impact people that compile from source, develop into pytorch, extend pytorch, etc +* not user facing: All commits that are not public end-user facing and hence should be dropped from the release notes """ return [header] From a2dce426544282b448875b7d590a843661ef2a3d Mon Sep 17 00:00:00 2001 From: Tovly Deutsch Date: Thu, 3 Apr 2025 23:04:17 +0000 Subject: [PATCH 174/332] Split up cub-RadixSortPairs.cu to parallelize compilation (#148936) Summary: `cub-RadixSortPairs.cu` has slow compilation times, especially on Windows. These changes split up the file into smaller components to allow each component to compile in parallel. On Windows, I observed a compile time drop from about 20 minutes to 6 minutes. Differential Revision: D70539649 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148936 Approved by: https://github.com/suo, https://github.com/eqy, https://github.com/malfet --- .lintrunner.toml | 1 + aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu | 7 +++++++ aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu | 7 +++++++ ...b-RadixSortPairs.cu => cub-RadixSortPairs.cuh} | 15 ++------------- 13 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu create mode 100644 aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu rename aten/src/ATen/cuda/{cub-RadixSortPairs.cu => cub-RadixSortPairs.cuh} (82%) diff --git a/.lintrunner.toml b/.lintrunner.toml index e7541e6dabe5..376d916e3c65 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -271,6 +271,7 @@ exclude_patterns = [ 'torch/csrc/utils/generated_serialization_types.h', 'torch/csrc/utils/pythoncapi_compat.h', 'torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h', + 'aten/src/ATen/ExpandBase.h', ] init_command = [ 'python3', diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu new file mode 100644 index 000000000000..6c20daed2e02 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-f16-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu new file mode 100644 index 000000000000..2adb6a519882 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-1.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu new file mode 100644 index 000000000000..39e29b7668c9 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-2.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu new file mode 100644 index 000000000000..3ad1ebd2a56a --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int32-4.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu new file mode 100644 index 000000000000..098615b68345 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-1.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu new file mode 100644 index 000000000000..d58e0c8d5ce7 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-2.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu new file mode 100644 index 000000000000..fe24f72151fb --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-int64-4.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu new file mode 100644 index 000000000000..1373668316c2 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-scalars.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu new file mode 100644 index 000000000000..f52f97fe588a --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint16-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu new file mode 100644 index 000000000000..db28bb602acc --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint32-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu b/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu new file mode 100644 index 000000000000..7ad51b90b834 --- /dev/null +++ b/aten/src/ATen/cuda/cub-RadixSortPairs-uint64-8.cu @@ -0,0 +1,7 @@ +#include + +namespace at::cuda::cub::detail { + +AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8) + +} // namespace at::cuda::cub::detail diff --git a/aten/src/ATen/cuda/cub-RadixSortPairs.cu b/aten/src/ATen/cuda/cub-RadixSortPairs.cuh similarity index 82% rename from aten/src/ATen/cuda/cub-RadixSortPairs.cu rename to aten/src/ATen/cuda/cub-RadixSortPairs.cuh index 0eefb0824e59..bd40deb4125b 100644 --- a/aten/src/ATen/cuda/cub-RadixSortPairs.cu +++ b/aten/src/ATen/cuda/cub-RadixSortPairs.cuh @@ -1,3 +1,5 @@ +#pragma once + #define TORCH_ASSERT_NO_OPERATORS #include #include @@ -66,20 +68,7 @@ void radix_sort_pairs_impl( int64_t begin_bit, \ int64_t end_bit); -AT_INSTANTIATE_SORT_PAIRS(int32_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int32_t, 4) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 1) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 2) -AT_INSTANTIATE_SORT_PAIRS(int64_t, 4) - #define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \ AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8) -AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8) -AT_INSTANTIATE_SORT_PAIRS(uint16_t, 8) -AT_INSTANTIATE_SORT_PAIRS(uint32_t, 8) -AT_INSTANTIATE_SORT_PAIRS(uint64_t, 8) -AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8) - } // namespace at::cuda::cub::detail From 118e3862bc5ffae70f9e8d52df6657afc590012f Mon Sep 17 00:00:00 2001 From: William Wen Date: Thu, 3 Apr 2025 11:50:11 -0700 Subject: [PATCH 175/332] [dynamo] disable new test_assert_failure_in_generic_ctx_mgr internally (#150631) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150631 Approved by: https://github.com/clee2000 ghstack dependencies: #150471 --- test/dynamo/test_error_messages.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 1098b1bfbb2f..eef2512bcfe5 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -835,6 +835,7 @@ def fn(x): """, ) + @unittest.skipIf(IS_FBCODE, "assert gets patched in internal pytest") @make_logging_test(graph_breaks=True) def test_assert_failure_in_generic_ctx_mgr(self, records): def fn(x): From 5cf3029503e98af3267ccf517aae272b39caefbe Mon Sep 17 00:00:00 2001 From: Henry Hu Date: Thu, 3 Apr 2025 23:26:59 +0000 Subject: [PATCH 176/332] Remove unused rand call if not fallback to eager for rand (#147790) Fixes #147171 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147790 Approved by: https://github.com/eellison --- test/dynamo/test_repros.py | 3 +++ test/fx/test_dce_pass.py | 13 +++++++++++++ test/inductor/test_compiled_autograd.py | 3 +++ torch/fx/graph.py | 6 +++++- torch/fx/node.py | 12 ++++++++---- 5 files changed, 32 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 6d8c86923cf3..e03e14b78799 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6650,6 +6650,9 @@ def f(image_latent): torch.cuda.manual_seed_all(54321) expected = f(torch.randn((2, 12, 16, 32, 32))).sum() + # https://github.com/pytorch/pytorch/issues/147171 + torch._inductor.config.fallback_random = True + for backend in ["eager", "aot_eager"]: torch.manual_seed(54321) torch.cuda.manual_seed_all(54321) diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index e74b90f268da..4e11ed562254 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -232,6 +232,19 @@ def forward(self, a: torch.Tensor) -> torch.Tensor: # %add_ node should not be removed because it has side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False) + def test_impure_random(self): + """ + Test that DCE doesn't remove call_function for torch.rand. + """ + + class TestModule(torch.nn.Module): + def forward(self, a: torch.Tensor) -> torch.Tensor: + x = torch.rand([10]) # noqa: F841 + return a * 2 + + # %torch.rand should not be removed because it has side effects. + self._run_dce_and_test(TestModule(), expect_dce_changes=False) + def test_impure_kwargs(self): """ Test that DCE doesn't remove call_function nodes with side effects on kwargs. diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index 7294417ad08e..f10bf940e711 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3924,6 +3924,9 @@ def backward(ctx, gO): x = torch.randn(10, 10, requires_grad=True) + # https://github.com/pytorch/pytorch/issues/147171 + torch._inductor.config.fallback_random = True + @torch.compile(backend="aot_eager") def fn(x): return SideEffectfulBackward.apply(x).sum() diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 541a76942739..4a156dba0463 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1810,10 +1810,14 @@ def forward(self, x): # DCE below will not behave as expected. self.lint() + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + def has_side_effect(node): if is_impure_node is not None: return is_impure_node(node) - return node.is_impure() + return node.is_impure(impure_random) # Reverse iterate so that when we remove a node, any nodes used as an # input to that node have an updated user count that no longer reflects diff --git a/torch/fx/node.py b/torch/fx/node.py index 8433e9ea651b..59a946deec23 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -714,11 +714,14 @@ def maybe_replace_node(n: Node) -> Node: return [n for n in to_process if n not in skipped] @compatibility(is_backward_compatible=False) - def is_impure(self) -> bool: + def is_impure(self, impure_random: bool = True) -> bool: """ Returns whether this op is impure, i.e. if its op is a placeholder or output, or if a call_function or call_module which is impure. + Args: + impure_random (bool): Whether to treat rand op as impure. + Returns: bool: If the op is impure or not. @@ -732,9 +735,10 @@ def is_impure(self) -> bool: # impure since it mutates inputs return True - if getattr(self.target, "_nondeterministic_seeded", False): - # impure since it mutates RNG state - return True + if impure_random: + if getattr(self.target, "_nondeterministic_seeded", False): + # impure since it mutates RNG state + return True return self.target in _side_effectful_functions From 8878289f89ae59402f84e5de0ed06f0ca608ea41 Mon Sep 17 00:00:00 2001 From: Zhao Zhu Date: Thu, 3 Apr 2025 23:40:15 +0000 Subject: [PATCH 177/332] [aten] 8 bytes aligned vector loads for bf16 and fp16 dtypes in torch.cat (#150233) Enable aligned vector loading for 2 bytes datatypes in torch.cat. Specifically: 1. reduce the vector length to 8 bytes for 2-byte types (fp16, bf16 etc) 2. enable through a conditional template The reason why 8-byte vector loading was chosen for fp16 and bf16: 16-byte load results in heavier register overheads (i.e. 4 register per load for fp32 -> 8 register per load for fp16). Therefore, to employ the benefits of vectorized loading, we reduced ALIGNED_VEC_LOAD_BYTES to 8 for fp16 and bf16 ### perf testing: before: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022621 0.036162 1 1000.0 0.133616 0.207051 2 10000.0 1.326848 1.848768 3 20000.0 2.744544 3.692128 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.022434 0.035477 1 1000.0 0.140608 0.144518 2 10000.0 1.303792 1.229584 3 20000.0 2.668288 2.436160 ``` after: ``` torch-cat-D1-30108-D2-624-D3-772-dtype-torch.float32: B pt_eager copy 0 100.0 0.022608 0.036328 1 1000.0 0.133861 0.207399 2 10000.0 1.325120 1.847136 3 20000.0 2.726528 3.693184 torch-cat-D1-30108-D2-624-D3-772-dtype-torch.bfloat16: B pt_eager copy 0 100.0 0.019942 0.035482 1 1000.0 0.084858 0.144544 2 10000.0 0.924384 1.230672 3 20000.0 1.944448 2.436480 ``` ### bw analysis: bw on fp16/bf16 got increased by 40%-50% for large tensors before: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|869.87|1382.74|1956.46|1952.73|1969.03|1963.66 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|568.43|926.53|1589.20|1567.52|1771.54|1783.68 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|752.07|1269.50|1894.86|1900.85|1954.10|1955.08 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|807.08|1354.69|1960.48|1962.45|1972.73|1973.85 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|864.02|1398.02|1963.43|1955.32|1963.37|1969.96 ``` after: ``` Bandwidth (GB/s) for ((16384, 16384), 1) int8;fp16;fp32;int32;fp64;long|873.08|1892.16|1954.35|1962.51|1962.03|1965.98 Bandwidth (GB/s) for ((4194304,), 0) int8;fp16;fp32;int32;fp64;long|575.13|1242.45|1576.37|1571.30|1769.94|1790.22 Bandwidth (GB/s) for ((16777216,), 0) int8;fp16;fp32;int32;fp64;long|742.92|1734.57|1887.99|1897.62|1940.99|1959.25 Bandwidth (GB/s) for ((33554432,), 0) int8;fp16;fp32;int32;fp64;long|802.60|1865.45|1952.64|1947.53|1974.47|1973.48 Bandwidth (GB/s) for ((134217728,), 0) int8;fp16;fp32;int32;fp64;long|865.32|1939.07|1965.72|1963.25|1969.06|1968.72 ``` ### Perf testing code: ``` # pyre-strict from typing import List, Optional, Tuple import click import pandas as pd import torch # @manual=//triton:triton import triton # CUDA_VISIBLE_DEVICEs=7 buck2 run @mode/opt //scripts/zhaozhu:cat_bench @click.command() @click.option("--data-type", type=str, default="bf16") @click.option("--return-result", type=bool, default=False) def main( data_type: str, return_result: bool, ) -> Optional[Tuple[List[triton.testing.Benchmark], List[pd.DataFrame]]]: torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if data_type == "fp32": dtype = torch.float32 elif data_type == "fp16": dtype = torch.float16 elif data_type == "bf16": dtype = torch.bfloat16 else: raise ValueError(f"Unsupported data type: {data_type}.") D1 = int(torch.randint(low=10000, high=50000, size=(1,)).item()) D2 = int(torch.randint(low=100, high=1000, size=(1,)).item()) D3 = int(torch.randint(low=500, high=1000, size=(1,)).item()) configs: List[triton.testing.Benchmark] = [ triton.testing.Benchmark( x_names=["B"], x_vals=[100, 1000, 10000, 20000], line_arg="provider", line_vals=["pt_eager", "copy"], line_names=["pt_eager", "copy"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="ms", plot_name=f"torch-cat-D1-{D1}-D2-{D2}-D3-{D3}-dtype-{dtype}", args={ "D1": D1, "D2": D2, "D3": D3, "dtype": dtype, }, ) ] @triton.testing.perf_report(configs) def bench_cat( B: int, D1: int, D2: int, D3: int, dtype: torch.dtype, provider: str, ) -> float: warmup = 10 rep = 3 tensors = [] a = torch.empty( # (B, 30108), (B, D1), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) b = torch.empty( # (B, 624), (B, D2), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) c = torch.empty( # (B, 772), (B, D3), dtype=dtype, device=torch.device("cuda"), ).uniform_(-1.0, 1.0) tensors = [a, b, c] total_cols: int = int(a.shape[1] + b.shape[1] + c.shape[1]) def torch_copy( tensors: List[torch.Tensor], is_inplace: bool = True ) -> torch.Tensor: f = torch.zeros([B, total_cols], dtype=dtype, device=torch.device("cuda")) col_idx = 0 for t in tensors: temp = f[:, col_idx : col_idx + t.shape[1]] if is_inplace: temp.copy_(t) else: f[:, col_idx : col_idx + t.shape[1]] = t col_idx += t.shape[1] return f def torch_cat(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.cat(tensors, dim=1) ref = torch_cat(tensors) real = torch_copy(tensors, is_inplace=False) torch.testing.assert_allclose(ref, real) if provider == "pt_eager": fn = lambda: torch_cat(tensors) # noqa E731 ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "stack": def torch_stack(tensors: List[torch.Tensor]) -> torch.Tensor: return torch.stack(tensors, dim=1).view(-1, total_cols) fn = lambda: torch_stack(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms elif provider == "copy": fn = lambda: torch_copy(tensors) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms else: raise ValueError(f"unsupported provider: {provider}") df = bench_cat.run(print_data=True, return_df=return_result) if return_result: return configs, df if __name__ == "__main__": main() ``` and bw analysis code is from: https://github.com/pytorch/pytorch/pull/102815?fbclid=IwZXh0bgNhZW0CMTEAAR1Rwclp_O1fknl1Litpm9GeY0ZZZovdCv8_kQfGf6Zy8LaoP9JhO0ZsutM_aem_BPCZEZda5OOMnzI9Mrlapg#issue-1737409146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150233 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cuda/Shape.cu | 33 ++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index b2fd2dc85895..e2eb2226acf4 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -27,7 +27,8 @@ namespace at::native { constexpr int CAT_ARRAY_BATCH_SIZE = 128; constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; -constexpr int ALIGNED_VEC_LOAD_BYTES = 16; +constexpr int ALIGNED_VEC_LOAD_BYTES_16 = 16; +constexpr int ALIGNED_VEC_LOAD_BYTES_8 = 8; namespace { @@ -72,14 +73,14 @@ inline std::tuple getCatGridRocm(unsigned int max_elements_per_tenso return std::make_tuple(grid, block); } -template +template inline std::tuple getCatGridContig(unsigned int max_elements_per_tensor, ptrdiff_t nTensors) { constexpr unsigned int threads_per_block = 128; constexpr unsigned int min_aligned_vec_per_thread = 1; constexpr unsigned int max_tb_per_sm = 32; - unsigned int elements_per_thread = ALIGNED_VEC_LOAD_BYTES / sizeof(T) * + unsigned int elements_per_thread = aligned_vec_load_bytes / sizeof(T) * min_aligned_vec_per_thread; unsigned int max_threads = ceil_div(max_elements_per_tensor, elements_per_thread); unsigned int thread_blocks = ceil_div(max_threads, threads_per_block); @@ -230,16 +231,19 @@ __global__ void CatArrayBatchedCopy_contig( to improve memory bandwidth throughput. */ -template -__global__ void CatArrayBatchedCopy_aligned16_contig( +template +__global__ void CatArrayBatchedCopy_alignedK_contig( T* output, CatArrInputTensorMetadata inputs, TensorSizeStride os, const int concatDim, IndexType dimStride) { - // This kernel tries to use 128 bit loads - constexpr int kILP = ALIGNED_VEC_LOAD_BYTES / sizeof(T); + // This kernel tries to use aligned_vec_load_bytes*8 bit loads + // Special case 2-byte types to use 8-byte vec loads to reduce register pressure + // The below lambda is to allow cc compiler to pass kILP>0 checks for large types (e.g. ComplexDouble, 16 bytes) + constexpr int kILP = aligned_vec_load_bytes / sizeof(T) > 0 ? aligned_vec_load_bytes / sizeof(T) : ALIGNED_VEC_LOAD_BYTES_16/sizeof(T); + IndexType inputOffset = (blockIdx.x * blockDim.x + threadIdx.x) * kILP; IndexType inputStride = gridDim.x * blockDim.x * kILP; @@ -349,7 +353,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i isAligned = false; #else // If at least one of the inputs is not aligned, we can't call the - // CatArrayBatchedCopy_aligned16_contig + // CatArrayBatchedCopy_alignedK_contig isAligned &= is_aligned_vec4(catMetaData.input[batchCounter]); #endif @@ -385,7 +389,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i #else dim3 applyBlock, catGrid; if (isContig && sizeof(scalar_t) > 2) { - std::tie(catGrid, applyBlock) = getCatGridContig( + std::tie(catGrid, applyBlock) = getCatGridContig( + max_elements_per_tensor, batchCounter); + } else if (isContig && sizeof(scalar_t) == 2) { + std::tie(catGrid, applyBlock) = getCatGridContig( max_elements_per_tensor, batchCounter); } else { applyBlock = dim3(32 * 16); @@ -406,8 +413,12 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i } // Template Declarations for dim = 1, 2, 3, 4 #define HANDLE_CASE(DIMS) \ - if (isContig && isAligned && sizeof(scalar_t) >= 4 && sizeof(scalar_t) <= 8) {\ - CatArrayBatchedCopy_aligned16_contig<<<\ + if (isContig && isAligned && sizeof(scalar_t) > 2 && sizeof(scalar_t) <= 8) {\ + CatArrayBatchedCopy_alignedK_contig<<<\ + catGrid, applyBlock, 0, stream.stream()>>>(\ + data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ + } else if (isContig && isAligned && sizeof(scalar_t) == 2) { \ + CatArrayBatchedCopy_alignedK_contig<<<\ catGrid, applyBlock, 0, stream.stream()>>>(\ data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]);\ } else if (isContig) {\ From 1ab6c4ff0417f3c3dcce973d6f28b3895881540c Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Thu, 3 Apr 2025 23:50:13 +0000 Subject: [PATCH 178/332] [Codemod][AddExplicitStrictExportForTrainingInferenceArg] caffe2/ (#149595) internal diff: D71497480 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149595 Approved by: https://github.com/Skylion007 --- test/fx/test_matcher_utils.py | 12 +- test/inductor/test_aot_inductor.py | 4 +- test/quantization/pt2e/test_duplicate_dq.py | 5 +- .../pt2e/test_metadata_porting.py | 5 +- .../pt2e/test_numeric_debugger.py | 22 +- test/quantization/pt2e/test_quantize_pt2e.py | 54 +- .../pt2e/test_quantize_pt2e_qat.py | 27 +- test/quantization/pt2e/test_representation.py | 5 +- .../pt2e/test_x86inductor_quantizer.py | 7 +- .../pt2e/test_xnnpack_quantizer.py | 18 +- test/test_model_exports_to_core_aten.py | 4 +- torch/ao/quantization/pt2e/utils.py | 1 + torch/distributed/pipelining/_IR.py | 4 +- .../testing/_internal/common_quantization.py | 1113 +++++++++++------ 14 files changed, 797 insertions(+), 484 deletions(-) diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 26caf91485e2..578e0ab07a6a 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -173,7 +173,7 @@ def pattern(x, weight): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() before_split_res = pattern_gm(*example_inputs) pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm) @@ -204,11 +204,11 @@ def pattern(x, weight): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) target_gm = export_for_training( - WrapperModule(target_graph), example_inputs + WrapperModule(target_graph), example_inputs, strict=True ).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: @@ -248,9 +248,11 @@ def forward(self, x): return linear, {"linear": linear, "x": x} example_inputs = (torch.randn(3, 5),) - pattern_gm = export_for_training(Pattern(), example_inputs).module() + pattern_gm = export_for_training( + Pattern(), example_inputs, strict=True + ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = export_for_training(M(), example_inputs).module() + target_gm = export_for_training(M(), example_inputs, strict=True).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 501fbb49c2b4..973e720c7eb9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1837,7 +1837,9 @@ def forward(self, x): with config.patch( {"freezing": True, "aot_inductor.force_mmap_weights": True} ), torch.no_grad(): - exported_model = export_for_training(model, example_inputs).module() + exported_model = export_for_training( + model, example_inputs, strict=True + ).module() quantizer = X86InductorQuantizer() quantizer.set_global( xiq.get_default_x86_inductor_quantization_config(reduce_range=True) diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 54456ab37b15..4a5cb6edaeb6 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -101,10 +101,7 @@ def _test_duplicate_dq( # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 4f6eb4f56d3a..96eff3a789f2 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -98,10 +98,7 @@ def _test_metadata_porting( # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training( - m, - example_inputs, - ).module() + m = torch.export.export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index b5ada0cc3d59..deff8e4987e5 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -81,7 +81,7 @@ def _extract_debug_handles_with_prev_decomp_op_from_node(node): def test_simple(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) debug_handle_map = self._extract_debug_handles(ep) @@ -91,7 +91,7 @@ def test_simple(self): def test_control_flow(self): m = TestHelperModules.ControlFlow() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -102,7 +102,7 @@ def test_control_flow(self): def test_quantize_pt2e_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() @@ -162,14 +162,14 @@ def test_deepcopy_preserve_handle(self): def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() self._assert_each_node_has_debug_handle(ep) debug_handle_map_ref = self._extract_debug_handles(ep) - ep_reexport = export_for_training(m, example_inputs) + ep_reexport = export_for_training(m, example_inputs, strict=True) self._assert_each_node_has_debug_handle(ep_reexport) debug_handle_map = self._extract_debug_handles(ep_reexport) @@ -179,7 +179,7 @@ def test_re_export_preserve_handle(self): def test_run_decompositions_same_handle_id(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -204,7 +204,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): for m in test_models: example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -227,7 +227,7 @@ def test_run_decompositions_map_handle_to_new_nodes(self): def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_logger = prepare_for_propagation_comparison(m) @@ -244,7 +244,7 @@ def test_prepare_for_propagation_comparison(self): def test_extract_results_from_loggers(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -269,7 +269,7 @@ def test_extract_results_from_loggers(self): def test_extract_results_from_loggers_list_output(self): m = TestHelperModules.Conv2dWithSplit() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -299,7 +299,7 @@ def test_extract_results_from_loggers_list_output(self): def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) ref_handles = self._extract_debug_handles(ep) ref_counter = Counter(ref_handles.values()) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 2bc87f72fc25..08ffecc3aabd 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -767,10 +767,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, BackendAQuantizer()) # make sure the two observers for input are shared conv_output_obs = [] @@ -830,10 +827,7 @@ def _test_transitive_sharing_with_cat_helper(self, quantizer): ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) # make sure the two input observers and output are shared @@ -1152,10 +1146,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = BackendAQuantizer() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -1305,7 +1296,7 @@ def validate(self, model: torch.fx.GraphModule) -> None: m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = export_for_training(m, (example_inputs,)).module() + m = export_for_training(m, (example_inputs,), strict=True).module() with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1428,10 +1419,7 @@ def forward(self, x): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: if ( @@ -1518,7 +1506,7 @@ def forward(self, x): m = M().eval() quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1569,7 +1557,7 @@ def forward(self, x, y, z): torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3), ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1824,7 +1812,7 @@ def forward(self, x): example_inputs = (torch.randn(1),) m = M().train() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() if inplace: target = torch.ops.aten.dropout_.default else: @@ -1889,7 +1877,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() # Assert that batch norm op exists and is in train mode bn_node = self._get_node(m, bn_train_op) @@ -1920,7 +1908,7 @@ def test_disallow_eval_train(self): m.train() # After export: this is not OK - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -1961,7 +1949,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): targets = [n.target for n in m.graph.nodes] @@ -2027,7 +2015,7 @@ def forward(self, x): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() torch.ao.quantization.allow_exported_model_train_eval(m) # Mock m.recompile() to count how many times it's been called @@ -2059,7 +2047,7 @@ def _fake_recompile(): def test_model_is_exported(self): m = TestHelperModules.ConvWithBNRelu(relu=True) example_inputs = (torch.rand(3, 3, 5, 5),) - exported_gm = export_for_training(m, example_inputs).module() + exported_gm = export_for_training(m, example_inputs, strict=True).module() fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) self.assertTrue( torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm) @@ -2077,7 +2065,9 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) m(*example_inputs) m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) @@ -2085,7 +2075,7 @@ def test_reentrant(self): quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -2257,7 +2247,7 @@ def test_speed(self): def dynamic_quantize_pt2e(model, example_inputs): torch._dynamo.reset() - model = export_for_training(model, example_inputs).module() + model = export_for_training(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -2360,7 +2350,7 @@ def forward(self, x): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -2442,7 +2432,7 @@ def prepare_obs_or_fq_callback( edge_or_node_to_obs_or_fq[x] = new_observer example_inputs = (torch.rand(1, 32, 16, 16),) - gm = export_for_training(Model().eval(), example_inputs).module() + gm = export_for_training(Model().eval(), example_inputs, strict=True).module() gm = prepare_pt2e(gm, BackendAQuantizer()) gm = convert_pt2e(gm) for n in gm.graph.nodes: @@ -2469,7 +2459,9 @@ def check_nn_module(node): "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() for node in m.conv_bn_relu.graph.nodes: if node.op not in ["placeholder", "output", "get_attr"]: check_nn_module(node) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index abc9849aee82..b52f34c68c5b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -140,8 +140,7 @@ def _verify_symmetric_xnnpack_qat_numerics_helper( ) ) model_pt2e = export_for_training( - model_pt2e, - example_inputs, + model_pt2e, example_inputs, strict=True ).module() model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) torch.manual_seed(MANUAL_SEED) @@ -229,10 +228,7 @@ def _verify_symmetric_xnnpack_qat_graph_helper( quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -621,7 +617,7 @@ def forward(self, x): m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) @@ -679,7 +675,7 @@ def get_source_fn(node: torch.fx.Node): def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -726,7 +722,7 @@ def test_qat_conv_bn_bias_derived_qspec(self): def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -780,7 +776,7 @@ def test_qat_conv_transpose_bn_relu(self): def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -837,7 +833,7 @@ def test_fold_bn_erases_bn_node(self): it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = export_for_training(m, self.example_inputs).module() + m = export_for_training(m, self.example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), @@ -1085,7 +1081,9 @@ def _prepare_qat_linears(self, model): in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) - traced_child = export_for_training(child, example_input).module() + traced_child = export_for_training( + child, example_input, strict=True + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=True, is_qat=True @@ -1116,10 +1114,7 @@ def test_mixing_qat_ptq(self): self._convert_qat_linears(model) model(*example_inputs) - model_pt2e = export_for_training( - model, - example_inputs, - ).module() + model_pt2e = export_for_training(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index c6eed1ed8260..3648ac352dc4 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -33,10 +33,7 @@ def _test_representation( ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training( - model, - example_inputs, - ).module() + model = export_for_training(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 51b7ce72f74f..1c14ded72fe9 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -665,10 +665,7 @@ def _test_quantizer( # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -2344,7 +2341,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 36209e5aad10..4e14dfd27ae2 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -361,7 +361,7 @@ def forward(self, x): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -497,10 +497,7 @@ def test_propagate_annotation(self): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -766,8 +763,7 @@ def forward(self, input_tensor, hidden_tensor): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -829,8 +825,7 @@ def forward(self, input_tensor, hidden_tensor): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -1039,10 +1034,7 @@ def test_resnet18(self): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/test/test_model_exports_to_core_aten.py b/test/test_model_exports_to_core_aten.py index aae14c28b8d6..3d1c25939ec4 100644 --- a/test/test_model_exports_to_core_aten.py +++ b/test/test_model_exports_to_core_aten.py @@ -27,7 +27,9 @@ def test_vit_aten_export(self): m = m.eval() input_shape = (1, 3, 224, 224) example_inputs = (torch.randn(input_shape),) - m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() + m = torch.export.export_for_training( + m, copy.deepcopy(example_inputs), strict=True + ).module() m(*example_inputs) m = export.export(m, copy.deepcopy(example_inputs)) ops = _get_ops_list(m.graph_module) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 47e939f7596a..86304247d151 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -355,6 +355,7 @@ def _get_aten_graph_module_for_pattern( pattern, # type: ignore[arg-type] example_inputs, kwargs, + strict=True, ).module() aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 416965e80ba3..4e1b9676d7ca 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1003,9 +1003,7 @@ def _trace_with_export( logger.info("Tracing model ...") try: ep = torch.export.export_for_training( - mod, - example_args, - example_kwargs, + mod, example_args, example_kwargs, strict=True ) except Exception as e: raise RuntimeError( diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 07e7da55eafc..e114a37b04df 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -4,24 +4,48 @@ checking quantization api and properties of resulting modules. """ -from functorch.experimental import control_flow - import torch -import torch.nn as nn -import torch.nn.functional as F import torch.ao.nn.intrinsic.quantized.dynamic as nniqd import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.dynamic as nnqd -from torch.ao.nn.intrinsic import _FusedModule import torch.distributed as dist -from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM - -from torch.export import export_for_training +import torch.nn as nn +import torch.nn.functional as F +from functorch.experimental import control_flow +from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization import ( - QuantType, + convert, default_dynamic_qat_qconfig, + default_dynamic_qconfig, + default_dynamic_quant_observer, default_embedding_qat_qconfig, + default_observer, + default_per_channel_qconfig, + default_qconfig, default_symmetric_qnnpack_qat_qconfig, + default_weight_observer, + DeQuantStub, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qat_qconfig_mapping, + get_default_qconfig, + get_default_qconfig_mapping, + PerChannelMinMaxObserver, + propagate_qconfig_, + QConfig, + QConfigMapping, + quantize, + quantize_dynamic_jit, + quantize_jit, + QuantStub, + QuantType, + QuantWrapper, +) +from torch.ao.quantization.backend_config import get_executorch_backend_config +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, ) from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, @@ -29,83 +53,75 @@ prepare_pt2e, prepare_qat_pt2e, ) -from torch.ao.quantization.backend_config import ( - get_executorch_backend_config, -) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, get_symmetric_quantization_config, + XNNPACKQuantizer, ) -from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \ - default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \ - QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping -from torch.ao.quantization.quantization_mappings import ( - get_default_dynamic_quant_module_mappings, - get_default_qconfig_propagation_list, - get_default_qat_module_mappings, -) -from torch.testing._internal.common_quantized import ( - override_quantized_engine, -) + +from torch.export import export_for_training from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase try: + from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph + # graph mode quantization based on fx from torch.ao.quantization.quantize_fx import ( - prepare_fx, - prepare_qat_fx, convert_fx, convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, ) - from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph - from torch.fx.graph import Node from torch.fx import GraphModule + from torch.fx.graph import Node + HAS_FX = True except ImportError: HAS_FX = False +import contextlib import copy -import io import functools +import io import os import unittest +from typing import Any, Callable, Optional, Union + import numpy as np -from torch.testing import FileCheck -from typing import Callable, Any, Union, Optional import torch._dynamo as torchdynamo import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer -import contextlib +from torch.testing import FileCheck + class NodeSpec: - ''' Used for checking GraphModule Node - ''' + """Used for checking GraphModule Node""" + def __init__(self, op, target): - ''' + """ op: call_function | call_module target: for call_function, target would be a function for call_module, target would be the type of PyTorch module - ''' + """ self.op = op self.target = target @classmethod def call_function(cls, target): - return NodeSpec('call_function', target) + return NodeSpec("call_function", target) @classmethod def call_method(cls, target): - return NodeSpec('call_method', target) + return NodeSpec("call_method", target) @classmethod def call_module(cls, target): - return NodeSpec('call_module', target) + return NodeSpec("call_module", target) def __hash__(self): return hash((self.op, self.target)) @@ -119,8 +135,12 @@ def __eq__(self, other): def __repr__(self): return repr(self.op) + " " + repr(self.target) + def get_supported_device_types(): - return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu'] + return ( + ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] + ) + def test_only_eval_fn(model, calib_data): r""" @@ -130,7 +150,10 @@ def test_only_eval_fn(model, calib_data): for inp in calib_data: model(*inp) + _default_loss_fn = torch.nn.CrossEntropyLoss() + + def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): r""" Default train function takes a torch.utils.data.Dataset and train the model @@ -153,9 +176,11 @@ def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): correct += (predicted == target).sum().item() return train_loss, correct, total + class AverageMeter: """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): + + def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() @@ -173,7 +198,7 @@ def update(self, val, n=1): self.avg = self.sum / self.count def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) @@ -193,10 +218,11 @@ def accuracy(output, target, topk=(1,)): res.append(correct_k.mul_(100.0 / batch_size)) return res + def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): model.train() for cnt, (image, target) in enumerate(data_loader, start=1): - print('.', end='') + print(".", end="") image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) @@ -208,16 +234,19 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_bat return return + def ddp_setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) + def ddp_cleanup(): dist.destroy_process_group() + def run_ddp(rank, world_size, prepared): ddp_setup(rank, world_size) prepared.cuda() @@ -232,24 +261,42 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) + def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) + def _make_conv_test_input( - batch_size, in_channels_per_group, input_feature_map_size, - out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, - W_zero_point, use_bias, use_channelwise, + batch_size, + in_channels_per_group, + input_feature_map_size, + out_channels_per_group, + groups, + kernel_size, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + use_bias, + use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( - X_value_min, X_value_max, - (batch_size, in_channels,) + input_feature_map_size) + X_value_min, + X_value_max, + ( + batch_size, + in_channels, + ) + + input_feature_map_size, + ) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) + X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 + ) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels @@ -266,109 +313,132 @@ def _make_conv_test_input( # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( - W_value_min, W_value_max, - (out_channels, in_channels_per_group,) + kernel_size) + W_value_min, + W_value_max, + ( + out_channels, + in_channels_per_group, + ) + + kernel_size, + ) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) - W = W_scales_tensor.reshape(*W_shape) * ( - W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + W = ( + W_scales_tensor.reshape(*W_shape) + * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + ) b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( - W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0, - dtype=torch.qint8) + W, + W_scales_tensor.double(), + W_zero_points_tensor.long(), + 0, + dtype=torch.qint8, + ) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( - W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) + W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 + ) return (X, X_q, W, W_q, b if use_bias else None) + def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, - sizes # Infer the size of tensor to do the add + sizes, # Infer the size of tensor to do the add ) X = scale * (X_init - zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=torch.quint8) + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) return X, X_q + def skipIfNoFBGEMM(fn): - reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.' + reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." if isinstance(fn, type): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoQNNPACK(fn): - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) - with override_quantized_engine('qnnpack'): + with override_quantized_engine("qnnpack"): fn(*args, **kwargs) return wrapper + def skipIfNoONEDNN(fn): - reason = 'Quantized operations require ONEDNN.' + reason = "Quantized operations require ONEDNN." if isinstance(fn, type): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoONEDNNBF16(fn): - reason = 'Quantized operations require BF16 support.' + reason = "Quantized operations require BF16 support." if isinstance(fn, type): if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): fn.__unittest_skip__ = True @@ -381,24 +451,28 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoX86(fn): - reason = 'Quantized operations require X86.' + reason = "Quantized operations require X86." if isinstance(fn, type): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoDynamoSupport(fn): reason = "dynamo doesn't support." if isinstance(fn, type): @@ -413,8 +487,10 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoInductorSupport(fn): reason = "inductor doesn't support." if isinstance(fn, type): @@ -429,18 +505,23 @@ def wrapper(*args, **kwargs): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + try: import torchvision # noqa: F401 + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + def get_script_module(model, tracing, data): return torch.jit.trace(model, data) if tracing else torch.jit.script(model) + def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): """ Convert lengths to offsets for embedding_bag @@ -464,7 +545,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2 ** n_bit - 1 + max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 @@ -476,7 +557,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.int32).reshape(w.shape) - if out.device != torch.device('cpu'): + if out.device != torch.device("cpu"): out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) # Scales and zeros for the same q-group should be contiguous, so we can @@ -490,15 +571,15 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, - ).transpose(0, 1).contiguous() + ) + .transpose(0, 1) + .contiguous() ) return out, scales_and_zeros -def _group_quantize_tensor_symmetric( - w, n_bit=4, groupsize=32 -): +def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): # W is of shape [K x N] # We transpose W as Quantization is applied on [N x K] w = w.transpose(0, 1).contiguous() @@ -566,26 +647,47 @@ class QuantizationTestCase(TestCase): def setUp(self): super().setUp() self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] - self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)] - self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] - for _ in range(2)] - self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - - self.img_data_dict = {1 : self.img_data_1d, - 2 : self.img_data_2d, - 3 : self.img_data_3d} + self.train_data = [ + [ + torch.rand(2, 5, dtype=torch.float), + torch.randint(0, 1, (2,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] + self.img_data_2d = [ + [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) + ] + self.img_data_3d = [ + [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) + ] + self.img_data_1d_train = [ + [ + torch.rand(2, 3, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_2d_train = [ + [ + torch.rand(1, 3, 10, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_3d_train = [ + [ + torch.rand(1, 3, 5, 5, 5, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + + self.img_data_dict = { + 1: self.img_data_1d, + 2: self.img_data_2d, + 3: self.img_data_3d, + } # Quant types that produce statically quantized ops self.static_quant_types = [QuantType.STATIC, QuantType.QAT] @@ -594,75 +696,92 @@ def setUp(self): def checkNoPrepModules(self, module): r"""Checks the module does not contain child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertFalse(hasattr(module, 'quant')) - self.assertFalse(hasattr(module, 'dequant')) + self.assertFalse(hasattr(module, "quant")) + self.assertFalse(hasattr(module, "dequant")) def checkNoQconfig(self, module): - r"""Checks the module does not contain qconfig - """ - self.assertFalse(hasattr(module, 'qconfig')) + r"""Checks the module does not contain qconfig""" + self.assertFalse(hasattr(module, "qconfig")) for child in module.children(): self.checkNoQconfig(child) def checkHasPrepModules(self, module): r"""Checks the module contains child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertTrue(hasattr(module, 'module')) - self.assertTrue(hasattr(module, 'quant')) - self.assertTrue(hasattr(module, 'dequant')) + self.assertTrue(hasattr(module, "module")) + self.assertTrue(hasattr(module, "quant")) + self.assertTrue(hasattr(module, "dequant")) - def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None): + def checkObservers( + self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None + ): r"""Checks the module or module's leaf descendants - have observers in preparation for quantization + have observers in preparation for quantization """ if propagate_qconfig_list is None: propagate_qconfig_list = get_default_qconfig_propagation_list() if prepare_custom_config_dict is None: prepare_custom_config_dict = {} - float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + float_to_observed_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) # check if a module is a leaf module, ignoring activation_post_process attribute def is_leaf_module(module): submodule_name_count = 0 for name, _ in module.named_children(): - if name != 'activation_post_process': + if name != "activation_post_process": submodule_name_count += 1 return submodule_name_count == 0 - if hasattr(module, 'qconfig') and module.qconfig is not None and \ - ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) - and type(module) in propagate_qconfig_list) or - type(module) in float_to_observed_module_class_mapping.keys()) and \ - not isinstance(module, torch.ao.quantization.DeQuantStub): - self.assertTrue(hasattr(module, 'activation_post_process'), - 'module: ' + str(type(module)) + ' do not have observer') + if ( + hasattr(module, "qconfig") + and module.qconfig is not None + and ( + ( + is_leaf_module(module) + and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list + ) + or type(module) in float_to_observed_module_class_mapping.keys() + ) + and not isinstance(module, torch.ao.quantization.DeQuantStub) + ): + self.assertTrue( + hasattr(module, "activation_post_process"), + "module: " + str(type(module)) + " do not have observer", + ) # we don't need to check observers for child modules of the # qat modules - if type(module) not in get_default_qat_module_mappings().values() and \ - type(module) not in float_to_observed_module_class_mapping.values() and \ - not isinstance(module, _FusedModule): + if ( + type(module) not in get_default_qat_module_mappings().values() + and type(module) not in float_to_observed_module_class_mapping.values() + and not isinstance(module, _FusedModule) + ): for child in module.children(): if type(child) in [nn.Dropout]: continue - self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) + self.checkObservers( + child, propagate_qconfig_list, prepare_custom_config_dict + ) def checkQuantDequant(self, mod): r"""Checks that mod has nn.Quantize and - nn.DeQuantize submodules inserted + nn.DeQuantize submodules inserted """ self.assertEqual(type(mod.quant), nnq.Quantize) self.assertEqual(type(mod.dequant), nnq.DeQuantize) def checkWrappedQuantizedLinear(self, mod): r"""Checks that mod has been swapped for an nnq.Linear - module, the bias is qint32, and that the module - has Quantize and DeQuantize submodules + module, the bias is qint32, and that the module + has Quantize and DeQuantize submodules """ self.assertEqual(type(mod.module), nnq.Linear) self.checkQuantDequant(mod) @@ -672,14 +791,14 @@ def checkQuantizedLinear(self, mod): def checkDynamicQuantizedLinear(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) def checkDynamicQuantizedLinearRelu(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nniqd.LinearReLU) self.assertEqual(mod._packed_params.dtype, dtype) @@ -721,25 +840,35 @@ def check_weight_bias_api(self, ref_model, weight_keys, bias_keys): def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.LSTM type - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkLinear(self, mod): self.assertEqual(type(mod), torch.nn.Linear) def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) - if hasattr(mod, '_all_weight_values'): + if hasattr(mod, "_all_weight_values"): for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkScriptable(self, orig_mod, calib_data, check_save_load=False): scripted = torch.jit.script(orig_mod) @@ -770,20 +899,29 @@ def _checkModuleCorrectnessAgainstOrig(self, orig_mod, test_mod, calib_data): scripted_output = test_mod(*inp) self.assertEqual(scripted_output, ref_output) - - def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False, - check=True, eval_mode=True, dynamic=False, qconfig=None): + def checkGraphModeOp( + self, + module, + inputs, + quantized_op, + tracing=False, + debug=False, + check=True, + eval_mode=True, + dynamic=False, + qconfig=None, + ): if debug: - print('Testing:', str(module)) - qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} + print("Testing:", str(module)) + qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: - qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig} + qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} model = get_script_module(module, tracing, inputs[0]).eval() if debug: - print('input graph:', model.graph) + print("input graph:", model.graph) models = {} outputs = {} for debug in [True, False]: @@ -796,31 +934,37 @@ def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=Fa # input data staying constant for comparisons inputs_copy = copy.deepcopy(inputs) models[debug] = quantize_jit( - model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, - debug=debug) + model, + qconfig_dict, + test_only_eval_fn, + [inputs_copy], + inplace=False, + debug=debug, + ) # make sure it runs outputs[debug] = models[debug](*inputs[0]) if debug: - print('debug graph:', models[True].graph) - print('non debug graph:', models[False].graph) + print("debug graph:", models[True].graph) + print("non debug graph:", models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op - FileCheck().check(quantized_op) \ - .run(models[False].graph) + FileCheck().check(quantized_op).run(models[False].graph) return models[False] def checkGraphModuleNodes( - self, graph_module, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None): - """ Check if GraphModule contains the target node + self, + graph_module, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + ): + """Check if GraphModule contains the target node Args: graph_module: the GraphModule instance we want to check expected_node, expected_node_occurrence, expected_node_list: @@ -831,9 +975,9 @@ def checkGraphModuleNodes( modules = dict(graph_module.named_modules(remove_duplicate=False)) for node in graph_module.graph.nodes: n = None - if node.op == 'call_function' or node.op == 'call_method': + if node.op == "call_function" or node.op == "call_method": n = NodeSpec(node.op, node.target) - elif node.op == 'call_module': + elif node.op == "call_module": n = NodeSpec(node.op, type(modules[node.target])) if n is not None: @@ -844,26 +988,34 @@ def checkGraphModuleNodes( nodes_in_graph[n] = 1 if expected_node is not None: - self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) + - ' not found in the graph module') + self.assertTrue( + expected_node in nodes_in_graph, + "node:" + str(expected_node) + " not found in the graph module", + ) if expected_node_occurrence is not None: for expected_node, occurrence in expected_node_occurrence.items(): if occurrence != 0: self.assertTrue( expected_node in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' not found') + "Check failed for node:" + str(expected_node) + " not found", + ) self.assertTrue( nodes_in_graph[expected_node] == occurrence, - 'Check failed for node:' + str(expected_node) + - ' Expected occurrence:' + str(occurrence) + - ' Found occurrence:' + str(nodes_in_graph[expected_node])) + "Check failed for node:" + + str(expected_node) + + " Expected occurrence:" + + str(occurrence) + + " Found occurrence:" + + str(nodes_in_graph[expected_node]), + ) else: self.assertTrue( expected_node not in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' expected no occurrence but found') + "Check failed for node:" + + str(expected_node) + + " expected no occurrence but found", + ) if expected_node_list is not None: cur_index = 0 @@ -874,20 +1026,21 @@ def checkGraphModuleNodes( cur_index += 1 self.assertTrue( cur_index == len(expected_node_list), - "Check failed for graph:" + - self.printGraphModule(graph_module, print_str=False) + - "Expected ordered list:" + - str(expected_node_list)) + "Check failed for graph:" + + self.printGraphModule(graph_module, print_str=False) + + "Expected ordered list:" + + str(expected_node_list), + ) def printGraphModule(self, graph_module, print_str=True): modules = dict(graph_module.named_modules(remove_duplicate=False)) node_infos = [] for n in graph_module.graph.nodes: - node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) - if n.op == 'call_module': - node_info += ' module type: ' + repr(type(modules[n.target])) + node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) + if n.op == "call_module": + node_info += " module type: " + repr(type(modules[n.target])) node_infos.append(node_info) - str_to_print = '\n'.join(node_infos) + str_to_print = "\n".join(node_infos) if print_str: print(str_to_print) return str_to_print @@ -897,7 +1050,9 @@ def printGraphModule(self, graph_module, print_str=True): def assert_types_for_matched_subgraph_pairs( self, matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], - expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]], + expected_types: dict[ + str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] + ], gm_a: GraphModule, gm_b: GraphModule, ) -> None: @@ -917,16 +1072,16 @@ def assert_types_for_matched_subgraph_pairs( def _get_underlying_op_type( node: Node, gm: GraphModule ) -> Union[Callable, str]: - if node.op == 'call_module': + if node.op == "call_module": mod = getattr(gm, node.target) return type(mod) else: - assert node.op in ('call_function', 'call_method') + assert node.op in ("call_function", "call_method") return node.target self.assertTrue( len(matched_subgraph_pairs) == len(expected_types), - f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}' + f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", ) for k, v in expected_types.items(): expected_types_a, expected_types_b = v @@ -938,14 +1093,16 @@ def _get_underlying_op_type( act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) - types_match = (exp_type_start_a is act_type_start_a) and \ - (exp_type_end_a is act_type_end_a) and \ - (exp_type_start_b is act_type_start_b) and \ - (exp_type_end_b is act_type_end_b) + types_match = ( + (exp_type_start_a is act_type_start_a) + and (exp_type_end_a is act_type_end_a) + and (exp_type_start_b is act_type_start_b) + and (exp_type_end_b is act_type_end_b) + ) self.assertTrue( types_match, - f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, ' - f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}' + f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " + f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", ) def assert_ns_compare_dict_valid( @@ -962,48 +1119,53 @@ def assert_ns_compare_dict_valid( for result_type, layer_data in result_type_to_data.items(): self.assertTrue( len(layer_data) == 2, - f"Layer {layer_name} does not have exactly two model results.") + f"Layer {layer_name} does not have exactly two model results.", + ) model_name_0, model_name_1 = layer_data.keys() for res_idx in range(len(layer_data[model_name_0])): layer_data_0 = layer_data[model_name_0][res_idx] layer_data_1 = layer_data[model_name_1][res_idx] self.assertTrue( - layer_data_0['type'] == layer_data_0['type'], - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.") + layer_data_0["type"] == layer_data_0["type"], + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", + ) self.assertTrue( - len(layer_data_0['values']) == - len(layer_data_1['values']), - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.") + len(layer_data_0["values"]) == len(layer_data_1["values"]), + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", + ) # F.conv1d weight has rank 3, and toq.conv1d unpacked weight # has rank 4. For now, skip the length check for conv1d only. is_weight_functional_conv1d = ( - result_type == NSSingleResultValuesType.WEIGHT.value and - ( - 'conv1d' in layer_data_0['prev_node_target_type'] or - 'conv1d' in layer_data_1['prev_node_target_type'] + result_type == NSSingleResultValuesType.WEIGHT.value + and ( + "conv1d" in layer_data_0["prev_node_target_type"] + or "conv1d" in layer_data_1["prev_node_target_type"] ) ) if not is_weight_functional_conv1d: - for idx in range(len(layer_data_0['values'])): - values_0 = layer_data_0['values'][idx] - values_1 = layer_data_1['values'][idx] + for idx in range(len(layer_data_0["values"])): + values_0 = layer_data_0["values"][idx] + values_1 = layer_data_1["values"][idx] if isinstance(values_0, torch.Tensor): self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) elif isinstance(values_0, list): values_0 = values_0[0] values_1 = values_1[0] self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) else: - assert isinstance(values_0, tuple), \ - f"unhandled type {type(values_0)}" + assert isinstance( + values_0, tuple + ), f"unhandled type {type(values_0)}" assert len(values_0) == 2 assert len(values_0[1]) == 2 assert values_0[0].shape == values_1[0].shape @@ -1011,80 +1173,91 @@ def assert_ns_compare_dict_valid( assert values_0[1][1].shape == values_1[1][1].shape # verify that ref_node_name is valid - ref_node_name_0 = layer_data_0['ref_node_name'] - ref_node_name_1 = layer_data_1['ref_node_name'] - prev_node_name_0 = layer_data_0['prev_node_name'] - prev_node_name_1 = layer_data_1['prev_node_name'] - if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value: + ref_node_name_0 = layer_data_0["ref_node_name"] + ref_node_name_1 = layer_data_1["ref_node_name"] + prev_node_name_0 = layer_data_0["prev_node_name"] + prev_node_name_1 = layer_data_1["prev_node_name"] + if ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_OUTPUT.value + ): self.assertTrue(ref_node_name_0 == prev_node_name_0) self.assertTrue(ref_node_name_1 == prev_node_name_1) - elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value: + elif ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_INPUT.value + ): self.assertTrue(ref_node_name_0 != prev_node_name_0) self.assertTrue(ref_node_name_1 != prev_node_name_1) def checkGraphModeFxOp( - self, - model, - inputs, - quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - is_reference=False, - print_debug_info=False, - custom_qconfig_dict=None, - prepare_expected_node=None, - prepare_expected_node_occurrence=None, - prepare_expected_node_list=None, - prepare_custom_config=None, - backend_config=None): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node - - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurrences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - is_reference: if True, enables reference mode - print_debug_info: if True, prints debug info - custom_qconfig_dict: overrides default qconfig_dict - prepare_expected_node: same as expected_node, but for prepare - prepare_expected_node_occurrence: same as - expected_node_occurrence, but for prepare - prepare_expected_node_list: same as expected_node_list, but - for prepare - - Returns: - A dictionary with the following structure: - { - "prepared": ..., # the prepared model - "quantized": ..., # the quantized non-reference model - "quantized_reference": ..., # the quantized reference model - "result": ..., # the result for either quantized or - # quantized_reference model depending on the - # is_reference argument - } + self, + model, + inputs, + quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + is_reference=False, + print_debug_info=False, + custom_qconfig_dict=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config=None, + backend_config=None, + ): + """Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node + + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurrences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + is_reference: if True, enables reference mode + print_debug_info: if True, prints debug info + custom_qconfig_dict: overrides default qconfig_dict + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare + + Returns: + A dictionary with the following structure: + { + "prepared": ..., # the prepared model + "quantized": ..., # the quantized non-reference model + "quantized_reference": ..., # the quantized reference model + "result": ..., # the result for either quantized or + # quantized_reference model depending on the + # is_reference argument + } """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: - qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qat_qconfig_mapping( + torch.backends.quantized.engine + ) model.train() elif quant_type == QuantType.STATIC: - qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qconfig_mapping( + torch.backends.quantized.engine + ) model.eval() else: qconfig = default_dynamic_qconfig @@ -1098,30 +1271,37 @@ def checkGraphModeFxOp( # overwrite qconfig_dict with custom_qconfig_dict if custom_qconfig_dict is not None: - assert type(custom_qconfig_dict) in (QConfigMapping, dict), \ - 'custom_qconfig_dict should be a QConfigMapping or a dict' + assert type(custom_qconfig_dict) in ( + QConfigMapping, + dict, + ), "custom_qconfig_dict should be a QConfigMapping or a dict" if isinstance(custom_qconfig_dict, QConfigMapping): qconfig_mapping = custom_qconfig_dict else: qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) prepared = prepare( - model, qconfig_mapping, + model, + qconfig_mapping, example_inputs=inputs, prepare_custom_config=prepare_custom_config, - backend_config=backend_config) + backend_config=backend_config, + ) if not quant_type == QuantType.DYNAMIC: prepared(*inputs) if print_debug_info: print() - print('quant type:\n', quant_type) - print('original model:\n', model) + print("quant type:\n", quant_type) + print("original model:\n", model) print() - print('prepared model:\n', prepared) + print("prepared model:\n", prepared) self.checkGraphModuleNodes( - prepared, prepare_expected_node, - prepare_expected_node_occurrence, prepare_expected_node_list) + prepared, + prepare_expected_node, + prepare_expected_node_occurrence, + prepare_expected_node_list, + ) prepared_copy = copy.deepcopy(prepared) qgraph = convert_fx(copy.deepcopy(prepared)) @@ -1134,20 +1314,34 @@ def checkGraphModeFxOp( qgraph_to_check = qgraph_reference if is_reference else qgraph if print_debug_info: print() - print('quantized model:\n', qgraph_to_check) + print("quantized model:\n", qgraph_to_check) self.printGraphModule(qgraph_to_check) print() self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) - return {"prepared": prepared_copy, - "quantized": qgraph_copy, - "quantized_reference": qgraph_reference_copy, - "quantized_output": result, - "quantized_reference_output": result_reference} - - - def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, - set_qconfig, is_emb_bag, dtype=torch.quint8): + qgraph_to_check, + expected_node, + expected_node_occurrence, + expected_node_list, + ) + return { + "prepared": prepared_copy, + "quantized": qgraph_copy, + "quantized_reference": qgraph_reference_copy, + "quantized_output": result, + "quantized_reference_output": result_reference, + } + + def checkEmbeddingSerialization( + self, + qemb, + num_embeddings, + embedding_dim, + indices, + offsets, + set_qconfig, + is_emb_bag, + dtype=torch.quint8, + ): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] @@ -1169,33 +1363,49 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic # Check state dict serialization and torch.save APIs if is_emb_bag: - loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum', dtype=dtype) + loaded_qemb = nnq.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + mode="sum", + dtype=dtype, + ) else: - loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) + loaded_qemb = nnq.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype + ) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) - self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight), - embedding_unpack(loaded_qemb._packed_params._packed_weight)) - + self.assertEqual( + embedding_unpack(qemb._packed_params._packed_weight), + embedding_unpack(loaded_qemb._packed_params._packed_weight), + ) # Test JIT serialization self.checkScriptable(qemb, [inputs], check_save_load=True) # Test from_float call if is_emb_bag: - float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + float_embedding = torch.nn.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) else: - float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + float_embedding = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) if set_qconfig: - float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, - qscheme=torch.per_channel_affine_float_qparams, - ch_axis=0) - float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer, - weight=float_qparams_observer) + float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 + ) + float_embedding.qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=float_qparams_observer + ) prepare_dynamic(float_embedding) @@ -1211,6 +1421,7 @@ def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indic self.assertTrue(expected_name in str(q_embeddingbag)) + class QuantizationLiteTestCase(QuantizationTestCase): def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): # Creates quantized model for testing mobile script modules @@ -1223,9 +1434,7 @@ def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): return model - def _compare_script_and_mobile(self, - model: torch.nn.Module, - input: torch.Tensor): + def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): # Compares the numerical outputs for script and lite modules qengine = "qnnpack" with override_quantized_engine(qengine): @@ -1236,18 +1445,28 @@ def _compare_script_and_mobile(self, for retry in range(1, max_retry + 1): # retries `max_retry` times; breaks iff succeeds else throws exception try: - buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) + buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter() + ) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) - torch.testing.assert_close(script_module_result, mobile_module_result) + torch.testing.assert_close( + script_module_result, mobile_module_result + ) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_close(script_module_result, mobile_module_forward_result) - - mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_close(script_module_result, mobile_module_run_method_result) + torch.testing.assert_close( + script_module_result, mobile_module_forward_result + ) + + mobile_module_run_method_result = mobile_module.run_method( + "forward", input + ) + torch.testing.assert_close( + script_module_result, mobile_module_run_method_result + ) except AssertionError as e: if retry == max_retry: raise e @@ -1260,6 +1479,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. """ + _MAP_TO_FX_TRACED_OPS = { torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, @@ -1297,6 +1517,7 @@ def _test_quantizer( m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() if is_qat: @@ -1337,6 +1558,7 @@ def _test_quantizer( m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() node_occurrence = {} for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): @@ -1344,7 +1566,8 @@ def _test_quantizer( node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] if training_ir_node_occurrence is not None: node_occurrence = { - ns.call_function(k): v for k, v in training_ir_node_occurrence.items() + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) @@ -1355,10 +1578,7 @@ def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: @@ -1377,14 +1597,18 @@ def forward(self, x): return self.linear(x) quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel) + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() return self._quantize(m, quantizer, example_inputs) + # Below are a series of toy models to use in testing quantization + class SingleLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1397,8 +1621,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class AnnotatedSingleLayerLinearModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) @@ -1410,8 +1635,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerLinearDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) @@ -1423,6 +1649,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearAddModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -1438,38 +1665,41 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class RNNDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRU': + if mod_type == "GRU": self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) - if mod_type == 'LSTM': + if mod_type == "LSTM": self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class RNNCellDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRUCell': + if mod_type == "GRUCell": self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) - if mod_type == 'LSTMCell': + if mod_type == "LSTMCell": self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) - if mod_type == 'RNNReLU': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float) - if mod_type == 'RNNTanh': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float) + if mod_type == "RNNReLU": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) + if mod_type == "RNNTanh": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class LSTMwithHiddenDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) @@ -1478,6 +1708,7 @@ def forward(self, x, hid): x, hid = self.lstm(x, hid) return x, hid + class ConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1490,6 +1721,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvTransposeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1502,6 +1734,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1519,6 +1752,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvTransposeModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1536,6 +1770,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1550,6 +1785,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1569,6 +1805,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnReLUModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1585,8 +1822,9 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnReLUModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) @@ -1606,13 +1844,18 @@ def forward(self, x): def fuse_model(self): # TODO: remove this check and define two fuse_modules function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv", "bn", "relu"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv", "bn", "relu"]], inplace=True + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1627,6 +1870,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1641,6 +1885,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearModelWithSubmodule(nn.Module): def __init__(self) -> None: super().__init__() @@ -1655,6 +1900,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.subm.get_example_inputs() + class AnnotatedTwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1670,6 +1916,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ActivationsTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1686,6 +1933,7 @@ def forward(self, x): x = self.dequant(x) return x + class LinearReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1716,6 +1964,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearReluAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1734,6 +1983,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearBnLeakyReluModel(torch.nn.Module): def __init__(self, with_bn=True): super().__init__() @@ -1752,6 +2002,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearTanhModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1766,13 +2017,16 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ConvBnAddReluModel(torch.nn.Module): - def __init__(self, - with_bn=True, - with_relu=True, - left_conv=True, - two_conv=True, - use_torch_add=True): + def __init__( + self, + with_bn=True, + with_relu=True, + left_conv=True, + two_conv=True, + use_torch_add=True, + ): super().__init__() self.conv = nn.Conv2d(5, 5, (2, 2)) self.conv2 = nn.Conv2d(5, 5, (2, 2)) @@ -1826,6 +2080,7 @@ def forward(self, x1, x2): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) + # TODO: self.fc should be self.conv class ConvReluModel(torch.nn.Module): def __init__(self) -> None: @@ -1840,6 +2095,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluConvModel(torch.nn.Module): def __init__(self) -> None: @@ -1857,6 +2113,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluAddModel(torch.nn.Module): def __init__(self) -> None: @@ -1876,6 +2133,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class NormalizationTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1897,6 +2155,7 @@ def forward(self, x): x = self.instance_norm3d(x.unsqueeze(-1)) return x + class NestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1910,6 +2169,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedNestedModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1918,7 +2178,7 @@ def __init__(self, qengine): self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) - if qengine == 'fbgemm': + if qengine == "fbgemm": self.sub2.fc1.qconfig = default_per_channel_qconfig else: self.sub2.fc1.qconfig = default_qconfig @@ -1929,6 +2189,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedSubNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1944,6 +2205,7 @@ def forward(self, x): x = self.fc3(x) return x + class AnnotatedCustomConfigNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1953,12 +2215,11 @@ def __init__(self) -> None: self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig - custom_options = { - 'dtype': torch.quint8, - 'qscheme': torch.per_tensor_affine - } - custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options), - weight=default_weight_observer) + custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} + custom_qconfig = QConfig( + activation=default_observer.with_args(**custom_options), + weight=default_weight_observer, + ) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) @@ -1970,6 +2231,7 @@ def forward(self, x): x = self.fc3(x) return x + class QuantSubModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1985,6 +2247,7 @@ def forward(self, x): x = self.fc3(x) return x + class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2004,14 +2267,14 @@ def fuse_modules(self): if idx >= len(named_children) - 1: break if isinstance(named_children[idx + 1][1], torch.nn.ReLU): - fusable_layers.append([current_name, - named_children[idx + 1][0]]) + fusable_layers.append([current_name, named_children[idx + 1][0]]) # TODO: remove this check and define two fuse_modules function on this module if self.training: torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) else: torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + class FunctionalLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2024,6 +2287,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2036,6 +2300,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class TwoLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2050,6 +2315,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2065,6 +2331,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2078,6 +2345,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear.get_example_inputs() + class FunctionalLinearReluLinearModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2094,6 +2362,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalConv2d(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2105,11 +2374,20 @@ def __init__(self) -> None: self.groups = 1 def forward(self, x): - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class SingleLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2122,6 +2400,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class TwoLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2136,6 +2415,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class FunctionalConvReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2149,6 +2429,7 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv.get_example_inputs() + class FunctionalConvReluConvModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2165,10 +2446,12 @@ def forward(self, x): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class SkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self) -> None: super().__init__() self.sub = InnerModule() @@ -2180,10 +2463,12 @@ def forward(self, x): def fuse_modules(self): self.sub.fuse_modules() + class AnnotatedSkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) @@ -2198,9 +2483,10 @@ def forward(self, x): def fuse_modules(self): self.sub.module.fuse_modules() + class QuantStubModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self) -> None: super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") @@ -2213,9 +2499,10 @@ def forward(self, x): x = self.fc(x) return self.dequant(x) + class ManualLinearQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2230,9 +2517,10 @@ def forward(self, x): x = self.fc2(x) return self.dequant(x) + class ManualDropoutQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2247,9 +2535,10 @@ def forward(self, x): x = self.dropout(x) return self.dequant(x) + class ManualLinearDynamicQATModel(torch.nn.Module): - r"""A Module that uses a dynamic QAT by default. - """ + r"""A Module that uses a dynamic QAT by default.""" + def __init__(self, qconfig=None): super().__init__() self.qconfig = qconfig or default_dynamic_qat_qconfig @@ -2261,13 +2550,19 @@ def forward(self, x): x = self.fc2(x) return x + class ManualConvLinearQATModel(torch.nn.Module): r"""A module with manually inserted `QuantStub` and `DeQuantStub` and contains both linear and conv modules """ + def __init__(self, qconfig=None): super().__init__() - self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + self.qconfig = ( + qconfig + if qconfig + else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + ) self.quant = QuantStub() self.dequant = DeQuantStub() self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) @@ -2282,30 +2577,38 @@ def forward(self, x): x = self.fc2(x) return self.dequant(x) + class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. Supported only with qnnpack. """ + def __init__(self) -> None: super().__init__(default_symmetric_qnnpack_qat_qconfig) + class ManualEmbeddingBagLinear(nn.Module): def __init__(self) -> None: super().__init__() - self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') + self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") self.emb.qconfig = default_embedding_qat_qconfig self.quant = QuantStub() self.dequant = DeQuantStub() self.linear = nn.Linear(12, 1).to(dtype=torch.float) self.qconfig = get_default_qat_qconfig("qnnpack") - def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, - per_sample_weights: Optional[torch.Tensor] = None): + def forward( + self, + input: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + per_sample_weights: Optional[torch.Tensor] = None, + ): x = self.emb(input, offsets, per_sample_weights) x = self.quant(x) x = self.linear(x) return self.dequant(x) + class DeFusedEmbeddingBagLinear(nn.Module): r"""A module to simulate QAT embedding bag with a linear layer, this module uses a separate embedding and bagging op, similar @@ -2313,6 +2616,7 @@ class DeFusedEmbeddingBagLinear(nn.Module): https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html """ + def __init__(self) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) @@ -2329,6 +2633,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: x = self.linear(x) return self.dequant(x) + class SubModelForFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2350,6 +2655,7 @@ def __init__(self) -> None: def forward(self, x): return self.relu(self.conv(x)) + class ModelForFusion(nn.Module): def __init__(self, qconfig): super().__init__() @@ -2396,14 +2702,14 @@ def forward(self, x): y = self.dequant(y) return x + class ConvBNReLU(nn.Sequential): def __init__(self) -> None: super().__init__( - nn.Conv2d(3, 3, 1, 1, bias=False), - nn.BatchNorm2d(3), - nn.ReLU(inplace=False) + nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) ) + class ModelWithSequentialFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2428,6 +2734,7 @@ def forward(self, x): x = self.dequant(x) return x + class ModelForFusionWithBias(nn.Module): def __init__(self) -> None: super().__init__() @@ -2449,6 +2756,7 @@ def forward(self, x): x = self.dequant(x) return x + class ModelForLinearBNFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2460,6 +2768,7 @@ def __init__(self) -> None: def forward(self, x): return self.bn(self.fc(x)) + class DummyObserver(torch.nn.Module): def calculate_qparams(self): return 1.0, 0 @@ -2543,9 +2852,14 @@ def forward(self, x): def fuse_model(self): # TODO: remove this check and define two fuse_model function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + class ModelMultipleOps(torch.nn.Module): def __init__(self) -> None: @@ -2578,6 +2892,7 @@ def forward(self, x): out = self.fc(out) return out + # Model to ensure consistency of fake quant with true quant # Average pooling and mean operations are not modelled # accurately with fake-quant so this model does not @@ -2612,15 +2927,22 @@ def forward(self, x): out = self.fc(out) return out + class EmbeddingBagModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.emb = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) def forward(self, indices, offsets, per_sample_weights): return self.emb(indices, offsets, per_sample_weights) + class EmbeddingModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2629,6 +2951,7 @@ def __init__(self) -> None: def forward(self, indices): return self.emb(indices) + class EmbeddingWithStaticLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2647,9 +2970,11 @@ def forward(self, indices, offsets, linear_in): features = torch.cat([fc] + [emb], dim=1) return features -class DenseTopMLP(nn.Module): - def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None: +class DenseTopMLP(nn.Module): + def __init__( + self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out + ) -> None: super().__init__() self.dense_mlp = nn.Sequential( @@ -2671,16 +2996,18 @@ def forward( out = self.top_mlp(features) return out + # thin wrapper around embedding bag, because tracing inside nn.Embedding # bag is not supported at the moment and this is top level class EmbBagWrapper(nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() - self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum') + self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") def forward(self, indices, offsets): return self.emb_bag(indices, offsets) + class SparseNNModel(nn.Module): _NUM_EMBEDDINGS = 10 _EMBEDDING_DIM = 5 @@ -2695,8 +3022,12 @@ def __init__(self) -> None: self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) self.dense_top = DenseTopMLP( - self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN, - self._TOP_OUT_OUT) + self._DENSE_DIM, + self._DENSE_OUTPUT, + self._EMBEDDING_DIM, + self._TOP_OUT_IN, + self._TOP_OUT_OUT, + ) def forward( self, @@ -2704,12 +3035,12 @@ def forward( sparse_offsets: torch.Tensor, dense: torch.Tensor, ) -> torch.Tensor: - sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) out = self.dense_top(sparse_feature, dense) return out + class TestHelperModules: class ControlFlow(torch.nn.Module): def forward( @@ -2719,7 +3050,6 @@ def forward( pred2: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: - def true_nested(y: torch.Tensor) -> torch.Tensor: y = y + y y = torch.mm(y, y) @@ -2736,7 +3066,10 @@ def false_fn(x: torch.Tensor, _) -> torch.Tensor: return x.cos() def map_fn( - x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor + x: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, ) -> torch.Tensor: x = x.cos() y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) @@ -2747,7 +3080,12 @@ def map_fn( return control_flow.map(map_fn, xs, pred1, pred2, y) def example_inputs(self): - return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),) + return ( + torch.ones(2, 2), + torch.tensor([False]), + torch.tensor([False]), + torch.ones(2, 2), + ) class Conv2dPropAnnotaton(torch.nn.Module): def __init__(self) -> None: @@ -3029,16 +3367,20 @@ def forward(self, x): x = self.relu(self.fc(x)) return x + def _generate_qdq_quantized_model( mod, inputs, is_qat=False, is_dynamic=False, quantizer=None ): - def get_default_quantizer(is_qat, is_dynamic, inputs): - has_xpu = any(isinstance(input, torch.Tensor) and input.device.type == "xpu" - for input in inputs) + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) if has_xpu: quantizer = XPUInductorQuantizer() - assert (not is_qat) and (not is_dynamic), "QAT and dynamic quantization is not supported at XPU backend currently" + assert (not is_qat) and ( + not is_dynamic + ), "QAT and dynamic quantization is not supported at XPU backend currently" quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) else: quantizer = X86InductorQuantizer() @@ -3051,12 +3393,11 @@ def get_default_quantizer(is_qat, is_dynamic, inputs): maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training( - mod, - inputs, - ).module() + export_model = export_for_training(mod, inputs, strict=True).module() quantizer = ( - quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs) + quantizer + if quantizer + else get_default_quantizer(is_qat, is_dynamic, inputs) ) prepare_model = ( prepare_qat_pt2e(export_model, quantizer) From b0e28f60df6906fa75ff99f6ae64be0bba9fbf33 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 3 Apr 2025 23:51:46 +0000 Subject: [PATCH 179/332] Revert "add unit test for preferred_blas_library settings (#150581)" This reverts commit 781d28e2655f88ae2fef827ed110f22ed553a0ab. Reverted https://github.com/pytorch/pytorch/pull/150581 on behalf of https://github.com/clee2000 due to new test broken internally D72395624 ([comment](https://github.com/pytorch/pytorch/pull/150581#issuecomment-2777228731)) --- test/test_cuda.py | 58 ----------------------------------------------- 1 file changed, 58 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index 4f4fb5148a7a..a3cc62c5e1d4 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,64 +595,6 @@ def test_serialization_array_with_storage(self): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) - @setBlasBackendsToDefaultFinally - def test_preferred_blas_library_settings(self): - def _check_default(): - default = torch.backends.cuda.preferred_blas_library() - if torch.version.cuda: - # CUDA logic is easy, it's always cublas - self.assertTrue(default == torch._C._BlasBackend.Cublas) - else: - # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else - gcn_arch = str( - torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] - ) - if gcn_arch in ["gfx90a", "gfx942", "gfx950"]: - self.assertTrue(default == torch._C._BlasBackend.Cublaslt) - else: - self.assertTrue(default == torch._C._BlasBackend.Cublas) - - _check_default() - # "Default" can be set but is immediately reset internally to the actual default value. - self.assertTrue( - torch.backends.cuda.preferred_blas_library("default") - != torch._C._BlasBackend.Default - ) - _check_default() - self.assertTrue( - torch.backends.cuda.preferred_blas_library("cublas") - == torch._C._BlasBackend.Cublas - ) - self.assertTrue( - torch.backends.cuda.preferred_blas_library("hipblas") - == torch._C._BlasBackend.Cublas - ) - # check bad strings - with self.assertRaisesRegex( - RuntimeError, - "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", - ): - torch.backends.cuda.preferred_blas_library("unknown") - # check bad input type - with self.assertRaisesRegex(RuntimeError, "Unknown input value type."): - torch.backends.cuda.preferred_blas_library(1.0) - # check env var override - custom_envs = [ - {"TORCH_BLAS_PREFER_CUBLASLT": "1"}, - {"TORCH_BLAS_PREFER_HIPBLASLT": "1"}, - ] - test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())" - for env_config in custom_envs: - env = os.environ.copy() - for key, value in env_config.items(): - env[key] = value - r = ( - subprocess.check_output([sys.executable, "-c", test_script], env=env) - .decode("ascii") - .strip() - ) - self.assertEqual("_BlasBackend.Cublaslt", r) - @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @setBlasBackendsToDefaultFinally def test_cublas_workspace_explicit_allocation(self): From 1bc2b2b12ae1ddd27b0401a1baac3b8099b6fc50 Mon Sep 17 00:00:00 2001 From: Avik Chaudhuri Date: Fri, 4 Apr 2025 00:15:32 +0000 Subject: [PATCH 180/332] bound sympy accuracy (#150383) Differential Revision: D72215735 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150383 Approved by: https://github.com/pianpwk --- test/export/test_export.py | 22 ++++++++++++++++++++++ torch/utils/_sympy/value_ranges.py | 17 +++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/test/export/test_export.py b/test/export/test_export.py index 988e2fae81c6..5eefb67c14b6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3105,6 +3105,28 @@ def forward(self, x, y): "dy - 6 = 6" not in exc.args[0] ) # don't suggest fix for non-root dim + @testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1) + @testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0) + def test_bound_sympy_accuracy(self): + class Foo(torch.nn.Module): + def forward(self, x): + expr = x.shape[0] - (x.shape[0] % 8) + return torch.empty(expr) + + ep = export( + Foo(), + (torch.randn(13),), + dynamic_shapes={"x": (Dim("dim", min=2),)}, + ) + + (output,) = ep.graph.output_node().args[0] + sym_node = output.meta["val"].shape[0].node + vr = torch.utils._sympy.value_ranges.bound_sympy( + sym_node.expr, + sym_node.shape_env.var_to_range, + ) + self.assertEqual(vr.lower, 0) + @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") def test_keep_composite_ops_invalid(self): class Foo(torch.nn.Module): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 784f9e7ba051..118959b8c4db 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1004,6 +1004,22 @@ def trunc(x): return ValueRanges.increasing_map(x, TruncToFloat) +def _rewrite_for_value_range_analysis(expr: sympy.Expr): + """ + Sometimes accuracy of value range analysis can be improved + with simple rewriting rules. + """ + + # Rewrite X - X%Y to (X//Y) * Y. + x, y = sympy.Wild("x"), sympy.Wild("y") + expr = expr.replace( + x - torch.utils._sympy.functions.Mod(x, y), + torch.utils._sympy.functions.FloorDiv(x, y) * y, + ) + + return expr + + def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: @@ -1047,6 +1063,7 @@ def missing_handler(s): vr = ValueRanges.unknown() return vr + expr = _rewrite_for_value_range_analysis(expr) return sympy_interp( SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler ) From d0026fa1383ced00140b1379889d6afcb2d082f6 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 4 Apr 2025 01:11:59 +0000 Subject: [PATCH 181/332] [ROCm][TunableOp] Fix UT race condition and reduce UT duration. (#150463) This PR fixes two race conditions that occur when UT tests are run: - In a particular order within a single shard. - Concurrently in multiple shards. Each test now gets a unique filename that depends on the test name. There were two other minor improvements to the UTs: - matmul_offline_mgpu could occasionally fail if run on 8 GPUs. Criteria was relaxed. - bmm_tunableop_rocm checks that the rotating buffer is not zero. Otherwise, the test is not useful. Additionally, several UTs took over 1 minute to run. Their duration was reduced by a combination of setting max tuning iterations to one, setting the rotating buffer size to zero, and/or reducing the matrix dimensions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150463 Approved by: https://github.com/jeffdaily --- test/test_linalg.py | 243 +++++++++++++++++++++++--------------------- 1 file changed, 127 insertions(+), 116 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index b49bed2a2e93..97c56796bbb9 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -65,22 +65,7 @@ def blaslt_supported_device(): return True return False -def set_tunableop_defaults(): - if not torch.cuda.is_available(): - # TunableOp not supported on CPU at this time. - return - - # disable TunableOp and restore to default values - torch.cuda.tunable.enable(False) - torch.cuda.tunable.record_untuned_enable(False) - torch.cuda.tunable.tuning_enable(True) - torch.cuda.tunable.set_max_tuning_duration(30) - torch.cuda.tunable.set_max_tuning_iterations(100) - torch.cuda.tunable.set_rotating_buffer_size(-1) - ordinal = torch.cuda.current_device() - torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv") - -def tunableop_matmul(device, dtype, offline=False): +def tunableop_matmul(device, dtype, result_filename=None, offline=False): # Helper function to test TunableOp in a subprocess # requires helper function since lambda function # not supported by multiprocessing module @@ -90,6 +75,9 @@ def tunableop_matmul(device, dtype, offline=False): if offline: torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) + else: + if result_filename is not None: + torch.cuda.tunable.set_filename(result_filename) torch.cuda.tunable.set_max_tuning_duration(1) A = torch.randn((17, 17), device=device, dtype=dtype) @@ -109,31 +97,13 @@ def find_tunableop_result(results, OpSig, ParamSig): return inner_tuple return None -def compare_untuned_tuned_entries(untuned_filename, tuned_filename): - # Compare the entries of untuned and tuned Tunableop results - # file. Verify that for each Op+Param Signature in the untuned file - # there is a matching one in the tuned results file. - import csv - ok = False - with open(untuned_filename) as file1: - with open(tuned_filename) as file2: - untuned_reader = csv.reader(file1) - untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader} - - tuned_reader = csv.reader(file2) - for _ in range(5): # Skip the first 5 lines for the validator - next(tuned_reader, None) - - result_csv_entries = {(row[0], row[1]) for row in tuned_reader} - - missing = untuned_csv_entries - result_csv_entries - - if missing: - ok = False - else: - ok = True - - return ok +def get_tunableop_untuned_filename(): + import os + ordinal = torch.cuda.current_device() + untuned_filename_env = os.getenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME") + untuned_filename_base, _, _ = untuned_filename_env.rpartition('.') + untuned_filename = f"{untuned_filename_base}{ordinal}.csv" + return untuned_filename class TestLinalg(TestCase): @contextlib.contextmanager @@ -165,7 +135,7 @@ def _tunableop_ctx(self): # Inialize and then tear down TunableOp import glob import os - set_tunableop_defaults() + self._set_tunableop_defaults() torch.cuda.tunable.enable(True) try: @@ -175,7 +145,13 @@ def _tunableop_ctx(self): torch.cuda.tunable.enable(False) # clean up, remove any files that were generated - for file in glob.glob("tunableop*.csv"): + results_filename = torch.cuda.tunable.get_filename() + results_filename_pattern, _, _ = results_filename.rpartition('.') + untuned_filename = get_tunableop_untuned_filename() + untuned_filename_pattern, _, _ = untuned_filename.rpartition('.') + patterns = [f"{results_filename_pattern[:-1]}*.csv", f"{untuned_filename_pattern[:-1]}*.csv"] + files = [f for pattern in patterns for f in glob.glob(pattern)] + for file in files: try: os.remove(file) # NB: The file is locked on Windows @@ -194,6 +170,59 @@ def _tunableop_ctx(self): except KeyError: pass + def _set_tunableop_defaults(self): + if not torch.cuda.is_available(): + # TunableOp not supported on CPU at this time. + return + + # disable TunableOp and restore to default values + torch.cuda.tunable.enable(False) + torch.cuda.tunable.record_untuned_enable(False) + torch.cuda.tunable.tuning_enable(True) + torch.cuda.tunable.set_max_tuning_duration(30) + torch.cuda.tunable.set_max_tuning_iterations(100) + torch.cuda.tunable.set_rotating_buffer_size(-1) + ordinal = torch.cuda.current_device() + + # Set filenames to be unique on a per test basis + import os + unique_id = self.id().split(".")[-1] + torch.cuda.tunable.set_filename(f"tunableop_results_{unique_id}_{ordinal}.csv") + # ordinal gets automatically appended + os.environ["PYTORCH_TUNABLEOP_UNTUNED_FILENAME"] = f"tunableop_untuned_{unique_id}_.csv" + + def _compare_untuned_tuned_entries(self, untuned_filename=None, tuned_filename=None): + # Compare the entries of untuned and tuned Tunableop results + # file. Verify that for each Op+Param Signature in the untuned file + # there is a matching one in the tuned results file. + import csv + ok = False + ordinal = torch.cuda.current_device() + if untuned_filename is None: + untuned_filename = get_tunableop_untuned_filename() + if tuned_filename is None: + tuned_filename = torch.cuda.tunable.get_filename() + + with open(untuned_filename) as file1: + with open(tuned_filename) as file2: + untuned_reader = csv.reader(file1) + untuned_csv_entries = {(row[0], row[1]) for row in untuned_reader} + + tuned_reader = csv.reader(file2) + for _ in range(5): # Skip the first 5 lines for the validator + next(tuned_reader, None) + + result_csv_entries = {(row[0], row[1]) for row in tuned_reader} + + missing = untuned_csv_entries - result_csv_entries + + if missing: + ok = False + else: + ok = True + + return ok + exact_dtype = True @dtypes(torch.float, torch.cfloat) @@ -4693,16 +4722,18 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): make_arg = partial(make_tensor, device=device, dtype=dtype) # Using gen_sizes_matmul(2) to ensure we cover # 'NN', 'TN', 'TT', and 'NN' cases. - for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2), (True, False), (True, False)): + for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(2, y_dim=3), + (True, False), (True, False)): x = make_arg(size_x, noncontiguous=nctg_x) y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) filename1 = torch.cuda.tunable.get_filename() - filename2 = "tunableop_results_tmp1.csv" - filename3 = "tunableop_results_tmp2.csv" + unique_id = self.id().split(".")[-1] + filename2 = f"{filename1}_tmp1.csv" + filename3 = f"{filename1}_tmp2.csv" ordinal = torch.cuda.current_device() - assert filename1 == f"tunableop_results{ordinal}.csv" + assert filename1 == f"tunableop_results_{unique_id}_{ordinal}.csv" assert len(torch.cuda.tunable.get_results()) > 0 assert torch.cuda.tunable.write_file() # use default filename @@ -4720,6 +4751,10 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): assert file1_contents == file2_contents assert file1_contents == file3_contents + # We need to reset the filename to the default value so we can properly + # clean up intermediate files + self._set_tunableop_defaults() + @onlyCUDA @skipCUDAIfNotRocm @dtypes(torch.half) @@ -4728,7 +4763,6 @@ def test_matmul_offline_tunableop(self, device, dtype): # NOTE: The offline tuning does not support certain tensor # shapes as noted below. Submatrics / matrix slices are # not supported at all. - import os def has_any_dim_size_one(tensor: torch.Tensor): """Check if any dimension of a PyTorch tensor has size 1.""" @@ -4750,7 +4784,6 @@ def is_bmm_compatible(A, B): torch.cuda.tunable.set_rotating_buffer_size(0) ordinal = torch.cuda.current_device() - result_filename = f"tunableop_results{ordinal}.csv" # record GEMM torch.cuda.tunable.tuning_enable(False) @@ -4821,8 +4854,7 @@ def is_bmm_compatible(A, B): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -4839,12 +4871,8 @@ def is_bmm_compatible(A, B): self.assertGreater(new_results - ref_results, 0) self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @onlyCUDA @@ -4853,14 +4881,11 @@ def is_bmm_compatible(A, B): @dtypes(torch.torch.float8_e4m3fnuz, torch.float8_e5m2fnuz) def test_scaled_gemm_offline_tunableop(self, device, dtype): # This test is the offline version of test_scaled_gemm_tunableop - import os with self._tunableop_ctx(): ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -4910,8 +4935,7 @@ def test_scaled_gemm_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -4937,12 +4961,8 @@ def test_scaled_gemm_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @unittest.skipIf(not TEST_MULTIGPU, "Requires at least 2 GPUs") @@ -4960,7 +4980,11 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): total_gpus = torch.cuda.device_count() ordinal = torch.cuda.current_device() - untuned_filename = f"tunableop_untuned{ordinal}.csv" + + # Untuned filename has unique id, but results file + # does not because it is executed in a subprocess + untuned_filename = get_tunableop_untuned_filename() + torch.cuda.tunable.set_filename(f"tunableop_results{ordinal}.csv") # turn on untuned GEMM recording and turn off tuning torch.cuda.tunable.tuning_enable(False) @@ -4985,19 +5009,14 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): torch.cuda.tunable.mgpu_tune_gemm_in_file(untuned_filename, total_gpus) # check the results files where written, one per gpu - # get the size of the first result and make sure it - # greater than 100. Since the validator text should - # be at least that much. - # The other results file will have - # at least the size of the first results file - 80 + # Check that the results file is not empty and store + # that in a local variable for the next loop. for i in range(total_gpus): result_filename = f"tunableop_results{i}.csv" self.assertTrue(os.path.exists(result_filename)) + self.assertGreater(os.path.getsize(result_filename), 0) if i == 0: # Store for next loop result_size = os.path.getsize(result_filename) - self.assertGreater(os.path.getsize(result_filename), 0) - self.assertGreater(os.path.getsize(result_filename), result_size - 80) - # Check the full results files was written, one per gpu # check that the size of the full results file for @@ -5018,6 +5037,7 @@ def test_matmul_offline_mgpu_tunableop(self, device, dtype): def test_rotating_buffer_tunableop(self, device, dtype): # Test the TunableOp rotating buffer API # Test the default value, will return the l2_cache_size + self._set_tunableop_defaults() l2_cache_size = torch.cuda.tunable.get_rotating_buffer_size() self.assertGreater(l2_cache_size, 0) # Test zero @@ -5038,6 +5058,9 @@ def test_bmm_tunableop_rocm(self, device, dtype): # buffer rotation (on by default) with strided batched gemm tunableop was causing a mem fault with self._tunableop_ctx(): torch.cuda.tunable.set_max_tuning_iterations(10) + # Make sure the rotating buffer is not zero, otherwise this test does nothing useful. + rotating_buffer = torch.cuda.tunable.get_rotating_buffer_size() + self.assertGreater(rotating_buffer, 0) # the following 3 cases cover all previous failure cases and are here to catch regressions B = 16 N = M = K = 256 @@ -5082,21 +5105,21 @@ def test_bmm_tunableop_rocm(self, device, dtype): @onlyCUDA @skipCUDAIfNotRocm - @dtypes(torch.float) + @dtypes(torch.bfloat16) def test_numeric_check_leak_tunableop_rocm(self, device, dtype): import os from torch.testing._internal.common_utils import CudaMemoryLeakCheck # run operator first without tuning to ensure all rocm libs are loaded, # otherwise false positive mem leak - B = 16 - N = M = K = 256 - dtype = torch.bfloat16 + B = 5 + N = M = K = 29 device = torch.device("cuda:0") i1 = torch.randn((B, N, M), device=device, dtype=dtype) i2 = torch.randn((B, M, K), device=device, dtype=dtype) out = torch.bmm(i1, i2) with self._tunableop_ctx(): + torch.cuda.tunable.set_rotating_buffer_size(0) # enable tunableop numeric check via env variable. os.environ["PYTORCH_TUNABLEOP_NUMERICAL_CHECK"] = "1" @@ -5213,9 +5236,9 @@ def test_disable_tuning_tunableop(self, device, dtype): ref_num_results = len(torch.cuda.tunable.get_results()) # Tune one GEMMs to make sure TunableOp is enabled - M = 3 - N = 3 - K = 3 + M = 11 + N = 13 + K = 17 A = torch.randn(N, K, device=device, dtype=dtype) B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) @@ -5234,9 +5257,9 @@ def test_disable_tuning_tunableop(self, device, dtype): torch.cuda.tunable.tuning_enable(False) # Try to tune one more GEMM - M = 3 - N = 3 - K = 4 + M = 11 + N = 13 + K = 18 A = torch.randn(N, K, device=device, dtype=dtype) B = torch.randn(K, M, device=device, dtype=dtype) C = torch.matmul(A, B) @@ -5257,8 +5280,7 @@ def test_dump_results_on_exit_tunableop(self, device, dtype): import multiprocessing as mp with self._tunableop_ctx(): - ordinal = torch.cuda.current_device() - filename = f"tunableop_results{ordinal}.csv" + filename = torch.cuda.tunable.get_filename() # force=True needed according to: # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method @@ -5266,7 +5288,7 @@ def test_dump_results_on_exit_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, filename, False)) p.start() p.join() @@ -5305,14 +5327,11 @@ def test_gemm_bias_tunableop(self, device, dtype): @dtypes(torch.bfloat16) def test_gemm_bias_offline_tunableop(self, device, dtype): # This test is the offline version of test_gemm_bias_tunableop - import os ordinal = torch.cuda.current_device() with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -5330,8 +5349,7 @@ def test_gemm_bias_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.is_enabled()) self.assertTrue(torch.cuda.tunable.tuning_is_enabled() is False) - untuned_filename = f"tunableop_untuned{ordinal}.csv" - self.assertTrue(os.path.exists(untuned_filename)) + untuned_filename = get_tunableop_untuned_filename() # tuning the untuned GEMMs in file torch.cuda.tunable.tuning_enable(True) @@ -5353,12 +5371,8 @@ def test_gemm_bias_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) @onlyCUDA @@ -5378,6 +5392,7 @@ def test_scaled_gemm_tunableop(self, device, dtype): # tested by PyTorch with self._tunableop_ctx(): # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_rotating_buffer_size(0) torch.cuda.tunable.set_max_tuning_iterations(1) # Reference number of results @@ -5386,9 +5401,9 @@ def test_scaled_gemm_tunableop(self, device, dtype): # Scaled GEMM parameters fillA = 0.25 fillB = 0.75 - n = 32 - m = 64 - k = 128 + n = 64 + m = 16 + k = 32 scaleA = torch.tensor(0.8, device=device) scaleB = torch.tensor(0.9, device=device) @@ -5519,8 +5534,6 @@ def test_tf32_offline_tunableop(self, device, dtype): ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) - result_filename = f"tunableop_results{ordinal}.csv" - # record GEMM torch.cuda.tunable.tuning_enable(False) torch.cuda.tunable.record_untuned_enable(True) @@ -5535,7 +5548,7 @@ def test_tf32_offline_tunableop(self, device, dtype): torch.backends.cuda.matmul.allow_tf32 = False C = torch.matmul(A, B) - untuned_filename = f"tunableop_untuned{ordinal}.csv" + untuned_filename = get_tunableop_untuned_filename() self.assertTrue(os.path.exists(untuned_filename)) # tuning the untuned GEMMs in file @@ -5569,12 +5582,8 @@ def test_tf32_offline_tunableop(self, device, dtype): self.assertTrue(torch.cuda.tunable.write_file()) - # Make sure the results file exists and that it is not zero - self.assertTrue(os.path.exists(result_filename)) - self.assertGreater(os.path.getsize(result_filename), 0) - # Compare Param Signature of untuned and tuned results - ok = compare_untuned_tuned_entries(untuned_filename, result_filename) + ok = self._compare_untuned_tuned_entries() self.assertTrue(ok) finally: @@ -5606,10 +5615,11 @@ def test_blaslog_tunableop(self, device, dtype): with self._tunableop_ctx(): os.putenv("PYTORCH_TUNABLEOP_BLAS_LOG", "1") - ordinal = torch.cuda.current_device() - - result_filename = f"tunableop_results{ordinal}.csv" - untuned_filename = f"tunableop_untuned{ordinal}.csv" + # TunableOp is running in a subprocess + # online tuning needs filename set through API + # offline tuning needs filename set through environment variableq + result_filename = torch.cuda.tunable.get_filename() + untuned_filename = get_tunableop_untuned_filename() # Offline Tuning case in a subprocess @@ -5619,7 +5629,7 @@ def test_blaslog_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype, True)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, None, True)) p.start() p.join() @@ -5646,7 +5656,7 @@ def test_blaslog_tunableop(self, device, dtype): # already set the start method mp.set_start_method("spawn", force=True) - p = mp.Process(target=tunableop_matmul, args=(device, dtype, False)) + p = mp.Process(target=tunableop_matmul, args=(device, dtype, result_filename, False)) p.start() p.join() @@ -6868,7 +6878,8 @@ def test_addmm_relu(self, device, dtype): @bf32_on_and_off(0.05) def test_addmm_relu_tunableop_rocm(self, device, dtype): with self._tunableop_ctx(): - torch.cuda.tunable.set_max_tuning_iterations(10) + torch.cuda.tunable.set_rotating_buffer_size(0) + torch.cuda.tunable.set_max_tuning_iterations(1) self._test_addmm_impl(torch._addmm_activation, "relu", device, dtype) From f9f6c080d8309ac1c5a546a47571389bac0b922c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 28 Mar 2025 04:57:38 -0700 Subject: [PATCH 182/332] support guard or false/true in user code and add tests (#150178) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150178 Approved by: https://github.com/pianpwk --- test/test_dynamic_shapes.py | 98 ++++++++++++++++++++++++++++++++ torch/_dynamo/trace_rules.py | 2 + torch/_dynamo/variables/torch.py | 22 +++++++ 3 files changed, 122 insertions(+) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index a3458efbe65b..6b7a2d3edcfc 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -2815,6 +2815,104 @@ def test_guards_float_print(self): guards = shape_env.produce_guards_expression([s0]) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_guard_or_true(self): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + def func(a, b): + x = a.item() + if guard_or_true(x == 1): + return b * 10 + else: + return b * 20 + + # call with guarding. + self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) + self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) + + unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func(a, b) + + # always return b*10 + self.assertEqual( + unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) + ) + self.assertEqual( + unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) + ) + + # Test that statically known true works. + def func2(a, b): + x = a.item() + if guard_or_true(x != x): + return b * 10 + else: + return b * 20 + + unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func2(a, b) + # always return b*20 + self.assertEqual( + unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + self.assertEqual( + unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) + ) + + @skipIfTorchDynamo("Not a TorchDynamo suitable test") + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_guard_or_false(self): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + def func(a, b): + x = a.item() + if guard_or_false(x == 1): + return b * 10 + else: + return b * 20 + + # call with guarding. + self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) + self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) + + unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func(a, b) + + # always return b*20 + self.assertEqual( + unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + self.assertEqual( + unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) + ) + + # Test that statically known true works. + def func2(a, b): + x = a.item() + if guard_or_false(x == x): + return b * 10 + else: + return b * 20 + + unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) + a = torch.tensor([1]) + b = torch.tensor([1]) + unbacked_func2(a, b) + # always return b*10 + self.assertEqual( + unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) + ) + self.assertEqual( + unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) + ) + def test_guards_float_div(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 8) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4b3eb10d09e7..42bbf9a0623f 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -308,6 +308,8 @@ "torch._dynamo.mark_static": UserFunctionVariable, "torch._dynamo.nonstrict_trace": UserFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, "torch.cuda._get_device_properties": TorchInGraphFunctionVariable, "torch.utils.hooks.BackwardHook": TorchInGraphFunctionVariable, "torch.set_default_device": UserFunctionVariable, diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 40821a16e5e5..13aaf715f8ac 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -897,6 +897,28 @@ def handle_guard_size_oblivious(self, tx: "InstructionTranslator", expr): elif isinstance(expr, ConstantVariable): return expr + @register(torch.fx.experimental.symbolic_shapes.guard_or_true) + def handle_guard_or_true(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_true(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + + @register(torch.fx.experimental.symbolic_shapes.guard_or_false) + def handle_guard_or_false(self, tx: "InstructionTranslator", expr): + if isinstance(expr, SymNodeVariable): + # TODO: this probably should be folded somewhere else but I'm not sure where + # TODO: some of the other symbolic_shapes special tools can also get this treatment too + return variables.ConstantVariable.create( + torch.fx.experimental.symbolic_shapes.guard_or_false(expr.sym_num) + ) + elif isinstance(expr, ConstantVariable): + return expr + @register(torch._C._autograd._unsafe_set_version_counter) def handle_unsafe_set_version_counter( self, tx: "InstructionTranslator", *args, **kwargs From 1979a409e92da533785c1340e14086ead744da43 Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 3 Apr 2025 13:28:40 -0700 Subject: [PATCH 183/332] Make CompileEventLogger more defensive w.r.t to AOTAutogradCache and FXGraphCache (#150423) This PR makes it so that we don't crash due to logging if we invoke AOTAutogradCache/FXGraphCache without using dynamo. This is preparation for supporting certain VLLM use cases where they store graph modules and have special handling in conjunection with the caches. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150423 Approved by: https://github.com/oulgen --- torch/_dynamo/utils.py | 17 ++- .../_aot_autograd/autograd_cache.py | 116 +++++++++--------- torch/_inductor/codecache.py | 37 ++++-- 3 files changed, 102 insertions(+), 68 deletions(-) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 8fa038ce7116..2a09d8943409 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -579,12 +579,27 @@ def instant( @staticmethod def try_add_pt2_compile(event_name: str, **metadata: object): """ - Adds to an existing pt2_compile event, but silently returns if the event doesn't exist. + Adds to an existing pt2_compile event, but silently returns if the event doesn't exist + or ChromiumEventLogger is not initialized. This function is syntactic sugar for chromium_event_logger().try_add_event_data. """ + if CHROMIUM_EVENT_LOG is None: + return chromium_log = get_chromium_event_logger() chromium_log.try_add_event_data(event_name, **metadata) + @staticmethod + def try_(method_fn, *args, **kwargs): + """ + Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set + """ + if CHROMIUM_EVENT_LOG is None: + return + metrics_context = get_metrics_context() + if not metrics_context.in_progress(): + return + method_fn(*args, **kwargs) + @contextmanager def dynamo_timed( diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index f23fbc84bad2..6e31070fd7a5 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -18,7 +18,12 @@ import torch from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions -from torch._dynamo.utils import CompileEventLogger, counters +from torch._dynamo.utils import ( + CHROMIUM_EVENT_LOG, + CompileEventLogger, + counters, + dynamo_timed, +) from torch._functorch import config from torch._inductor.codecache import ( _ident, @@ -549,45 +554,45 @@ def wrap_post_compile( torch._logging.trace_structured( "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str ) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func = self.compiled_fw.load(args) + compiled_bw_func = None + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + else: + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } - compiled_fw_func = self.compiled_fw.load(args) - compiled_bw_func = None - if self.compiled_bw is not None: - compiled_bw_func = self.compiled_bw.load(args) - needs_autograd = True - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - # Now that we've loaded forward and backward, call post compile on both - # This avoids setting things like BoxedBools in fx_config until - # after both forward and backward cache hit - fw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - bw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": True, - } - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, fw_fx_config - ) - compiled_bw_func = self.compiled_bw.post_compile( - compiled_bw_func, bw_fx_config - ) - else: - inference_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - - needs_autograd = False - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, inference_fx_config - ) + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) # Wrap the forward function in post compile wrappers compiled_fw_func = AOTDispatchSubclassWrapper( @@ -600,7 +605,7 @@ def wrap_post_compile( ) req_subclass_dispatch = self.maybe_subclass_meta is not None - CompileEventLogger.pt2_compile( + CompileEventLogger.try_add_pt2_compile( "backend_compile", requires_subclass_dispatch=req_subclass_dispatch ) @@ -843,21 +848,22 @@ def load( "components": debug_lines, } ) - CompileEventLogger.instant( - f"autograd_cache_{cache_state}", - metadata=cache_info, - time_ns=cache_event_time, - ) - CompileEventLogger.try_add_pt2_compile( - "backend_compile", - cache_state=cache_state, - cache_event_time=cache_event_time, - key=cache_info.get("key"), - components=cache_info.get("components"), - cache_bypass_reason=cache_info.get("cache_bypass_reason"), - remote_cache_enabled=remote, - local_cache_enabled=local, - ) + if CHROMIUM_EVENT_LOG: + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) torch._logging.trace_structured( "artifact", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 177b53e3e999..8a01ae5d6429 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1047,12 +1047,17 @@ def iterate_over_candidates() -> Generator[ triton_bundler_meta = TritonBundler.read_and_emit(bundle) if (meta := triton_bundler_meta) is not None: cache_info["triton_bundler_meta"] = str(meta) - # TODO: Clean up autograd cache integration CompileEventLogger.try_add_pt2_compile( "inductor_compile", cached_kernel_names=meta.cached_kernel_names ) + CompileEventLogger.try_add_pt2_compile( + "AOTAutogradCache.inductor_load", + cached_kernel_names=meta.cached_kernel_names, + ) if len(meta.cached_kernel_names) > 0: - CompileEventLogger.increment_toplevel("num_triton_bundles") + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, "num_triton_bundles" + ) try: artifact_path = graph.after_deserialization(constants) @@ -1306,17 +1311,22 @@ def load_with_key( cache_info["cache_state"] = "hit" if remote_cache: # Count remote cache hit stats - CompileEventLogger.increment_toplevel( - "inductor_fx_remote_cache_hit_count" + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_hit_count", ) - CompileEventLogger.add_to_set_toplevel( - "inductor_fx_remote_cache_hit_keys", key + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_hit_keys", + key, ) if (time_saved_ns := compiled_graph._time_taken_ns) is not None: cache_info["time_saved_ns"] = time_saved_ns - CompileEventLogger.increment_toplevel( - "distributed_ephemeral_timeout_us", time_saved_ns // 1000 + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "distributed_ephemeral_timeout_us", + time_saved_ns // 1000, ) if ( ephemeral_increase @@ -1326,11 +1336,14 @@ def load_with_key( else: if remote_cache: # Count remote cache miss stats - CompileEventLogger.increment_toplevel( - "inductor_fx_remote_cache_miss_count" + CompileEventLogger.try_( + CompileEventLogger.increment_toplevel, + "inductor_fx_remote_cache_miss_count", ) - CompileEventLogger.add_to_set_toplevel( - "inductor_fx_remote_cache_miss_keys", key + CompileEventLogger.try_( + CompileEventLogger.add_to_set_toplevel, + "inductor_fx_remote_cache_miss_keys", + key, ) log.info("fx graph cache miss for key %s", key) counters["inductor"]["fxgraph_cache_miss"] += 1 From a9e2f22405f2ab6eb716ec22381907d8bb8111ff Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Fri, 4 Apr 2025 01:58:29 +0000 Subject: [PATCH 184/332] [Bugfix] Fix compile error with `torch.Tensor.unsqueeze_` and inplace views called from Tensor Class (#150573) Fixes #129673 ### Summary: Modifying a tensor by reshaping in place (such as `unsqueeze_`) should cause a graph break; however, when accessed through `torch.Tensor` api as opposed to as self attribute caused the code to crash with an error (see attached issue) Paths differed when traced due to the stack variable popped, as: * `self.unsqueeze_` pops a `LazyVariableTracker` which gets resolved to `TensorVariable`, so when looking for the method, triggers the fn call `var_getattr` in `_dynamo/variables/tensor.py`; since this is an inplace view (metadata mutation) on graph input, it is not well supported so should fall back (see [L446](https://github.com/pytorch/pytorch/blob/1017927c83dd95a4be6074c48e0fb38f0a1bd8f3/torch/_dynamo/variables/tensor.py#L446) in that file) * `torch.Tensor.unsqueeze` pops a `UserDefinedClassVariable` so when looking for the method, triggers the fn call `var_getattr` in `_dynamo/variables/user_defined.py` on [L273](https://github.com/pytorch/pytorch/blob/a8f6b40e36bc4afe4e58568620a008c9a8a8704e/torch/_dynamo/variables/user_defined.py#L273). This path tries to build a variable tracker from the obj popped, which resolves to a trace_rule , and as a Tensor method, is resolved to `TorchInGraphFunctionVariable` on [L3767](https://github.com/pytorch/pytorch/blob/a8f6b40e36bc4afe4e58568620a008c9a8a8704e/torch/_dynamo/trace_rules.py#L3767) So, one straightforward option is to check if the fn is an inplace_view on a input tensor in `torch.py` when we resolve the `__call__function` for the `TorchInGraphFunctionVariable` instead, which resolves the bug by providing a graph break ### Test ``` pytest test/dynamo/test_functions.py::FunctionTests::test_unsqueeze_inplace ``` Results in ``` Running 1 items in this shard test/dynamo/test_functions.py . [100%] =========================================================================================== 1 passed in 9.16s ========================================================================================== ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150573 Approved by: https://github.com/anijain2305 --- test/dynamo/test_functions.py | 15 +++++++++++++++ torch/_dynamo/variables/torch.py | 24 +++++++++++++++++++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 09aee481c0cc..54a0d70727b9 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -3821,6 +3821,21 @@ def test_map_unpack_vars(a, b): x, y = map(lambda x: x + 1, [a, b]) return x + y + def test_unsqueeze_inplace(self): + def fn(x): + return torch.Tensor.unsqueeze_(x, dim=1) + 1 + + def self_fn(x): + return x.unsqueeze_(dim=1) + 1 + + v = torch.ones([3], device="cpu") + # identical tensor since modify inplace + v2 = torch.ones([3], device="cpu") + opt_fn = torch.compile(fn) + opt_self_fn = torch.compile(self_fn) + self.assertEqual(v, v2) + self.assertEqual(opt_fn(v), opt_self_fn(v2)) + def test_enumerate_custom(self): class MyClass: def __iter__(self): diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 13aaf715f8ac..8034f440e775 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -64,7 +64,7 @@ proxy_args_kwargs, unwrap_if_wrapper, ) -from .base import VariableTracker +from .base import typestr, VariableTracker from .ctx_manager import ( AutocastModeVariable, ProfilerContextVariable, @@ -1180,6 +1180,28 @@ def patched_fn(*args, **kwargs): ) if self.is_tensor_method(): + name = self.value.__name__ + # Guard against inplace view op on input tensor (not supported) + if args and isinstance(args[0], variables.TensorVariable): + tensor_var = args[0] + # Check if input tensor and inplace_view op specifcally + if tensor_var.source is not None and hasattr(torch.ops.aten, name): + fn = getattr(torch.ops.aten, name) + if ( + hasattr(fn, "overloads") + and hasattr(fn, fn.overloads()[0]) + and torch.Tag.inplace_view + in getattr(fn, fn.overloads()[0]).tags + ): + unimplemented_v2( + gb_type="Inplace op on input tensor", + context="", + explanation=f"Attempted to trace an inplace view op on input tensor {typestr(self.value)}.", + hints=[ + *graph_break_hints.SUPPORTABLE, + "Ensure you do not modify input tensor in place.", + ], + ) return self.call_tensor_method(tx, args, kwargs) special_handler = self._get_handlers().get(self.value) From bd9c42ebfba23f3610bb788fb0f7b18f83549766 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 3 Apr 2025 15:59:40 -0700 Subject: [PATCH 185/332] [c10d] Surface error type when we unlink and create named pipe for DumpPipe (#150648) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150648 Approved by: https://github.com/fegin, https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 6f8b192a1a51..aa0021d7608f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -195,11 +195,15 @@ struct DumpPipe { TORCH_CHECK( unlink(filename.c_str()) != -1 || errno == ENOENT, "Error removing existing named pipe ", - filename); + filename, + ", Error: ", + std::strerror(errno)); TORCH_CHECK( mkfifo(filename.c_str(), 0666) != -1, "Error creating named pipe ", - filename); + filename, + ", Error: ", + std::strerror(errno)); fd_ = open(filename.c_str(), O_RDONLY | O_NONBLOCK); LOG(INFO) << "Pipe file " << filename << " has been opened, write to it to trigger NCCL Debug Dump."; From ed0fd2fa7a1bf827027a0dda5172cea06170ce64 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 2 Apr 2025 14:18:53 -0700 Subject: [PATCH 186/332] clang-format aten/src/ATen/cpu/vec/*.h (#150426) I got a complaint about indentation on #150380. Make the machines fix it for us. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150426 Approved by: https://github.com/aditew01, https://github.com/cyyever, https://github.com/frost-intel, https://github.com/Skylion007 --- .lintrunner.toml | 1 + aten/src/ATen/cpu/vec/functional_base.h | 124 ++-- aten/src/ATen/cpu/vec/functional_bfloat16.h | 244 +++++--- aten/src/ATen/cpu/vec/intrinsics.h | 20 +- aten/src/ATen/cpu/vec/vec.h | 19 +- aten/src/ATen/cpu/vec/vec_base.h | 654 +++++++++++++------- aten/src/ATen/cpu/vec/vec_convert.h | 4 +- aten/src/ATen/cpu/vec/vec_half.h | 19 +- aten/src/ATen/cpu/vec/vec_mask.h | 35 +- 9 files changed, 735 insertions(+), 385 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 376d916e3c65..ec62529d1f49 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -55,6 +55,7 @@ init_command = [ code = 'CLANGFORMAT' include_patterns = [ 'aten/src/ATen/*.h', + 'aten/src/ATen/cpu/vec/*.h', 'aten/src/ATen/mps/**/*.mm', 'aten/src/ATen/mps/**/*.h', 'aten/src/ATen/xpu/**/*.h', diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index 4d1d05ea8d32..e7429d18712d 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -29,16 +29,21 @@ inline scalar_t vec_reduce_all( template struct VecReduceAllSIMD { - static inline scalar_t apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline scalar_t apply( + const Op& vec_fun, + const Vectorized& acc_vec) { return vec_reduce_all(vec_fun, acc_vec, Vectorized::size()); } }; -#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && \ + !defined(C10_MOBILE) #if defined(CPU_CAPABILITY_AVX2) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -57,7 +62,9 @@ struct VecReduceAllSIMD { #if defined(CPU_CAPABILITY_AVX512) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 256-bit shuffle @@ -76,25 +83,33 @@ struct VecReduceAllSIMD { } }; #endif // defined(CPU_CAPABILITY_AVX512) -#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE) +#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && + // !defined(C10_MOBILE) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + !defined(CPU_CAPABILITY_SVE) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; - // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -] + // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, + // a4+a8, a1+a5, a2+a6, -, -, -, -] float32x4_t v1_1 = vextq_f32(v, v, 2); Vec v1 = v1_1; // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] v = vec_fun(v, v1); - // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -] + // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, + // -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, + // -] v1_1 = vrev64q_f32(v); v1 = v1_1; - // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] + // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, + // a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] v = vec_fun(v, v1); return v[0]; @@ -102,10 +117,13 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) -#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256) +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && \ + defined(CPU_CAPABILITY_SVE256) template struct VecReduceAllSIMD { - static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + static inline float apply( + const Op& vec_fun, + const Vectorized& acc_vec) { using Vec = Vectorized; Vec v = acc_vec; // 128-bit shuffle @@ -125,15 +143,21 @@ struct VecReduceAllSIMD { }; #endif // defined(__aarch64__) - template -inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized& acc_vec) { +inline scalar_t vec_reduce_all( + const Op& vec_fun, + const Vectorized& acc_vec) { return VecReduceAllSIMD::apply(vec_fun, acc_vec); } -template , int> = 0> -inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> +inline scalar_t reduce_all( + const Op& vec_fun, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) return vec_reduce_all(vec_fun, Vec::loadu(data, size), size); @@ -151,16 +175,22 @@ inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size } // similar to reduce_all, but reduces into two outputs -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using Vec = vec::Vectorized; if (size < Vec::size()) { auto loaded_data = Vec::loadu(data, size); return std::pair( - vec_reduce_all(vec_fun1, loaded_data, size), - vec_reduce_all(vec_fun2, loaded_data, size)); + vec_reduce_all(vec_fun1, loaded_data, size), + vec_reduce_all(vec_fun2, loaded_data, size)); } int64_t d = Vec::size(); Vec acc_vec1 = Vec::loadu(data); @@ -176,12 +206,14 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d); } return std::pair( - vec_reduce_all(vec_fun1, acc_vec1), - vec_reduce_all(vec_fun2, acc_vec2)); + vec_reduce_all(vec_fun1, acc_vec1), vec_reduce_all(vec_fun2, acc_vec2)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -205,8 +237,11 @@ inline scalar_t map_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -237,8 +272,11 @@ inline scalar_t map2_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline scalar_t map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -274,8 +312,10 @@ inline scalar_t map3_reduce_all( return vec_reduce_all(red_fun, acc_vec); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -293,8 +333,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -317,8 +359,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -344,8 +388,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, diff --git a/aten/src/ATen/cpu/vec/functional_bfloat16.h b/aten/src/ATen/cpu/vec/functional_bfloat16.h index 3bd22b3820f0..d4a40acaeefd 100644 --- a/aten/src/ATen/cpu/vec/functional_bfloat16.h +++ b/aten/src/ATen/cpu/vec/functional_bfloat16.h @@ -8,86 +8,120 @@ namespace at::vec { // BFloat16 specification -template struct VecScalarType { using type = scalar_t; }; -template <> struct VecScalarType { using type = float; }; -template <> struct VecScalarType { using type = float; }; +template +struct VecScalarType { + using type = scalar_t; +}; +template <> +struct VecScalarType { + using type = float; +}; +template <> +struct VecScalarType { + using type = float; +}; // This is different from at::acc_type since we only need to specialize BFloat16 template using vec_scalar_t = typename VecScalarType::type; // Vector conversion between float and bfloat16/half -template , int> = 0> -inline std::tuple, Vectorized> convert_to_float(const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline std::tuple, Vectorized> convert_to_float( + const Vectorized&); template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { +inline std::tuple, Vectorized> convert_to_float< + BFloat16>(const Vectorized& a) { return convert_bfloat16_float(a); } template <> -inline std::tuple, Vectorized> convert_to_float (const Vectorized& a) { - return convert_half_float(a); +inline std::tuple, Vectorized> convert_to_float( + const Vectorized& a) { + return convert_half_float(a); } -template , int> = 0> -inline Vectorized convert_from_float(const Vectorized&, const Vectorized&); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline Vectorized convert_from_float( + const Vectorized&, + const Vectorized&); template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_bfloat16(a, b); } template <> -inline Vectorized convert_from_float(const Vectorized& a, const Vectorized& b) { +inline Vectorized convert_from_float( + const Vectorized& a, + const Vectorized& b) { return convert_float_half(a, b); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out1, Vectorized &out2); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float( + const scalar_t* data, + Vectorized& out1, + Vectorized& out2); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_bf16(data, out1, out2); } template <> -inline void load_to_float (const Half *data, Vectorized &out1, Vectorized &out2) { +inline void load_to_float( + const Half* data, + Vectorized& out1, + Vectorized& out2) { load_fp32_from_fp16(data, out1, out2); } -template , int> = 0> -inline void load_to_float(const scalar_t *data, Vectorized &out); +template < + typename scalar_t, + typename std::enable_if_t, int> = 0> +inline void load_to_float(const scalar_t* data, Vectorized& out); template <> -inline void load_to_float (const BFloat16 *data, Vectorized &out) { +inline void load_to_float( + const BFloat16* data, + Vectorized& out) { load_fp32_from_bf16(data, out); } template <> -inline void load_to_float (const Half *data, Vectorized &out) { +inline void load_to_float(const Half* data, Vectorized& out) { load_fp32_from_fp16(data, out); } -// Note that we already have specialized member of Vectorized for BFloat16 -// so the following functions would run smoothly: +// Note that we already have specialized member of Vectorized for +// BFloat16 so the following functions would run smoothly: // using Vec = Vectorized; // Vec one = Vec(BFloat16(1)); // vec::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); // // Then why we still need to specialize "functional"? -// If we do specialization at Vectorized<> level, the above example would need 3 pairs of -// conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and "/". -// If we do specialization at vec::map<>() level, we have only 1 pair of conversion -// of bf16->fp32/fp32->bf16, for the input and output BFloat16 vector only. +// If we do specialization at Vectorized<> level, the above example would need +// 3 pairs of conversion of bf16->fp32/fp32->bf16, each for ".exp()", "+" and +// "/". If we do specialization at vec::map<>() level, we have only 1 pair of +// conversion of bf16->fp32/fp32->bf16, for the input and output BFloat16 +// vector only. // -// The following BFloat16 functionality will only do data type conversion for input -// and output vector (reduce functionality will only convert the final scalar back to bf16). -// Compared to Vectorized<> specialization, +// The following BFloat16 functionality will only do data type conversion for +// input and output vector (reduce functionality will only convert the final +// scalar back to bf16). Compared to Vectorized<> specialization, // 1. better performance since we have less data type conversion; // 2. less rounding error since immediate results are kept in fp32; // 3. accumulation done on data type of fp32. @@ -95,8 +129,10 @@ inline void load_to_float (const Half *data, Vectorized &out) { // If you plan to extend this file, please ensure adding unit tests at // aten/src/ATen/test/vec_test_all_types.cpp // -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; @@ -104,7 +140,8 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - data_fvec0 = fVec::set(data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, vec_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(vec_fun, data_fvec0, fVec::size()); } else { return vec_reduce_all(vec_fun, data_fvec0, size); @@ -124,27 +161,37 @@ inline float reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) { auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc_fvec0 = vec_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, vec_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { - acc_fvec0 = fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, vec_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = vec_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(vec_fun, acc_fvec0); } -template , int> = 0> -inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2, - const scalar_t* data, int64_t size) { +template < + typename scalar_t, + typename Op1, + typename Op2, + typename std::enable_if_t, int> = 0> +inline std::pair reduce2_all( + const Op1& vec_fun1, + const Op2& vec_fun2, + const scalar_t* data, + int64_t size) { using bVec = vec::Vectorized; using fVec = vec::Vectorized; if (size < bVec::size()) { bVec data_bvec = bVec::loadu(data, size); auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size > fVec::size()) { - fVec acc1_fvec = fVec::set(data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); - fVec acc2_fvec = fVec::set(data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc1_fvec = fVec::set( + data_fvec0, vec_fun1(data_fvec0, data_fvec1), size - fVec::size()); + fVec acc2_fvec = fVec::set( + data_fvec0, vec_fun2(data_fvec0, data_fvec1), size - fVec::size()); return std::pair( vec_reduce_all(vec_fun1, acc1_fvec, fVec::size()), vec_reduce_all(vec_fun2, acc2_fvec, fVec::size())); @@ -171,12 +218,20 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f auto [data_fvec0, data_fvec1] = convert_to_float(data_bvec); if (size - d > fVec::size()) { acc1_fvec0 = vec_fun1(acc1_fvec0, data_fvec0); - acc1_fvec1 = fVec::set(acc1_fvec1, vec_fun1(acc1_fvec1, data_fvec1), size - d - fVec::size()); + acc1_fvec1 = fVec::set( + acc1_fvec1, + vec_fun1(acc1_fvec1, data_fvec1), + size - d - fVec::size()); acc2_fvec0 = vec_fun2(acc2_fvec0, data_fvec0); - acc2_fvec1 = fVec::set(acc2_fvec1, vec_fun2(acc2_fvec1, data_fvec1), size - d - fVec::size()); + acc2_fvec1 = fVec::set( + acc2_fvec1, + vec_fun2(acc2_fvec1, data_fvec1), + size - d - fVec::size()); } else { - acc1_fvec0 = fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); - acc2_fvec0 = fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); + acc1_fvec0 = + fVec::set(acc1_fvec0, vec_fun1(acc1_fvec0, data_fvec0), size - d); + acc2_fvec0 = + fVec::set(acc2_fvec0, vec_fun2(acc2_fvec0, data_fvec0), size - d); } } acc1_fvec0 = vec_fun1(acc1_fvec0, acc1_fvec1); @@ -186,8 +241,11 @@ inline std::pair reduce2_all(const Op1& vec_fun1, const Op2& vec_f vec_reduce_all(vec_fun2, acc2_fvec0)); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -201,7 +259,8 @@ inline float map_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); @@ -228,18 +287,23 @@ inline float map_reduce_all( data_fvec0 = map_fun(data_fvec0); data_fvec1 = map_fun(data_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map2_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -256,7 +320,8 @@ inline float map2_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); @@ -289,18 +354,23 @@ inline float map2_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename MapOp, + typename ReduceOp, + typename std::enable_if_t, int> = 0> inline float map3_reduce_all( const MapOp& map_fun, const ReduceOp& red_fun, @@ -320,7 +390,8 @@ inline float map3_reduce_all( if (size > fVec::size()) { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); - data_fvec0 = fVec::set(data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); + data_fvec0 = fVec::set( + data_fvec0, red_fun(data_fvec0, data_fvec1), size - fVec::size()); return vec_reduce_all(red_fun, data_fvec0, fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); @@ -359,18 +430,22 @@ inline float map3_reduce_all( data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); data_fvec1 = map_fun(data_fvec1, data2_fvec1, data3_fvec1); acc_fvec0 = red_fun(acc_fvec0, data_fvec0); - acc_fvec1 = fVec::set(acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); + acc_fvec1 = fVec::set( + acc_fvec1, red_fun(acc_fvec1, data_fvec1), size - d - fVec::size()); } else { data_fvec0 = map_fun(data_fvec0, data2_fvec0, data3_fvec0); - acc_fvec0 = fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); + acc_fvec0 = + fVec::set(acc_fvec0, red_fun(acc_fvec0, data_fvec0), size - d); } } acc_fvec0 = red_fun(acc_fvec0, acc_fvec1); return vec_reduce_all(red_fun, acc_fvec0); } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -397,8 +472,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map( const Op& vec_fun, scalar_t* output_data, @@ -419,7 +496,8 @@ inline void map( fVec data_fvec0, data_fvec1; if (size - d > fVec::size()) { data_fvec0 = fVec::loadu(input_data + d); - data_fvec1 = fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); + data_fvec1 = + fVec::loadu(input_data + d + fVec::size(), size - d - fVec::size()); } else { // choose to align with behaviour of bVec::loadu(ptr, size), // which leaves data_fvec1 uninitialized @@ -432,8 +510,10 @@ inline void map( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map2( const Op& vec_fun, scalar_t* output_data, @@ -465,8 +545,10 @@ inline void map2( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map3( const Op& vec_fun, scalar_t* output_data, @@ -503,8 +585,10 @@ inline void map3( } } -template , int> = 0> +template < + typename scalar_t, + typename Op, + typename std::enable_if_t, int> = 0> inline void map4( const Op& vec_fun, scalar_t* output_data, @@ -525,8 +609,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d); } @@ -539,8 +625,10 @@ inline void map4( auto [data3_fvec0, data3_fvec1] = convert_to_float(data3_bvec); bVec data4_bvec = bVec::loadu(input_data4 + d, size - d); auto [data4_fvec0, data4_fvec1] = convert_to_float(data4_bvec); - fVec output_fvec0 = vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); - fVec output_fvec1 = vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); + fVec output_fvec0 = + vec_fun(data1_fvec0, data2_fvec0, data3_fvec0, data4_fvec0); + fVec output_fvec1 = + vec_fun(data1_fvec1, data2_fvec1, data3_fvec1, data4_fvec1); bVec output_bvec = convert_from_float(output_fvec0, output_fvec1); output_bvec.store(output_data + d, size - d); } diff --git a/aten/src/ATen/cpu/vec/intrinsics.h b/aten/src/ATen/cpu/vec/intrinsics.h index 48b18793b079..f9086f7d3d0b 100644 --- a/aten/src/ATen/cpu/vec/intrinsics.h +++ b/aten/src/ATen/cpu/vec/intrinsics.h @@ -13,10 +13,14 @@ /* Microsoft C/C++-compatible compiler */ #include #if _MSC_VER <= 1900 -#define _mm256_extract_epi64(X, Y) (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) -#define _mm256_extract_epi32(X, Y) (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) -#define _mm256_extract_epi16(X, Y) (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) -#define _mm256_extract_epi8(X, Y) (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) +#define _mm256_extract_epi64(X, Y) \ + (_mm_extract_epi64(_mm256_extractf128_si256(X, Y >> 1), Y % 2)) +#define _mm256_extract_epi32(X, Y) \ + (_mm_extract_epi32(_mm256_extractf128_si256(X, Y >> 2), Y % 4)) +#define _mm256_extract_epi16(X, Y) \ + (_mm_extract_epi16(_mm256_extractf128_si256(X, Y >> 3), Y % 8)) +#define _mm256_extract_epi8(X, Y) \ + (_mm_extract_epi8(_mm256_extractf128_si256(X, Y >> 4), Y % 16)) #endif #elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__)) /* GCC-compatible compiler, targeting ARM with NEON */ @@ -25,9 +29,9 @@ /* GCC-compatible compiler, targeting ARM with SVE */ #include #endif -#if defined (MISSING_ARM_VLD1) +#if defined(MISSING_ARM_VLD1) #include -#elif defined (MISSING_ARM_VST1) +#elif defined(MISSING_ARM_VST1) #include #endif #elif defined(__GNUC__) && defined(__IWMMXT__) @@ -36,8 +40,8 @@ #elif defined(__s390x__) // targets Z/architecture // we will include vecintrin later -#elif (defined(__GNUC__) || defined(__xlC__)) && \ - (defined(__VEC__) || defined(__ALTIVEC__)) +#elif (defined(__GNUC__) || defined(__xlC__)) && \ + (defined(__VEC__) || defined(__ALTIVEC__)) /* XLC or GCC-compatible compiler, targeting PowerPC with VMX/VSX */ #include /* We need to undef those tokens defined by to avoid conflicts diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index e4b0c4b95d84..0bfe65cd1959 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -28,21 +28,30 @@ inline Vectorized Vectorized::loadu(const void* ptr) { } template <> -inline Vectorized Vectorized::loadu(const void* ptr, int64_t count) { +inline Vectorized Vectorized::loadu( + const void* ptr, + int64_t count) { // See NOTE [Loading boolean values] return convert_to_bool(Vectorized::loadu(ptr, count)); } template -struct VecHoldType { using hold_type = typename VT::value_type; }; +struct VecHoldType { + using hold_type = typename VT::value_type; +}; template <> -struct VecHoldType> { using hold_type = BFloat16; }; +struct VecHoldType> { + using hold_type = BFloat16; +}; template <> -struct VecHoldType> {using hold_type = Half; }; +struct VecHoldType> { + using hold_type = Half; +}; template using vechold_type = typename VecHoldType::hold_type; -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 2591338881ae..3e6124cbc500 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -1,5 +1,6 @@ #pragma once -#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && defined(__ARM_FEATURE_SVE) +#if defined(__GNUC__) && __GNUC__ == 10 && __GNUC_MINOR__ <= 2 && \ + defined(__ARM_FEATURE_SVE) // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117161 #pragma GCC optimize("no-tree-vectorize") #endif @@ -18,27 +19,27 @@ // See https://github.com/pytorch/pytorch/issues/37577 for an instance // of this bug in the past. -#include #include +#include #include +#include +#include #include #include -#include #include -#include +#include #include #include -#include -#include -#include -#include -#include #include -#include #include -#include +#include +#include +#include #include +#include +#include +#include #if defined(__GNUC__) #define __FORCE_INLINE __attribute__((always_inline)) inline @@ -66,7 +67,8 @@ Windows llvm will not have this definition. #endif #define VECTOR_WIDTH 64 #define int_vector __m512i -#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +#elif defined(__aarch64__) && \ + !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 // SVE code expects 256-vectors; leave that set for SVE? #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(16))) @@ -93,40 +95,43 @@ namespace at::vec { inline namespace CPU_CAPABILITY { // at::Half and at::BFloat16 should be treated as floating point template -struct is_floating_point: - std::integral_constant || - std::is_same_v || - std::is_same_v> { -}; +struct is_floating_point + : std::integral_constant< + bool, + std::is_floating_point_v || std::is_same_v || + std::is_same_v> {}; -template +template constexpr bool is_floating_point_v = is_floating_point::value; template -struct is_reduced_floating_point: - std::integral_constant || - std::is_same_v> { -}; +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> {}; template -constexpr bool is_reduced_floating_point_v = is_reduced_floating_point::value; +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; template -struct is_8bit_integer: - std::integral_constant || - std::is_same_v> { +struct is_8bit_integer + : std::integral_constant< + bool, + std::is_same_v || std::is_same_v> { }; template constexpr bool is_8bit_integer_v = is_8bit_integer::value; -template struct int_of_size; +template +struct int_of_size; -#define DEFINE_INT_OF_SIZE(int_t) \ -template<> struct int_of_size { using type = int_t; } +#define DEFINE_INT_OF_SIZE(int_t) \ + template <> \ + struct int_of_size { \ + using type = int_t; \ + } DEFINE_INT_OF_SIZE(int64_t); DEFINE_INT_OF_SIZE(int32_t); @@ -142,14 +147,15 @@ using int_same_size_t = typename int_of_size::type; // emulates Vectorized types #if defined(__s390x__) -template +template #else template #endif struct Vectorized { -private: + private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; -public: + + public: using value_type = T; using size_type = int; @@ -163,11 +169,11 @@ struct Vectorized { values[i] = val; } } - template> - Vectorized(Args... vals) : values{vals...}{ - } - Vectorized(const T(&arr)[kSize]) { + template < + typename... Args, + typename = std::enable_if_t<(sizeof...(Args) == size())>> + Vectorized(Args... vals) : values{vals...} {} + Vectorized(const T (&arr)[kSize]) { std::memcpy(values, arr, sizeof(values)); } // This also implies const T& operator[](int idx) const @@ -198,20 +204,23 @@ struct Vectorized { } // Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 #if __GNUC__ <= 12 && !defined(__clang__) && defined(__ARM_FEATURE_SVE) - static Vectorized __attribute__ ((optimize("-fno-tree-loop-vectorize"))) blendv(const Vectorized& a, + static Vectorized __attribute__((optimize("-fno-tree-loop-vectorize"))) + blendv( + const Vectorized& a, #else - static Vectorized blendv(const Vectorized& a, + static Vectorized blendv( + const Vectorized& a, #endif - const Vectorized& b, const Vectorized& mask) { + const Vectorized& b, + const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); #if defined(__clang__) && __ARM_FEATURE_SVE - #pragma clang loop vectorize(disable) +#pragma clang loop vectorize(disable) #endif for (const auto i : c10::irange(size())) { - if (buffer[i] & 0x01) - { + if (buffer[i] & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; @@ -219,15 +228,21 @@ struct Vectorized { } return vector; } - template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) - static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { + template // step sometimes requires a higher precision type + // (e.g., T=int, step_t=double) + static Vectorized arange( + T base = static_cast(0), + step_t step = static_cast(1)) { Vectorized vector; for (const auto i : c10::irange(size())) { vector.values[i] = base + i * step; } return vector; } - static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { + static Vectorized set( + const Vectorized& a, + const Vectorized& b, + int64_t count = size()) { Vectorized vector; for (const auto i : c10::irange(size())) { if (i < count) { @@ -249,7 +264,9 @@ struct Vectorized { return vector; } static Vectorized loadu_one_fourth(const void* ptr) { - static_assert(std::is_same_v || std::is_same_v, "For byte types only"); + static_assert( + std::is_same_v || std::is_same_v, + "For byte types only"); return Vectorized::loadu(ptr, 8); } @@ -257,9 +274,10 @@ struct Vectorized { std::memcpy(ptr, values, count * sizeof(T)); } int zero_mask() const { - // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit + // returns an integer mask where all zero elements are translated to 1-bit + // and others are translated to 0-bit int mask = 0; - for (int i = 0; i < size(); ++ i) { + for (int i = 0; i < size(); ++i) { if (values[i] == static_cast(0)) { mask |= (1 << i); } @@ -279,15 +297,18 @@ struct Vectorized { } bool has_inf_nan() const { for (int64_t i = 0; i != size(); i++) { - if(_isnan(values[i]) || _isinf(values[i])) { + if (_isnan(values[i]) || _isinf(values[i])) { return true; } } return false; } -// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows Arm64 -// See https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 -#if defined(_WIN32) && defined(__aarch64__) && ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) +// MSVC versions between 14.36 and 14.42 has a loop unrolling bug on Windows +// Arm64 +// See +// https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) && \ + ((_MSVC_VER >= 1936) && (_MSVC_VER <= 1942)) Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i < size(); i++) { @@ -322,38 +343,44 @@ struct Vectorized { return ret; } #endif - Vectorized map(T (*const f)(const T &)) const { + Vectorized map(T (*const f)(const T&)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } - T reduce(T (*const f)(const T &)) const { + T reduce(T (*const f)(const T&)) const { T ret = 0; for (int64_t i = 0; i != size(); i++) { ret = f(ret, values[i]); } return ret; } - template && !c10::is_complex::value, int> = 0> + template < + typename other_t_abs = T, + typename std::enable_if_t< + !is_floating_point_v && + !c10::is_complex::value, + int> = 0> Vectorized abs() const { // other_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_abs must be T"); return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } - template , int> = 0> + template < + typename float_t_abs = T, + typename std::enable_if_t, int> = 0> Vectorized abs() const { // float_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "float_t_abs must be T"); - // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in - // 0.0) properly. + // Specifically deal with floating-point because the generic code above + // won't handle -0.0 (which should result in 0.0) properly. return map([](T x) -> T { return std::abs(x); }); } - template ::value, int> = 0> + template < + typename complex_t_abs = T, + typename std::enable_if_t::value, int> = 0> Vectorized abs() const { // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "complex_t_abs must be T"); @@ -361,66 +388,85 @@ struct Vectorized { return map([](T x) { return static_cast(std::abs(x)); }); } - template ::value, int> = 0> + template < + typename other_t_sgn = T, + typename std::enable_if_t::value, int> = 0> Vectorized sgn() const { return map(at::native::sgn_impl); } - template ::value, int> = 0> + template < + typename other_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_angle must be T"); - return map(at::native::angle_impl); // compiler is unable to resolve the overload without + return map(at::native::angle_impl); // compiler is unable to resolve the + // overload without } - template ::value, int> = 0> + template < + typename complex_t_angle = T, + typename std::enable_if_t::value, int> = + 0> Vectorized angle() const { // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_angle must be T"); + static_assert( + std::is_same_v, "complex_t_angle must be T"); return map([](T x) { return static_cast(std::arg(x)); }); } - template ::value, int> = 0> + template < + typename other_t_real = T, + typename std::enable_if_t::value, int> = 0> Vectorized real() const { // other_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_real must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_real = T, + typename std::enable_if_t::value, int> = + 0> Vectorized real() const { // complex_t_real is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_real must be T"); + static_assert( + std::is_same_v, "complex_t_real must be T"); return map([](T x) { return static_cast(x.real()); }); } - template ::value, int> = 0> + template < + typename other_t_imag = T, + typename std::enable_if_t::value, int> = 0> Vectorized imag() const { // other_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_imag must be T"); return Vectorized(0); } - template ::value, int> = 0> + template < + typename complex_t_imag = T, + typename std::enable_if_t::value, int> = + 0> Vectorized imag() const { // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_imag must be T"); + static_assert( + std::is_same_v, "complex_t_imag must be T"); return map([](T x) { return static_cast(x.imag()); }); } - template ::value, int> = 0> + template < + typename other_t_conj = T, + typename std::enable_if_t::value, int> = 0> Vectorized conj() const { // other_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_conj must be T"); return *this; } - template ::value, int> = 0> + template < + typename complex_t_conj = T, + typename std::enable_if_t::value, int> = + 0> Vectorized conj() const { // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_conj must be T"); + static_assert( + std::is_same_v, "complex_t_conj must be T"); return map([](T x) { return static_cast(std::conj(x)); }); } Vectorized acos() const { @@ -441,7 +487,7 @@ struct Vectorized { Vectorized atanh() const { return map(std::atanh); } - Vectorized atan2(const Vectorized &exp) const { + Vectorized atan2(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::atan2(values[i], exp[i]); @@ -449,9 +495,9 @@ struct Vectorized { return ret; } template < - typename U = T, - typename std::enable_if_t, int> = 0> - Vectorized copysign(const Vectorized &sign) const { + typename U = T, + typename std::enable_if_t, int> = 0> + Vectorized copysign(const Vectorized& sign) const { Vectorized ret; for (size_type i = 0; i < size(); i++) { ret[i] = c10::copysign(values[i], sign[i]); @@ -483,8 +529,8 @@ struct Vectorized { return *this - this->trunc(); } template < - typename U = T, - typename std::enable_if_t, int> = 0> + typename U = T, + typename std::enable_if_t, int> = 0> Vectorized fmod(const Vectorized& q) const { // U is for SFINAE purposes only. Make sure it is not changed. static_assert(std::is_same_v, "U must be T"); @@ -503,20 +549,24 @@ struct Vectorized { Vectorized log1p() const { return map(std::log1p); } - template ::value, int> = 0> + template < + typename other_t_log2 = T, + typename std::enable_if_t::value, int> = 0> Vectorized log2() const { // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same_v, "other_t_log2 must be T"); return map(std::log2); } - template ::value, int> = 0> + template < + typename complex_t_log2 = T, + typename std::enable_if_t::value, int> = + 0> Vectorized log2() const { // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. - static_assert(std::is_same_v, "complex_t_log2 must be T"); + static_assert( + std::is_same_v, "complex_t_log2 must be T"); const T log_2 = T(std::log(2.0)); - return Vectorized(map(std::log))/Vectorized(log_2); + return Vectorized(map(std::log)) / Vectorized(log_2); } Vectorized ceil() const { return map(at::native::ceil_impl); @@ -530,7 +580,7 @@ struct Vectorized { Vectorized floor() const { return map(at::native::floor_impl); } - Vectorized hypot(const Vectorized &b) const { + Vectorized hypot(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::hypot(values[i], b[i]); @@ -546,14 +596,14 @@ struct Vectorized { Vectorized digamma() const { return map(calc_digamma); } - Vectorized igamma(const Vectorized &x) const { + Vectorized igamma(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igamma(values[i], x[i]); } return ret; } - Vectorized igammac(const Vectorized &x) const { + Vectorized igammac(const Vectorized& x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igammac(values[i], x[i]); @@ -566,7 +616,7 @@ struct Vectorized { // promotion return map([](T x) -> T { return -x; }); } - Vectorized nextafter(const Vectorized &b) const { + Vectorized nextafter(const Vectorized& b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::nextafter(values[i], b[i]); @@ -574,7 +624,8 @@ struct Vectorized { return ret; } Vectorized round() const { - // We do not use std::round because we would like to round midway numbers to the nearest even integer. + // We do not use std::round because we would like to round midway numbers to + // the nearest even integer. return map(at::native::round_impl); } Vectorized sin() const { @@ -604,20 +655,21 @@ struct Vectorized { Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); } - Vectorized pow(const Vectorized &exp) const { + Vectorized pow(const Vectorized& exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::pow(values[i], exp[i]); } return ret; } - T reduce_add() const { + T reduce_add() const { return reduce([](T x, T y) -> T { return x + y; }); } T reduce_max() const { return reduce(std::max); } -private: + + private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. @@ -632,35 +684,61 @@ struct Vectorized { return vector; } -public: - Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); } - Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); } - Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); } - Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); } - Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); } - Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); } + public: + Vectorized operator==(const Vectorized& other) const { + return binary_pred(other, std::equal_to()); + } + Vectorized operator!=(const Vectorized& other) const { + return binary_pred(other, std::not_equal_to()); + } + Vectorized operator>=(const Vectorized& other) const { + return binary_pred(other, std::greater_equal()); + } + Vectorized operator<=(const Vectorized& other) const { + return binary_pred(other, std::less_equal()); + } + Vectorized operator>(const Vectorized& other) const { + return binary_pred(other, std::greater()); + } + Vectorized operator<(const Vectorized& other) const { + return binary_pred(other, std::less()); + } -private: + private: template - inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { + inline Vectorized binary_pred_bool(const Vectorized& other, Op op) + const { // 1 if the pred is true, otherwise 0. Vectorized vector; - for (int i = 0; i != size(); ++ i) { + for (int i = 0; i != size(); ++i) { vector[i] = static_cast(op(values[i], other.values[i])); } return vector; } -public: - Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); } - Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); } - Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); } - Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); } - Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); } - Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); } + public: + Vectorized eq(const Vectorized& other) const { + return binary_pred_bool(other, std::equal_to()); + } + Vectorized ne(const Vectorized& other) const { + return binary_pred_bool(other, std::not_equal_to()); + } + Vectorized gt(const Vectorized& other) const { + return binary_pred_bool(other, std::greater()); + } + Vectorized ge(const Vectorized& other) const { + return binary_pred_bool(other, std::greater_equal()); + } + Vectorized lt(const Vectorized& other) const { + return binary_pred_bool(other, std::less()); + } + Vectorized le(const Vectorized& other) const { + return binary_pred_bool(other, std::less_equal()); + } }; -template Vectorized inline operator+(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] + b[i]; @@ -668,7 +746,8 @@ template Vectorized inline operator+(const Vectorized &a, const return c; } -template Vectorized inline operator-(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] - b[i]; @@ -676,7 +755,8 @@ template Vectorized inline operator-(const Vectorized &a, const return c; } -template Vectorized inline operator*(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] * b[i]; @@ -684,7 +764,9 @@ template Vectorized inline operator*(const Vectorized &a, const return c; } -template Vectorized inline operator/(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +template +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] / b[i]; @@ -692,14 +774,16 @@ template Vectorized inline operator/(const Vectorized &a, const return c; } -template , int> = 0> -Vectorized inline operator%(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { +template , int> = 0> +Vectorized inline operator%(const Vectorized& a, const Vectorized& b) + __ubsan_ignore_float_divide_by_zero__ { return a - a / b * b; } -template Vectorized inline operator||( - const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator||( + const Vectorized& a, + const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] || b[i]; @@ -709,9 +793,10 @@ template Vectorized inline operator||( // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; @@ -725,9 +810,10 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; @@ -743,9 +829,10 @@ Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; @@ -759,9 +846,10 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; @@ -775,9 +863,13 @@ Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { return c; } -template ::value, int> = 0> -Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, const Vectorized &max_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp( + const Vectorized& a, + const Vectorized& min_vec, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); @@ -785,9 +877,12 @@ Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, return c; } -template ::value, int> = 0> -Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_max( + const Vectorized& a, + const Vectorized& max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; @@ -795,9 +890,12 @@ Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_ return c; } -template ::value, int> = 0> -Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_vec) { +template < + class T, + typename std::enable_if_t::value, int> = 0> +Vectorized inline clamp_min( + const Vectorized& a, + const Vectorized& min_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; @@ -809,14 +907,21 @@ struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { int_vector buffer; #if defined(CPU_CAPABILITY_AVX2) - int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm256_load_si256(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm256_load_si256(reinterpret_cast((const T*)b)); #elif defined(CPU_CAPABILITY_AVX512) - int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); - int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); + int_vector a_buffer = + _mm512_load_si512(reinterpret_cast((const T*)a)); + int_vector b_buffer = + _mm512_load_si512(reinterpret_cast((const T*)b)); #endif buffer = op(a_buffer, b_buffer); __at_align__ T results[Vectorized::size()]; @@ -829,31 +934,52 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(results); } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline + // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline + // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); #endif } -template>::value, int> = 0> +template < + class T, + typename std::enable_if_t< + !std::is_base_of>::value, + int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { - // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline + // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is + // always_inline #if defined(CPU_CAPABILITY_AVX2) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) - return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); + return bitwise_binary_op( + a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); #endif } @@ -866,12 +992,19 @@ auto load(char const* data) -> T { return ret; } -template -static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { +template +static inline Vectorized bitwise_binary_op( + const Vectorized& a, + const Vectorized& b, + Op op) { static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); __at_align__ intmax_t buffer[element_no]; - static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); - static_assert(sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)"); + static_assert( + VECTOR_WIDTH % sizeof(intmax_t) == 0, + "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); + static_assert( + sizeof(buffer) == sizeof(Vectorized), + "sizeof(buffer) must match sizeof(Vectorized)"); // We should be using memcpy in order to respect the strict aliasing rule // see: https://github.com/pytorch/pytorch/issues/66119 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 @@ -889,34 +1022,50 @@ static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vect return Vectorized::loadu(buffer); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_and()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_or()); } -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) -template>, int> = 0> +template < + class T, + typename std:: + enable_if_t>, int> = 0> inline Vectorized operator~(const Vectorized& a) { using int_t = int_same_size_t; - Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 + Vectorized ones(c10::bit_cast((int_t)(~(int_t)0))); // All bits are 1 return a ^ ones; } -template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator<<( + const Vectorized& a, + const Vectorized& b) { constexpr T max_shift = sizeof(T) * CHAR_BIT; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = 0; } else { c[i] = static_cast>(a[i]) << shift; @@ -925,13 +1074,17 @@ template Vectorized inline operator<<(const Vectorized &a, const return c; } -template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { +template +Vectorized inline operator>>( + const Vectorized& a, + const Vectorized& b) { // right shift value to retain sign bit for signed and no bits for unsigned constexpr T max_shift = sizeof(T) * CHAR_BIT - std::is_signed_v; Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { T shift = b[i]; - if ((static_cast>(shift) < 0) || (shift >= max_shift)) { + if ((static_cast>(shift) < 0) || + (shift >= max_shift)) { c[i] = a[i] >> max_shift; } else { c[i] = a[i] >> shift; @@ -941,50 +1094,56 @@ template Vectorized inline operator>>(const Vectorized &a, const } template -inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator+=(Vectorized& a, const Vectorized& b) { a = a + b; return a; } template -inline Vectorized& operator -= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator-=(Vectorized& a, const Vectorized& b) { a = a - b; return a; } template -inline Vectorized& operator /= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator/=(Vectorized& a, const Vectorized& b) { a = a / b; return a; } template -inline Vectorized& operator %= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator%=(Vectorized& a, const Vectorized& b) { a = a % b; return a; } template -inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator*=(Vectorized& a, const Vectorized& b) { a = a * b; return a; } template -inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator<<=(Vectorized& a, const Vectorized& b) { a = a << b; return a; } template -inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { +inline Vectorized& operator>>=(Vectorized& a, const Vectorized& b) { a = a >> b; return a; } template -inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmadd( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b + c; } template -inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { +inline Vectorized fmsub( + const Vectorized& a, + const Vectorized& b, + const Vectorized& c) { return a * b - c; } @@ -1000,8 +1159,10 @@ Vectorized inline operator&&( } template -std::enable_if_t> -inline gather(T const* base_addr, const Vectorized>& vindex) { +std::enable_if_t< + scale == 1 || scale == 2 || scale == 4 || scale == 8, + Vectorized< + T>> inline gather(T const* base_addr, const Vectorized>& vindex) { static constexpr int size = Vectorized::size(); int_same_size_t index_arr[size]; vindex.store(static_cast(index_arr)); @@ -1013,36 +1174,39 @@ inline gather(T const* base_addr, const Vectorized>& vindex) } template -std::enable_if_t> -inline mask_gather(const Vectorized& src, T const* base_addr, - const Vectorized>& vindex, Vectorized& mask) { +std:: + enable_if_t> inline mask_gather( + const Vectorized& src, + T const* base_addr, + const Vectorized>& vindex, + Vectorized& mask) { static constexpr int size = Vectorized::size(); T src_arr[size]; - int_same_size_t mask_arr[size]; // use int type so we can logical and + int_same_size_t mask_arr[size]; // use int type so we can logical and int_same_size_t index_arr[size]; src.store(static_cast(src_arr)); mask.store(static_cast(mask_arr)); vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { - if (mask_arr[i] & 0x01) { // check highest bit + if (mask_arr[i] & 0x01) { // check highest bit buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } else { buffer[i] = src_arr[i]; } } - mask = Vectorized(static_cast(0)); // "zero out" mask + mask = Vectorized(static_cast(0)); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } // Cast a given vector to another type without changing the bits representation. // So a Vectorized of 512 bits containing all ones can be cast to a -// Vectorized of 512 bits containing all ones (i.e., eight negative 1s). -// A Vec of 256 bits containing all ones can be cast to a +// Vectorized of 512 bits containing all ones (i.e., eight negative +// 1s). A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). // There is a struct here because we don't have static_if and I can't // partially specialize a templated function. -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { src_t src_arr[Vectorized::size()]; @@ -1051,44 +1215,51 @@ struct CastImpl { } }; -template +template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { return src; } }; -template +template inline Vectorized cast(const Vectorized& src) { return CastImpl::apply(src); } template > -inline Vectorized convert_to_int_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_int_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const T& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const T& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } template > -inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) { +inline Vectorized convert_to_fp_of_same_size( + const Vectorized& src) { static_assert(sizeof(T) == sizeof(IntType)); static constexpr int size = Vectorized::size(); std::array src_arr; src.store(static_cast(src_arr.data())); std::array buffer; - std::transform(src_arr.cbegin(), src_arr.cend(), buffer.begin(), - [](const IntType& x) { return static_cast(x); }); + std::transform( + src_arr.cbegin(), src_arr.cend(), buffer.begin(), [](const IntType& x) { + return static_cast(x); + }); return Vectorized::loadu(static_cast(buffer.data())); } +// clang-format off // Example inputs for AVX512: // a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} @@ -1099,8 +1270,11 @@ inline Vectorized convert_to_fp_of_same_size(const Vectorized& src) // b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1116,10 +1290,12 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { buffer2[i] = a_arr[i * 2 + 1]; buffer2[half_size + i] = b_arr[i * 2 + 1]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } +// clang-format off // inverse operation of deinterleave2 // Example inputs for AVX512: // a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} @@ -1131,8 +1307,11 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} +// clang-format on template -inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> +inline std::enable_if_t< + Vectorized::size() % 2 == 0, + std::pair, Vectorized>> interleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; @@ -1148,14 +1327,15 @@ interleave2(const Vectorized& a, const Vectorized& b) { buffer2[i * 2] = a_arr[half_size + i]; buffer2[i * 2 + 1] = b_arr[half_size + i]; } - return std::make_pair(Vectorized::loadu(static_cast(buffer1)), - Vectorized::loadu(static_cast(buffer2))); + return std::make_pair( + Vectorized::loadu(static_cast(buffer1)), + Vectorized::loadu(static_cast(buffer2))); } template -inline void convert(const src_T *src, dst_T *dst, int64_t n) { +inline void convert(const src_T* src, dst_T* dst, int64_t n) { #ifndef _MSC_VER -# pragma unroll +#pragma unroll #endif for ([[maybe_unused]] const auto i : c10::irange(n)) { *dst = c10::convert(c10::load(src)); @@ -1165,7 +1345,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { } template -inline Vectorized flip(const Vectorized & data) { +inline Vectorized flip(const Vectorized& data) { static constexpr int size = Vectorized::size(); T output[size]; T buffer[size]; @@ -1176,25 +1356,37 @@ inline Vectorized flip(const Vectorized & data) { return Vectorized::loadu(static_cast(output)); } -// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading -// dimension of `src` and `ld_dst` is the leading dimension of `dst`. +// Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. +// `ld_src` is the leading dimension of `src` and `ld_dst` is the leading +// dimension of `dst`. template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst, int M, int N) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst, + int M, + int N) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { - dst[j*ld_dst + i] = src[i*ld_src + j]; + dst[j * ld_dst + i] = src[i * ld_src + j]; } } } template -inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { +inline void transpose_mxn( + const T* src, + int64_t ld_src, + T* dst, + int64_t ld_dst) { transpose_mxn(src, ld_src, dst, ld_dst, M, N); } -}} // namespace at::vec::CPU_CAPABILITY +} // namespace CPU_CAPABILITY +} // namespace at::vec // additional headers for more operations that depend on vec_base -#include -#include #include +#include +#include diff --git a/aten/src/ATen/cpu/vec/vec_convert.h b/aten/src/ATen/cpu/vec/vec_convert.h index a5cee03dabcf..f5e5177908c1 100644 --- a/aten/src/ATen/cpu/vec/vec_convert.h +++ b/aten/src/ATen/cpu/vec/vec_convert.h @@ -28,8 +28,8 @@ struct VecConvert { }; template -inline std::enable_if_t, Vectorized> -convert(const Vectorized& src) { +inline std::enable_if_t, Vectorized> convert( + const Vectorized& src) { return src; } diff --git a/aten/src/ATen/cpu/vec/vec_half.h b/aten/src/ATen/cpu/vec/vec_half.h index c7c90cc95b47..972d3ee3929b 100644 --- a/aten/src/ATen/cpu/vec/vec_half.h +++ b/aten/src/ATen/cpu/vec/vec_half.h @@ -103,7 +103,9 @@ static inline void transpose_pad_2x32_block( _mm512_storeu_si512(reinterpret_cast<__m512i*>(dst + 32), d1); } #else -TORCH_CHECK(false, "transpose_pad_2x32_block is only supported when avx512 is supported") + TORCH_CHECK( + false, + "transpose_pad_2x32_block is only supported when avx512 is supported") #endif } @@ -124,28 +126,31 @@ static inline void pack_vnni2( for (; bk < _K; bk += 2) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 2, nrem); } } if (K % 2 == 1) { int64_t bn = 0; for (; bn < _N; bn += 32) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1); } int64_t nrem = N - bn; if (nrem > 0) { - transpose_pad_2x32_block(src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); + transpose_pad_2x32_block( + src + bk * ld_src + bn, dst + bk * N + bn * 2, ld_src, 1, nrem); } } #else -TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") + TORCH_CHECK(false, "pack_vnni2 is only supported when avx512 is supported") #endif } - } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index c547e5911ecb..e19d7f75388a 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -68,7 +68,12 @@ struct VecMaskTo { } }; -template +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + typename Enabled = void> struct VecMaskCast { static inline VecMask apply( const VecMask& vec_mask) { @@ -88,15 +93,17 @@ struct VecMaskCheck { static inline bool all_zero(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m == static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m == static_cast(0); + }); } static inline bool all_masked(const VectorizedN& vec_mask) { __at_align__ T mask[VectorizedN::size()]; vec_mask.store(mask); - return std::all_of( - mask, mask + VectorizedN::size(), [](T m) { return m != static_cast(0); }); + return std::all_of(mask, mask + VectorizedN::size(), [](T m) { + return m != static_cast(0); + }); } static inline bool is_masked(const VectorizedN& vec_mask, int i) { @@ -159,13 +166,11 @@ class VecMask { } static VecMask blendv( - const VecMask& c, - const VecMask& b, - const VecMask& a) { + const VecMask& c, + const VecMask& b, + const VecMask& a) { VectorizedN result = VectorizedN::blendv( - VectorizedN(c), - VectorizedN(b), - VectorizedN(a)); + VectorizedN(c), VectorizedN(b), VectorizedN(a)); return result; } @@ -174,14 +179,14 @@ class VecMask { const VecMask& b, int64_t count = size()) { VectorizedN result = VectorizedN::set( - VectorizedN(a), - VectorizedN(b), - count); + VectorizedN(a), VectorizedN(b), count); return result; } void store(bool* b, int count = size()) { - constexpr int L = (VectorizedN::size() + Vectorized::size() - 1)/ Vectorized::size(); + constexpr int L = + (VectorizedN::size() + Vectorized::size() - 1) / + Vectorized::size(); auto res = this->to(); res.store(b, count); return; From 7df6f930e8fd943f2c3094364e08659888f9ee6b Mon Sep 17 00:00:00 2001 From: Aby Mathew C Date: Fri, 4 Apr 2025 02:47:40 +0000 Subject: [PATCH 187/332] Adapt test_misc.py for HPUs (#149499) This PR is related to https://github.com/pytorch/pytorch/pull/145476 . That PR had two files (test_functions.py and test_misc.py) . test_functions was causing CI/rebase/merge issues and hence removed for now. This PR contains only test_misc.py. This is a continuation of https://github.com/pytorch/pytorch/pull/144387 . ## MOTIVATION We recently integrated support for Intel Gaudi devices (identified as 'hpu') into the common_device_type framework via the pull request at https://github.com/pytorch/pytorch/pull/126970. This integration allows tests to be automatically instantiated for Gaudi devices upon loading the relevant library. Building on this development, the current pull request extends the utility of these hooks by adapting selected CUDA tests to operate on Gaudi devices. Additionally, we have confirmed that these modifications do not interfere with the existing tests on CUDA devices. Other accelerators can also extend the functionality by adding the device in the devices list. ( For eg: xpu ) ## CHANGES Create a separate class for test functions running on CUDA devices Extend the functionality of these tests to include HPUs Use instantiate_device_type_tests with targeted attributes to generate device-specific test instances within the new classes Apply skipIfHPU decorator to bypass tests that are not yet compatible with HPU devices PS: Most of these changes were initially part of https://github.com/pytorch/pytorch/pull/147609 , but closed that PR due to merge conflicts. The review comments were handled in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149499 Approved by: https://github.com/EikanWang, https://github.com/desertfire, https://github.com/cyyever --- test/dynamo/test_misc.py | 432 ++++++++++++++++++++------------------- 1 file changed, 219 insertions(+), 213 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 8579ee8e1b2e..9050539194bf 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -76,6 +76,7 @@ TEST_CUDA, TEST_MULTIGPU, ) +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import ( sample_inputs_take_along_dim, ) @@ -84,8 +85,10 @@ IS_FBCODE, scoped_load_inline, set_default_dtype, + skipIfHpu, skipIfNNModuleInlined, skipIfWindows, + TEST_HPU, wrapDeterministicFlagAPITest, ) from torch.testing._internal.jit_utils import JitTestCase @@ -4282,27 +4285,6 @@ def test_version_ci(self): # temporary test to check that the ci torch version is set correctly self.assertTrue(hasattr(torch, "_subclasses")) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - def test_rand(self): - cnts = torch._dynamo.testing.CompileCounter() - device = "cuda" - - def fn(): - return torch.randn(10, device=device) - - torch.manual_seed(10) - ref_run1 = fn() - - torch.manual_seed(10) - ref_run2 = fn() - self.assertTrue(same(ref_run1, ref_run2)) - - torch.manual_seed(10) - opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) - res = opt_fn() - - self.assertTrue(same(res, ref_run1)) - def test_slice_input(self): cnts = torch._dynamo.testing.CompileCounter() @@ -5985,57 +5967,6 @@ def fn(param, y): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 3) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, - "Can't run fused SDPA on this platform", - ) - def test_parsing_sdpa(self): - class MyModule(torch.nn.Module): - def forward(self, query, key, value): - out = F.scaled_dot_product_attention(query, key, value, None, 0, True) - out = F.scaled_dot_product_attention( - query, key, value, None, 0, True, scale=8 - ) - out = F.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=None, - dropout_p=0, - is_causal=True, - ) - out = F.scaled_dot_product_attention( - query, - key=key, - value=value, - attn_mask=None, - dropout_p=0, - is_causal=True, - ) - out = F.scaled_dot_product_attention( - query, key, value, None, dropout_p=0, is_causal=True - ) - out = F.scaled_dot_product_attention(query, key, value, None, scale=8) - return out - - device = "cuda" - dtype = torch.float16 - seq_len_q = 1 - seq_len_k = 1 - head_dim = 8 - query = torch.ones( - 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True - ) - key = torch.ones( - 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True - ) - value = torch.ones( - 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True - ) - module = MyModule() - opt_mod = torch.compile(module, backend="inductor") - opt_mod(query, key, value) - def test_generate_tensor_from_list_of_numpy_primitive_type(self): # Test sth like torch.LongTensor(list(np.int64, np.int64, ...)) def fn(): @@ -6492,19 +6423,6 @@ def fn(x, obj): res = opt_fn(x, obj) self.assertTrue(same(ref, res)) - def test_torch_cuda_is_available(self): - def fn(x): - if torch.cuda.is_available(): - return x + 1 - else: - return x - 1 - - x = torch.rand(4) - ref = fn(x) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - def test_variable_tracker_recursively_contains(self): # VariableTracker.recursively_contains should be updated correctly when mutation happens def fn(x): @@ -6522,61 +6440,6 @@ def fn(x): res = opt_fn(x) self.assertTrue(same(ref, res)) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") - def test_torch_cudnn_is_acceptable(self): - def fn(x): - if torch.backends.cudnn.is_acceptable(tensor=x): - return x + 1 - return x - - x = torch.rand(4).cuda() - ref = fn(x) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x) - self.assertTrue(same(ref, res)) - - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") - def test_torch_cudnn_is_acceptable_bad_inputs(self): - def fn1(x): - if torch.backends.cudnn.is_acceptable("invalid"): - return x + 1 - return x - - def fn2(x): - if torch.backends.cudnn.is_acceptable(x, 3.14): - return x + 1 - return x - - with self.assertRaisesRegex( - AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" - ): - x1 = torch.rand(4).cuda() - opt_fn1 = torch.compile(fn1, backend="eager", fullgraph=True) - res1 = opt_fn1(x1) - - with self.assertRaisesRegex( - AssertionError, "Expect 1 input to cudnn.is_acceptable" - ): - x2 = torch.rand(4).cuda() - opt_fn2 = torch.compile(fn2, backend="eager", fullgraph=True) - res = opt_fn2(x2) - - @unittest.skipIf(not TEST_CUDA, "requires cuda") - def test_get_device(self): - def fn(x, y): - x = x + 1 - y = y + 1 - return x.get_device(), y.get_device() - - x = torch.rand(4, device="cuda") - y = torch.rand(4, device="cpu") - ref = fn(x, y) - opt_fn = torch.compile(fn, backend="eager", fullgraph=True) - res = opt_fn(x, y) - self.assertTrue(same(ref, res)) - def test_disable_flag(self): cnt = torch._dynamo.testing.CompileCounter() @@ -6872,17 +6735,6 @@ def guard_export_print(guards): # This guard was created self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") - @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") - def test_symint_as_device_kwarg(self): - def f(rank): - # -2 to make device id 0 for easier testing on CI - return torch.ones(10, device=rank.size(0) - 2) - - x = torch.randn(2) - out = f(torch.randn(2)) - opt_out = torch.compile(backend="eager", dynamic=True, fullgraph=True)(f)(x) - self.assertEqual(out, opt_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") def test_symint_as_device_kwarg_multi_gpu(self): def fn(rank): @@ -8331,21 +8183,6 @@ def func(x): self.assertTrue(isinstance(compile_out, torch.Size)) self.assertEqual(eager_out, compile_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def test_cuda_set_device(self): - def fn(): - a = torch.ones(2, device="cuda") - torch.cuda.set_device(1) - return a + 1 - - with torch.cuda.device(0): - counter = CompileCounter() - opt_fn = torch.compile(fn, backend=counter) - res = opt_fn() - self.assertEqual(res.device.type, "cuda") - self.assertEqual(res.device.index, 0) - self.assertEqual(counter.frame_count, 2) - def test_nested_function_resuming_with_correct_globals(self): # https://github.com/pytorch/pytorch/issues/99665 try: @@ -9584,36 +9421,6 @@ def fn(): res = opt_func() self.assertEqual(ref, res) - def test_torch_device_python_type(self): - for device, device_type, index in [ - ("cpu", "cpu", None), - ("cuda:0", "cuda", 0), - ]: - if device == "cuda:0" and not TEST_CUDA: - continue - - def fn(target): - target_device = target.device - a = torch.zeros(2, 3, device=target_device) - # Constant assert at trace time - assert isinstance(target_device, torch.device) - assert target_device.type == device_type - assert target_device.index == index - b = torch.zeros(2, 3, device=target_device) - c = torch.zeros(2, 3, device=target_device) - return a + b + c - - from torch._dynamo.variables import ConstantVariable - - device = torch.device(device) - expected_variable = ConstantVariable(device) - self.assertEqual(expected_variable.python_type(), type(device)) - - opt_func = torch.compile(fn, backend="eager", fullgraph=True) - a = torch.tensor([2, 3], device=device) - res = opt_func(a) - self.assertIsInstance(res, torch.Tensor) - def test_torch_dtype_python_type(self): def fn(target): target_dtype = target.dtype @@ -11535,23 +11342,6 @@ def fn(x, d): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) - @unittest.skipIf(not TEST_CUDA, "requires cuda") - @torch._dynamo.config.patch( - capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True - ) - @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) - def test_interpolate_propagate_real_tensors(self): - @torch.compile(backend="eager", fullgraph=True) - def f(mask, box): - # u0, u1 = mask.tolist() - mask = torch.randn(1, 1, 30, 30, device="cuda") - h, w = box.tolist() - return torch.nn.functional.interpolate( - mask, (h, w), mode="bilinear", align_corners=False - ) - - f(torch.tensor([30, 30], device="cuda"), torch.tensor([68, 32], device="cuda")) - def test_iter_type(self): @torch.compile(fullgraph=True) def fn(y): @@ -12193,6 +11983,222 @@ def fn(x, y): self.assertTrue(y.grad is not None) +class MiscTestsDevice(torch._inductor.test_case.TestCase): + def test_rand(self, device): + cnts = torch._dynamo.testing.CompileCounter() + device = device + + def fn(): + return torch.randn(10, device=device) + + torch.manual_seed(10) + ref_run1 = fn() + + torch.manual_seed(10) + ref_run2 = fn() + self.assertTrue(same(ref_run1, ref_run2)) + + torch.manual_seed(10) + opt_fn = torch.compile(fn, backend=cnts, fullgraph=True) + res = opt_fn() + + self.assertTrue(same(res, ref_run1)) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Can't run fused SDPA on this platform", + ) + def test_parsing_sdpa(self, device): + class MyModule(torch.nn.Module): + def forward(self, query, key, value): + out = F.scaled_dot_product_attention(query, key, value, None, 0, True) + out = F.scaled_dot_product_attention( + query, key, value, None, 0, True, scale=8 + ) + out = F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, + dropout_p=0, + is_causal=True, + ) + out = F.scaled_dot_product_attention( + query, + key=key, + value=value, + attn_mask=None, + dropout_p=0, + is_causal=True, + ) + out = F.scaled_dot_product_attention( + query, key, value, None, dropout_p=0, is_causal=True + ) + out = F.scaled_dot_product_attention(query, key, value, None, scale=8) + return out + + device = device + dtype = torch.float16 + seq_len_q = 1 + seq_len_k = 1 + head_dim = 8 + query = torch.ones( + 1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True + ) + key = torch.ones( + 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True + ) + value = torch.ones( + 1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True + ) + module = MyModule() + opt_mod = torch.compile(module, backend="inductor") + opt_mod(query, key, value) + + def test_torch_device_is_available(self, device): + def fn(x): + if TEST_HPU or TEST_CUDA: + return x + 1 + else: + return x - 1 + + x = torch.rand(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + @unittest.skipIf(not TEST_CUDA, "requires cuda") + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") + def test_torch_cudnn_is_acceptable(self, device): + def fn(x): + if torch.backends.cudnn.is_acceptable(tensor=x): + return x + 1 + return x + + x = torch.rand(4).to(device) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + @unittest.skipIf(not TEST_CUDA, "requires cuda") + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") + def test_torch_cudnn_is_acceptable_bad_inputs(self, device): + def fn1(x): + if torch.backends.cudnn.is_acceptable("invalid"): + return x + 1 + return x + + def fn2(x): + if torch.backends.cudnn.is_acceptable(x, 3.14): + return x + 1 + return x + + with self.assertRaisesRegex( + AssertionError, "Expect input to cudnn.is_acceptable to be a tensor" + ): + x1 = torch.rand(4).to(device) + opt_fn1 = torch.compile(fn1, backend="eager", fullgraph=True) + res1 = opt_fn1(x1) + + with self.assertRaisesRegex( + AssertionError, "Expect 1 input to cudnn.is_acceptable" + ): + x2 = torch.rand(4).to(device) + opt_fn2 = torch.compile(fn2, backend="eager", fullgraph=True) + res = opt_fn2(x2) + + def test_get_device(self, device): + def fn(x, y): + x = x + 1 + y = y + 1 + return x.get_device(), y.get_device() + + x = torch.rand(4, device=device) + y = torch.rand(4, device="cpu") + ref = fn(x, y) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x, y) + self.assertTrue(same(ref, res)) + + def test_symint_as_device_kwarg(self, device): + def f(rank): + # -2 to make device id 0 for easier testing on CI + return torch.ones(10, device=rank.size(0) - 2) + + x = torch.randn(2) + out = f(torch.randn(2)) + opt_out = torch.compile(backend="eager", dynamic=True, fullgraph=True)(f)(x) + self.assertEqual(out, opt_out) + + @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") + def test_cuda_set_device(self, device): + def fn(): + a = torch.ones(2, device=device) + torch.cuda.set_device(1) + return a + 1 + + with torch.cuda.device(0): + counter = CompileCounter() + opt_fn = torch.compile(fn, backend=counter) + res = opt_fn() + self.assertEqual(res.device.type, "cuda") + self.assertEqual(res.device.index, 0) + self.assertEqual(counter.frame_count, 2) + + def test_torch_device_python_type(self): + for device, device_type, index in [ + ("cpu", "cpu", None), + ("cuda:0", "cuda", 0), + ("hpu:0", "hpu", 0), + ]: + if (device == "cuda:0" and not TEST_CUDA) or ( + device == "hpu:0" and not TEST_HPU + ): + continue + + def fn(target): + target_device = target.device + a = torch.zeros(2, 3, device=target_device) + # Constant assert at trace time + assert isinstance(target_device, torch.device) + assert target_device.type == device_type + assert target_device.index == index + b = torch.zeros(2, 3, device=target_device) + c = torch.zeros(2, 3, device=target_device) + return a + b + c + + from torch._dynamo.variables import ConstantVariable + + device = torch.device(device) + expected_variable = ConstantVariable(device) + self.assertEqual(expected_variable.python_type(), type(device)) + + opt_func = torch.compile(fn, backend="eager", fullgraph=True) + a = torch.tensor([2, 3], device=device) + res = opt_func(a) + self.assertIsInstance(res, torch.Tensor) + + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) + def test_interpolate_propagate_real_tensors(self, device): + @torch.compile(backend="eager", fullgraph=True) + def f(mask, box): + # u0, u1 = mask.tolist() + mask = torch.randn(1, 1, 30, 30, device=device) + h, w = box.tolist() + return torch.nn.functional.interpolate( + mask, (h, w), mode="bilinear", align_corners=False + ) + + f(torch.tensor([30, 30], device=device), torch.tensor([68, 32], device=device)) + + +devices = ("cuda", "hpu") +instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) if __name__ == "__main__": from torch._dynamo.test_case import run_tests From c6d79c163c86ae2442575b07242f8cd61ad1f0e7 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 4 Apr 2025 03:24:43 +0000 Subject: [PATCH 188/332] [dynamic shapes] allow duck typing for 0/1 (#150222) Fixes #150184 e.g. for config.backed_size_oblivious=True and compile Pull Request resolved: https://github.com/pytorch/pytorch/pull/150222 Approved by: https://github.com/laithsakka --- test/dynamo/test_misc.py | 4 ++-- test/test_dynamic_shapes.py | 16 ++++++++++++++++ torch/fx/experimental/symbolic_shapes.py | 7 ++++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 9050539194bf..b91129e6c1c4 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10211,8 +10211,8 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self): > Left: {44, 93} > Right: {} ==> val_to_var: values don't match. - > Left: {0: 0, 1: 1, 2: s44, 3: s93} - > Right: {0: 0, 1: 1} + > Left: {2: s44, 3: s93} + > Right: {} ==> var_to_range: values don't match. > Left: {s44: VR[2, int_oo], s93: VR[2, int_oo]} > Right: {} diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 6b7a2d3edcfc..96115b7b37fd 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -1329,6 +1329,22 @@ def test_tensor_factory_with_symint(self): res = Tensor(sym_args) self.assertEqual(res, expected, exact_dtype=False) + def test_backed_size_oblivious_01_spec(self): + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + @torch.compile(dynamic=True, fullgraph=True) + def f(a, b): + if guard_size_oblivious(a.size(0) == 1): + return b * 10 + else: + return b * 20 + + with torch.fx.experimental._config.patch(backed_size_oblivious=True): + # always go to the >= 2 branch. + self.assertEqual( + f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) + ) + @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 95ca2cef4fd2..7b870f298ca0 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -3330,8 +3330,6 @@ def _init( # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable self.val_to_var: dict[int, sympy.Symbol] = {} - if specialize_zero_one: - self.val_to_var = {0: sympy.S.Zero, 1: sympy.S.One} self.unbacked_symfloat_counter = itertools.count() self.unbacked_symint_counter = itertools.count() # Similar to guards, but these MUST evaluate to true and can @@ -4541,7 +4539,10 @@ def create_symbol( sloc = self._get_sloc() if val in (0, 1) and specialize_zero_one: - r = self.val_to_var[val] + if val == 0: + return sympy.S.Zero + else: + return sympy.S.One elif not duck or val not in self.val_to_var: # If we're not duck shaping, we always create a new symbol # Even if we're duck shaping, if we haven't seen this particular From e6e1f8c272002aeea5c8b48e9f2657e18c80ae25 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 4 Apr 2025 04:29:40 +0000 Subject: [PATCH 189/332] [audio hash update] update the pinned audio hash (#150589) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150589 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 71fe6e9fb351..d585cc27cdab 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -318bace01aebc1f82ae13d0d133fcf9fede73383 +bccaa454a54c3c648697cc2f46a4fb0500b1f01b From 98d06b401b113a4acfbc5da86837d69e3be2a826 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Fri, 4 Apr 2025 04:32:13 +0000 Subject: [PATCH 190/332] [Dynamo] Fix `dict.items()` return type (#150112) Fixes #150110 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150112 Approved by: https://github.com/jansel, https://github.com/zou3519 --- test/dynamo/test_dicts.py | 12 ++++++++++++ torch/_dynamo/utils.py | 10 +++++++++- torch/_dynamo/variables/dicts.py | 33 ++++++++++++++++++++++---------- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 61cafbcbda2c..dcecc827cb99 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -21,6 +21,7 @@ import torch.nn import torch.utils.checkpoint from torch._dynamo.testing import same +from torch._dynamo.utils import dict_items from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import TestCase @@ -936,6 +937,17 @@ def fn(x, d): self.assertEqual(ref, res) self.assertEqual(d1.calls, d2.calls) + def test_items_type(self): + def fn(): + d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) + return d.items() + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + ref = fn() + res = opt_fn() + self.assertEqual(ref, res) + self.assertEqual(type(res), dict_items) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 2a09d8943409..13ee160a93d5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -94,7 +94,14 @@ if typing.TYPE_CHECKING: - from collections.abc import Generator, Iterable, Iterator, KeysView, ValuesView + from collections.abc import ( + Generator, + ItemsView, + Iterable, + Iterator, + KeysView, + ValuesView, + ) try: @@ -2416,6 +2423,7 @@ def check_numpy_ndarray_args(args, kwargs): dict_keys: type[KeysView[Any]] = type({}.keys()) dict_values: type[ValuesView[Any]] = type({}.values()) +dict_items: type[ItemsView[Any, Any]] = type({}.items()) odict_values: type[ValuesView[Any]] = type(OrderedDict().values()) tuple_iterator: type[Iterator[Any]] = type(iter(())) range_iterator: type[Iterator[Any]] = type(iter(range(0))) diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 6ed522f5a874..60ae7744461e 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -32,7 +32,13 @@ from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard from ..source import is_from_local_source -from ..utils import cmp_name_to_op_mapping, dict_keys, dict_values, specialize_symnode +from ..utils import ( + cmp_name_to_op_mapping, + dict_items, + dict_keys, + dict_values, + specialize_symnode, +) from .base import ValueMutationNew, VariableTracker from .constant import ConstantVariable @@ -376,7 +382,7 @@ def call_method( # corresponding value VT. For __contains__, we add a DICT_CONTAINS # guard. But for all the other methods, we insert the DICT_KEYS_MATCH # guard to be conservative. - from . import BuiltinVariable, ConstantVariable, TupleVariable + from . import BuiltinVariable, ConstantVariable Hashable = ConstDictVariable._HashableTracker @@ -398,9 +404,7 @@ def call_method( self.install_dict_keys_match_guard() if self.source: tx.output.guard_on_key_order.add(self.source.name()) - return TupleVariable( - [TupleVariable([k.vt, v]) for k, v in self.items.items()] - ) + return DictItemsVariable(self) elif name == "keys": self.install_dict_keys_match_guard() if self.source: @@ -858,7 +862,7 @@ class DictViewVariable(VariableTracker): def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: super().__init__(**kwargs) - assert self.kv in ("keys", "values") + assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @@ -873,10 +877,7 @@ def view_items_vt(self): raise NotImplementedError def unpack_var_sequence(self, tx): - def unwrap(x): - return x.vt if self.kv == "keys" else x - - return [unwrap(x) for x in self.view_items] + return self.view_items_vt def reconstruct(self, codegen): codegen(self.dv_dict) @@ -938,3 +939,15 @@ def view_items_vt(self): def python_type(self): return dict_values + + +class DictItemsVariable(DictViewVariable): + kv = "items" + + @property + def view_items_vt(self): + # Returns an iterable of the unpacked items + return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] + + def python_type(self): + return dict_items From f3cb3557d694e9e1adb570161c18dd4f95fc2041 Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Fri, 4 Apr 2025 05:21:39 +0000 Subject: [PATCH 191/332] [executorch hash update] update the pinned executorch hash (#149817) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned executorch hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149817 Approved by: https://github.com/pytorchbot --- .ci/docker/ci_commit_pins/executorch.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 6e9cfe33fe63..39005b14ab7e 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -ebe8522378c3f9944aaaef44868f5ececdd845fc +7e487c24e1c20c3f4606c2d8aca2778873b00b4c From 4854926aeb539c00c0d0b6e6661e1a675b6a0432 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 4 Apr 2025 06:52:54 +0000 Subject: [PATCH 192/332] Revert "Add torch._scaled_mm for CPU (#150410)" This reverts commit 3b02f795c5ad2339794b15b370c0e4a235d36adf. Reverted https://github.com/pytorch/pytorch/pull/150410 on behalf of https://github.com/malfet due to It breaks ROCM tests ([comment](https://github.com/pytorch/pytorch/pull/150410#issuecomment-2777704212)) --- aten/src/ATen/native/Blas.cpp | 96 ------------- aten/src/ATen/native/mkldnn/Linear.cpp | 126 +----------------- aten/src/ATen/native/mkldnn/Linear.h | 12 -- aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp | 22 +-- aten/src/ATen/native/native_functions.yaml | 2 - test/inductor/test_fp8.py | 113 ++++++---------- test/test_matmul_cuda.py | 23 ++-- torch/_inductor/codegen/cpp_prefix.h | 4 - .../aoti_torch/generated/c_shim_cpu.h | 2 - torch/testing/_internal/common_device_type.py | 2 - .../_internal/common_methods_invocations.py | 20 +-- 11 files changed, 58 insertions(+), 364 deletions(-) diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index 560a8f7657a8..f62c31777822 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -7,11 +7,6 @@ #include #include -#include -#include -#if !defined(__s390x__) && !defined(__powerpc__) -#include -#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -29,9 +24,6 @@ #include #include #include -#include -#include -#include #endif namespace at::meta { @@ -230,92 +222,4 @@ Tensor vdot(const Tensor &self, const Tensor &other){ } -static Tensor& -_scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], - " but got ", bias->numel()); - - // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - - auto mat1_c = mat1.contiguous(); - auto mat2_c = mat2.contiguous(); - IntArrayRef mat1_sizes = mat1_c.sizes(); - IntArrayRef mat2_sizes = mat2_c.sizes(); - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - - float input_scale = scale_a.item(); - float weight_scale = scale_b.item(); - auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale); - auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale); - auto out_tmp = at::matmul(fp32_mat1, fp32_mat2); - if (bias) { - out_tmp.add_(bias.value()); - } - out_tmp = out_tmp.to(out.scalar_type()); - out.copy_(out_tmp); - return out; -} - -Tensor& -_scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { -#if AT_MKLDNN_ENABLED() - if (at::globalContext().userEnabledMkldnn()) { - bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type(); - if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) || - (mixed_dtype && cpuinfo_has_x86_amx_fp16())) { - return mkldnn_scaled_mm( - mat1, - mat2, - scale_a, - scale_b, - bias, - scale_result, - out_dtype, - use_fast_accum, - out); - } - } -#endif - { - return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); - } -} - -Tensor -_scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum) { - const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); - Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); - return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); -} - } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index b1175b796224..8153ae8a4d8e 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -47,20 +46,9 @@ std::tuple mkldnn_linear_backward( TORCH_CHECK(false, "mkldnn_linear_backward: ATen not compiled with MKLDNN support"); } -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_INTERNAL_ASSERT(false, "mkldnn_scaled_mm: ATen not compiled with MKLDNN support"); -} - } // namespace at::native + #else // AT_MKLDNN_ENABLED #include @@ -471,118 +459,6 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) { TORCH_FN(mkldnn_linear_pointwise_binary)); } -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out) { - TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); - TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); - TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", - mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); - TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], - " but got ", bias->numel()); - - // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); - TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); - TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - - // Validation checks have passed lets resize the output to actual size - auto mat1_c = mat1.contiguous(); - auto mat2_c = mat2.contiguous(); - IntArrayRef mat1_sizes = mat1_c.sizes(); - IntArrayRef mat2_sizes = mat2_c.sizes(); - at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - - float input_scale = scale_a.item(); - float weight_scale = scale_b.item(); - auto src = at::native::itensor_view_from_dense(mat1_c); - auto weight_t = at::native::itensor_view_from_dense(mat2_c); - bool with_bias = bias.has_value(); - int64_t K = mat1_sizes[1], M = mat1_sizes[0], - N = mat2_sizes[1]; - - std::vector src_dims = {M, K}; - std::vector weight_dims = {K, N}; - std::vector dst_dims = {M, N}; - - ideep::tensor dst = at::native::itensor_view_from_dense(out); - auto src_desc = ideep::tensor::desc( - src_dims, - get_mkldnn_dtype(mat1.scalar_type()), - ideep::format_tag::any); - auto weights_desc = ideep::tensor::desc( - weight_dims, - get_mkldnn_dtype(mat2.scalar_type()), - ideep::format_tag::any); - auto dst_desc = ideep::tensor::desc( - dst_dims, - get_mkldnn_dtype(out.scalar_type()), - ideep::format_tag::any); - ideep::tensor onednn_bias; - if (with_bias) { - auto bias_value = bias.value(); - if (bias_value.dim() == 1) { - auto b_reshape = bias_value.reshape({1, bias_value.size(0)}); - onednn_bias = at::native::itensor_view_from_dense(b_reshape); - } else { - onednn_bias = at::native::itensor_view_from_dense(bias_value); - } - } - auto bias_desc = ideep::tensor::desc(); - if (with_bias) { - bias_desc = ideep::tensor::desc(onednn_bias.get_dims(), - get_mkldnn_dtype(bias.value().scalar_type()), - ideep::format_tag::any); - } - auto op_attr = ideep::attr_t(); - if (input_scale != 1.0f) { - op_attr.set_scales_mask(DNNL_ARG_SRC, 0); - } - if (weight_scale != 1.0f) { - op_attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); - } - - op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - auto engine = ideep::engine::cpu_engine(); - dnnl::matmul::primitive_desc primitive_desc = with_bias - ? dnnl::matmul::primitive_desc( - engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) - : dnnl::matmul::primitive_desc( - engine, src_desc, weights_desc, dst_desc, op_attr); - auto expected_weight = weight_t.reorder_if_differ_in(primitive_desc.weights_desc()); - auto primitive = dnnl::matmul(primitive_desc); - - // Prepare args and execute primitive - ideep::tensor scratchpad(primitive_desc.scratchpad_desc()); - ideep::exec_args args; - args.insert({DNNL_ARG_SRC, src}); - args.insert({DNNL_ARG_WEIGHTS, expected_weight}); - args.insert({DNNL_ARG_DST, dst}); - args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); - if (with_bias) { - args.insert({DNNL_ARG_BIAS, onednn_bias}); - } - ideep::tensor src_scales_t = ideep::tensor(ideep::scale_t(1, input_scale)); - ideep::tensor wei_scales_t = ideep::tensor(ideep::scale_t(1, weight_scale)); - - if (input_scale != 1.0f) { - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, src_scales_t}); - } - args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, wei_scales_t}); - - primitive.execute(ideep::stream::default_stream(), args); - return out; -} - } // namespace at #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/Linear.h b/aten/src/ATen/native/mkldnn/Linear.h index 1dc50c7c5416..6a7fcd60b0e6 100644 --- a/aten/src/ATen/native/mkldnn/Linear.h +++ b/aten/src/ATen/native/mkldnn/Linear.h @@ -35,15 +35,3 @@ C10_API Tensor mkl_linear( } // namespace at #endif // AT_MKLDNN_ENABLED() - -namespace at::native { -Tensor& -mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2, - const Tensor& scale_a, - const Tensor& scale_b, - const std::optional& bias, - const std::optional& scale_result, - std::optional out_dtype, - bool use_fast_accum, - Tensor& out); -} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index f26427a981f7..32daef37a563 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -57,10 +57,6 @@ ideep::tensor::data_type get_mkldnn_dtype(ScalarType type) { return ideep::tensor::data_type::bf16; case ScalarType::Half: return ideep::tensor::data_type::f16; - case ScalarType::Float8_e4m3fn: - return ideep::tensor::data_type::f8_e4m3; - case ScalarType::Float8_e5m2: - return ideep::tensor::data_type::f8_e5m2; default: TORCH_CHECK(false, "get_mkldnn_dtype: unsupported data type"); } @@ -165,24 +161,8 @@ ideep::tensor itensor_view_from_dense(const Tensor& tensor, bool from_const_data const_cast(tensor.const_data_ptr()) : tensor.data_ptr()}; } - else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) { - return {{tensor.sizes().vec(), - ideep::tensor::data_type::f8_e4m3, - tensor.strides().vec()}, - from_const_data_ptr ? - const_cast(tensor.const_data_ptr()) : - tensor.data_ptr()}; - } - else if (tensor.scalar_type() == ScalarType::Float8_e5m2) { - return {{tensor.sizes().vec(), - ideep::tensor::data_type::f8_e5m2, - tensor.strides().vec()}, - from_const_data_ptr ? - const_cast(tensor.const_data_ptr()) : - tensor.data_ptr()}; - } else { - TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8/fp8 tensor input"); + TORCH_CHECK(false, "itensor_view_from_dense expects float/bfloat16/half/int8 tensor input"); } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c574130ac43d..e3a1cd175c86 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7063,13 +7063,11 @@ - func: _scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor variants: function dispatch: - CPU: _scaled_mm_cpu CUDA: _scaled_mm_cuda - func: _scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: - CPU: _scaled_mm_out_cpu CUDA: _scaled_mm_out_cuda diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 8f36b2930f00..e208565081a1 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -13,7 +13,7 @@ instantiate_parametrized_tests, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import HAS_CUDA from torch.utils._triton import has_triton_tma_device @@ -116,9 +116,9 @@ def _fix_fp8_dtype_for_rocm( @instantiate_parametrized_tests class TestFP8Types(TestCase): + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) - @parametrize("device", ("cuda", "cpu")) - def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): + def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): """ TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 depends on the variant of fp8 type. @@ -128,33 +128,29 @@ def test_xblock_for_small_numel(self, float8_dtype: torch.dtype, device: str): We should not pick a XBLOCK larger than xnumel """ float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) def f(x): return x.to(dtype=float8_dtype) - x = torch.randn(1, device=device) + x = torch.randn(1, device="cuda") expected = f(x) actual = torch.compile(f)(x) torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16)) - @parametrize("device", ("cuda", "cpu")) - def test_eager_fallback(self, dtype: torch.dtype, device: torch.device): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) + def test_eager_fallback(self, dtype: torch.dtype): weight_shape = (32, 16) e4m3_type = torch.float8_e4m3fn e4m3_type = _fix_fp8_dtype_for_rocm(e4m3_type, device="cuda") def fp8_matmul_unwrapped(x): - a_scale = torch.Tensor([1.0]).to(device=device) - b_scale = torch.Tensor([1.0]).to(device=device) + a_scale = torch.Tensor([1.0]).to(device="cuda") + b_scale = torch.Tensor([1.0]).to(device="cuda") output_scale = None - input_bias = torch.rand(32, device=device, dtype=dtype) - weight = torch.rand(*weight_shape, device=device, dtype=dtype).T.to( + input_bias = torch.rand(32, device="cuda", dtype=dtype) + weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to( e4m3_type ) a_inverse_scale = 1 / a_scale @@ -175,24 +171,19 @@ def fp8_matmul_unwrapped(x): ) x_shape = (16, 16) - x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 x_shape = (15, 16) - x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) + x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) y_fp8 = compiled_fp8_matmul(x) # noqa: F841 + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("shape", ("15,3,13", "4,2048,4096")) @parametrize("dst_types", [(torch.float8_e4m3fn, torch.float8_e5m2)]) - @parametrize("device", ("cuda", "cpu")) - def test_valid_cast( - self, dtype: torch.dtype, shape: str, dst_types: tuple, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") + def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): + dst_types = _fix_fp8_dtype_for_rocm(dst_types, device="cuda") e4m3, e5m2 = dst_types def fp8_cast(x): @@ -203,7 +194,7 @@ def fp8_cast(x): compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device=device, dtype=dtype) + x = torch.rand(*shape, device="cuda", dtype=dtype) y0_fp8, y1_fp8 = compiled_fp8_cast(x) torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) @@ -232,21 +223,14 @@ def fp8_cast(x, dtype): x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) compiled_fp8_cast(x, torch.float8_e4m3fn) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) @parametrize("dst_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("16,16,16", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_to_fp8_saturated( - self, - src_dtype: torch.dtype, - dst_dtype: torch.dtype, - shape: str, - device: torch.device, + self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") + dst_dtype = _fix_fp8_dtype_for_rocm(dst_dtype, device="cuda") def fp8_saturated(x, dtype): return _to_fp8_saturated(x, dtype) @@ -255,22 +239,17 @@ def fp8_saturated(x, dtype): fp8_saturated, backend="inductor", dynamic=True ) shape = [int(dim) for dim in shape.split(",")] - x = torch.rand(*shape, device=device, dtype=src_dtype) + x = torch.rand(*shape, device="cuda", dtype=src_dtype) y_compiled = compiled_fp8_cast(x, dst_dtype) y = fp8_saturated(x, dst_dtype) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) - def test_amax_fp8_quant( - self, float8_dtype: torch.dtype, shape: str, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest( - "FP8 is only supported on H100+ and sm_89 and MI300+ devices" - ) + def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -283,24 +262,19 @@ def amax_fp8(x: Tensor, scale: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(0.2, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(0.2, device="cuda", dtype=torch.float) y_compiled = compiled_amax_fp8_quant(x, scale) y = amax_fp8(x, scale) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) - def test_amax_along_with_fp8_quant( - self, float8_dtype: torch.dtype, shape: str, device: torch.device - ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) - if device == "cuda": - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -313,12 +287,12 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(1.0, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(1.0, device="cuda", dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) y = amax_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -326,21 +300,14 @@ def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("amax_keep_dim", (True, False)) @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) - @parametrize("device", ("cuda", "cpu")) def test_layernorm_fp8_quant( - self, - float8_dtype: torch.dtype, - amax_keep_dim: bool, - shape: str, - device: torch.device, + self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str ): - if device == "cuda" and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest( - "FP8 is only supported on H100+ and sm_89 and MI300+ devices" - ) + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -362,12 +329,12 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device=device, dtype=torch.half) - scale = torch.tensor(0.2, device=device, dtype=torch.float) + x = torch.rand(*x_shape, device="cuda", dtype=torch.half) + scale = torch.tensor(0.2, device="cuda", dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) - amax_buffer = torch.zeros((1), device=device, dtype=torch.half) + amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) y = ln_fp8(x, scale, amax_buffer) torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) @@ -783,5 +750,5 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): if __name__ == "__main__": - if HAS_CUDA or HAS_CPU: + if HAS_CUDA: run_tests() diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 17ece41af239..49da165ca20e 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -484,15 +484,15 @@ def _bfloat16_to_float4_e2m1fn_x2(x): return x -class TestFP8Matmul(TestCase): +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") +class TestFP8MatmulCuda(TestCase): + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def _test_tautological_mm(self, device: str = "cuda", x_dtype: torch.dtype = e4m3_type, y_dtype: torch.dtype = e4m3_type, out_dtype: Optional[torch.dtype] = None, size: int = 16) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) x_fp8 = torch.rand(size, size, device=device).to(x_dtype) y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) @@ -503,13 +503,12 @@ def _test_tautological_mm(self, device: str = "cuda", self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_basics(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # supported on ROCm but fails on CUDA - ctx = self.assertRaises(RuntimeError) if torch.version.hip is None and device != "cpu" else contextlib.nullcontext() + ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() with ctx: self._test_tautological_mm(device, e5m2_type, e5m2_type) @@ -520,12 +519,11 @@ def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) - with self.assertRaises(AssertionError if torch.version.hip or device == "cpu" else RuntimeError): + with self.assertRaises(AssertionError if torch.version.hip else RuntimeError): self._test_tautological_mm(device, out_dtype=e5m2_type) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_scale(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) size = (16, 16) x = torch.full(size, .5, device=device, dtype=e4m3_type) # hipblaslt does not yet support mixed e4m3_type input @@ -640,9 +638,8 @@ def test_scaled_mm_change_stride(self, base_dtype): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias(self, device) -> None: - if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8: - raise unittest.SkipTest(f8_msg) (k, l, m) = (16, 48, 32) x = torch.ones((k, l), device=device).to(e4m3_type) y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() @@ -695,7 +692,7 @@ def test_float32_output_errors_with_bias(self, device) -> None: lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), ) - @unittest.skipIf(PLATFORM_SUPPORTS_FP8 or not torch.cuda.is_available(), f8_msg) + @unittest.skipIf(PLATFORM_SUPPORTS_FP8, f8_msg) def test_error_message_fp8_pre_sm89(self, device) -> None: (k, l, m) = (16, 48, 32) x = torch.rand((k, l), device=device).to(e4m3_type) @@ -1551,8 +1548,8 @@ def run_test( ) instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") +instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") -instantiate_device_type_tests(TestFP8Matmul, globals()) if __name__ == '__main__': TestCase._default_dtype_check_enabled = True diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 3a00ce1e3015..8254363cbdcb 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -21,8 +21,6 @@ #include #include -#include -#include #include #include #include @@ -50,8 +48,6 @@ typedef at::BFloat16 bfloat16; typedef at::Float8_e4m3fn float8_e4m3fn; typedef at::Float8_e5m2 float8_e5m2; -typedef at::Float8_e4m3fnuz float8_e4m3fnuz; -typedef at::Float8_e5m2fnuz float8_e5m2fnuz; template struct Welford { diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 55085ee1be7b..682364e950c4 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -37,8 +37,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attent AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_flash_attention_for_cpu_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, double dropout_p, int32_t is_causal, AtenTensorHandle* attn_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__scaled_mm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, AtenTensorHandle scale_b, AtenTensorHandle* bias, AtenTensorHandle* scale_result, int32_t* out_dtype, int32_t use_fast_accum); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__segment_reduce_backward(AtenTensorHandle grad, AtenTensorHandle output, AtenTensorHandle data, const char* reduce, AtenTensorHandle* lengths, AtenTensorHandle* offsets, int64_t axis, double* initial, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__trilinear(AtenTensorHandle i1, AtenTensorHandle i2, AtenTensorHandle i3, const int64_t* expand1, int64_t expand1_len_, const int64_t* expand2, int64_t expand2_len_, const int64_t* expand3, int64_t expand3_len_, const int64_t* sumdim, int64_t sumdim_len_, int64_t unroll_dim, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 4ec7eb34a5dc..9cd0661cac15 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -1003,8 +1003,6 @@ class OpDTypes(Enum): torch.int8, torch.uint8, torch.bool, - torch.float8_e4m3fn, - torch.float8_e5m2, ) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index d16d31d42684..24f651020d75 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -22,7 +22,7 @@ from torch.testing._internal.common_dtype import ( _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, - empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, + empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, ) from torch.testing._internal.common_device_type import \ (onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -16217,7 +16217,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo( 'torch._scaled_mm', sample_inputs_func=sample_inputs_scaled_mm, - dtypes=float8_types(), + dtypes=empty_types(), dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,), supports_out=True, supports_forward_ad=False, @@ -16225,20 +16225,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): decorators=[skipCUDAIf(not SM89OrLater or TEST_WITH_ROCM, 'Requires CUDA SM >= 8.9')], skips=( # Sample inputs isn't really parametrized on dtype - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes'), - # "add_stub" not implemented for 'Float8_e4m3fn' - # "ufunc_add_CUDA" not implemented for 'Float8_e4m3fn' - # https://github.com/pytorch/pytorch/issues/107256 - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), + DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_dtypes', + device_type='cuda'), # "mul_cuda" not implemented for float8_e4m3fn - # "mul_cpu_reduced_float" not implemented for 'Float8_e4m3fn' # https://github.com/pytorch/pytorch/issues/107256 - DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness'), - # aten::_scaled_mm hit the vmap fallback which is currently disabled - DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_op_has_batch_rule"), - DecorateInfo(unittest.skip("Skipped!"), "TestVmapOperatorsOpInfo", "test_vmap_exhaustive"), - DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness', - dtypes=(torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz)), + DecorateInfo(unittest.skip("Skipped!"), 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', + dtypes=(torch.float8_e4m3fn,)), ) ), OpInfo( From 73358d37dab22a9d080de3e29a576dbab775d15f Mon Sep 17 00:00:00 2001 From: Jakub Grzybek Date: Fri, 4 Apr 2025 09:59:59 +0000 Subject: [PATCH 193/332] =?UTF-8?q?Fix=20codegen,=20change=20str=20compari?= =?UTF-8?q?son=20opeator=20to=20=3D=3D=20for=20proper=20equality=20?= =?UTF-8?q?=E2=80=A6=20(#150611)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/pytorch/pull/150611 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- test/test_fx.py | 17 +++++++++++++++++ torch/fx/graph.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/test/test_fx.py b/test/test_fx.py index 5b54025d8d32..58e925f8633e 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -1271,6 +1271,23 @@ def forward(self, x: torch.Tensor, y: int = 2): "call_module" ).check("clamp").check("call_method").run(all_formatted) + def test_print_graph(self): + op: torch._ops.OpOverload = torch.ops.aten.relu.default + type_name: str = torch.typename(op) + + graph: torch.fx.Graph = torch.fx.Graph() + a: torch.fx.Node = graph.create_node("placeholder", "x") + b: torch.fx.Node = graph.create_node("call_function", op, (a,), type_expr=type_name) + c: torch.fx.Node = graph.create_node("call_function", op, (b,), type_expr=type_name) + graph.output((b, c)) + + gm: torch.fx.GraphModule = torch.fx.GraphModule( + torch.nn.Module(), graph + ) + gm.graph.lint() + text = gm.print_readable(False) + assert 2 == text.count("_torch__ops_aten_aten_relu_") + def test_script_tensor_constant(self): # TorchScript seems to ignore attributes that start with `__`. # We used to call anonymous Tensor values `__tensor_constant*`, but diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 4a156dba0463..75c0eb8081fb 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -439,7 +439,7 @@ def add_global(name_hint: str, obj: Any): global_name = namespace.create_name(name_hint, obj) if global_name in globals_: - assert globals_[global_name] is obj + assert globals_[global_name] == obj return global_name globals_[global_name] = obj return global_name From 09c4da9325595f0091c81f5c47fc4ee1df0c4094 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 4 Apr 2025 13:05:40 +0000 Subject: [PATCH 194/332] [CUDA][avgpool2d] Fix backward launch bounds again for `sm100`, `sm120` (#150640) `__CUDA_ARCH__` is not visible in host code, which causes incorrect launch bounds and `too many resources requested for launch` on blackwell CC @atalman @malfet as we would want this in 2.7 @nWEIdia Pull Request resolved: https://github.com/pytorch/pytorch/pull/150640 Approved by: https://github.com/malfet, https://github.com/drisspg, https://github.com/atalman --- aten/src/ATen/native/cuda/AveragePool2d.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/AveragePool2d.cu b/aten/src/ATen/native/cuda/AveragePool2d.cu index 41fbddb3c583..25eda2b6eabb 100644 --- a/aten/src/ATen/native/cuda/AveragePool2d.cu +++ b/aten/src/ATen/native/cuda/AveragePool2d.cu @@ -402,11 +402,12 @@ TORCH_IMPL_FUNC(avg_pool2d_backward_out_cuda) ( bool use_divisor = divisor_override.has_value(); const auto divisor_override_value = use_divisor ? divisor_override.value() : 0; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 - constexpr int double_threads = 768; -#else - constexpr int double_threads = 1024; -#endif + cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); + const bool gesm10x = properties->major >= 10; + int double_threads = 1024; + if (gesm10x) { + double_threads = 768; + } AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_backward_out_cuda_frame", From 295b7e21eba104f00eb09c60f91d0f711408652e Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Fri, 4 Apr 2025 13:14:52 +0000 Subject: [PATCH 195/332] [MPS/inductor] Add support for hermite_polynomial_h. (#150664) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150664 Approved by: https://github.com/malfet --- test/inductor/test_mps_basic.py | 1 + torch/_inductor/codegen/mps.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index ee2cc4e4fbba..47dab3ad972c 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -132,6 +132,7 @@ def test_pointwise_polygamma(self): "chebyshev_polynomial_u", "chebyshev_polynomial_v", "chebyshev_polynomial_w", + "hermite_polynomial_h", ], ) def test_pointwise_binary_op(self, op_name): diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index b600721e1a30..1aae913fa0cf 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -449,6 +449,10 @@ def chebyshev_polynomial_v(x: CSEVariable, n: CSEVariable) -> str: def chebyshev_polynomial_w(x: CSEVariable, n: CSEVariable) -> str: return f"c10::metal::chebyshev_polynomial_w_forward({x}, {n})" + @staticmethod + def hermite_polynomial_h(x: CSEVariable, n: CSEVariable) -> str: + return f"c10::metal::hermite_polynomial_h_forward({x}, {n})" + MetalOverrides._initialize_pointwise_overrides("mps") From 1b0a023ddef55dab7ebb0b5c673dc043338706a8 Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Fri, 4 Apr 2025 14:26:22 +0000 Subject: [PATCH 196/332] [Dynamo][Misc] Apply typing hints for `codegen` (#150289) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150289 Approved by: https://github.com/Skylion007, https://github.com/cyyever --- torch/_dynamo/codegen.py | 8 ++- torch/_dynamo/output_graph.py | 10 ++-- torch/_dynamo/source.py | 61 ++++++++++++----------- torch/_dynamo/variables/base.py | 5 +- torch/_dynamo/variables/builder.py | 5 +- torch/_dynamo/variables/builtin.py | 3 +- torch/_dynamo/variables/constant.py | 2 +- torch/_dynamo/variables/ctx_manager.py | 13 ++--- torch/_dynamo/variables/dicts.py | 11 ++-- torch/_dynamo/variables/functions.py | 9 ++-- torch/_dynamo/variables/iter.py | 15 +++--- torch/_dynamo/variables/misc.py | 21 ++++---- torch/_dynamo/variables/sdpa.py | 3 +- torch/_dynamo/variables/tensor.py | 5 +- torch/_dynamo/variables/torch.py | 2 +- torch/_dynamo/variables/torch_function.py | 5 +- torch/_dynamo/variables/user_defined.py | 3 +- 17 files changed, 101 insertions(+), 80 deletions(-) diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 05dd42866e81..b065c188bcbc 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -18,7 +18,7 @@ import sys import types from collections import Counter -from typing import Optional, Union +from typing import Optional, TYPE_CHECKING, Union import torch.nn from torch.utils._ordered_set import OrderedSet @@ -54,6 +54,10 @@ from .variables.torch_function import TensorWithTFOverrideVariable +if TYPE_CHECKING: + from .symbolic_convert import InstructionTranslatorBase + + @dataclasses.dataclass class GraphOutputEntry: index: int @@ -67,7 +71,7 @@ class PyCodegen: def __init__( self, - tx=None, + tx: "InstructionTranslatorBase", root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, tempvars=None, diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index c11e6deccc7d..92a6ea2f15c8 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -390,7 +390,7 @@ def __init__( # and LOAD_ATTR for same python objects free. self.variable_tracker_cache = VariableTrackerCache() self.unique_var_id = itertools.count() - self.code_options = dict(code_options) + self.code_options: dict[str, Any] = dict(code_options) self.output_instructions: list[Instruction] = [] # used to track nodes that are added between calls of copy_graphstate # and restore_graphstate @@ -401,7 +401,7 @@ def __init__( # Not checkpointed self.compiler_fn: Optional[CompilerFn] = compiler_fn - self.global_scope = global_scope + self.global_scope: Scope = global_scope self.local_scope = local_scope self.root_tx = root_tx @@ -462,7 +462,7 @@ def __init__( self.random_calls: list[ tuple[Callable[..., object], tuple[object, ...], dict[str, object]] ] = [] - self.random_values_var = None + self.random_values_var: Any = None # Bytecode to insert right before we call the graph self.pregraph_bytecode: list[Instruction] = [] @@ -888,7 +888,9 @@ def wrap_name(module_key): self.output.update_co_names(module_key) self.global_scope[module_key] = target return VariableTracker.build( - self, target, ConstantSource(source_name=module_key) + self, # type: ignore[arg-type] + target, + ConstantSource(source_name=module_key), ) for k, v in self.nn_modules.items(): diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 4116f110b21d..f31d613170a5 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -21,7 +21,7 @@ import dataclasses import enum -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union from torch._guards import ChainedSource, GuardSource, Source @@ -29,6 +29,9 @@ from .bytecode_transformation import create_call_function, create_instruction +if TYPE_CHECKING: + from .codegen import PyCodegen + # It shouldn't be supported to construct an NNModuleVariable inside an FSDP module, # so those cases are omitted intentionally @@ -120,7 +123,7 @@ class LocalSource(Source): # or `co_freevars`. is_derefed_cell_contents: bool = False - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.is_derefed_cell_contents: codegen.load_deref(self.local_name) else: @@ -137,7 +140,7 @@ def name(self): class SyntheticLocalSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(self.local_name)) def guard_source(self): @@ -154,7 +157,7 @@ class RandomValueSource(Source): def guard_source(self): return GuardSource.RANDOM_VALUE - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(create_instruction("BINARY_SUBSCR")) @@ -167,7 +170,7 @@ def name(self): class GlobalSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.global_name, add=True)) def guard_source(self): @@ -181,7 +184,7 @@ def name(self): class GlobalWeakRefSource(Source): global_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.append_output( codegen.create_load_global(self.global_name, add=True) @@ -198,7 +201,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class WeakRefCallSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(self.base)) codegen.extend_output(create_call_function(0, False)) @@ -227,7 +230,7 @@ def __post_init__(self): ) object.__setattr__(self, "member", member_parts[-1]) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -249,7 +252,7 @@ class LocalCellSource(Source): local_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, # Dynamo's bytecode transformation differentiates them slightly, so we # always emit `LOAD_CLOSURE` here. @@ -267,7 +270,7 @@ def reconstruct(self, codegen): class GradSource(ChainedSource): member: str = "grad" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.member)) @@ -342,7 +345,7 @@ def __post_init__(self): else: assert self.idx is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, f"call_{self.prop.method_name()}" @@ -378,7 +381,7 @@ class IndexedSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -393,7 +396,7 @@ class NegateSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def guard_source(self): @@ -409,7 +412,7 @@ class ConvertIntSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -424,7 +427,7 @@ class FlattenScriptObjectSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -439,7 +442,7 @@ class ScriptObjectQualifiedNameSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -450,7 +453,7 @@ def name(self): class AttrProxySource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -484,7 +487,7 @@ def __post_init__(self): self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) codegen.extend_output(codegen.create_load_attrs(self.field)) codegen.append_output(codegen.create_load_const(self.idx_key)) @@ -509,7 +512,7 @@ def __post_init__(self): super().__setattr__("index", self.index.__reduce__()) super().__setattr__("index_is_slice", True) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) if self.index_is_slice: codegen.append_output(codegen.create_load_const(self.unpack_slice())) @@ -543,7 +546,7 @@ class ConstDictKeySource(ChainedSource): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") ) @@ -577,7 +580,7 @@ def __post_init__(self): def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # reconstruct dict.__getitem__(dct, key) # Load dict.__getitem__ @@ -609,7 +612,7 @@ class ListGetItemSource(GetItemSource): Same as GetItemSource with reconstruct and name overridden to be list specific. """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Reconstruct list.__getitem__(lst, index) to avoid any side effects # from possibly overridden __getitem__. @@ -646,7 +649,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class TupleIteratorGetItemSource(GetItemSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") ) @@ -663,7 +666,7 @@ class TypeSource(ChainedSource): def __post_init__(self): assert self.base is not None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) @@ -677,7 +680,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class OptimizerSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -689,7 +692,7 @@ def name(self): @dataclasses.dataclass(frozen=True) class NNModuleSource(ChainedSource): - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.base) def guard_source(self): @@ -738,7 +741,7 @@ def _get_index(self): return TorchFunctionModeStackVariable.get_mode_index(self.ind) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( utils.__name__, "get_torch_function_mode_stack_at" @@ -755,7 +758,7 @@ def guard_source(self): class ConstantSource(Source): source_name: str - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.append_output(codegen.create_load_global(self.source_name, add=False)) def guard_source(self): @@ -776,7 +779,7 @@ def name(self) -> str: def guard_source(self): return self.base.guard_source() - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen(self.base) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index fbf780bf7fa3..e5274d0f0ce7 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -29,7 +29,8 @@ if TYPE_CHECKING: - from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase + from ..codegen import PyCodegen + from ..symbolic_convert import InstructionTranslator, InstructionTranslatorBase class SourceType(Enum): @@ -399,7 +400,7 @@ def maybe_fx_node(self): except NotImplementedError: return None - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError def unpack_var_sequence(self, tx) -> list["VariableTracker"]: diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d5cea823b7f6..d85885449b06 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -276,6 +276,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -348,7 +349,7 @@ def __post_init__(self): self._example = TensorWeakRef(self._example) assert is_fake(self.fake_tensor) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.source) def erase(self): @@ -369,7 +370,7 @@ def __init__(self) -> None: is_tensor=False, ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert codegen.tx.output.backward_state_var codegen.add_push_null( lambda: codegen.load_import_from(BackwardState.__module__, "BackwardState") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5360868dd7e7..2a7d031b7b87 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -87,6 +87,7 @@ if TYPE_CHECKING: # Cyclic dependency... + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator log = logging.getLogger(__name__) @@ -730,7 +731,7 @@ def as_proxy(self): return DTYPE[self.fn] return super().as_proxy() - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen"): + def reconstruct(self, codegen: "PyCodegen"): name = self.fn.__name__ assert self.fn.__module__ == "builtins" assert name not in codegen.tx.f_globals, "shadowed global" diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 6760bd1ff73a..f86d2d2062a7 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -133,7 +133,7 @@ def const_getattr(self, tx: "InstructionTranslator", name): def call_method( self, - tx, + tx: "InstructionTranslator", name, args: "list[VariableTracker]", kwargs: "dict[str, VariableTracker]", diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 04f552c54fa3..7cbed617d823 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -50,6 +50,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -85,12 +86,12 @@ def exit(self, tx: "InstructionTranslator", *args): self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -1057,7 +1058,7 @@ def exit(self, tx: "InstructionTranslator", *args): _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1278,7 +1279,7 @@ def call_method( def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this stream is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1340,7 +1341,7 @@ def call_method( def as_proxy(self): return self.proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # If we got here, this event is fully subsumed by the graph - this means it is # not an input or global assert not self.source @@ -1378,7 +1379,7 @@ def call_function( assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 60ae7744461e..7c38539bd217 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -263,7 +264,7 @@ def is_new_item(self, value, other): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.add_push_null( @@ -546,7 +547,7 @@ def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: def unpack_var_sequence(self, tx): return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # load types.MappingProxyType if self.source: unimplemented( @@ -681,7 +682,7 @@ def python_type(self): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) @@ -786,7 +787,7 @@ def python_type(self): def as_python_constant(self): return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -879,7 +880,7 @@ def view_items_vt(self): def unpack_var_sequence(self, tx): return self.view_items_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 257ccac4d37b..d8beec6aaeb2 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator from torch._higher_order_ops.triton_kernel_wrap import ( TritonGridType, @@ -470,7 +471,7 @@ def __str__(self): __repr__ = __str__ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): from torch._dynamo.side_effects import disallow_side_effects_in_generator from torch._dynamo.symbolic_convert import ( InstructionTranslator, @@ -1109,7 +1110,7 @@ def bind_args(self, parent, args, kwargs): return result - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from(__name__, "_create_nested_fn") ) @@ -1506,7 +1507,7 @@ def __init__(self, func: VariableTracker, args, keywords, **kwargs) -> None: def python_type(self): return functools.partial - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("functools", "partial")) codegen(self.func) if self.args: @@ -1962,7 +1963,7 @@ def to_metadata(self): self.element_size.as_proxy(), ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from( "triton.tools.experimental_descriptor", diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 502616c440e9..3cf9c994ddc2 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -34,6 +34,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -249,7 +250,7 @@ def __init__(self, item: VariableTracker, **kwargs) -> None: def next_variable(self, tx): return self.item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -279,7 +280,7 @@ def next_variable(self, tx): self.item = self.item.call_method(tx, "__add__", [self.step], {}) return old_item - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -425,7 +426,7 @@ def get_item(it): self.index += 1 return variables.TupleVariable(args) - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): for it in self.iterables: if isinstance(it, list): remaining_items = it[self.index :] @@ -436,7 +437,7 @@ def reconstruct_items(self, codegen): else: codegen(it) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True ) @@ -481,7 +482,7 @@ def next_variable(self, tx): args = super().next_variable(tx) return self.fn.call_function(tx, args.items, {}) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True ) @@ -555,7 +556,7 @@ def _next(): if pred_res.as_python_constant(): return item - def reconstruct_items(self, codegen): + def reconstruct_items(self, codegen: "PyCodegen"): if isinstance(self.iterable, list): remaining_items = self.iterable[self.index :] codegen.foreach(remaining_items) @@ -565,7 +566,7 @@ def reconstruct_items(self, codegen): else: codegen(self.iterable) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("builtins", "filter")) codegen(self.fn) self.reconstruct_items(codegen) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 2c92599a8b28..1430dc912cb1 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -57,6 +57,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -81,7 +82,7 @@ def __init__(self, typevar, objvar=None, **kwargs) -> None: # cls for a classmethod) self.objvar = objvar - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) codegen(self.typevar) if self.objvar is not None: @@ -331,7 +332,7 @@ def __init__(self, exc_type, args, **kwargs) -> None: def set_context(self, context: "ExceptionVariable"): self.__context__ = context - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.load_import_from("builtins", self.exc_type.__name__) ) @@ -460,7 +461,7 @@ class ComptimeVariable(VariableTracker): Dynamo compile time """ - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): raise NotImplementedError("comptime is special form") def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": @@ -944,7 +945,7 @@ def const_getattr(self, tx: "InstructionTranslator", name): raise NotImplementedError return inspect.getattr_static(step2, name) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.obj) codegen.extend_output(codegen.create_load_attrs(self.name)) @@ -1161,7 +1162,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): def as_python_constant(self): return self.value - def reconstruct(self, codegen: "torch._dynamo.codegen.PyCodegen") -> None: + def reconstruct(self, codegen: "PyCodegen") -> None: # We're just trying to load the type here. Reconstructing the type from # scratch is tricky - for a type like `typing.List[int]` we'd need to # deconstruct the origin and args. The origin for `List[int]` is `list` @@ -1336,7 +1337,7 @@ def __init__(self, **kwargs) -> None: def __repr__(self) -> str: return "NullVariable" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if sys.version_info < (3, 11): unimplemented("cannot reconstruct NullVariable in < Python 3.11") codegen.append_output(create_instruction("PUSH_NULL")) @@ -1377,7 +1378,7 @@ def __init__(self, format_string, sym_args, sym_kwargs, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1426,7 +1427,7 @@ def call_function(self, tx: "InstructionTranslator", args, kwargs): tx.debug_locals.append((self, list(args))) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): return self.source.reconstruct(codegen) @staticmethod @@ -1721,7 +1722,7 @@ def call_random_meth(*args, **kwargs): return call_random_fn(tx, call_random_meth, args, kwargs) return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null( lambda: codegen.extend_output( [ @@ -1762,7 +1763,7 @@ def call_function( ) -> "VariableTracker": return self.referent_vt - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen.add_push_null(lambda: codegen.load_import_from("weakref", "ref")) codegen(self.referent_vt) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 51c1ea6bf141..6edd4a7c8ea4 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() @@ -36,7 +37,7 @@ def __init__(self, proxy, param_vars, **kwargs) -> None: self.param_vars = param_vars super().__init__(**kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): assert self.source is None assert self.param_vars is not None codegen.add_push_null( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index c477979fa9e3..ef6a69ceee7c 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -80,6 +80,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -1558,7 +1559,7 @@ def call_method( return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("untyped_storage") codegen.call_method(0) @@ -1573,7 +1574,7 @@ def __init__( super().__init__(**kwargs) self.from_tensor = from_tensor - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): codegen(self.from_tensor) codegen.load_method("data_ptr") codegen.call_method(0) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 8034f440e775..429b3b572774 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -223,7 +223,7 @@ def __init__(self, value, **kwargs) -> None: super().__init__(**kwargs) self.value = value - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): try: name = f"{self.value.__module__}.{self.value.__name__}" except Exception: diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 330faf9bf902..982a65117717 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -382,7 +383,7 @@ def __init__(self, value, source=None, **kwargs): self.cm_obj = value # needed for BC with calling enter from CM code self.source = source - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): # This shouldn't be called unless we have a source assert self.source self.source.reconstruct(codegen) @@ -426,7 +427,7 @@ def exit(self, tx: "InstructionTranslator", *args): ) return ConstantVariable.create(None) - def reconstruct_type(self, codegen): + def reconstruct_type(self, codegen: "PyCodegen"): ty = NoEnterTorchFunctionMode codegen( AttrSource( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 2d22e0d35805..fc39d238f309 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -97,6 +97,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -1507,7 +1508,7 @@ def call_method(self, tx: "InstructionTranslator", method_name, args, kwargs): return variables.ConstantVariable.create(None) super().call_method(tx, method_name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen"): if self.idx == self.REMOVED: # Hook has already been removed, return a dummy handle codegen.add_push_null( From 07d439e7829c6a769e88ecb3bafd1ac6b8af4e7a Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 4 Apr 2025 15:48:45 +0000 Subject: [PATCH 197/332] [aoti] Split ConstantType definition out of model.h (#150545) Summary: Splitting the type definition of ConstantType into a separate header because it's needed by Sigmoid OSS but the entire model.h header include cause the following compilation error: ``` 2025-04-01T18:12:42.0391272Z FAILED: caffe2/CMakeFiles/torch_cpu.dir/__/torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp.o 2025-04-01T18:12:42.0417705Z /opt/cache/bin/sccache /opt/cache/bin/clang++ -DAT_PER_OPERATOR_HEADERS -DBUILD_ONEDNN_GRAPH -DCAFFE2_BUILD_MAIN_LIB -DCPUINFO_SUPPORTED_PLATFORM=1 -DFMT_HEADER_ONLY=1 -DFXDIV_USE_INLINE_ASSEMBLY=0 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DIDEEP_USE_MKL -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DNNP_CONVOLUTION_ONLY=0 -DNNP_INFERENCE_ONLY=0 -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_ENABLE_LLVM -DUSE_C10D_GLOO -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_RPC -DUSE_TENSORPIPE -DXNN_LOG_LEVEL=0 -D_FILE_OFFSET_BITS=64 -Dtorch_cpu_EXPORTS -I/var/lib/jenkins/workspace/build/aten/src -I/var/lib/jenkins/workspace/aten/src -I/var/lib/jenkins/workspace/build -I/var/lib/jenkins/workspace -I/var/lib/jenkins/workspace/cmake/../third_party/benchmark/include -I/opt/llvm/include -I/var/lib/jenkins/workspace/third_party/onnx -I/var/lib/jenkins/workspace/build/third_party/onnx -I/var/lib/jenkins/workspace/nlohmann -I/var/lib/jenkins/workspace/torch/csrc/api -I/var/lib/jenkins/workspace/torch/csrc/api/include -I/var/lib/jenkins/workspace/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src/TH -I/var/lib/jenkins/workspace/build/caffe2/aten/src -I/var/lib/jenkins/workspace/build/caffe2/../aten/src -I/var/lib/jenkins/workspace/torch/csrc -I/var/lib/jenkins/workspace/third_party/miniz-3.0.2 -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/include -I/var/lib/jenkins/workspace/third_party/kineto/libkineto/src -I/var/lib/jenkins/workspace/third_party/cpp-httplib -I/var/lib/jenkins/workspace/aten/src/ATen/.. -I/var/lib/jenkins/workspace/third_party/FXdiv/include -I/var/lib/jenkins/workspace/c10/.. -I/var/lib/jenkins/workspace/third_party/pthreadpool/include -I/var/lib/jenkins/workspace/third_party/cpuinfo/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/include -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/src -I/var/lib/jenkins/workspace/aten/src/ATen/native/quantized/cpu/qnnpack/deps/clog/include -I/var/lib/jenkins/workspace/third_party/NNPACK/include -I/var/lib/jenkins/workspace/third_party/fbgemm/include -I/ 2025-04-01T18:12:42.0444143Z In file included from /var/lib/jenkins/workspace/torch/csrc/nativert/kernels/AOTICallDelegateKernel.cpp:5: 2025-04-01T18:12:42.0445081Z In file included from /var/lib/jenkins/workspace/torch/csrc/nativert/executor/AOTIDelegateExecutor.h:6: 2025-04-01T18:12:42.0446002Z In file included from /var/lib/jenkins/workspace/torch/csrc/nativert/executor/AOTInductorModelImpl.h:5: 2025-04-01T18:12:42.0447549Z /var/lib/jenkins/workspace/torch/csrc/inductor/aoti_runtime/model.h:78:13: error: function 'RAII_cpuMalloc' is not needed and will not be emitted [-Werror,-Wunneeded-internal-declaration] 2025-04-01T18:12:42.0448656Z RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { ``` model.h defines RAII_malloc functions directly into anonymous namespace which seems pretty sad. we should do something about it but may not in the current diff. Test Plan: CI Differential Revision: D72320413 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150545 Approved by: https://github.com/desertfire --- .../inductor/aoti_runtime/constant_type.h | 20 +++++++++++++++++++ torch/csrc/inductor/aoti_runtime/model.h | 8 +------- 2 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 torch/csrc/inductor/aoti_runtime/constant_type.h diff --git a/torch/csrc/inductor/aoti_runtime/constant_type.h b/torch/csrc/inductor/aoti_runtime/constant_type.h new file mode 100644 index 000000000000..053eed728fb0 --- /dev/null +++ b/torch/csrc/inductor/aoti_runtime/constant_type.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. + +namespace torch::aot_inductor { + +enum ConstantType : uint8_t { + Unknown = 0, + Parameter = 1, + Buffer = 2, + TensorConstant = 3, + FoldedConstant = 4, +}; + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index 617548a53a3c..d3789def392a 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -20,6 +20,7 @@ #else #include #endif +#include #define AOTI_RUNTIME_CHECK(EXPR, MSG) \ do { \ @@ -89,13 +90,6 @@ RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { } // anonymous namespace namespace torch::aot_inductor { -enum ConstantType : uint8_t { - Unknown = 0, - Parameter = 1, - Buffer = 2, - TensorConstant = 3, - FoldedConstant = 4, -}; using ConstantMap = std::unordered_map; From f443035f10db31f8e6cb2dae03f0d83a5099e317 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 4 Apr 2025 16:05:18 +0000 Subject: [PATCH 198/332] Revert "[cuda] Add new faster gammabeta backward kernel (#148605) (Reapply with launch bounds) (#150625)" This reverts commit c6defa9443d241dd7a0baac4e708b6e906bd012c. Reverted https://github.com/pytorch/pytorch/pull/150625 on behalf of https://github.com/atalman due to failing internal build ([comment](https://github.com/pytorch/pytorch/pull/150625#issuecomment-2779183414)) --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 527 +++++++----------- test/test_nn.py | 20 - 2 files changed, 195 insertions(+), 352 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 3ce2c24c18e6..9feb30c21941 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -508,6 +508,7 @@ __global__ void layer_norm_grad_input_kernel_vectorized( } } + template __global__ void GammaBetaBackwardSimpleCUDAKernel( int64_t M, @@ -539,365 +540,191 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } } -template -__device__ -__forceinline__ -void -blockReduceGammaBetaBackwardsHelper( - int64_t M_start, - int64_t M, - int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db, - T_ACC &dg_sum, - T_ACC &db_sum -) { - constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; - int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; - - int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); - int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; - T_ACC warp_mean = 0, warp_rstd = 0; - if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { - warp_mean = mean[mean_index + lane_id]; - warp_rstd = rstd[mean_index + lane_id]; - } - // We do a WARP_SYNC() here because we use WARP_SHFL below to access - // warp_mean and warp_rstd. - WARP_SYNC(); - - T_ACC dY_regs[rows_per_thread_y] = {0}; - T_ACC X_regs[rows_per_thread_y] = {0}; - #pragma unroll - for (int i = 0; i < rows_per_thread_y; ++i) { - int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i; - bool active = true; - if (check_x && thread_x >= N) { - active = false; - } - if (check_y && current_y >= M) { - active = false; - } - if (active) { - dY_regs[i] = dY[current_y * N + thread_x]; - X_regs[i] = X[current_y * N + thread_x]; - } - } - - #pragma unroll - for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); - T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); - dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; - db_sum += dY_regs[i]; - } -} - -template -__device__ -__forceinline__ -void -blockReduceGammaBetaBackwardsWithChecks( - int64_t M, - int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db, - T_ACC &dg_sum, - T_ACC &db_sum -) { - for (int64_t M_start = blockIdx.y * rows_per_block_y; - M_start < M; - M_start += rows_per_block_y * gridDim.y) { - int64_t M_end = M_start + rows_per_block_y - 1; - if (!check_y || M_end < M) { - blockReduceGammaBetaBackwardsHelper - (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - blockReduceGammaBetaBackwardsHelper - (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } - } -} +// This implementation gets called if M and N divide with 32. This case should +// be the most common. We can then make better use of warp level intrinsics +// to improve performance. -// block_dim_x is the number of threads in the x dimension per block. -// block_dim_y is the number of threads in the y dimension per block. -// rows_per_block_y is the size of the tile (number of data elements) -// in the y dimension per block. -// partial_reduction indicates whether we need to reduce across threads -// or not. If set to true, we will not reduce across threads. This can -// be faster in the M >> N case but requires another kernel to do a full -// final reduction. -// aligned_grid means the data size is a multiple of tile size. In that -// case we don't need to check for boundary conditions which can provide -// a further speedup by not needing instructions to check for edge cases -// and not needing predicate registers. -template -__global__ -void -__launch_bounds__(block_dim_x * block_dim_y) - GammaBetaBackwardCUDAKernelTemplate( +template +__global__ void GammaBetaBackwardCUDAKernel_32x32( int64_t M, int64_t N, - const T* __restrict__ dY, - const T* __restrict__ X, - const T_ACC* __restrict__ mean, - const T_ACC* __restrict__ rstd, - T* __restrict__ dg, - T* __restrict__ db) { - // This assert is a compile-time check only. - constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; - static_assert(rows_per_thread_y <= kWarpSize); + const T* dY, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + T* dg, + T* db) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; T_ACC dg_sum = 0; T_ACC db_sum = 0; - if (aligned_grid) { - // When N and M align perfectly with block_dim_x and block_dim_y, we - // can skip boundary condition checks that waste instruction issue slots. - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - // In the general case we need to check boundary conditions in the M - // dimension. However, we can still avoid boundary checks in the N dimension - // for the inner blocks. So try to avoid those checks when possible. - if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } else { - blockReduceGammaBetaBackwardsWithChecks - - (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); - } - } + const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x; + if (j < N) { + constexpr int unroll_factor = 8; + int laneId = threadIdx.x & (C10_WARP_SIZE - 1); + + T_ACC mean_reg, mean_reg_tmp; + T_ACC rstd_reg, rstd_reg_tmp; + T dY_reg; + T X_reg; + + // Main loop + int bcounter; + for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); + bcounter++) { + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + + if (laneId < unroll_factor) { + mean_reg_tmp = mean[offset + laneId]; + rstd_reg_tmp = rstd[offset + laneId]; + } + WARP_SYNC(); - // When partial_reduction is requested, we don't reduce within a block. - // We also don't reduce if we are only a single block in the y dimension. - if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { - if (aligned_grid || thread_x < N) { - int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y; - if (dg) { - dg[thread_y * N + thread_x] = dg_sum; + #pragma unroll + for (int ii = 0; ii < unroll_factor; ++ii) { + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); + rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; } - if (db) { - db[thread_y * N + thread_x] = db_sum; + } + + // Remainder loop + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (int ii = 0; ii < unroll_factor; ii++) { + if ((offset + ii) < M) { + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; } } - } else { - // The caller requested a full reduction so we must reduce across - // warps using shared memory and warp shuffles. - static_assert(rows_per_thread_y <= C10_WARP_SIZE); - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - int padded_bx = (block_dim_x + 1); - // Transpose dg and db. + + // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and + // gets called when M; N divide by 32. We can use warp shuffles + // for the final reduction step. This removes 4 shmem loads and + // stores with their corresponding __syncthreads() + + // This greatly reduces bank conflicts at the expense of a little + // extra shared memory. It does not impact occupancy + int padded_bx = (1 + blockDim.x); + s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * block_dim_y); + s_db = s_data_typed + (padded_bx * blockDim.y); s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; __syncthreads(); // Load transposed so that a warp holds an entire column - // Because block_dim_x != block_dim_y in the general case, we need - // some code to handle the general case. - static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); - constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; - int thread_id = threadIdx.y * block_dim_x + threadIdx.x; - int warp_id = thread_id / C10_WARP_SIZE; - int lane_id = thread_id & (C10_WARP_SIZE - 1); - #pragma unroll - for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { - T_ACC reg_db, reg_dg; - if (lane_id < block_dim_y) { - reg_dg = s_dg[lane_id * padded_bx + i]; - reg_db = s_db[lane_id * padded_bx + i]; - } - #pragma unroll - for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; + T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; + for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + + if (threadIdx.x == 0) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; + if (dg) { + dg[j] = reg_dg; } - // Reduce is done. Now write it out to global memory. - int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; - if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { - if (dg) { - dg[out_index] = reg_dg; - } - if (db) { - db[out_index] = reg_db; - } + if (db) { + db[j] = reg_db; } } } } -template -void LaunchAndCheckGammaBetaBackwardKernel( - bool aligned_grid, - dim3 blocks, - dim3 threads, - size_t shmem_sz, - cudaStream_t cuda_stream, - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, - int64_t M, - int64_t N, - T* dgamma_data, - T* dbeta_data) { -if (aligned_grid) { - GammaBetaBackwardCUDAKernelTemplate - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - } else { - GammaBetaBackwardCUDAKernelTemplate - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - } - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void ConfigureAndLaunchGammaBetaBackwardKernel( - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, +template +__global__ void GammaBetaBackwardCUDAKernel( int64_t M, int64_t N, - Tensor* dgamma, - Tensor* dbeta, - cudaStream_t cuda_stream) { - T* dgamma_data = - dgamma->defined() ? dgamma->template data_ptr() : nullptr; - T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; - bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); - dim3 threads{block_dim_x, block_dim_y}; - dim3 blocks; - blocks.x = (N + block_dim_x - 1) / block_dim_x; - blocks.y = 1; - size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; - if (blocks.y == 1 && threads.y == 1) { - // Optimization: since there is just one thread doing all the summation, we don't need a reduction - // across threads. So we set partial_reduction to true. - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); - } else { - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); - } + const T* dY, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + T* dg, + T* db) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; -} + const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; -template -void LaunchGammaBetaBackwardCUDAKernel( - const T* dY_data, - const T* X_data, - const T_ACC* mean_data, - const T_ACC* rstd_data, - int64_t M, - int64_t N, - Tensor* dgamma, - Tensor* dbeta, - cudaStream_t cuda_stream) { - constexpr int block_dim_x = 32; - const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { - // We have a situation where M >> N and N is small. - // In this case we can speed up the computation by parallelizing in the M dimension. - // We launch multiple blocks in the y-dimension, and compute partial sums for the - // gradient in the first pass. Then we do a .sum(0) to do a final reduction. - // Although we launch 2 kernels, we can get up to a 10x speedup for large M. - constexpr int block_dim_y = 1; - constexpr int rows_per_block_y = 32; - bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); - dim3 threads{block_dim_x, block_dim_y}; - dim3 blocks; - blocks.x = (N + block_dim_x - 1) / block_dim_x; - // int rows_per_block = my_gamma_beta_unroll_factor * - blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; - constexpr int max_grid_size = 64 * 1024 / 2; - blocks.y = std::min(max_grid_size / blocks.x, blocks.y); - Tensor dgamma_blocks; - Tensor dbeta_blocks; - T * dgamma_blocks_ptr = nullptr; - T * dbeta_blocks_ptr = nullptr; - if (dgamma->defined()) { - auto options = dgamma->options(); - dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); - dgamma_blocks_ptr = dgamma_blocks.data_ptr(); + T_ACC dg_sum = 0; + T_ACC db_sum = 0; + + if (j < N) { + constexpr int unroll_factor = 8; + + T_ACC mean_reg; + T_ACC rstd_reg; + T dY_reg; + T X_reg; + + // Main Loop + int bcounter; + for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + + #pragma unroll + for (int ii = 0; ii < unroll_factor; ++ii) { + dY_reg = dY[(offset + ii) * N + j]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } } - if (dbeta->defined()) { - auto options = dbeta->options(); - dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); - dbeta_blocks_ptr = dbeta_blocks.data_ptr(); + + // Remainder loop + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + for (int ii = 0; ii < unroll_factor; ii++ ){ + if ((offset + ii) < M) { + dY_reg = dY[(offset + ii) * N + j ]; + X_reg = X[(offset + ii) * N + j]; + mean_reg = mean[offset + ii]; + rstd_reg = rstd[offset + ii]; + dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; + db_sum += dY_reg; + } } - LaunchAndCheckGammaBetaBackwardKernel( - aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - *dgamma = dgamma_blocks.sum(0); - *dbeta = dbeta_blocks.sum(0); - } else { - // We are in the normal case where M is not that large. - // We can change the tile shape (which is the last template parameter) in accordance with M. - // For small M it is faster to have a smaller tile, otherwise we could have idle threads. - // For larger M we use a bigger tile size. - if (M < 64) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else if (M < 128) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else if (M < 256) { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); - } else { - ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + // Do the final reduction in shared memory + s_dg = s_data_typed; + s_db = s_data_typed + blockDim.x * blockDim.y; + s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; + s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum; + __syncthreads(); + + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + if (threadIdx.y < offset) { + s_dg[threadIdx.y * blockDim.x + threadIdx.x] += + s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + s_db[threadIdx.y * blockDim.x + threadIdx.x] += + s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + } + __syncthreads(); + } + + if (threadIdx.y == 0) { + if (dg) { + dg[j] = s_dg[threadIdx.x]; + } + if (db) { + db[j] = s_db[threadIdx.x]; + } } } } @@ -1423,7 +1250,6 @@ void LayerNormBackwardKernelImplInternal( dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; -#if defined(USE_ROCM) if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; @@ -1439,6 +1265,7 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { +#if defined(USE_ROCM) // For small batch size, do colwise reduce directly. const int part_size = warp_size; const dim3 threads2(warp_size, 4, 1); @@ -1473,11 +1300,47 @@ void LayerNormBackwardKernelImplInternal( dgamma_data, dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); - } #else - LaunchGammaBetaBackwardCUDAKernel( - dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { + // This implementation relies on warp primitives and requires that M and N divide + // exactly to warp size. + dim3 threads{kWarpSize, kWarpSize}; + int blocks = (N + threads.x - 1) / threads.x; + + // If M and N divide by warp_size, we can use warp shuffles for the final reduction. + // That requires transposing values in shared memory, so we apply a padding to + // reduce bank conflicts. + + size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; + GammaBetaBackwardCUDAKernel_32x32 + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + dim3 threads{16, 32}; + int blocks = (N + threads.x - 1) / threads.x; + size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; + GammaBetaBackwardCUDAKernel + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } #endif + } } } diff --git a/test/test_nn.py b/test/test_nn.py index 72c440ca5ec5..30fe71b4162e 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7195,26 +7195,6 @@ def test_layer_norm_eps(self): ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) - @unittest.skipIf(not TEST_CUDA, "CUDA not available") - def test_layer_norm_backwards_eps(self): - dtype = torch.float - m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), - (32, 32), (1024, 32), (1024, 1024), - (33, 33), (1025, 33), (1025, 1025)] - for m, n in m_x_n_list: - x = torch.randn((m, n), dtype=dtype, requires_grad=True) - grad_output = torch.rand_like(x) - x_cuda = x.clone().detach().to("cuda").requires_grad_() - grad_output_cuda = grad_output.clone().detach().to("cuda") - ln = nn.LayerNorm(n, dtype=dtype) - ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) - ln_out = ln(x) - ln_out_cuda = ln_cuda(x_cuda) - ln_out.backward(grad_output) - ln_out_cuda.backward(grad_output_cuda) - self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) - self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) - @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291 From c93e34d7b5690ee77cd29dc7b28a8aa7f61d58aa Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 4 Apr 2025 16:26:00 +0000 Subject: [PATCH 199/332] Revert "bound sympy accuracy (#150383)" This reverts commit 1bc2b2b12ae1ddd27b0401a1baac3b8099b6fc50. Reverted https://github.com/pytorch/pytorch/pull/150383 on behalf of https://github.com/laithsakka due to big regression ([comment](https://github.com/pytorch/pytorch/pull/150383#issuecomment-2779227548)) --- test/export/test_export.py | 22 ---------------------- torch/utils/_sympy/value_ranges.py | 17 ----------------- 2 files changed, 39 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 5eefb67c14b6..988e2fae81c6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3105,28 +3105,6 @@ def forward(self, x, y): "dy - 6 = 6" not in exc.args[0] ) # don't suggest fix for non-root dim - @testing.expectedFailureLegacyExportNonStrict # FIXME constraint violation (guard: s0 - s0%8 != 1) - @testing.expectedFailureCppSerDes # FIXME data-dependent error (hinted: True, unhinted: s0 - s0%8 >= 0) - def test_bound_sympy_accuracy(self): - class Foo(torch.nn.Module): - def forward(self, x): - expr = x.shape[0] - (x.shape[0] % 8) - return torch.empty(expr) - - ep = export( - Foo(), - (torch.randn(13),), - dynamic_shapes={"x": (Dim("dim", min=2),)}, - ) - - (output,) = ep.graph.output_node().args[0] - sym_node = output.meta["val"].shape[0].node - vr = torch.utils._sympy.value_ranges.bound_sympy( - sym_node.expr, - sym_node.shape_env.var_to_range, - ) - self.assertEqual(vr.lower, 0) - @unittest.skip("See https://github.com/pytorch/pytorch/issues/135759") def test_keep_composite_ops_invalid(self): class Foo(torch.nn.Module): diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 118959b8c4db..784f9e7ba051 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -1004,22 +1004,6 @@ def trunc(x): return ValueRanges.increasing_map(x, TruncToFloat) -def _rewrite_for_value_range_analysis(expr: sympy.Expr): - """ - Sometimes accuracy of value range analysis can be improved - with simple rewriting rules. - """ - - # Rewrite X - X%Y to (X//Y) * Y. - x, y = sympy.Wild("x"), sympy.Wild("y") - expr = expr.replace( - x - torch.utils._sympy.functions.Mod(x, y), - torch.utils._sympy.functions.FloorDiv(x, y) * y, - ) - - return expr - - def bound_sympy( expr: sympy.Expr, ranges: Optional[dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: @@ -1063,7 +1047,6 @@ def missing_handler(s): vr = ValueRanges.unknown() return vr - expr = _rewrite_for_value_range_analysis(expr) return sympy_interp( SymPyValueRangeAnalysis, ranges, expr, missing_handler=missing_handler ) From c53bc616d59cf654a0b331417ebe7617139d7625 Mon Sep 17 00:00:00 2001 From: Eric Griffith Date: Fri, 4 Apr 2025 17:14:43 +0000 Subject: [PATCH 200/332] caffe2: Fix lint errors in native/xnnpack/Linear.cpp (#150508) Summary: See title Test Plan: Sandcastle Differential Revision: D72275403 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150508 Approved by: https://github.com/malfet, https://github.com/Skylion007, https://github.com/cyyever --- aten/src/ATen/native/xnnpack/Linear.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp index 4d98cd753159..8d50aa66b4d9 100644 --- a/aten/src/ATen/native/xnnpack/Linear.cpp +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -129,6 +129,7 @@ Tensor run( const IntArrayRef input_size = padded_input.sizes(); std::vector output_size(input_size.cbegin(), input_size.cend()); + // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) output_size.back() = context.output_channels; Tensor output = mobile::empty_with_tail_padding( From 861d2cc02cce860d789cfda644a366abb95b53a5 Mon Sep 17 00:00:00 2001 From: Ankita George Date: Fri, 4 Apr 2025 17:52:53 +0000 Subject: [PATCH 201/332] Add a param for save format in Storage Writer (#150025) Summary: add a param to specify to the storage writer how to save tensors. Write now the only options are safetensors and torch.save. Test Plan: (lintrunner) [ankitageorge@devgpu003.cco3 /data/users/ankitageorge/fbsource/fbcode/caffe2 (1d57cb27b)]$ buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/distributed/checkpoint:test_hf_storage File changed: fbcode//caffe2/torch/distributed/checkpoint/filesystem.py Buck UI: https://www.internalfb.com/buck2/e80cc963-e34a-4876-b6f4-7ce2794e48dd Test UI: https://www.internalfb.com/intern/testinfra/testrun/3659174965882569 Network: Up: 32KiB Down: 1.9KiB (reSessionID-ef9fa764-a40a-451b-ab58-08eabe7a9422) Executing actions. Remaining 0/4 3.4s exec time total Command: test. Finished 2 local Time elapsed: 19.6s Tests finished: Pass 4. Fail 0. Fatal 0. Skip 0. Build failure 0 Reviewed By: saumishr Differential Revision: D70271943 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150025 Approved by: https://github.com/saumishr --- .../checkpoint/_fsspec_filesystem.py | 3 ++ torch/distributed/checkpoint/_hf_storage.py | 9 ++++- torch/distributed/checkpoint/filesystem.py | 39 +++++++++++++------ 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index b7b71bdf4b2b..8363fcf207a3 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -15,6 +15,7 @@ FileSystemBase, FileSystemReader, FileSystemWriter, + SerializationFormat, ) @@ -115,6 +116,7 @@ def __init__( per_thread_copy_ahead: int = 10_000_000, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, **kwargs, ) -> None: """ @@ -139,6 +141,7 @@ def __init__( per_thread_copy_ahead, overwrite=overwrite, _extensions=_extensions, + serialization_format=serialization_format, ) self.fs = FileSystem() self.path = self.fs.init_path(path, **kwargs) diff --git a/torch/distributed/checkpoint/_hf_storage.py b/torch/distributed/checkpoint/_hf_storage.py index 7b8f2d656e01..6927aed7e570 100644 --- a/torch/distributed/checkpoint/_hf_storage.py +++ b/torch/distributed/checkpoint/_hf_storage.py @@ -11,6 +11,7 @@ _FqnToFileMapping, _HuggingFaceLoadPlanner, ) +from torch.distributed.checkpoint.filesystem import SerializationFormat from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, Metadata, @@ -64,7 +65,11 @@ def __init__( if HfFileSystem.protocol not in fsspec.available_protocols(): fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) - super().__init__(path=path, token=token) + super().__init__( + path=path, + token=token, + serialization_format=SerializationFormat.SAFETENSORS, + ) self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping def prepare_local_plan(self, plan: SavePlan) -> SavePlan: @@ -99,7 +104,7 @@ def write_data( (self.fs.concat_path(self.path, file_name), file_name, write_items) ) - return super()._write_data(planner, file_queue, safe_tensors=True) + return super()._write_data(planner, file_queue) def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: metadata_to_write = {} diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 89b82e7bc127..76954da21eb0 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -13,6 +13,7 @@ from collections.abc import Generator, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass +from enum import Enum from io import UnsupportedOperation from pathlib import Path from typing import Any, Callable, cast, IO, Optional, Union @@ -49,7 +50,13 @@ from torch.futures import Future -__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystem", "FileSystemBase"] +__all__ = [ + "FileSystemWriter", + "FileSystemReader", + "FileSystem", + "FileSystemBase", + "SerializationFormat", +] _metadata_fn: str = ".metadata" @@ -72,6 +79,11 @@ class _StoragePrefix: prefix: str +class SerializationFormat(Enum): + TORCH_SAVE = "torch_save" + SAFETENSORS = "safetensors" + + DEFAULT_SUFFIX = ".distcp" @@ -298,7 +310,7 @@ def _write_item( data: Union[io.BytesIO, torch.Tensor], write_item: WriteItem, storage_key: str, - safe_tensors: bool = False, + serialization_format: SerializationFormat, ) -> WriteResult: offset = stream.tell() @@ -312,12 +324,14 @@ def _write_item( else: assert isinstance(data, torch.Tensor) assert data.device == torch.device("cpu") - if not safe_tensors: + if serialization_format == SerializationFormat.TORCH_SAVE: torch.save(data, transform_to) transform_to.close() - if not safe_tensors or isinstance(data, io.BytesIO): + if serialization_format == SerializationFormat.TORCH_SAVE or isinstance( + data, io.BytesIO + ): length = stream.tell() - offset else: length = data.numel() * data.element_size() @@ -349,7 +363,7 @@ def _write_files_from_queue( inflight_threshhold: int, use_fsync: bool, thread_count: int, - safe_tensors: bool, + serialization_format: SerializationFormat, ) -> None: try: while True: @@ -397,7 +411,7 @@ def _write_files_from_queue( data, write_item, storage_key, - safe_tensors, + serialization_format, ) ) @@ -411,12 +425,12 @@ def _write_files_from_queue( tensor, write_item, storage_key, - safe_tensors, + serialization_format, ) ) tensor_dict[write_item.index.fqn] = tensor - if safe_tensors: + if serialization_format == SerializationFormat.SAFETENSORS: from safetensors.torch import save # type: ignore[import-not-found] stream.write(save(tensor_dict)) @@ -549,6 +563,7 @@ def __init__( per_thread_copy_ahead: int = 10_000_000, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, *args: Any, **kwargs: Any, ) -> None: @@ -576,6 +591,7 @@ def __init__( self.save_id = _generate_uuid() self.overwrite = overwrite self.transforms = _StorageWriterTransforms(_extensions) + self.serialization_format = serialization_format def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: if checkpoint_id: @@ -638,7 +654,6 @@ def _write_data( self, planner: SavePlanner, file_queue: queue.Queue, - safe_tensors: bool = False, ) -> Future[list[WriteResult]]: result_queue: queue.Queue = queue.Queue() @@ -655,7 +670,7 @@ def _write_data( self.per_thread_copy_ahead, self.sync_files, self.thread_count, - safe_tensors, + self.serialization_format, ), ) t.start() @@ -670,7 +685,7 @@ def _write_data( inflight_threshhold=self.per_thread_copy_ahead, use_fsync=self.sync_files, thread_count=self.thread_count, - safe_tensors=safe_tensors, + serialization_format=self.serialization_format, ) for t in threads: @@ -892,6 +907,7 @@ def __init__( cache_staged_state_dict: bool = False, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, + serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, ) -> None: """ Initialize the writer pointing to `path`. @@ -919,6 +935,7 @@ def __init__( per_thread_copy_ahead=per_thread_copy_ahead, overwrite=overwrite, _extensions=_extensions, + serialization_format=serialization_format, ) BlockingAsyncStager.__init__( self, From 2a2ddff214c4228eb684d0f6edc8c5034bddbfe5 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 4 Apr 2025 22:49:22 +0000 Subject: [PATCH 202/332] [Inductor] Fix consolidating _scaled_mm into mm template TMA error (#150686) Summary: The previous diff broke a few tests that didn't run on internal or GH CI: T220169086, this fixes that issue. The {% if } block is only supposed to support autotuned parameters (constexpr), and should not be used for locals based on other examples. Test Plan: buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:fp8 -- --exact 'caffe2/test/inductor:fp8 - test_tensorwise_scaling_bfloat16_shape_16,32,32_has_bias_False_use_fast_accum_True_persistent_matmul_True (caffe2.test.inductor.test_fp8.TestFP8Lowering)' Reviewed By: NikhilAPatel Differential Revision: D72460516 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150686 Approved by: https://github.com/eellison, https://github.com/NikhilAPatel --- torch/_inductor/kernel/mm.py | 71 ++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 3a7d87fc8596..604f4523793a 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -312,18 +312,18 @@ allow_tf32=ALLOW_TF32, ) - {% if ki == k_tiles - 1 %} - # rematerialize rm and rn to save registers - rcm = rm + tl.arange(0, BLOCK_M) - rcn = rn + tl.arange(0, BLOCK_N) - idx_m = rcm[:, None] - idx_n = rcn[None, :] - mask = (idx_m < M) & (idx_n < N) - - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) - {% endif %} + if ki == k_tiles - 1: + # rematerialize rm and rn to save registers + rcm = rm + tl.arange(0, BLOCK_M) + rcn = rn + tl.arange(0, BLOCK_N) + idx_m = rcm[:, None] + idx_n = rcn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}} + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + """, ) @@ -467,31 +467,30 @@ def apply_scaling( else: accumulator += tl.dot(a, b.T) - {% if ki == k_tiles - 1 %} - # Apply inverse scaling - offs_cm = offs_am + tl.arange(0, BLOCK_M) - offs_cn = offs_bn + tl.arange(0, BLOCK_N) - # Apply scaling - accumulator = apply_scaling( - accumulator, - a_scale, - b_scale, - SCALING_ROWWISE, - offs_cm, - offs_cn, - M, - N, - stride_a_scale_m, - stride_b_scale_n, - ) + if ki == k_tiles - 1: + # Apply inverse scaling + offs_cm = offs_am + tl.arange(0, BLOCK_M) + offs_cn = offs_bn + tl.arange(0, BLOCK_N) + # Apply scaling + accumulator = apply_scaling( + accumulator, + a_scale, + b_scale, + SCALING_ROWWISE, + offs_cm, + offs_cn, + M, + N, + stride_a_scale_m, + stride_b_scale_n, + ) - idx_m = offs_cm[:, None] - idx_n = offs_cn[None, :] - mask = (idx_m < M) & (idx_n < N) - # inductor generates a suffix - {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - {% endif %} + idx_m = offs_cm[:, None] + idx_n = offs_cn[None, :] + mask = (idx_m < M) & (idx_n < N) + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "accumulator", "mask", indent_width=12)}} + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) """ From 2e23768d2521acec9314d07a4b418089848a0f26 Mon Sep 17 00:00:00 2001 From: Stepan Hruda Date: Fri, 4 Apr 2025 23:03:16 +0000 Subject: [PATCH 203/332] Expose symbols on macos in the xplat pytorch stack (#150487) Summary: X-link: https://github.com/pytorch/executorch/pull/9819 Had to revert D71321310 because it affected way too many targets and build sizes. These changes should expose just enough symbols to be buildable in arvr mode on macOS. Could potentially make narrow it down even more by avoiding eg `get_pt_compiler_flags` Differential Revision: D72255474 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150487 Approved by: https://github.com/drisspg --- buckbuild.bzl | 17 +++++++++++++++-- third_party/xnnpack.buck.bzl | 4 ++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/buckbuild.bzl b/buckbuild.bzl index 29addd3bf724..b208a4d25c18 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -194,6 +194,9 @@ def get_pt_compiler_flags(): return select({ "DEFAULT": _PT_COMPILER_FLAGS, "ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS), + }) + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], }) _PT_COMPILER_FLAGS = [ @@ -228,6 +231,9 @@ ATEN_COMPILER_FLAGS = [ # Not supported by clang on Windows "DEFAULT": ["-fPIC"], "ovr_config//compiler:clang-windows": [], +}) + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], }) def get_aten_compiler_flags(): @@ -982,6 +988,10 @@ def define_buck_targets( fb_xplat_cxx_library( name = "torch_mobile_headers", header_namespace = "", + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), exported_headers = subdir_glob( [ ("", "torch/csrc/jit/mobile/*.h"), @@ -1185,7 +1195,10 @@ def define_buck_targets( srcs = [ "torch/csrc/jit/mobile/observer.cpp", ] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]), - compiler_flags = ["-fexceptions"], + compiler_flags = ["-fexceptions"] + select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), header_namespace = "", exported_headers = subdir_glob( [ @@ -2035,7 +2048,7 @@ def define_buck_targets( "ovr_config//os:xtensa-xos": [ "-fdata-sections", "-ffunction-sections", - ], + ] }), exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DMIN_EDGE_RUNTIME", diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b20a7be4ed1a..231384bd859a 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -2249,6 +2249,10 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_deps = [ ":subgraph", ], + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//os:macos": ["-fvisibility=default"], + }), platforms = (APPLE, ANDROID, CXX, WINDOWS), preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS + [ "-DXNN_NO_Q8_OPERATORS", From d6887f444fa61b46c9c31114028484e658f3dc99 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 3 Apr 2025 20:40:04 -0700 Subject: [PATCH 204/332] [Inductor] Fallback embedding when sparse is True (#150659) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/150656, fallback `embedding` when sparse is True. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_embedding_sparse ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150659 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 13 +++++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 5 +++++ 3 files changed, 19 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 30aafa062068..54524be5314c 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5360,6 +5360,19 @@ def test_embedding(self): (torch.randint(10, [2, 8]),), ) + def test_embedding_sparse(self): + # Fix https://github.com/pytorch/pytorch/issues/150656 + def fn(weight, indices): + return F.embedding(indices, weight, sparse=True) + + indices = torch.randint(10, (2, 3)) + weight = torch.randn(10, 3, requires_grad=True) + + self.common( + fn, + (weight, indices), + ) + def test_mean(self): def fn(x): return ( diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index c090b7b7846f..29d74152bf4e 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -137,6 +137,7 @@ def run(*ex, **kwargs): "test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)), "test_flip_cat_dynamic_shapes": TestFailure(("cpu",)), "test_pad_single_dynamic_shapes": TestFailure(("cpu",)), + "test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), # # Failed to find for loop/triton kernel: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 7fcf79041851..24520887f6aa 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3376,6 +3376,11 @@ def fn(idx): @register_lowering(aten.embedding, type_promotion_kind=None) def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + if sparse: + return fallback_handler(aten.embedding.default)( + weight, indices, padding_idx, scale_grad_by_freq, sparse + ) + assert not sparse assert isinstance(weight, TensorBox) assert isinstance(indices, TensorBox) From 2e4ae2ab41dbe1939bd1ffb427af8e5ea8eaff41 Mon Sep 17 00:00:00 2001 From: eellison Date: Fri, 4 Apr 2025 13:35:25 -0700 Subject: [PATCH 205/332] Fix conv2d strided prologue (#150697) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150697 Approved by: https://github.com/drisspg --- test/inductor/test_max_autotune.py | 38 +++++++++++++++++++++++++++++ torch/_inductor/select_algorithm.py | 11 ++++----- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 499dcbf4ae47..65721ba67a5e 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1380,6 +1380,44 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) + @parametrize("prologue", (False, True)) + def test_conv1x1_cast(self, prologue): + with torch._inductor.config.patch(prologue_fusion=prologue): + conv1x1 = ( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) + .to(memory_format=torch.channels_last) + .to(GPU_TYPE) + .to(dtype=torch.float16) + ) + input_tensor = ( + torch.randn(4, 3, 32, 32) + .contiguous(memory_format=torch.channels_last) + .to(GPU_TYPE) + ) + + def foo(mod, input): + return torch.nn.functional.conv2d( + input, + mod.weight.to(input.dtype), + None, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + ) + + with torch.no_grad(): + out_eager = foo(conv1x1, input_tensor) + foo_c = torch.compile(foo) + out, code = run_and_get_code(foo_c, conv1x1, input_tensor) + + FileCheck().check_not("extern_kernels.convolution").run(code[0]) + if prologue: + self.check_code( + code[0], num_kernels=1, num_allocs=1, num_deallocs=2 + ) + self.assertEqual(out_eager, out, atol=1e-2, rtol=0) + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index fed0f9ebebd7..1fbb9aff8580 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -746,11 +746,10 @@ def load_input( indices, self.range_trees[0].construct_entries(lengths) ): range_tree_entry.set_name(name) - contiguous_index = sympy_dot( - ir.FlexibleLayout.contiguous_strides(lengths), index_symbols - ) - contiguous_index = self.rename_indexing(contiguous_index) - self.body.writeline("xindex = " + texpr(contiguous_index)) + + strided_index = sympy_dot(input_node.get_stride(), index_symbols) + strided_index = self.rename_indexing(strided_index) + self.body.writeline("xindex = " + texpr(strided_index)) xindex_range_root = self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) @@ -823,7 +822,7 @@ def store( output_index = self.rename_indexing(output_index) - if output_index == contiguous_index: + if output_index == strided_index: output_index_str = "xindex" else: out_indexing = self.indexing( From 3320efef6b7f4d80564dcbec900aca3aadb6e564 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Fri, 4 Apr 2025 14:19:16 -0700 Subject: [PATCH 206/332] Refresh expected results. (#150264) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150264 Approved by: https://github.com/bobrenjc93 --- .../pr_time_benchmarks/expected_results.csv | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 934e10e5c275..af033eacff97 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,32 +1,32 @@ -add_loop_eager,compile_time_instruction_count,2926000000,0.015 +add_loop_eager,compile_time_instruction_count,2944000000,0.015 -add_loop_eager_dynamic,compile_time_instruction_count,5637000000,0.025 +add_loop_eager_dynamic,compile_time_instruction_count,5633000000,0.025 -add_loop_inductor,compile_time_instruction_count,28680000000,0.015 +add_loop_inductor,compile_time_instruction_count,28810000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42170000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,24980000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,25120000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,969300000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,963100000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17840000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17990000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,15990000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16130000000,0.015 @@ -34,44 +34,44 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,97140000 -update_hint_regression,compile_time_instruction_count,1593000000,0.02 +update_hint_regression,compile_time_instruction_count,1608000000,0.02 -float_args,compile_time_instruction_count,416400000,0.015 +float_args,compile_time_instruction_count,417400000,0.015 -sum_floordiv_regression,compile_time_instruction_count,989900000,0.015 +sum_floordiv_regression,compile_time_instruction_count,985300000,0.015 -symint_sum,compile_time_instruction_count,3164000000,0.015 +symint_sum,compile_time_instruction_count,3189000000,0.015 -symint_sum_loop,compile_time_instruction_count,4142000000,0.015 +symint_sum_loop,compile_time_instruction_count,4180000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2034000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2042000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5880000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5884000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8419000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8501000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1838000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1856000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3742000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3751000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10190000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10200000000,0.015 From c14977e91c9fe33517fec167dd1592c9fa209579 Mon Sep 17 00:00:00 2001 From: Jithun Nair Date: Sat, 5 Apr 2025 02:09:11 +0000 Subject: [PATCH 207/332] Use 'rocm' naming for rocm-related workflows/jobs (#150555) Reduces number of places in the workflow files needing update for ROCm version update Pull Request resolved: https://github.com/pytorch/pytorch/pull/150555 Approved by: https://github.com/jeffdaily --- .../inductor-perf-test-nightly-rocm.yml | 18 +++++++++--------- .github/workflows/inductor-periodic.yml | 18 +++++++++--------- .github/workflows/inductor-rocm-mi300.yml | 18 +++++++++--------- .github/workflows/inductor-rocm.yml | 18 +++++++++--------- .github/workflows/periodic.yml | 18 +++++++++--------- .github/workflows/pull.yml | 6 +++--- .github/workflows/rocm-mi300.yml | 18 +++++++++--------- .github/workflows/rocm.yml | 18 +++++++++--------- .github/workflows/slow.yml | 18 +++++++++--------- .github/workflows/trunk.yml | 18 +++++++++--------- 10 files changed, 84 insertions(+), 84 deletions(-) diff --git a/.github/workflows/inductor-perf-test-nightly-rocm.yml b/.github/workflows/inductor-perf-test-nightly-rocm.yml index df84b158f1c0..f1ff593161db 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm.yml @@ -78,12 +78,12 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-benchmark-build: + linux-focal-rocm-py3_10-inductor-benchmark-build: if: github.repository_owner == 'pytorch' - name: rocm6_3-py3_10-inductor-benchmark-build + name: rocm-py3_10-inductor-benchmark-build uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -102,18 +102,18 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-benchmark-test: + linux-focal-rocm-py3_10-inductor-benchmark-test: permissions: id-token: write contents: read - name: rocm6_3-py3_10-inductor-benchmark-test + name: rocm-py3_10-inductor-benchmark-test uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-benchmark-build + needs: linux-focal-rocm-py3_10-inductor-benchmark-build with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-cudagraphs_low_precision-true - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-benchmark-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-benchmark-build.outputs.test-matrix }} + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-benchmark-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-benchmark-build.outputs.test-matrix }} timeout-minutes: 720 # Disable monitor in perf tests for more investigation disable-monitor: true diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 3d7a1c7da941..6d08179df512 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -67,12 +67,12 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda12_6-py3_10-gcc9-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build: + linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build: if: github.repository_owner == 'pytorch' - name: rocm6_3-py3_10-periodic-dynamo-benchmarks + name: rocm-py3_10-periodic-dynamo-benchmarks uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6_3-py3_10 + build-environment: linux-focal-rocm-py3_10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -95,17 +95,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-test: + linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-test: permissions: id-token: write contents: read - name: rocm6_3-py3_10-periodic-dynamo-benchmarks + name: rocm-py3_10-periodic-dynamo-benchmarks uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build + needs: linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build with: - build-environment: linux-focal-rocm6_3-py3_10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-periodic-dynamo-benchmarks-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3_10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-periodic-dynamo-benchmarks-build.outputs.test-matrix }} secrets: inherit linux-focal-cuda12_6-py3_10-gcc9-inductor-build-gcp: diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index bddb625cbc90..753c30e6427a 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -36,13 +36,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-build: - name: rocm6.3-py3.10-inductor + linux-focal-rocm-py3_10-inductor-build: + name: rocm-py3.10-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -51,15 +51,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-test: + linux-focal-rocm-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.3-py3.10-inductor + name: rocm-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-build + needs: linux-focal-rocm-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index bcbbe0dd85bc..0d21b4570c2e 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -29,13 +29,13 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-inductor-build: - name: rocm6.3-py3.10-inductor + linux-focal-rocm-py3_10-inductor-build: + name: rocm-py3.10-inductor uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -44,15 +44,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-inductor-test: + linux-focal-rocm-py3_10-inductor-test: permissions: id-token: write contents: read - name: rocm6.3-py3.10-inductor + name: rocm-py3.10-inductor uses: ./.github/workflows/_rocm-test.yml - needs: linux-focal-rocm6_3-py3_10-inductor-build + needs: linux-focal-rocm-py3_10-inductor-build with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-inductor-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-inductor-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 8aadcd548b7e..686f5b83a92e 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -140,13 +140,13 @@ jobs: test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-debug-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-build: - name: linux-focal-rocm6.3-py3.10 + linux-focal-rocm-py3_10-build: + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -156,19 +156,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit linux-focal-cuda12_6-py3-gcc11-slow-gradcheck-build: diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index bb967e2f3e82..e4ee18664ba5 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -411,15 +411,15 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: # don't run build twice on main if: github.event_name == 'pull_request' - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index f1a16ddea234..cce7ff72cdc4 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -36,14 +36,14 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -57,17 +57,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 6ff8667a9d94..063daaf4fe67 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -26,12 +26,12 @@ jobs: id-token: write contents: read - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -45,17 +45,17 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 1d1b8d5eb567..0a8cf3721e70 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -103,13 +103,13 @@ jobs: test-matrix: ${{ needs.linux-focal-py3_9-clang10-build.outputs.test-matrix }} secrets: inherit - linux-focal-rocm6_3-py3_10-build: - name: linux-focal-rocm6.3-py3.10 + linux-focal-rocm-py3_10-build: + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 test-matrix: | { include: [ @@ -118,19 +118,19 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} secrets: inherit linux-jammy-py3_10-clang15-asan-build: diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index ec98f4faf3c6..15d2b53ed81b 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -165,14 +165,14 @@ jobs: runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" secrets: inherit - linux-focal-rocm6_3-py3_10-build: + linux-focal-rocm-py3_10-build: if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-focal-rocm6.3-py3.10 + build-environment: linux-focal-rocm-py3.10 docker-image-name: pytorch-linux-focal-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -183,20 +183,20 @@ jobs: ]} secrets: inherit - linux-focal-rocm6_3-py3_10-test: + linux-focal-rocm-py3_10-test: if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/trunk') }} permissions: id-token: write contents: read - name: linux-focal-rocm6.3-py3.10 + name: linux-focal-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: - - linux-focal-rocm6_3-py3_10-build + - linux-focal-rocm-py3_10-build - target-determination with: - build-environment: linux-focal-rocm6.3-py3.10 - docker-image: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-rocm6_3-py3_10-build.outputs.test-matrix }} + build-environment: linux-focal-rocm-py3.10 + docker-image: ${{ needs.linux-focal-rocm-py3_10-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-rocm-py3_10-build.outputs.test-matrix }} tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" secrets: inherit From 7ac81868513a212af6be4a05e2f921cafeeb3069 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Fri, 4 Apr 2025 17:20:18 -0700 Subject: [PATCH 208/332] [MPSInductor] Speedup `sum`/`prod` reductions (#150566) By using cooperative `simd_sum`/`simd_product` instead of a C-style for loop for threadgroup reductions. This also allows significantly reduce amount of shared memory needed to perform those reductions Using such reduction increases the `torch.compile` performance for gpt-fast using `stories110M` from 29 tokens/sec to 630 tokens/sec on M4 and changes perf of torch.rand as follows: |size| before | after | |------------------------|------------|-------------| | 512x512 | 202.1 | 131.8 | | 1024x1024 | 780.6 | 176.9 | | 2048x2048 | 1423.4 | 339.9 | | 4096x4097 | 2982.2 | 1047.2 | Unfortunately, none of the SIMDgroup operations are available for 64-bit integers, but one can simulate the behavior using using `simd_shuffle_down` of 64-bit values represented as `int2` types, that yields reduction in $log_2(threadgroup\\_size)$ steps. [`mlx/kernels/reduction/ops.h](https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L15-L18) contains an implementation of such algorithm, but alas it yields wrong results on M1/M2(and may be M3 machines) if not all threads in the simdgroup are active which could be observed by running ```python import torch lib=torch.mps.compile_shader(""" kernel void do_sum(device int* out, constant int* in, uint idx [[thread_position_in_grid]]) { out[idx] = metal::simd_shuffle_down(in[idx], 8); } """) x=torch.arange(22, device='mps', dtype=torch.int32) y=torch.empty_like(x) lib.do_sum(y, x) print(y) ``` that returns following on M4 ``` tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 0, 0, 0, 0, 0, 0, 0, 0], device='mps:0', dtype=torch.int32) ``` but same kernel running on M1 returns ``` tensor([ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 14, 15, 16, 17, 18, 19, 20, 21], device='mps:0', dtype=torch.int32) ``` This discrepancy in behavior can be addressed by using `simd_shuffle_and_fill_down`, but any kernels using simd_shuffle_and_fill_down cause an internal compiler error on MacOS-13.2. Considering that OS is to be EOL soon, skip the offending tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150566 Approved by: https://github.com/manuelcandales ghstack dependencies: #150452, #150457 --- c10/metal/reduction_utils.h | 89 ++++++++++++++++++++++++----- test/inductor/test_torchinductor.py | 4 ++ test/test_mps.py | 22 +++++++ torch/_inductor/codegen/mps.py | 49 ++++++++-------- 4 files changed, 127 insertions(+), 37 deletions(-) diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 5445d53039b1..b6f7f6bc83ee 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -6,27 +6,88 @@ namespace c10 { namespace metal { +constant constexpr ushort simdgroup_size = 32; + template -opmath_t threadgroup_sum(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee - ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - opmath_t rc = data[0]; - // TODO: Use `simd_shuffle_down` - for (unsigned idx = 1; idx < size; ++idx) { - rc += data[idx]; +inline ::metal::enable_if_t, T> simd_sum(T val) { + return ::metal::simd_sum(val); +} + +template +inline ::metal::enable_if_t, T> simd_prod(T val) { + return ::metal::simd_product(val); +} + +// Metal does not support SIMD reductions over 64-bit types, but it could be +// implement using simd_shuffle_down, that yields result in log2(simdgroup_size) +// iterations Use fill variant, as shuffle down returns garbage if inactive +// thread is referenced (on M1/M2, works fine on M4) and broadcast result to all +// threads in the end. Implementation heavily borrows from +// https://github.com/ml-explore/mlx/blob/86389bf9707f46101af45d90510e8e97c8a90b93/mlx/backend/metal/kernels/reduction/ops.h#L16 +template +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_sum(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val += as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); } - return rc; + return as_type(::metal::simd_broadcast(as_type(val), 0)); } template -opmath_t threadgroup_prod(threadgroup T* data, unsigned size) { - // TODO: This should be moved to the callee +inline ::metal::enable_if_t<::metal::is_same_v, T> simd_prod(T val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val *= as_type( + ::metal::simd_shuffle_and_fill_down(as_type(val), int2(0), i)); + } + return as_type(::metal::simd_broadcast(as_type(val), 0)); +} + +// Below algorithms are written with hardcoded assumption that simdgroup is 32 +// and threadgroup_max is 1024, i.e. reduction can be done in two stages max +template +opmath_t threadgroup_sum( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_sum(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; + } + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_sum(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); - opmath_t rc = data[0]; - for (unsigned idx = 1; idx < size; ++idx) { - rc *= data[idx]; + return data[0]; +} + +template +opmath_t threadgroup_prod( + threadgroup opmath_t* data, + T val, + unsigned idx, + unsigned size) { + auto rc = simd_prod(static_cast>(val)); + if (idx % simdgroup_size == 0) { + data[idx / simdgroup_size] = rc; } - return rc; + if (size > simdgroup_size) { + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + if (idx < ((size + simdgroup_size - 1) / simdgroup_size)) { + auto rc1 = simd_prod(data[idx]); + if (idx == 0) { + data[0] = rc1; + } + } + } + ::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup); + return data[0]; } template diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 54524be5314c..fa83302732f1 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1996,6 +1996,8 @@ def fn(a): return torch.max(a), torch.sum(a) # Requires masked loading for the intermediate reduction + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") sample = torch.full((3999971,), 0, dtype=torch.int64) sample[-1] = 1 self.common(fn, (sample,)) @@ -2492,6 +2494,8 @@ def fn(x): dtypes = torch.bool, torch.uint8, torch.int inps = [torch.randint(2, (64,), dtype=dtype) for dtype in dtypes] + if self.device == "mps" and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Fails with internal compiler error on MacOS-13") for i in inps: self.common(fn, (i,), check_lowp=False) diff --git a/test/test_mps.py b/test/test_mps.py index e27a78785be7..a950ae28b767 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -12937,6 +12937,27 @@ def test_metal_include(self): lib = torch.mps.compile_shader("#include ") self.assertIsNotNone(lib) + @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.int64]) + def test_reduction_utils(self, dtype): + if dtype == torch.int64 and MACOS_VERSION < 13.3: + raise unittest.SkipTest("Using simd_shuffle_down_and_fill results in ICE on MacOS-13") + from torch._inductor.codegen.mps import DTYPE_TO_METAL + lib = torch.mps.compile_shader(f""" + #include + kernel void do_sum(device {DTYPE_TO_METAL[dtype]}* out, + constant {DTYPE_TO_METAL[dtype]}* inp, + uint idx [[thread_position_in_grid]]) {{ + out[idx] = c10::metal::simd_sum(inp[idx]); + }} + """) + x = torch.testing.make_tensor(28, device="mps", dtype=dtype) + y = torch.empty_like(x) + lib.do_sum(y, x) + x_sum = x.sum() + max_err = (y - x_sum).abs().max().item() + self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5, + f"results are {y}, but all elements should have been {x_sum.item()}") + @unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again") def test_metal_capture(self): lib = torch.mps.compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }") @@ -12968,6 +12989,7 @@ def test_metal_capture(self): instantiate_parametrized_tests(TestMPS) instantiate_parametrized_tests(TestSDPA) instantiate_parametrized_tests(TestSmoothL1Loss) +instantiate_parametrized_tests(TestMetalLibrary) if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 1aae913fa0cf..470e7f049cf2 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -13,7 +13,7 @@ from torch.utils._sympy.printers import ExprPrinter as ExprPrinter_ from torch.utils._sympy.value_ranges import ValueRanges -from ..utils import get_bounds_index_expr, get_kernel_metadata +from ..utils import ceildiv, get_bounds_index_expr, get_kernel_metadata from ..virtualized import ops, OpsWrapper, V from .common import ( CSEVariable, @@ -462,6 +462,7 @@ class MetalKernel(SIMDKernel): suffix = ";" newvar_prefix = "auto " max_threadgroup_size = 1024 + simd_group_size = 32 pexpr = PythonPrinter().doprint sexpr = MetalExprPrinter().doprint kexpr = sexpr @@ -507,22 +508,23 @@ def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable) -> N line = f"if ({reduction_dim.name} == 0) {line}" self.stores.writeline(DeferredLine(name, line)) - def _new_accvar( + def _new_idxvar( self, dtype: torch.dtype, elem_count: Optional[int] = None, + default_value: Optional[Any] = None, + is_threadgroup: bool = True, bounds: ValueRanges[Any] = ValueRanges.unknown(), ) -> CSEVariable: var_name = f"tmp_acc_{next(self.acc_var_ids)}" var = V.kernel.create_cse_var(var_name, bounds, dtype) + var_def = "threadgroup " if is_threadgroup else "" + var_def += f"{self.dtype_to_str(dtype)} {var_name}" if elem_count: - self.indexing_code.writeline( - f"threadgroup {self.dtype_to_str(dtype)} {var_name}[{elem_count}];" - ) - else: - self.indexing_code.writeline( - f"threadgroup {self.dtype_to_str(dtype)} {var_name};" - ) + var_def += f"[{elem_count}]" + if default_value is not None: + var_def += f" = {default_value}" + self.indexing_code.writeline(var_def + self.suffix) return var def reduction( @@ -536,7 +538,7 @@ def reduction( reduction_dim = next(t for t in self.range_trees if t.is_reduction) acc_buf_size = min(reduction_dim.numel, self.max_threadgroup_size) if reduction_type == "any": - acc = self._new_accvar(dtype) + acc = self._new_idxvar(dtype) self.indexing_code.writeline(f"{acc} = false;") self.indexing_code.writeline( "threadgroup_barrier(metal::mem_flags::mem_threadgroup);" @@ -553,26 +555,27 @@ def reduction( ) return acc if reduction_type in ["prod", "sum"]: - acc_buf = self._new_accvar(src_dtype, acc_buf_size) - if self.multistage_reduction: + acc_dtype = DTYPE_TO_COMPUTATION_DTYPE[src_dtype] + acc_buf = self._new_idxvar( + acc_dtype, ceildiv(acc_buf_size, self.simd_group_size) + ) + if not self.multistage_reduction: + val = value + else: default_val, reduction_op = ( (0, "+") if reduction_type == "sum" else (1, "*") ) - self.indexing_code.writeline( - f"{acc_buf}[{reduction_dim.name}] = {default_val};" - ) - self.compute.splice( - f"{acc_buf}[{reduction_dim.name}] {reduction_op}= {value};" + val = self._new_idxvar( + acc_dtype, default_value=default_val, is_threadgroup=False ) - else: - self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + self.compute.splice(f"{val} {reduction_op}= {value};") return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_dim.name}, {acc_buf_size})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["max", "min", "argmin", "argmax"]: - acc_buf = self._new_accvar(src_dtype, acc_buf_size) + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) acc_thread_var = f"{acc_buf}[{reduction_dim.name}]" src_metal_type = DTYPE_TO_METAL[src_dtype] if not self.multistage_reduction: @@ -592,7 +595,7 @@ def reduction( idx_var = next( t for t in self.range_tree_nodes.values() if t.is_reduction ) - idx_acc_buf = self._new_accvar(torch.long, acc_buf_size) + idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) cmp_op = ">" if reduction_type == "argmax" else "<" idx_thread_var = f"{idx_acc_buf}[{reduction_dim.name}]" self.indexing_code.splice(f"{idx_thread_var} = -1;") @@ -619,7 +622,7 @@ def reduction( assert not self.multistage_reduction, ( f"Multistage reduction not yet supported for {reduction_type}" ) - acc_buf = self._new_accvar(src_dtype, acc_buf_size) + acc_buf = self._new_idxvar(src_dtype, acc_buf_size) self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") wf_res = self.cse.generate( self.compute, From 60a45eb862d5e8b4ba2dd435d34ef04ae231e885 Mon Sep 17 00:00:00 2001 From: Mu-Chu Lee Date: Fri, 4 Apr 2025 15:39:41 -0700 Subject: [PATCH 209/332] [AOTInductor] Introduce MaybeOwningAtenTensorHandle for ConstantMap (#150275) Summary: We used RAIIAtenTensorHandle for ConstantMap, where RAIIAtenTensorHandle is a unique_ptr, indicating that all memory handling is by the AOTInductor internally. In this PR, we introduce ConstantAtenTensorHandle which replaces RAIIATenTensorHandle. This class holds a raw AtenTensorHandle, and also owns a RAIIAtenTensorHandle if user decides to delegate memory management to AOTInductor. This is a prerequisite for user managed buffer, this PR, however only introduces this class and make sure it works with existing AOTInductor and has the default behavior identical as using RAIIAtenTensorHandle. Test Plan: Existing tests. No change should be introduced within this PR. Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/150275 Approved by: https://github.com/chenyang78, https://github.com/desertfire --- torch/csrc/inductor/aoti_runtime/model.h | 3 +- .../inductor/aoti_runtime/model_container.h | 6 +- torch/csrc/inductor/aoti_runtime/utils.h | 116 ++++++++++++++++++ 3 files changed, 122 insertions(+), 3 deletions(-) diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index d3789def392a..83dd1c4e7437 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -91,7 +91,8 @@ RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { namespace torch::aot_inductor { -using ConstantMap = std::unordered_map; +using ConstantMap = + std::unordered_map; // valid device strs are: cpu, cuda, cuda:0, cuda:1, ... // Update the list here if more devices are supported in the future diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 42f6157f5eef..408a9274417c 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -349,7 +349,8 @@ class AOTInductorModelContainer { tensor = it->second; } - constants_map_to_update->insert_or_assign(constant_name, tensor); + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor)); } // Update the inactive constant array. update_array_from_map( @@ -437,7 +438,8 @@ class AOTInductorModelContainer { // Now place the tensor to constants_map. Note at this point the ownership // of the tensor_handle will be taken over. - constants_map_to_update->insert_or_assign(constant_name, tensor_handle); + constants_map_to_update->insert_or_assign( + constant_name, RAIIAtenTensorHandle(tensor_handle)); } // Update the inactive constant array. update_array_from_map( diff --git a/torch/csrc/inductor/aoti_runtime/utils.h b/torch/csrc/inductor/aoti_runtime/utils.h index 2f23826be77f..9e2f5c160f73 100644 --- a/torch/csrc/inductor/aoti_runtime/utils.h +++ b/torch/csrc/inductor/aoti_runtime/utils.h @@ -135,6 +135,122 @@ class RAIIAtenTensorHandle { std::unique_ptr handle_; }; +class MaybeOwningAtenTensorHandle { + public: + MaybeOwningAtenTensorHandle() : handle_(nullptr), raii_handle_() {} + // We skip copy constructor as MaybeOwningAtenTensorHandle might be RAII which + // makes it undefined. + MaybeOwningAtenTensorHandle(const MaybeOwningAtenTensorHandle& other) = + delete; + MaybeOwningAtenTensorHandle& operator=( + const MaybeOwningAtenTensorHandle& other) = delete; + + // Move constructor and move assignment operator + MaybeOwningAtenTensorHandle(MaybeOwningAtenTensorHandle&& other) = default; + MaybeOwningAtenTensorHandle& operator=(MaybeOwningAtenTensorHandle&& other) = + default; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + MaybeOwningAtenTensorHandle(RAIIAtenTensorHandle&& other) + : raii_handle_(std::move(other)) { + handle_ = raii_handle_.get(); + } + MaybeOwningAtenTensorHandle& operator=(RAIIAtenTensorHandle&& other) { + raii_handle_ = std::move(other); + handle_ = raii_handle_.get(); + return *this; + } + + // By default, steal the ownership from raw AtenTensorHandle + MaybeOwningAtenTensorHandle(AtenTensorHandle handle) : raii_handle_(handle) { + handle_ = raii_handle_.get(); + } + + // If user_managed is true, we do not steal the ownership. + MaybeOwningAtenTensorHandle(AtenTensorHandle handle, bool user_managed) { + if (user_managed) { + handle_ = handle; + } else { + raii_handle_ = RAIIAtenTensorHandle(handle); + handle_ = raii_handle_.get(); + } + } + + ~MaybeOwningAtenTensorHandle() { + // This is no-op if we don't hold raii_handle with the + // MaybeOwningAtenTensorHandle. + raii_handle_.reset(); + } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { + return handle_; + } + + AtenTensorHandle release() { + if (raii_handle_) { + return raii_handle_.release(); + } else { + AtenTensorHandle handle = handle_; + handle_ = nullptr; + return handle; + } + } + + AtenTensorHandle get() const { + return handle_; + } + + void reset() { + handle_ = nullptr; + raii_handle_.reset(); + } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_, d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_stride(handle_, d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_, &storage_offset)); + return storage_offset; + } + + void* data_ptr() const { + void* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &result)); + return result; + } + + int64_t* sizes() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t* strides() const { + int64_t* result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + // handle_ is the underlying AtenTensorHandle of raii_handle_ if raii_handle_ + // exists. Otherwise it would just be the AtenTensorHandle passed in by users. + AtenTensorHandle handle_; + RAIIAtenTensorHandle raii_handle_; +}; + // Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle inline std::vector steal_from_raw_handles_to_raii_handles( AtenTensorHandle* handles, From cfea55dbecf93a88a40290a69c5e3b324dcec69c Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sat, 5 Apr 2025 21:49:21 +0000 Subject: [PATCH 210/332] [MPS] fix inverse bug for N>1024 (#146754) Fixes #138200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146754 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .../src/ATen/native/mps/operations/Inverse.mm | 61 ------------------- .../native/mps/operations/LinearAlgebra.mm | 44 +++++++++++-- test/test_mps.py | 5 +- 3 files changed, 41 insertions(+), 69 deletions(-) delete mode 100644 aten/src/ATen/native/mps/operations/Inverse.mm diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm deleted file mode 100644 index 5574df89afe5..000000000000 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ /dev/null @@ -1,61 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - -namespace at::native { - -TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { - TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); - TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - TORCH_WARN_ONCE( - "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); - auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt); - auto cpu_result = result.to("cpu"); - at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); - info.copy_(cpu_info); - result.copy_(cpu_result); - return; - } - - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - MPSStream* stream = getCurrentMPSStream(); - info.zero_(); - - if (A.numel() == 0) { - return; - } - - if (!result.is_contiguous()) { - result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); - } - - @autoreleasepool { - string key = "inv_out_mps" + getTensorsStringKey({A}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A); - MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - auto feeds = dictionaryFromPlaceholders(inputPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 22aee2307f69..1a9e841cfbcf 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -2,6 +2,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include // For MTLLanguageVersion_3_1 @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, } } -static void linalg_solve_out_mps_impl(const at::Tensor& A, - const at::Tensor& B, +static void linalg_solve_out_mps_impl(const Tensor& A, + const Tensor& B, bool left, bool check_errors, - const at::Tensor& result, - const at::Tensor& LU, - const at::Tensor& pivots, - const at::Tensor& info) { + const Tensor& result, + const Tensor& LU, + const Tensor& pivots, + const Tensor& info) { using namespace mps; TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), @@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A, } } +static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + using namespace mps; + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); + TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); + using CachedGraph = MPSUnaryCachedGraph; + + MPSStream* stream = getCurrentMPSStream(); + info.zero_(); + + if (A.numel() == 0) { + return; + } + + if (!result.is_contiguous()) { + result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); + } + auto A_sizes = A.sizes(); + int ndim = A.dim(); + + Tensor LU = empty_like(A); + Tensor identity = zeros_like(A); + Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt)); + (ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1); + linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info); +} + static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); @@ -1427,4 +1455,8 @@ Tensor linalg_solve_triangular_mps(const Tensor& A, const Tensor& B, bool upper, (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); } + +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); +} } // namespace at::native diff --git a/test/test_mps.py b/test/test_mps.py index a950ae28b767..61903bd39005 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7815,18 +7815,19 @@ def helper(shape, diag=0): # Test inverse def test_inverse(self): - def helper(n): + def helper(n, atol=1e-5, rtol=1e-6): cpu_input = torch.randn(n, n, device='cpu') mps_input = cpu_input.to('mps') cpu_result = torch.linalg.inv(cpu_input) mps_result = torch.linalg.inv(mps_input) - self.assertEqual(cpu_result, mps_result) + self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol) helper(2) helper(6) helper(3) helper(8) + helper(1025, atol=1e-4) # Test tril def test_tril(self): From c830c12a87d0e8e5871cd7f2e62e1c8805916879 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 5 Apr 2025 15:50:29 -0700 Subject: [PATCH 211/332] [MPSInductor] Fix tiled reduction logic (#150737) In case of tiles, index must include both reduction dimentions Pull Request resolved: https://github.com/pytorch/pytorch/pull/150737 Approved by: https://github.com/dcci --- test/inductor/test_mps_basic.py | 1 + torch/_inductor/codegen/mps.py | 22 ++++++++++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index 47dab3ad972c..d2b1c5c2bec2 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -231,6 +231,7 @@ def fn(a): "test_sum_keepdims", "test_tanh", "test_vectorized_ops_masked", + "test_var_mean_tile_reduction_True", "test_view_as_complex", "test_view_on_aliased", "test_views3", diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 470e7f049cf2..c83c572e6e94 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -535,8 +535,18 @@ def reduction( value: Union[CSEVariable, tuple[CSEVariable, ...]], ) -> Union[CSEVariable, tuple[CSEVariable, ...]]: """Codegen a reduction operation""" - reduction_dim = next(t for t in self.range_trees if t.is_reduction) - acc_buf_size = min(reduction_dim.numel, self.max_threadgroup_size) + # Establish reduction buffer size and index expression + reduction_idx = "" + acc_buf_size = 1 + for rd in self.range_trees: + if not rd.is_reduction: + continue + if reduction_idx: + reduction_idx += " + " + reduction_idx += f"{rd.name} * {acc_buf_size}" + acc_buf_size *= rd.numel + acc_buf_size = min(acc_buf_size, self.max_threadgroup_size) + if reduction_type == "any": acc = self._new_idxvar(dtype) self.indexing_code.writeline(f"{acc} = false;") @@ -571,12 +581,12 @@ def reduction( self.compute.splice(f"{val} {reduction_op}= {value};") return self.cse.generate( self.stores, - f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_dim.name}, {acc_buf_size})", + f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {val}, {reduction_idx}, {acc_buf_size})", dtype=DTYPE_TO_COMPUTATION_DTYPE[dtype], ) if reduction_type in ["max", "min", "argmin", "argmax"]: acc_buf = self._new_idxvar(src_dtype, acc_buf_size) - acc_thread_var = f"{acc_buf}[{reduction_dim.name}]" + acc_thread_var = f"{acc_buf}[{reduction_idx}]" src_metal_type = DTYPE_TO_METAL[src_dtype] if not self.multistage_reduction: self.compute.splice( @@ -597,7 +607,7 @@ def reduction( ) idx_acc_buf = self._new_idxvar(torch.long, acc_buf_size) cmp_op = ">" if reduction_type == "argmax" else "<" - idx_thread_var = f"{idx_acc_buf}[{reduction_dim.name}]" + idx_thread_var = f"{idx_acc_buf}[{reduction_idx}]" self.indexing_code.splice(f"{idx_thread_var} = -1;") self.compute.splice(f""" if ({value} {cmp_op} {acc_thread_var}) {{ @@ -623,7 +633,7 @@ def reduction( f"Multistage reduction not yet supported for {reduction_type}" ) acc_buf = self._new_idxvar(src_dtype, acc_buf_size) - self.compute.splice(f"{acc_buf}[{reduction_dim.name}] = {value};") + self.compute.splice(f"{acc_buf}[{reduction_idx}] = {value};") wf_res = self.cse.generate( self.compute, f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})", From 83b870a28a3b3d922667c086380d30639bddfc24 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sun, 6 Apr 2025 01:29:59 +0000 Subject: [PATCH 212/332] Fix missing braces for clang CUDA (#150736) Test Plan: Sandcastle Differential Revision: D72469764 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150736 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu index 35d6559b62ce..47a19d26342e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -209,7 +209,7 @@ void spgemm_cutlass( std::is_same_v) { return {ElementComputeEpilogue{alpha.to()}}; } else { - return {alpha.to()}; + return {{alpha.to()}}; } }() }; @@ -219,7 +219,7 @@ void spgemm_cutlass( std::is_same_v) { return {ElementComputeEpilogue{beta.to()}}; } else { - return {beta.to()}; + return {{beta.to()}}; } }() }; @@ -230,7 +230,7 @@ void spgemm_cutlass( ElementC(0), {cute::_1{}, cute::_0{}, problem_size.m()}}; } else { - return {ElementC(0)}; + return {{ElementC(0)}}; } }() }; From 15768cc34b9cdafdac645b9c22806e0bf4e74100 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Sun, 6 Apr 2025 01:44:07 +0000 Subject: [PATCH 213/332] add unit test for preferred_blas_library settings (#150581) Follow up to #150212 that was committed without a unit test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150581 Approved by: https://github.com/atalman, https://github.com/malfet Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- test/test_cuda.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/test/test_cuda.py b/test/test_cuda.py index a3cc62c5e1d4..192b41fed324 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -595,6 +595,65 @@ def test_serialization_array_with_storage(self): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Does not work in fbcode yet") + @setBlasBackendsToDefaultFinally + def test_preferred_blas_library_settings(self): + def _check_default(): + default = torch.backends.cuda.preferred_blas_library() + if torch.version.cuda: + # CUDA logic is easy, it's always cublas + self.assertTrue(default == torch._C._BlasBackend.Cublas) + else: + # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else + gcn_arch = str( + torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] + ) + if gcn_arch in ["gfx90a", "gfx942", "gfx950"]: + self.assertTrue(default == torch._C._BlasBackend.Cublaslt) + else: + self.assertTrue(default == torch._C._BlasBackend.Cublas) + + _check_default() + # "Default" can be set but is immediately reset internally to the actual default value. + self.assertTrue( + torch.backends.cuda.preferred_blas_library("default") + != torch._C._BlasBackend.Default + ) + _check_default() + self.assertTrue( + torch.backends.cuda.preferred_blas_library("cublas") + == torch._C._BlasBackend.Cublas + ) + self.assertTrue( + torch.backends.cuda.preferred_blas_library("hipblas") + == torch._C._BlasBackend.Cublas + ) + # check bad strings + with self.assertRaisesRegex( + RuntimeError, + "Unknown input value. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", + ): + torch.backends.cuda.preferred_blas_library("unknown") + # check bad input type + with self.assertRaisesRegex(RuntimeError, "Unknown input value type."): + torch.backends.cuda.preferred_blas_library(1.0) + # check env var override + custom_envs = [ + {"TORCH_BLAS_PREFER_CUBLASLT": "1"}, + {"TORCH_BLAS_PREFER_HIPBLASLT": "1"}, + ] + test_script = "import torch;print(torch.backends.cuda.preferred_blas_library())" + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("_BlasBackend.Cublaslt", r) + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @setBlasBackendsToDefaultFinally def test_cublas_workspace_explicit_allocation(self): From 2d98a1caf552c468bb72a2b98e53ad423e14a5de Mon Sep 17 00:00:00 2001 From: Klint Qinami Date: Sun, 6 Apr 2025 03:11:11 +0000 Subject: [PATCH 214/332] [MTIA] Map names to operand indices when folding submodules (#150692) When replacing placeholders with getattrs during constant folding, we can have an argument and parameter name mismatch. In fact, there is no guarantee that the parameter name is equivalent to the argument name used in the module call. Differential Revision: D72415970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150692 Approved by: https://github.com/jfix71 --- torch/fx/experimental/const_fold.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 483b7e8b2ea2..525014bf1e80 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -252,13 +252,20 @@ def mod_partition(node: torch.fx.Node): # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) # return add root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + + # The order of placeholders in the const_gm graph should match the order of + # args in the outer module, so we can simply use an index for the + # placeholder mapping + ph_idx = 0 for node in root_const_gm.graph.nodes: if node.op == "output": multiple_outputs = isinstance(node.args[0], tuple) continue if node.op != "placeholder": continue - in_node = next(n for n in call_const_gm_args if n.name == node.target) + assert ph_idx < len(call_const_gm_args) + in_node = call_const_gm_args[ph_idx] + ph_idx += 1 assert in_node.op == "get_attr" with root_const_gm.graph.inserting_before(node): new_node = root_const_gm.graph.get_attr(in_node.target) From caf8d9bc1744eca8f1bba370f267a986f141b60e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sun, 6 Apr 2025 04:50:15 +0000 Subject: [PATCH 215/332] Revert "Fix conv2d strided prologue (#150697)" This reverts commit 2e4ae2ab41dbe1939bd1ffb427af8e5ea8eaff41. Reverted https://github.com/pytorch/pytorch/pull/150697 on behalf of https://github.com/ngimel due to breaks rocm build ([comment](https://github.com/pytorch/pytorch/pull/150697#issuecomment-2781218658)) --- test/inductor/test_max_autotune.py | 38 ----------------------------- torch/_inductor/select_algorithm.py | 11 +++++---- 2 files changed, 6 insertions(+), 43 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 65721ba67a5e..499dcbf4ae47 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1380,44 +1380,6 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) - @parametrize("prologue", (False, True)) - def test_conv1x1_cast(self, prologue): - with torch._inductor.config.patch(prologue_fusion=prologue): - conv1x1 = ( - torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) - .to(memory_format=torch.channels_last) - .to(GPU_TYPE) - .to(dtype=torch.float16) - ) - input_tensor = ( - torch.randn(4, 3, 32, 32) - .contiguous(memory_format=torch.channels_last) - .to(GPU_TYPE) - ) - - def foo(mod, input): - return torch.nn.functional.conv2d( - input, - mod.weight.to(input.dtype), - None, - mod.stride, - mod.padding, - mod.dilation, - mod.groups, - ) - - with torch.no_grad(): - out_eager = foo(conv1x1, input_tensor) - foo_c = torch.compile(foo) - out, code = run_and_get_code(foo_c, conv1x1, input_tensor) - - FileCheck().check_not("extern_kernels.convolution").run(code[0]) - if prologue: - self.check_code( - code[0], num_kernels=1, num_allocs=1, num_deallocs=2 - ) - self.assertEqual(out_eager, out, atol=1e-2, rtol=0) - @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 1fbb9aff8580..fed0f9ebebd7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -746,10 +746,11 @@ def load_input( indices, self.range_trees[0].construct_entries(lengths) ): range_tree_entry.set_name(name) - - strided_index = sympy_dot(input_node.get_stride(), index_symbols) - strided_index = self.rename_indexing(strided_index) - self.body.writeline("xindex = " + texpr(strided_index)) + contiguous_index = sympy_dot( + ir.FlexibleLayout.contiguous_strides(lengths), index_symbols + ) + contiguous_index = self.rename_indexing(contiguous_index) + self.body.writeline("xindex = " + texpr(contiguous_index)) xindex_range_root = self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) @@ -822,7 +823,7 @@ def store( output_index = self.rename_indexing(output_index) - if output_index == strided_index: + if output_index == contiguous_index: output_index_str = "xindex" else: out_indexing = self.indexing( From 55e62ff74ad5614faf80b060c7bfc551e3b7af5a Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Sun, 6 Apr 2025 04:53:24 +0000 Subject: [PATCH 216/332] bf16 grouped gemm (#150374) Enabled bf16 grouped gemm with an API similar to _scaled_group_gemm, except without scale and fast accum arguments. All transpose variants are enabled, unlike scaled gemm. Ideally we'd factor out a lot more code from scaled gemm, currently there's a lot of repetition between scaled and non-scaled versions. I factored out only a helper kernel that prepares arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150374 Approved by: https://github.com/drisspg --- aten/src/ATen/native/cuda/Blas.cpp | 64 ++- aten/src/ATen/native/cuda/GroupMM.cu | 383 ++++++++++++++++++ aten/src/ATen/native/cuda/GroupMM.h | 12 + aten/src/ATen/native/cuda/GroupMMCommon.cuh | 122 ++++++ aten/src/ATen/native/cuda/ScaledGroupMM.cu | 196 +++------ aten/src/ATen/native/native_functions.yaml | 5 + cmake/Codegen.cmake | 6 +- ...asDecompTest.test_has_decomposition.expect | 1 + test/test_matmul_cuda.py | 188 ++++++++- tools/autograd/derivatives.yaml | 4 + torch/csrc/autograd/FunctionsManual.cpp | 56 +++ torch/csrc/autograd/FunctionsManual.h | 16 + 12 files changed, 891 insertions(+), 162 deletions(-) create mode 100644 aten/src/ATen/native/cuda/GroupMM.cu create mode 100644 aten/src/ATen/native/cuda/GroupMM.h create mode 100644 aten/src/ATen/native/cuda/GroupMMCommon.cuh diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index dd04e58c3721..90a1a8ee07f2 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -942,7 +943,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { return _int_mm_out_cuda(self, mat2, result); } -static bool _scaled_mm_allowed_device() { +static bool _scaled_mm_allowed_device(bool sm90_only=false) { #ifdef USE_ROCM static const std::vector archs = { "gfx942", @@ -956,7 +957,11 @@ static bool _scaled_mm_allowed_device() { return at::detail::getCUDAHooks().isGPUArch(archs); #else auto dprops = at::cuda::getCurrentDeviceProperties(); - return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); + if (sm90_only) { + return dprops->major == 9; + } else { + return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); + } #endif } @@ -1401,16 +1406,20 @@ namespace { } } - bool transposed(const Tensor& mat) { + bool check_valid_strides_and_return_transposed(const Tensor& mat) { IntArrayRef tensor_strides = mat.strides(); IntArrayRef tensor_sizes = mat.sizes(); int end_dim = mat.dim() - 1; + int alignment = 16 / mat.element_size(); + TORCH_CHECK(uint64_t(mat.data_ptr()) % 16 ==0, "expected data_ptr to be aligned to 16 bytes\n"); if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max(1, tensor_sizes[end_dim - 1]))) { + TORCH_CHECK(tensor_strides[end_dim] % alignment == 0, "strides should be multiple of 16 bytes"); return true; } else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max(1, tensor_sizes[end_dim]))) { + TORCH_CHECK(tensor_strides[end_dim - 1] % alignment == 0, "strides should be multiple of 16 bytes"); return false; } else { - TORCH_CHECK(false, "Tensor should not be self-overlapping"); + TORCH_CHECK(false, "Tensor should have a contiguous dimension and not be self-overlapping, got ", mat.strides(), " for strides and ", mat.sizes(), " for sizes"); } } @@ -1476,13 +1485,13 @@ const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum) { #ifndef USE_ROCM - bool allowed_device = _scaled_mm_allowed_device(); - TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = 9.0"); TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); - TORCH_CHECK(!transposed(mat_a), "Expected mat1 to not be transposed"); - TORCH_CHECK(transposed(mat_b), "Expected mat2 to be transposed"); + TORCH_CHECK(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); + TORCH_CHECK(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); const bool a_is_2d = mat_a.dim() == 2; @@ -1500,7 +1509,7 @@ bool use_fast_accum) { ")."); - + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); if (offs.has_value()) { @@ -1543,5 +1552,42 @@ bool use_fast_accum) { } +Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, +const std::optional& offs, +const std::optional& bias, +std::optional out_dtype) { +#ifndef USE_ROCM + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true); + TORCH_CHECK(allowed_device, "torch._grouped_mm is only supported on CUDA devices with compute capability = 9.0"); + + TORCH_CHECK(mat_a.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_a.scalar_type()); + TORCH_CHECK(mat_b.dtype() == at::kBFloat16, "Expected mat_a to be BFloat16 matrix got ", mat_b.scalar_type()); + TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); + TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + // check that the strides are valid, the fn will throw an error if not + check_valid_strides_and_return_transposed(mat_a); + check_valid_strides_and_return_transposed(mat_b); + TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix, or no offset if both matrices are 3d"); + + if (offs.has_value()) { + TORCH_CHECK(offs->dim() == 1, "offs has to be 1D"); + TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32"); + } + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high output type is supported for grouped gemm"); + TORCH_CHECK(!bias.has_value(), "Bias not supported yet"); + + const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs); + Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_)); + at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); + return out; +#else + TORCH_CHECK(false, "grouped gemm is not supported on ROCM") +#endif +} + + } // namespace at::native diff --git a/aten/src/ATen/native/cuda/GroupMM.cu b/aten/src/ATen/native/cuda/GroupMM.cu new file mode 100644 index 000000000000..d43875e3c8a6 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMM.cu @@ -0,0 +1,383 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include +#include +#include +#include + + +// Two warninngs in Cutlass included header files +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used") +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") + +// Determine if the architecture supports rowwise scaled mm +// Currently failing on windows with: +// https://github.com/NVIDIA/cutlass/issues/1571 +#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \ + CUDA_VERSION >= 12000 + +#define BUILD_GG_KERNEL +#endif + +#if defined(BUILD_GG_KERNEL) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { +using Strides = at::cuda::detail::Strides; // std::array; + +template +struct Schedule { + using CooperativeSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using CooperativeEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using PongEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using KernelSchedule = + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; +}; + +int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +int round_up_to_nearest_multiple(int a, int b) { + return ceildiv(a, b) * b; +} + +template < + bool a_row_major, + bool b_row_major, + bool Pong, + typename TB_M, + typename TB_N, + typename TB_K> +void bf16bf16_grouped_gemm_impl_sm90( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + using DtypeA = cutlass::bfloat16_t; + using DtypeB = cutlass::bfloat16_t; + using DtypeOutput = cutlass::bfloat16_t; + using DtypeAccum = float; + using LayoutA = cute::conditional_t< + a_row_major, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor>; + constexpr int AlignmentA = 16 / sizeof(DtypeA); + + using LayoutB = cute::conditional_t< + b_row_major, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor>; + constexpr int AlignmentB = 16 / sizeof(DtypeB); + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput); + using ArchTag = cutlass::arch::Sm90; + using OperatorClass = cutlass::arch::OpClassTensorOp; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using KernelSchedule = + typename Schedule::KernelSchedule; + using EpilogueSchedule = + typename Schedule::EpilogueSchedule; + using ProblemShape = cutlass::gemm::GroupProblemShape< + cute::Shape>; // per + // group + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + DtypeAccum, + DtypeAccum, + DtypeOutput, + LayoutOutput*, + AlignmentOutput, + DtypeOutput, + LayoutOutput*, + AlignmentOutput, + EpilogueSchedule, + cutlass::epilogue::fusion:: + LinearCombination>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + DtypeA, + LayoutA*, + AlignmentA, + DtypeB, + LayoutB*, + AlignmentB, + DtypeAccum, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideOutput = typename Gemm::GemmKernel::InternalStrideD; + int32_t M, N, K, group_count; + + M = mat_a.size(-2); + K = mat_a.size(-1); + N = mat_b.size(-1); + + if (mat_a.dim() == 2 && mat_b.dim() == 2) { + // if both inputs are ragged, K is dynamic, M and N come from inputs + group_count = offs->size(0); + K = -1; + } else if (mat_a.dim() == 2) { + group_count = mat_b.size(0); + M = -1; + } else if (mat_b.dim() == 2) { + group_count = mat_a.size(0); + N = -1; + } else { + // regular bmm + group_count = mat_a.size(0); + } + + TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); + const int64_t problem_shape_size = + group_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape)); + + const int64_t stride_size = 3 * group_count * ((int64_t)sizeof(StrideA)); + + // dummy tmas are created based on these pointer-to-pointers + // the actual values are never used, they are replaced + // by real addresses, but for dummy tma creation to succeed + // due to bug in cuda < 12.4 the pointers have to be aligned to 128 bits + const int group_alignment = 16 / sizeof(void*); + const int aligned_group_count = + round_up_to_nearest_multiple(group_count, group_alignment); + int64_t input_args_size = aligned_group_count * 3 * sizeof(void*) + + problem_shape_size + stride_size; + + auto& allocator = *c10::cuda::CUDACachingAllocator::get(); + auto input_buf = allocator.allocate(input_args_size); + void* buf_ptr = input_buf.get(); + DtypeA** inputA_ptrs = reinterpret_cast(buf_ptr); + DtypeB** inputB_ptrs = + reinterpret_cast(inputA_ptrs + aligned_group_count); + DtypeOutput** output_ptrs = + reinterpret_cast(inputB_ptrs + aligned_group_count); + static_assert( + sizeof(StrideA) == 8, "expected StrideA to be 8 bytes for alignment"); + StrideA* stride_A = + reinterpret_cast(output_ptrs + aligned_group_count); + StrideB* stride_B = reinterpret_cast(stride_A + group_count); + StrideOutput* stride_output = + reinterpret_cast(stride_B + group_count); + ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast( + stride_output + group_count); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto make_strides = [](at::IntArrayRef strides) -> Strides { + Strides out; + std::copy(strides.begin(), strides.end(), out.begin()); + return out; + }; + + Strides tensor_StrideA = make_strides(mat_a.strides()); + Strides tensor_StrideB = make_strides(mat_b.strides()); + Strides tensor_StrideOutput = make_strides(out.strides()); + + at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( + reinterpret_cast(mat_a.data_ptr()), + reinterpret_cast(mat_b.data_ptr()), + reinterpret_cast(out.data_ptr()), + static_cast(nullptr), // type for template inference + static_cast(nullptr), // type for template inference + inputA_ptrs, + inputB_ptrs, + output_ptrs, + static_cast(nullptr), // type for template inference + static_cast(nullptr), // type for template inference + problem_sizes, + stride_A, + stride_B, + stride_output, + offs.has_value() ? offs->const_data_ptr() : nullptr, + M, + N, + K, + tensor_StrideA, + tensor_StrideB, + tensor_StrideOutput, + 0, + 0, + a_row_major, + b_row_major); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {group_count, problem_sizes, nullptr}, + {(const DtypeA**)inputA_ptrs, + stride_A, + (const DtypeB**)inputB_ptrs, + stride_B}, + {{}, + (const DtypeOutput**)output_ptrs, + stride_output, + output_ptrs, + stride_output}}; + + arguments.epilogue.thread.alpha = 1.0; + arguments.epilogue.thread.dAlpha = {cute::_0{}, cute::_0{}, 0}; + + int sm_count = + at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = allocator.allocate(workspace_size); + Gemm gemm; + TORCH_CHECK( + gemm.can_implement(arguments) == cutlass::Status::kSuccess, + "cutlass cannot implement"); + TORCH_CHECK( + gemm.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess, + "cutlass cannot initialize"); + auto status = gemm(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK( + status == cutlass::Status::kSuccess, + "cutlass cannot run, error ", + int(status)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_bf16_grouped_kernel_on_tile_size( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + int32_t M, N, K, group_count; + + M = mat_a.size(-2); + K = mat_a.size(-1); + N = mat_b.size(-1); + + // below we assume that gemms are approx same size + if (mat_a.dim() == 2 && mat_b.dim() == 2) { + // if both inputs are ragged, K is dynamic, M and N come from inputs + group_count = offs->size(0); + K = K / group_count; + } else if (mat_a.dim() == 2) { + group_count = mat_b.size(0); + M = M / group_count; + } else if (mat_b.dim() == 2) { + group_count = mat_a.size(0); + N = N / group_count; + } + // bool large = + // ((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) || + // (K >= 2048 && N >= 2048)); + bool small = (M <= 128 || N <= 128); + if (small) { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ true, + cute::_64, + cute::_128, + cute::_128>(mat_a, mat_b, offs, bias, out); + } else { + bf16bf16_grouped_gemm_impl_sm90< + a_row_major, + b_row_major, + /*Pong*/ false, + cute::_128, + cute::_256, + cute::_64>(mat_a, mat_b, offs, bias, out); + } +} + +void dispatch_bf16_grouped_kernel_on_ab_transpose( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { + // we already checked that one of the strides is 1 + bool a_row_major = mat_a.stride(-1) == 1; + bool b_row_major = mat_b.stride(-1) == 1; + if (a_row_major && b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else if (a_row_major && !b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else if (!a_row_major && b_row_major) { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } else { + dispatch_bf16_grouped_kernel_on_tile_size( + mat_a, mat_b, offs, bias, out); + } +} + +} // namespace +#endif + +namespace at::cuda::detail { + +void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out) { +#if defined(BUILD_GG_KERNEL) + dispatch_bf16_grouped_kernel_on_ab_transpose(mat_a, mat_b, offs, bias, out); +#else + TORCH_CHECK(false, "grouped mm is not supported on your system"); +#endif +} + +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/GroupMM.h b/aten/src/ATen/native/cuda/GroupMM.h new file mode 100644 index 000000000000..1fc23207a090 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMM.h @@ -0,0 +1,12 @@ +#pragma once +#include +#include + +namespace at::cuda::detail { +TORCH_API void bf16bf16_grouped_mm( + at::Tensor mat_a, // bf16 + at::Tensor mat_b, // bf16 + std::optional offs, + std::optional bias, // BF16 + at::Tensor& out); +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/GroupMMCommon.cuh b/aten/src/ATen/native/cuda/GroupMMCommon.cuh new file mode 100644 index 000000000000..613e2a6331d1 --- /dev/null +++ b/aten/src/ATen/native/cuda/GroupMMCommon.cuh @@ -0,0 +1,122 @@ +#pragma once +#include + +namespace at::cuda::detail { + +using Strides = std::array; + +template < + typename DtypeA, + typename DtypeB, + typename DtypeOutput, + typename DtypeScale, + typename ProblemShape, + typename StrideA, + typename StrideB, + typename StrideOutput> +__global__ void prepare_grouped_gemm_data( + DtypeA* A, + DtypeB* B, + DtypeOutput* output, + DtypeScale* scale_A, + DtypeScale* scale_B, + DtypeA** A_ptrs, + DtypeB** B_ptrs, + DtypeOutput** output_ptrs, + DtypeScale** inputA_scale_ptrs, + DtypeScale** inputB_scale_ptrs, + ProblemShape* problem_sizes, + // Strides for cutlass, cute::Stride + StrideA* stride_A, + StrideB* stride_B, + StrideOutput* stride_output, + const int32_t* offs, + int32_t M, + int32_t N, + int32_t K, + // Original strides of the input tensors + Strides tensor_StrideA, + Strides tensor_StrideB, + Strides tensor_StrideOutput, + int64_t a_scale_stride, + int64_t b_scale_stride, + bool a_row_major = true, + bool b_row_major = false) { + int32_t tid = threadIdx.x; + int32_t delta = 0; + if (offs != nullptr) { + int32_t start = tid == 0 ? 0 : offs[tid - 1]; + delta = offs[tid] - start; + int align = 16 / sizeof(DtypeA); + CUDA_KERNEL_ASSERT( + delta % align == 0 && + "expected dynamic dimension byte size to be multiple of 16 \n"); + } + int64_t lda, ldb, ldoutput; + if (M < 0) { + // A and output is 2d + M = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + } else if (N < 0) { + N = delta; + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; // B is transposed + ldoutput = tensor_StrideOutput[0]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[1]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; + } + } else if (K < 0) { + // A, B is 2d, output is 3d + K = delta; + lda = a_row_major ? tensor_StrideA[0] : tensor_StrideA[1]; + ldb = b_row_major ? tensor_StrideB[0] : tensor_StrideB[1]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * tensor_StrideA[1]; + B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * M; + inputB_scale_ptrs[tid] = scale_B + tid * N; + } + } else { + // A, B, output are 3D + lda = a_row_major ? tensor_StrideA[1] : tensor_StrideA[2]; + ldb = b_row_major ? tensor_StrideB[1] : tensor_StrideB[2]; + ldoutput = tensor_StrideOutput[1]; + A_ptrs[tid] = A + tid * tensor_StrideA[0]; + B_ptrs[tid] = B + tid * tensor_StrideB[0]; + output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; + if (scale_A != nullptr) { + inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; + inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; + } + } + problem_sizes[tid] = ProblemShape(M, N, K); + + // make_cute_packed_stride only replaces one of the stride elements with + // one the provided values in the shape arguments + // the indices of the src/dst depend on whether A/B are row-major + // so constructing shape argument with two similar lda values + // while it looks non-sensical (and it is a nonsensical shape) + // is fine for these stride construction purposes - the one that will be used + // for replacement is correct, the other one is ignored, and we don't have to + // branch on whether A/B are row-major + stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {lda, lda, 1}); + stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {ldb, ldb, 1}); + stride_output[tid] = + cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); +} +} // namespace at::cuda::detail diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index 7573dd943498..fe6fb2dba0b6 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -22,16 +22,15 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter") #if defined(BUILD_ROWWISE_FP8_KERNEL) +#include #include #include #include #include #include -#include #include #include -#include #include #include @@ -51,101 +50,7 @@ C10_DIAGNOSTIC_POP() namespace { -using Strides = std::array; - -template < - typename DtypeA, - typename DtypeB, - typename DtypeOutput, - typename DtypeScale, - typename ProblemShape, - typename StrideA, - typename StrideB, - typename StrideOutput> -__global__ void prepare_gemm_data( - DtypeA* A, - DtypeB* B, - DtypeOutput* output, - DtypeScale* scale_A, - DtypeScale* scale_B, - DtypeA** A_ptrs, - DtypeB** B_ptrs, - DtypeOutput** output_ptrs, - DtypeScale** inputA_scale_ptrs, - DtypeScale** inputB_scale_ptrs, - ProblemShape* problem_sizes, - // Strides for cutlass, cute::Stride - StrideA* stride_A, - StrideB* stride_B, - StrideOutput* stride_output, - const int32_t* offs, - int32_t M, - int32_t N, - int32_t K, - // Original strides of the input tensors - Strides tensor_StrideA, - Strides tensor_StrideB, - Strides tensor_StrideOutput, - int64_t a_scale_stride, - int64_t b_scale_stride) { - int32_t tid = threadIdx.x; - int32_t delta = 0; - if (offs != nullptr) { - int32_t start = tid == 0 ? 0 : offs[tid - 1]; - delta = offs[tid] - start; - CUDA_KERNEL_ASSERT(delta % 16 == 0 && "expected dynamic dimension to be multiple of 16\n"); - } - int64_t lda, ldb, ldoutput; - if (M < 0) { - // A and output is 2d - M = delta; - lda = tensor_StrideA[0]; - ldb = tensor_StrideB[2]; // B is transposed - ldoutput = tensor_StrideOutput[0]; - A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * lda; - inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1]; - output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput; - B_ptrs[tid] = B + tid * tensor_StrideB[0]; - inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; - } else if (N < 0) { - N = delta; - lda = tensor_StrideA[1]; - ldb = tensor_StrideB[1]; // B is transposed - ldoutput = tensor_StrideOutput[0]; - A_ptrs[tid] = A + tid * tensor_StrideA[0]; - inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; - output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1]; - B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * ldb; - inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1]; - } else if (K < 0) { - // A, B is 2d, output is 3d - K = delta; - lda = tensor_StrideA[0]; - ldb = tensor_StrideB[1]; // B is transposed - ldoutput = tensor_StrideOutput[1]; - A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1]; - B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1]; - inputA_scale_ptrs[tid] = scale_A + tid * M; - inputB_scale_ptrs[tid] = scale_B + tid * N; - output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; - } else { - // A, B, output are 3D - lda = tensor_StrideA[1]; - ldb = tensor_StrideB[2]; - ldoutput = tensor_StrideOutput[1]; - A_ptrs[tid] = A + tid * tensor_StrideA[0]; - B_ptrs[tid] = B + tid * tensor_StrideB[0]; - inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride; - inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride; - output_ptrs[tid] = output + tid * tensor_StrideOutput[0]; - } - problem_sizes[tid] = ProblemShape(M, N, K); - - stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {M, lda, 1}); - stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {N, ldb, 1}); - stride_output[tid] = - cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1}); -} +using Strides = at::cuda::detail::Strides; using DtypeScale = float; using DtypeAccum = float; @@ -205,7 +110,6 @@ struct Schedule { using ClusterShape = cute::Shape; }; - int ceildiv(int a, int b) { return (a + b - 1) / b; } @@ -257,8 +161,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( typename Schedule:: EpilogueSchedule; // TODO remove *BroadcastPtrArrays and replace with just Broadcast - // when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass version - // Implement rowwise scaling epilogue. + // when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass + // version Implement rowwise scaling epilogue. using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcastPtrArray< 0, TileShape, @@ -345,6 +249,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( group_count = mat_a.size(0); } + TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); + const int64_t problem_shape_size = group_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape)); @@ -383,7 +289,6 @@ void f8f8bf16_grouped_gemm_impl_sm90( reinterpret_cast( stride_output + group_count); - TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups"); auto stream = at::cuda::getCurrentCUDAStream().stream(); auto make_strides = [](at::IntArrayRef strides) -> Strides { @@ -400,7 +305,7 @@ void f8f8bf16_grouped_gemm_impl_sm90( int64_t a_scale_stride = scale_a.stride(0); int64_t b_scale_stride = scale_b.stride(0); - prepare_gemm_data<<<1, group_count, 0, stream>>>( + at::cuda::detail::prepare_grouped_gemm_data<<<1, group_count, 0, stream>>>( reinterpret_cast(mat_a.data_ptr()), reinterpret_cast(mat_b.data_ptr()), reinterpret_cast(out.data_ptr()), @@ -427,46 +332,50 @@ void f8f8bf16_grouped_gemm_impl_sm90( C10_CUDA_KERNEL_LAUNCH_CHECK(); -// auto buf_cpu = mat_a.new_empty( -// input_args_size, at::TensorOptions().dtype(at::kByte).device(at::kCPU)); -// AT_CUDA_CHECK(cudaMemcpy( -// (char*)buf_cpu.data_ptr(), -// buf_ptr, -// input_args_size, -// cudaMemcpyDeviceToHost)); -// char* buf_ptr_cpu = (char*)buf_cpu.data_ptr(); -// DtypeA** inputA_ptrs_h = reinterpret_cast(buf_ptr_cpu); -// DtypeB** inputB_ptrs_h = -// reinterpret_cast(inputA_ptrs_h + aligned_group_count); -// DtypeOutput** output_ptrs_h = -// reinterpret_cast(inputB_ptrs_h + aligned_group_count); -// DtypeScale** inputA_scale_ptrs_h = -// reinterpret_cast(output_ptrs_h + aligned_group_count); -// DtypeScale** inputB_scale_ptrs_h = -// reinterpret_cast(inputA_scale_ptrs_h + aligned_group_count); -// StrideA* stride_A_h = -// reinterpret_cast(inputB_scale_ptrs_h + aligned_group_count); -// StrideB* stride_B_h = reinterpret_cast(stride_A_h + group_count); -// StrideOutput* stride_output_h = -// reinterpret_cast(stride_B_h + group_count); -// ProblemShape::UnderlyingProblemShape* problem_sizes_h = -// reinterpret_cast( -// stride_output_h + group_count); - -// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << " " -// << out.data_ptr() << " " << scale_a.data_ptr() << " " -// << scale_b.data_ptr() << "\n"; -// for (int i = 0; i < group_count; i++) { -// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n"; -// std::cout << "B " << (void*)inputB_ptrs_h[i] << "\n"; -// std::cout << "O " << (void*)output_ptrs_h[i] << "\n"; -// std::cout << "A_scale " << (void*)inputA_scale_ptrs_h[i] << "\n"; -// std::cout << "B_scale " << (void*)inputB_scale_ptrs_h[i] << "\n"; -// std::cout << "sizes " << problem_sizes_h[i] << "\n"; -// std::cout << "strideA" << stride_A_h[i] << "\n"; -// std::cout << "strideB" << stride_B_h[i] << "\n"; -// std::cout << "stride_output" << stride_output_h[i] << "\n"; -// } + // auto buf_cpu = mat_a.new_empty( + // input_args_size, + // at::TensorOptions().dtype(at::kByte).device(at::kCPU)); + // AT_CUDA_CHECK(cudaMemcpy( + // (char*)buf_cpu.data_ptr(), + // buf_ptr, + // input_args_size, + // cudaMemcpyDeviceToHost)); + // char* buf_ptr_cpu = (char*)buf_cpu.data_ptr(); + // DtypeA** inputA_ptrs_h = reinterpret_cast(buf_ptr_cpu); + // DtypeB** inputB_ptrs_h = + // reinterpret_cast(inputA_ptrs_h + aligned_group_count); + // DtypeOutput** output_ptrs_h = + // reinterpret_cast(inputB_ptrs_h + aligned_group_count); + // DtypeScale** inputA_scale_ptrs_h = + // reinterpret_cast(output_ptrs_h + aligned_group_count); + // DtypeScale** inputB_scale_ptrs_h = + // reinterpret_cast(inputA_scale_ptrs_h + + // aligned_group_count); + // StrideA* stride_A_h = + // reinterpret_cast(inputB_scale_ptrs_h + + // aligned_group_count); + // StrideB* stride_B_h = reinterpret_cast(stride_A_h + + // group_count); StrideOutput* stride_output_h = + // reinterpret_cast(stride_B_h + group_count); + // ProblemShape::UnderlyingProblemShape* problem_sizes_h = + // reinterpret_cast( + // stride_output_h + group_count); + + // std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << " + // " + // << out.data_ptr() << " " << scale_a.data_ptr() << " " + // << scale_b.data_ptr() << "\n"; + // for (int i = 0; i < group_count; i++) { + // std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n"; + // std::cout << "B " << (void*)inputB_ptrs_h[i] << "\n"; + // std::cout << "O " << (void*)output_ptrs_h[i] << "\n"; + // std::cout << "A_scale " << (void*)inputA_scale_ptrs_h[i] << "\n"; + // std::cout << "B_scale " << (void*)inputB_scale_ptrs_h[i] << "\n"; + // std::cout << "sizes " << problem_sizes_h[i] << "\n"; + // std::cout << "strideA" << stride_A_h[i] << "\n"; + // std::cout << "strideB" << stride_B_h[i] << "\n"; + // std::cout << "stride_output" << stride_output_h[i] << "\n"; + // } // int device_id = 0; // cutlass::KernelHardwareInfo kernel_hw_info = // cutlass::KernelHardwareInfo::make_kernel_hardware_info(device_id); @@ -484,7 +393,8 @@ void f8f8bf16_grouped_gemm_impl_sm90( output_ptrs, stride_output}}; - int sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; + int sm_count = + at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount; if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e3a1cd175c86..b4e000a2a3ca 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7076,6 +7076,11 @@ dispatch: CUDA: _scaled_grouped_mm_cuda +- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + variants: function + dispatch: + CUDA: _grouped_mm_cuda + # NOTE [ Sparse: autograd and API ] # # diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index 5ca808f20c8a..dc6cf8db5c37 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -120,7 +120,11 @@ if(INTERN_BUILD_ATEN_OPS) "89;90a;100a") _BUILD_FOR_ADDITIONAL_ARCHS( "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu" - "89;90a") + "90a") + _BUILD_FOR_ADDITIONAL_ARCHS( + "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu" + "90a") + endif() set(GEN_ROCM_FLAG) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 3faa1186562f..9eb7c572228f 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -383,6 +383,7 @@ aten::_fw_primal_copy aten::_fw_primal_copy.out aten::_grid_sampler_2d_cpu_fallback aten::_grid_sampler_2d_cpu_fallback.out +aten::_grouped_mm aten::_has_same_storage_numel aten::_histogramdd_bin_edges aten::_histogramdd_bin_edges.out diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 49da165ca20e..7de6f3d725ed 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -258,6 +258,175 @@ def _expand_to_batch(t: torch.Tensor): # cross comparison self.assertEqual(out1_gpu, out2_gpu[0]) + def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist): + for a, b, gO, agrad, bgrad, out in zip(alist, blist, gOlist, agradlist, bgradlist, outlist): + a = a.clone().detach().requires_grad_() + b = b.clone().detach().requires_grad_() + out_ref = torch.mm(a, b.t()) + out_ref.backward(gO) + self.assertEqual(out, out_ref) + self.assertEqual(agrad, a.grad) + self.assertEqual(bgrad, b.grad) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 + if a_row_major: + a = torch.randn(m, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + a = torch.randn(k * n_groups + k * int(strided), m, device=device, dtype=dtype).t()[:, :k * n_groups] + + if b_row_major: + b = torch.randn(n, k * n_groups + k * int(strided), device=device, dtype=dtype)[:, :k * n_groups] + else: + b = torch.randn(k * n_groups + k * int(strided), n, device=device, dtype=dtype).t()[:, :k * n_groups] + + a.requires_grad_(True) + b.requires_grad_(True) + offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) + out = torch._grouped_mm(a, b.t(), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, blist, agradlist, bgradlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[:, start:offs_cpu[i]]) + blist.append(b[:, start:offs_cpu[i]]) + agradlist.append(a.grad[:, start:offs_cpu[i]]) + bgradlist.append(b.grad[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, blist, gO, agradlist, bgradlist, out) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_2d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(m * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + a = torch.randn(k, (m + 2 * s_int) * n_groups, device=device, dtype=dtype).t()[:m * n_groups, :] + + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.t() + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) + + out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + alist, agradlist, gOlist, outlist = [], [], [], [] + start = 0 + for i in range(n_groups): + alist.append(a[start:offs_cpu[i]]) + agradlist.append(a.grad[start:offs_cpu[i]]) + outlist.append(out[start:offs_cpu[i]]) + gOlist.append(gO[start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(alist, b, gOlist, agradlist, b.grad, outlist) + + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + b = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), n, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + + out = torch._grouped_mm(a, b.transpose(-2, -1), out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + self.grouped_mm_helper(a, b, gO, a.grad, b.grad, out) + + @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") + @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") + @parametrize("strided", [False, True]) + @parametrize("a_row_major", [False, True]) + @parametrize("b_row_major", [False, True]) + def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major): + device = "cuda" + dtype = torch.bfloat16 + s_int = int(strided) + m, n, k, n_groups = 16, 32, 16, 4 + if a_row_major: + a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, dtype=dtype)[::(1 + s_int), :, :k] + else: + a = torch.randn(n_groups * (1 + s_int), k * (1 + s_int), m, device=device, + dtype=dtype).transpose(-2, -1)[::(1 + s_int), :, :k] + if b_row_major: + b = torch.randn(n * n_groups, k * (1 + s_int), device=device, dtype=dtype)[:, :k] + else: + b = torch.randn(k, n * (n_groups + s_int), device=device, dtype=dtype).transpose(-2, -1)[:n * n_groups, :] + + a.requires_grad_(True) + b.requires_grad_(True) + + a_contig = a if a_row_major else a.transpose(-2, -1) + self.assertTrue(a_contig.is_contiguous() is not strided) + b_contig = b if b_row_major else b.transpose(-2, -1) + self.assertTrue(b_contig.is_contiguous() is not strided) + offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32) + out = torch._grouped_mm(a, b.transpose(-2, -1), offs=offs, + out_dtype=torch.bfloat16) + gO = torch.rand_like(out) + out.backward(gO) + offs_cpu = offs.cpu() + blist, outlist, bgradlist, gOlist = [], [], [], [] + start = 0 + for i in range(n_groups): + blist.append(b[start:offs_cpu[i]]) + bgradlist.append(b.grad[start:offs_cpu[i]]) + outlist.append(out[:, start:offs_cpu[i]]) + gOlist.append(gO[:, start:offs_cpu[i]]) + start = offs_cpu[i] + self.grouped_mm_helper(a, blist, gOlist, a.grad, bgradlist, outlist) + f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" @@ -1258,7 +1427,7 @@ def test_blockwise_mxfp8_nvfp4_error_messages(self, device, recipe) -> None: out_dtype=torch.bfloat16, ) - def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum): + def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum): for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist): out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1), out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum) @@ -1268,7 +1437,7 @@ def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_f @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_2d_2d(self, fast_accum, strided): + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): device = "cuda" m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16 a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups] @@ -1287,14 +1456,14 @@ def test_grouped_gemm_2d_2d(self, fast_accum, strided): ascalelist.append(scale_a[i * m : (i + 1) * m]) bscalelist.append(scale_b[i * n : (i + 1) * n]) start = offs_cpu[i] - self.grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) + self.scaled_grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_2d_3d(self, fast_accum, strided): + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1317,14 +1486,14 @@ def test_grouped_gemm_2d_3d(self, fast_accum, strided): ascalelist.append(scale_a[start:offs_cpu[i]]) outlist.append(out[start:offs_cpu[i]]) start = offs_cpu[i] - self.grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) + self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_3d_3d(self, fast_accum, strided): + def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1338,14 +1507,14 @@ def test_grouped_gemm_3d_3d(self, fast_accum, strided): out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=fast_accum) - self.grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) + self.scaled_grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum) @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90") @parametrize("fast_accum", [False, True]) @parametrize("strided", [False, True]) - def test_grouped_gemm_3d_2d(self, fast_accum, strided): + def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided): device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 16, 4 @@ -1367,7 +1536,8 @@ def test_grouped_gemm_3d_2d(self, fast_accum, strided): bscalelist.append(scale_b[start:offs_cpu[i]]) outlist.append(out[:, start:offs_cpu[i]]) start = offs_cpu[i] - self.grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) + self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum) + @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) def test_blockwise_mxfp8_compile(self) -> None: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 25749372a1f3..6a42e26d7618 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1206,6 +1206,10 @@ mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) result: at::mm(self_t, mat2_p) + at::mm(self_p, mat2_t) +- name: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + self: _grouped_mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), offs, 1) + mat2: _grouped_mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), offs, 1) + - name: mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) self: value_selecting_reduction_backward_symint(grad, dim, indices, self.sym_sizes(), keepdim) values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index e5ee41e6fd56..498259c8fa12 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1450,6 +1450,62 @@ Tensor mm_mat2_backward( return maybe_multiply(mat1.t().conj().mm(grad), alpha.conj()); } +Tensor _grouped_mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::SymIntArrayRef mat1_sizes, + at::SymIntArrayRef mat1_strides, + c10::Layout mat1_layout, + std::optional offs, + const Scalar& alpha) { + TORCH_CHECK( + grad.layout() == c10::kStrided && mat2.layout() == c10::kStrided && + mat1_layout == c10::kStrided, + "only strided layout supported for grouped mm"); + // if input was column-major, return grad as column-order for efficiency + if (offs.has_value() && !offs->defined()) { + offs = std::nullopt; + } + auto mat1_dim = mat1_sizes.size(); + if (mat1_strides[mat1_dim - 2] == 1 && + mat1_strides[mat1_dim - 1] == mat1_sizes[mat1_dim - 2]) { + auto grad_inp = + (at::_grouped_mm(mat2, grad.transpose(-2, -1), offs)).transpose(-2, -1); + return maybe_multiply(grad_inp, alpha.conj()); + } else { + auto grad_inp = (at::_grouped_mm(grad, mat2.transpose(-2, -1), offs)); + return maybe_multiply(grad_inp, alpha.conj()); + } +} + +Tensor _grouped_mm_mat2_backward( + const Tensor& grad, + const Tensor& mat1, + at::SymIntArrayRef mat2_sizes, + at::SymIntArrayRef mat2_strides, + c10::Layout mat2_layout, + std::optional offs, + const Scalar& alpha) { + TORCH_CHECK( + grad.layout() == c10::kStrided && mat1.layout() == c10::kStrided && + mat2_layout == c10::kStrided, + "only strided layout supported for grouped mm"); + // if input was column-major, return grad as column-order for efficiency + auto mat2_dim = mat2_sizes.size(); + if (offs.has_value() && !offs->defined()) { + offs = std::nullopt; + } + if (mat2_strides[mat2_dim - 2] == 1 && + mat2_strides[mat2_dim - 1] == mat2_sizes[mat2_dim - 2]) { + auto grad_inp = + at::_grouped_mm(grad.transpose(-2, -1), mat1, offs).transpose(-2, -1); + return maybe_multiply(grad_inp, alpha.conj()); + } else { + auto grad_inp = at::_grouped_mm(mat1.transpose(-2, -1), grad, offs); + return maybe_multiply(grad_inp, alpha.conj()); + } +} + Tensor mm_mat1_sparse_backward( const Tensor& grad, const Tensor& mat1, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 8d01a80eb406..1bbad0ae92dd 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -306,6 +306,22 @@ at::Tensor mm_mat2_backward( at::SymIntArrayRef strides, c10::Layout layout, const at::Scalar& alpha); +at::Tensor _grouped_mm_mat1_backward( + const Tensor& grad, + const Tensor& mat2, + at::SymIntArrayRef mat1_sizes, + at::SymIntArrayRef mat1_strides, + c10::Layout mat1_layout, + std::optional offs, + const Scalar& alpha); +at::Tensor _grouped_mm_mat2_backward( + const at::Tensor& grad, + const at::Tensor& mat1, + at::SymIntArrayRef sizes, + at::SymIntArrayRef strides, + c10::Layout layout, + std::optional offs, + const at::Scalar& alpha); at::Tensor mm_mat1_sparse_backward( const at::Tensor& grad, const at::Tensor& mat1, From 49f6cce736fc7590d2a881a9971149bb5dce7279 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Sun, 6 Apr 2025 17:06:55 +0000 Subject: [PATCH 217/332] [MPS] grad scaler (#150255) Fixes #142397 Basic implementation is done. What's left: - [x] Different dtype/device tensors in the TensorList - [x] fast path for grouping the foreach kernel - [x] Tests Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device. By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put: `instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)` This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- aten/src/ATen/native/mps/kernels/Amp.metal | 132 ++++++++++++++++++ .../mps/kernels/FusedOptimizerOps.metal | 6 +- aten/src/ATen/native/mps/operations/Amp.mm | 132 ++++++++++++++++++ .../native/mps/operations/FusedSgdKernel.mm | 4 - .../native/mps/operations/MultiTensorApply.h | 114 ++++++++++++++- aten/src/ATen/native/native_functions.yaml | 2 + test/test_mps.py | 119 ++++++++++++++++ torch/amp/grad_scaler.py | 6 +- torch/optim/adam.py | 3 - 9 files changed, 503 insertions(+), 15 deletions(-) create mode 100644 aten/src/ATen/native/mps/kernels/Amp.metal create mode 100644 aten/src/ATen/native/mps/operations/Amp.mm diff --git a/aten/src/ATen/native/mps/kernels/Amp.metal b/aten/src/ATen/native/mps/kernels/Amp.metal new file mode 100644 index 000000000000..f32621320ab4 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Amp.metal @@ -0,0 +1,132 @@ +#include +using namespace metal; + +constant constexpr unsigned kmaxThreadGroups = 32; +constant constexpr unsigned kmaxTensors = 32; +constant constexpr unsigned kChunkSize = 65536; + +template +struct AmpNonFiniteCheckAndUnscaleArgs { + metal::array data [[id(0)]]; +}; + +struct MetadataArguments { + ulong numels[kmaxTensors]; + ulong threadgroup_to_tensor[kmaxThreadGroups]; + ulong threadgroup_to_chunk[kmaxThreadGroups]; +}; + +template +kernel void ampNonFiniteCheckAndUnscale( + constant AmpNonFiniteCheckAndUnscaleArgs& pointerArgs [[buffer(0)]], + constant MetadataArguments& metadata [[buffer(1)]], + device float& foundInf [[buffer(2)]], + constant T& invScale [[buffer(3)]], + uint local_tid [[thread_position_in_threadgroup]], + uint tgSize [[threads_per_threadgroup]], + uint group_id [[threadgroup_position_in_grid]]) { + uint threadGroupSize = tgSize; + uint tensor_index = metadata.threadgroup_to_tensor[group_id]; + uint chunk = metadata.threadgroup_to_chunk[group_id]; + uint numel = metadata.numels[tensor_index]; + + uint offset = chunk * kChunkSize; + uint chunk_size = + ((offset + kChunkSize) > numel) ? (numel - offset) : kChunkSize; + + device T* data = pointerArgs.data[tensor_index]; + + for (uint i = local_tid; i < chunk_size; i += threadGroupSize) { + uint index = offset + i; + T val = data[index]; + if (!isfinite(val)) { + foundInf = 1.0f; + } + data[index] = (invScale == static_cast(1.0) ? val : val * invScale); + } +} + +template +kernel void ampNonFiniteCheckAndUnscaleSingle( + device T* data [[buffer(0)]], + device float& foundInf [[buffer(1)]], + constant T& invScale [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + T val = data[tid]; + if (!isfinite(val)) { + foundInf = 1.0f; + } + data[tid] = (invScale == T(1.0) ? val : val * invScale); +} + +template +kernel void ampUpdateScale( + device T& scale [[buffer(0)]], + device int& growth_tracker [[buffer(1)]], + device float& foundInf [[buffer(2)]], + constant T& scaleGrowthFactor [[buffer(3)]], + constant T& scaleBackoffFactor [[buffer(4)]], + constant int& growthInterval [[buffer(5)]], + uint tid [[thread_position_in_grid]]) { + if (foundInf != 0.0f) { + scale *= scaleBackoffFactor; + growth_tracker = 0; + } else { + int g = growth_tracker + 1; + if (g >= growthInterval) { + scale *= scaleGrowthFactor; + g = 0; + } + growth_tracker = g; + } +} + +#define INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(DTYPE) \ + template [[host_name("ampNonFiniteCheckAndUnscale_" #DTYPE)]] kernel void \ + ampNonFiniteCheckAndUnscale( \ + constant AmpNonFiniteCheckAndUnscaleArgs & \ + pointerArgs [[buffer(0)]], \ + constant MetadataArguments & metadata [[buffer(1)]], \ + device float& foundInf [[buffer(2)]], \ + constant DTYPE& invScale [[buffer(3)]], \ + uint local_tid [[thread_position_in_threadgroup]], \ + uint tgSize [[threads_per_threadgroup]], \ + uint group_id [[threadgroup_position_in_grid]]) + +#define INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(DTYPE) \ + template \ + [[host_name("ampNonFiniteCheckAndUnscaleSingle_" #DTYPE)]] kernel void \ + ampNonFiniteCheckAndUnscaleSingle( \ + device DTYPE * data [[buffer(0)]], \ + device float& foundInf [[buffer(1)]], \ + constant DTYPE& invScale [[buffer(2)]], \ + uint tid [[thread_position_in_grid]]) + +#define INSTANTIATE_AMP_UPDATE_SCALE(DTYPE) \ + template [[host_name("ampUpdateScale_" #DTYPE)]] kernel void \ + ampUpdateScale( \ + device DTYPE & scale [[buffer(0)]], \ + device int& growth_tracker [[buffer(1)]], \ + device float& foundInf [[buffer(2)]], \ + constant DTYPE& scaleGrowthFactor [[buffer(3)]], \ + constant DTYPE& scaleBackoffFactor [[buffer(4)]], \ + constant int& growthInterval [[buffer(5)]], \ + uint tid [[thread_position_in_grid]]) + +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float); +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat); +#endif + +INSTANTIATE_AMP_UPDATE_SCALE(float); +INSTANTIATE_AMP_UPDATE_SCALE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_UPDATE_SCALE(bfloat); +#endif + +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); +#endif \ No newline at end of file diff --git a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal index 2006e768d826..fe5605226748 100644 --- a/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal +++ b/aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal @@ -57,9 +57,9 @@ struct SgdMomentumArguments { }; struct MetadataArguments { - uint32_t numels[kmaxTensors]; - uint32_t threadgroup_to_tensor[kmaxThreadGroups]; - uint32_t threadgroup_to_chunk[kmaxThreadGroups]; + ulong numels[kmaxTensors]; + ulong threadgroup_to_tensor[kmaxThreadGroups]; + ulong threadgroup_to_chunk[kmaxThreadGroups]; }; enum ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; diff --git a/aten/src/ATen/native/mps/operations/Amp.mm b/aten/src/ATen/native/mps/operations/Amp.mm new file mode 100644 index 000000000000..e410d434ec7a --- /dev/null +++ b/aten/src/ATen/native/mps/operations/Amp.mm @@ -0,0 +1,132 @@ +// Copyright © 2022 Apple Inc. +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +namespace at::native { +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif +namespace mps { + +static void _amp_non_finite_check_and_unscale_mps_single_impl(const Tensor& scaled_grad, + at::Tensor& found_inf, + const at::Tensor& inv_scale) { + if (scaled_grad.numel() == 0) { + return; + } + TORCH_CHECK(scaled_grad.is_mps(), "Tensor is not on the MPS device."); + TORCH_CHECK(scaled_grad.numel() <= std::numeric_limits::max(), "scaled_grad is too large"); + float inv_scale_val = inv_scale.item(); + auto stream = getCurrentMPSStream(); + auto device = MPSDevice::getInstance()->device(); + auto ampPipelineState = + lib.getPipelineStateForFunc("ampNonFiniteCheckAndUnscaleSingle_" + mps::scalarToMetalTypeString(scaled_grad)); + + const uint32_t threadsPerThreadgroup = 256; + uint32_t numel = static_cast(scaled_grad.numel()); + MTLSize threadGroupSize = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize gridSize = MTLSizeMake(numel, 1, 1); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:ampPipelineState]; + mtl_setArgs(computeEncoder, scaled_grad, found_inf, inv_scale_val); + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; + }); +} + +static void _amp_update_scale_mps_impl(Tensor& self, + Tensor& growth_tracker, + const Tensor& found_inf, + float scale_growth_factor, + float scale_backoff_factor, + int32_t growth_interval) { + auto stream = getCurrentMPSStream(); + auto ampUpdatePipelineState = lib.getPipelineStateForFunc("ampUpdateScale_" + mps::scalarToMetalTypeString(self)); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:ampUpdatePipelineState]; + + mtl_setArgs( + computeEncoder, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + mtl_dispatch1DJob(computeEncoder, ampUpdatePipelineState, 1); + }); +} + +std::pair, id> getAmpCPLState(const std::string& fname) { + return {lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname)}; +} +} // namespace mps + +void _amp_foreach_non_finite_check_and_unscale_mps_(at::TensorList self, + at::Tensor& found_inf, + const at::Tensor& inv_scale) { + if (self.size() == 0) { + return; + } + TORCH_CHECK(inv_scale.is_mps(), "inv_scale must be a MPS tensor."); + TORCH_CHECK(found_inf.is_mps(), "found_inf must be a MPS tensor."); + TORCH_CHECK(inv_scale.numel() == 1, "inv_scale must be a 1-element tensor."); + TORCH_CHECK(found_inf.numel() == 1, "found_inf must be a 1-element tensor."); + TORCH_CHECK(inv_scale.scalar_type() == at::ScalarType::Float, "inv_scale must be a float tensor."); + TORCH_CHECK(found_inf.scalar_type() == at::ScalarType::Float, "found_inf must be a float tensor."); + // Ensures client code (GradScaler) filtered scaled_grads by API restrictions. + check_foreach_api_restrictions(self); + + // Prepare a vector of tensor lists. + std::vector> tensor_lists; + if (can_use_fast_route(self)) { + TORCH_CHECK(self[0].is_mps(), "scaled_grads must be MPS tensors."); + tensor_lists.emplace_back(self.vec()); + } else { + tensor_lists.resize(1); + tensor_lists[0].reserve(self.size()); + auto expected_device = self[0].device(); + const auto expected_dtype = self[0].scalar_type(); + for (const at::Tensor& t : self) { + // Ensure that GradScaler has filtered by device, layout, and dtype. + TORCH_CHECK(t.is_mps(), "one of scaled_grads was not a MPS tensor."); + TORCH_CHECK(t.device() == expected_device, "scaled_grads must be on the same device."); + TORCH_CHECK(t.layout() == at::kStrided, "one of scaled_grads was not a strided tensor."); + if (!t.is_non_overlapping_and_dense() || t.scalar_type() != expected_dtype) { + // Fall back to the single-tensor implementation + mps::_amp_non_finite_check_and_unscale_mps_single_impl(const_cast(t), found_inf, inv_scale); + } else { + tensor_lists[0].push_back(t); + } + } + if (tensor_lists[0].empty()) { + return; + } + } + + std::string kernel_name = + "ampNonFiniteCheckAndUnscale_" + mps::scalarToMetalTypeString(tensor_lists[0][0].scalar_type()); + mps::multi_tensor_apply<1>(kernel_name, tensor_lists, found_inf, inv_scale); +} + +Tensor& _amp_update_scale_mps_(Tensor& self, + Tensor& growth_tracker, + const Tensor& found_inf, + double scale_growth_factor, + double scale_backoff_factor, + int64_t growth_interval) { + mps::_amp_update_scale_mps_impl( + self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + return self; +} +} // namespace at::native \ No newline at end of file diff --git a/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm b/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm index 538d04fed999..4057f2dcbac5 100644 --- a/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm +++ b/aten/src/ATen/native/mps/operations/FusedSgdKernel.mm @@ -114,8 +114,6 @@ void _fused_sgd_kernel_mps_(TensorList params, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { - TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); - if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, @@ -163,8 +161,6 @@ void _fused_sgd_kernel_mps_(TensorList params, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { - TORCH_CHECK(!grad_scale.has_value() && !found_inf.has_value(), "grad_scale and found_inf are not supported on MPS"); - if (!momentum_buffer_list.empty()) { return _fused_sgd_with_momentum_kernel_mps_(params, grads, diff --git a/aten/src/ATen/native/mps/operations/MultiTensorApply.h b/aten/src/ATen/native/mps/operations/MultiTensorApply.h index cb8d65a129c5..2897d643648a 100644 --- a/aten/src/ATen/native/mps/operations/MultiTensorApply.h +++ b/aten/src/ATen/native/mps/operations/MultiTensorApply.h @@ -11,10 +11,10 @@ static constexpr int64_t kChunkSize = 65536; static constexpr int64_t kmaxThreadGroups = 32; static constexpr int64_t kmaxTensors = 32; -struct MetadataArguments { // the size of this struct must be less than 4 bytes - uint numels[kmaxTensors]; - uint threadgroup_to_tensor[kmaxThreadGroups]; - uint threadgroup_to_chunk[kmaxThreadGroups]; +struct MetadataArguments { // the size of this struct must be less than 4 kilobytes + uint64_t numels[kmaxTensors]; + uint64_t threadgroup_to_tensor[kmaxThreadGroups]; + uint64_t threadgroup_to_chunk[kmaxThreadGroups]; }; struct FusedAdamEncodingFunctor { @@ -253,4 +253,110 @@ static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_nam }); } +std::pair, id> getAmpCPLState(const std::string& fname); +template +void multi_tensor_apply(const std::string& kernel_name, + std::vector>& tensor_lists, + ArgTypes... args) { + const auto num_tensors = tensor_lists[0].size(); + if (num_tensors == 0) { + return; + } + + TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists must match depth."); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + auto [pipeline, function] = getAmpCPLState(kernel_name); + [computeEncoder setComputePipelineState:pipeline]; + + id argumentEncoder = [function newArgumentEncoderWithBufferIndex:0]; + auto tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + int tensor_loc = 0; + int threadgroup_loc = 0; + MetadataArguments metadata_arguments; + std::memset(&metadata_arguments, 0, sizeof(metadata_arguments)); + + for (size_t t = 0; t < num_tensors; t++) { + if (tensor_lists[0][t].numel() == 0) + continue; + + // bind each tensor in this list to the correct slots across depths + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + tensor_loc); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + + // save number of elements for this tensor + metadata_arguments.numels[tensor_loc] = tensor_lists[0][t].numel(); + int currentTensorIndex = tensor_loc; + tensor_loc++; + + const auto numel = tensor_lists[0][t].numel(); + const auto chunks = numel / kChunkSize + ((numel % kChunkSize) ? 1 : 0); + + // process tensor in chunks based on max chunk size + for (uint chunk = 0; chunk < chunks; chunk++) { + metadata_arguments.threadgroup_to_tensor[threadgroup_loc] = currentTensorIndex; + metadata_arguments.threadgroup_to_chunk[threadgroup_loc] = chunk; + threadgroup_loc++; + + // dispatch when we've filled the threadgroup array or finished the chunks + const bool dispatch_now = (threadgroup_loc == kmaxThreadGroups) || (chunk == chunks - 1); + if (dispatch_now) { + // check for a partial dispatch (i.e. more chunks remain for the current tensor) + bool partial = (chunk != chunks - 1); + uint carried_numels = 0; + if (partial) { + carried_numels = metadata_arguments.numels[currentTensorIndex]; + } + + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, (uint32_t)64), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + + // prepare for the next batch: reset threadgroup count and create a new buffer + threadgroup_loc = 0; + tensorArgumentBuffer = [[device newBufferWithLength:argumentEncoder.encodedLength options:0] autorelease]; + [argumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0]; + + if (partial) { + // for a partial dispatch, rebind the partially processed tensor to slot 0 + // so that its metadata is in the correct location + for (int d = 0; d < depth; d++) { + mtl_setBuffer(argumentEncoder, tensor_lists[d][t], d * kmaxTensors + 0); + [computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][t]) + usage:(MTLResourceUsageRead | MTLResourceUsageWrite)]; + } + metadata_arguments.numels[0] = carried_numels; + // the currently processed tensor now lives at index 0 + currentTensorIndex = 0; + tensor_loc = 1; + } else { + tensor_loc = 0; + } + } + } + } + + if (threadgroup_loc != 0) { + mtl_setArgs(computeEncoder, tensorArgumentBuffer, metadata_arguments, args...); + MTLSize gridSize = MTLSizeMake(threadgroup_loc, 1, 1); + uint32_t maxThreads = [pipeline maxTotalThreadsPerThreadgroup]; + MTLSize threadGroupSize = MTLSizeMake(std::min(maxThreads, static_cast(64)), 1, 1); + [computeEncoder dispatchThreadgroups:gridSize threadsPerThreadgroup:threadGroupSize]; + } + } + }); +} + } // namespace at::native::mps diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index b4e000a2a3ca..a29fee8c7066 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -10394,6 +10394,7 @@ dispatch: CUDA: _amp_foreach_non_finite_check_and_unscale_cuda_ CPU: _amp_foreach_non_finite_check_and_unscale_cpu_ + MPS: _amp_foreach_non_finite_check_and_unscale_mps_ autogen: _amp_foreach_non_finite_check_and_unscale, _amp_foreach_non_finite_check_and_unscale.out - func: _amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) @@ -10401,6 +10402,7 @@ dispatch: CUDA: _amp_update_scale_cuda_ CPU: _amp_update_scale_cpu_ + MPS: _amp_update_scale_mps_ autogen: _amp_update_scale, _amp_update_scale.out #- func: _cat(Tensor[] tensors, int dim=0) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index 61903bd39005..576659ae29d6 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1089,6 +1089,125 @@ def test_scaled_dot_product_attention_autocast(self): y = F.scaled_dot_product_attention(query, key, value.to(torch.float32)) self.assertEqual(y.to(y_autocast.dtype), y_autocast) + def test_gradscaler_mps(self): + # big model to force chunking/depth in the gradscaler dispatch + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 2048) + self.fc2 = nn.Linear(2048, 2048) + self.fc3 = nn.Linear(2048, 2048) + self.fc4 = nn.Linear(2048, 2048) + self.fc5 = nn.Linear(2048, 5) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) + x = self.relu(self.fc3(x)) + x = self.relu(self.fc4(x)) + return self.fc5(x) + torch.manual_seed(42) + + def helper(model_cpu, model_mps, dtype, iterations, batch_size, atol=3e-4, rtol=1e-5): + if dtype == torch.bfloat16 and MACOS_VERSION < 14.0: + raise unittest.SkipTest("bfloat16 needs MacOS14+") + optimizer_cpu = torch.optim.SGD(model_cpu.parameters(), lr=0.01) + optimizer_mps = torch.optim.SGD(model_mps.parameters(), lr=0.01) + loss_fn = nn.MSELoss() + + input_cpu = torch.randn(batch_size, 10) + target_cpu = torch.randn(batch_size, 5) + input_mps = input_cpu.to('mps') + target_mps = target_cpu.to('mps') + + scaler_cpu = torch.amp.GradScaler(device="cpu") + scaler_mps = torch.amp.GradScaler(device="mps") + for _ in range(iterations): + optimizer_cpu.zero_grad() + optimizer_mps.zero_grad() + + with torch.amp.autocast(device_type="cpu", dtype=dtype): + output_cpu = model_cpu(input_cpu) + loss_cpu = loss_fn(output_cpu, target_cpu) + scaler_cpu.scale(loss_cpu).backward() + scaler_cpu.step(optimizer_cpu) + scaler_cpu.update() + + with torch.autocast(device_type="mps", dtype=dtype): + output_mps = model_mps(input_mps) + loss_mps = loss_fn(output_mps, target_mps) + scaler_mps.scale(loss_mps).backward() + scaler_mps.step(optimizer_mps) + scaler_mps.update() + + for p_cpu, p_mps in zip(model_cpu.parameters(), model_mps.parameters()): + self.assertEqual(p_mps.cpu(), p_cpu, rtol=rtol, atol=atol) + + model_cpu = Model().to('cpu') + model_mps = Model().to('mps') + model_mps.load_state_dict(model_cpu.state_dict()) + + helper(model_cpu, model_mps, torch.float16, iterations=5, batch_size=4) + helper(model_cpu, model_mps, torch.bfloat16, iterations=5, batch_size=4) + + def test_non_fast_path_amp_unscale(self): + torch.manual_seed(42) + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x): + x = self.linear1(x) + x = F.relu(x) + x = self.linear2(x) + x = x.mean(dim=1) + return x + + cpu_model = Model().to("cpu") + mps_model = copy.deepcopy(cpu_model).to("mps") + + cpu_optimizer = torch.optim.SGD(cpu_model.parameters(), lr=0.01) + mps_optimizer = torch.optim.SGD(mps_model.parameters(), lr=0.01) + cpu_scaler = torch.amp.GradScaler(device="cpu") + mps_scaler = torch.amp.GradScaler(device="mps") + + def helper(model, optimizer, scaler, device, input, target, apply_grad_transform=False): + optimizer.zero_grad() + with torch.autocast(device_type=device, dtype=torch.bfloat16): + output = model(input) + loss = nn.MSELoss()(output, target) + scaler.scale(loss).backward() + + if apply_grad_transform: + for p in model.parameters(): + if p.grad is not None and p.grad.dim() >= 2: + p.grad = p.grad.as_strided(p.grad.size(), (1,) * p.grad.dim()) + + scaler.unscale_(optimizer) + scaler.step(optimizer) + scaler.update() + + # CPU forward/backward pass + input_cpu = torch.randn(32, 10, device="cpu") + target_cpu = torch.randn(32, device="cpu") + helper(cpu_model, cpu_optimizer, cpu_scaler, "cpu", input_cpu, target_cpu) + + # MPS forward/backward pass + input_mps = input_cpu.to("mps") + target_mps = target_cpu.to("mps") + helper(mps_model, mps_optimizer, mps_scaler, "mps", input_mps, target_mps, apply_grad_transform=True) + + updated_linear1_weight_cpu = cpu_model.linear1.weight.detach() + updated_linear2_weight_cpu = cpu_model.linear2.weight.detach() + updated_linear1_weight_mps = mps_model.linear1.weight.detach().cpu() + updated_linear2_weight_mps = mps_model.linear2.weight.detach().cpu() + + self.assertEqual(updated_linear1_weight_cpu, updated_linear1_weight_mps, atol=6e-4, rtol=1e-6) + self.assertEqual(updated_linear2_weight_cpu, updated_linear2_weight_mps, atol=6e-4, rtol=1e-6) # Expand TestCase class with Memory Leak Detection on MPS device class TestCaseMPS(TestCase): diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index 93b1d667c08a..2931b5b9fadd 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -336,7 +336,11 @@ def unscale_(self, optimizer: torch.optim.Optimizer) -> None: # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. assert self._scale is not None - inv_scale = self._scale.double().reciprocal().float() + inv_scale = ( + self._scale.double().reciprocal().float() + if self._scale.device != torch.device("mps:0") + else self._scale.reciprocal() + ) found_inf = torch.full((), 0.0, dtype=torch.float32, device=self._scale.device) optimizer_state["found_inf_per_device"] = self._unscale_grads_( diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 2f01b1d683bb..a86cb340082f 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -832,9 +832,6 @@ def _fused_adam( device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_) device_state_steps = cast(list[Tensor], device_state_steps_) - if device.type == "mps": # type: ignore[union-attr] - assert found_inf is None and grad_scale is None - device_grad_scale, device_found_inf = None, None if grad_scale is not None: device_grad_scale = grad_scale_dict.setdefault( From 6c38b9be730269e23a3feb2a480ccf5bfa291b65 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sun, 6 Apr 2025 17:50:35 +0000 Subject: [PATCH 218/332] [typing] Add type hints to `__init__` methods in `torch.distributions`. (#144197) Fixes #144196 Extends #144106 and #144110 ## Open Problems: - [ ] Annotating with `numbers.Number` is a bad idea, should consider using `float`, `SupportsFloat` or some `Procotol`. https://github.com/pytorch/pytorch/pull/144197#discussion_r1903324769 # Notes - `beta.py`: needed to add `type: ignore` since `broadcast_all` is untyped. - `categorical.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`dirichlet.py`: replaced `axis` with `dim` arguments.~~ #144402 - `gemoetric.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - ~~`independent.py`: fixed bug in `Independent.__init__` where `tuple[int, ...]` could be passed to `Distribution.__init__` instead of `torch.Size`.~~ **EDIT:** turns out the bug is related to typing of `torch.Size`. #144218 - `independent.py`: made `Independent` a generic class of its base distribution. - `multivariate_normal.py`: converted `else` branches of mutually exclusive arguments to `if` branch[^2]. - `relaxed_bernoulli.py`: added class-level type hint for `base_dist`. - `relaxed_categorical.py`: added class-level type hint for `base_dist`. - ~~`transforms.py`: Added missing argument to docstring of `ReshapeTransform`~~ #144401 - ~~`transforms.py`: Fixed bug in `AffineTransform.sign` (could return `Tensor` instead of `int`).~~ #144400 - `transforms.py`: Added `type: ignore` comments to `AffineTransform.log_abs_det_jacobian`[^1]; replaced `torch.abs(scale)` with `scale.abs()`. - `transforms.py`: Added `type: ignore` comments to `AffineTransform.__eq__`[^1]. - `transforms.py`: Fixed type hint on `CumulativeDistributionTransform.domain`. Note that this is still an LSP violation, because `Transform.domain` is defined as `Constraint`, but `Distribution.domain` is defined as `Optional[Constraint]`. - skipped: `constraints.py`, `constraints_registry.py`, `kl.py`, `utils.py`, `exp_family.py`, `__init__.py`. ## Remark `TransformedDistribution`: `__init__` uses the check `if reinterpreted_batch_ndims > 0:`, which can lead to the creation of `Independent` distributions with only 1 component. This results in awkward code like `base_dist.base_dist` in `LogisticNormal`. ```python import torch from torch.distributions import * b1 = Normal(torch.tensor([0.0]), torch.tensor([1.0])) b2 = MultivariateNormal(torch.tensor([0.0]), torch.eye(1)) t = StickBreakingTransform() d1 = TransformedDistribution(b1, t) d2 = TransformedDistribution(b2, t) print(d1.base_dist) # Independent with 1 dimension print(d2.base_dist) # MultivariateNormal ``` One could consider changing this to `if reinterpreted_batch_ndims > 1:`. [^1]: Usage of `isinstance(value, numbers.Real)` leads to problems with static typing, as the `numbers` module is not supported by `mypy` (see ). This results in us having to add type-ignore comments in several places [^2]: Otherwise, we would have to add a bunch of `type: ignore` comments to make `mypy` happy, as it isn't able to perform the type narrowing. Ideally, such code should be replaced with structural pattern matching once support for Python 3.9 is dropped. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144197 Approved by: https://github.com/malfet Co-authored-by: Aaron Gokaslan --- torch/distributions/bernoulli.py | 12 +++- torch/distributions/beta.py | 9 ++- torch/distributions/binomial.py | 11 +++- torch/distributions/categorical.py | 10 ++- torch/distributions/cauchy.py | 8 ++- torch/distributions/chi2.py | 8 ++- torch/distributions/continuous_bernoulli.py | 10 ++- torch/distributions/dirichlet.py | 8 ++- torch/distributions/distribution.py | 2 +- torch/distributions/exponential.py | 8 ++- torch/distributions/fishersnedecor.py | 9 ++- torch/distributions/gamma.py | 9 ++- torch/distributions/geometric.py | 13 +++- torch/distributions/gumbel.py | 8 ++- torch/distributions/half_cauchy.py | 8 ++- torch/distributions/half_normal.py | 8 ++- torch/distributions/independent.py | 20 ++++-- torch/distributions/inverse_gamma.py | 12 +++- torch/distributions/kumaraswamy.py | 9 ++- torch/distributions/laplace.py | 9 ++- torch/distributions/lkj_cholesky.py | 9 ++- torch/distributions/log_normal.py | 12 +++- torch/distributions/logistic_normal.py | 12 +++- .../lowrank_multivariate_normal.py | 9 ++- torch/distributions/mixture_same_family.py | 3 +- torch/distributions/multinomial.py | 10 ++- torch/distributions/multivariate_normal.py | 14 +++-- torch/distributions/negative_binomial.py | 11 +++- torch/distributions/normal.py | 8 ++- torch/distributions/one_hot_categorical.py | 9 ++- torch/distributions/pareto.py | 7 ++- torch/distributions/poisson.py | 10 ++- torch/distributions/relaxed_bernoulli.py | 24 +++++-- torch/distributions/relaxed_categorical.py | 21 ++++++- torch/distributions/studentT.py | 9 ++- .../distributions/transformed_distribution.py | 8 ++- torch/distributions/transforms.py | 63 +++++++++++++------ torch/distributions/uniform.py | 9 ++- torch/distributions/utils.py | 6 +- torch/distributions/von_mises.py | 8 ++- torch/distributions/weibull.py | 9 ++- torch/distributions/wishart.py | 4 +- 42 files changed, 382 insertions(+), 84 deletions(-) diff --git a/torch/distributions/bernoulli.py b/torch/distributions/bernoulli.py index 105038641bcc..659f9a20b10e 100644 --- a/torch/distributions/bernoulli.py +++ b/torch/distributions/bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Bernoulli"] @@ -41,7 +43,12 @@ class Bernoulli(ExponentialFamily): has_enumerate_support = True _mean_carrier_measure = 0 - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +57,7 @@ def __init__(self, probs=None, logits=None, validate_args=None): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/beta.py b/torch/distributions/beta.py index e030b648a88e..e06a28ca5aa4 100644 --- a/torch/distributions/beta.py +++ b/torch/distributions/beta.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -36,7 +38,12 @@ class Beta(ExponentialFamily): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(concentration1, _Number) and isinstance(concentration0, _Number): concentration1_concentration0 = torch.tensor( [float(concentration1), float(concentration0)] diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index 6cbfae150844..90461784c06d 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -50,7 +52,13 @@ class Binomial(Distribution): } has_enumerate_support = True - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, int] = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -62,6 +70,7 @@ def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index 715429c66552..1c8fed2636ad 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -51,7 +53,12 @@ class Categorical(Distribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -61,6 +68,7 @@ def __init__(self, probs=None, logits=None, validate_args=None): raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs / probs.sum(-1, keepdim=True) else: + assert logits is not None # helps mypy if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") # Normalize diff --git a/torch/distributions/cauchy.py b/torch/distributions/cauchy.py index 582c08ebb858..84c1d34bda79 100644 --- a/torch/distributions/cauchy.py +++ b/torch/distributions/cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -34,7 +35,12 @@ class Cauchy(Distribution): support = constraints.real has_rsample = True - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/chi2.py b/torch/distributions/chi2.py index f175bc44f69e..fa23115fc035 100644 --- a/torch/distributions/chi2.py +++ b/torch/distributions/chi2.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.gamma import Gamma @@ -25,7 +27,11 @@ class Chi2(Gamma): arg_constraints = {"df": constraints.positive} - def __init__(self, df, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: super().__init__(0.5 * df, 0.5, validate_args=validate_args) def expand(self, batch_shape, _instance=None): diff --git a/torch/distributions/continuous_bernoulli.py b/torch/distributions/continuous_bernoulli.py index b1e8eddfb0ec..14d0d6a9c177 100644 --- a/torch/distributions/continuous_bernoulli.py +++ b/torch/distributions/continuous_bernoulli.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -13,7 +14,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["ContinuousBernoulli"] @@ -52,7 +53,11 @@ class ContinuousBernoulli(ExponentialFamily): has_rsample = True def __init__( - self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + lims: tuple[float, float] = (0.499, 0.501), + validate_args: Optional[bool] = None, ) -> None: if (probs is None) == (logits is None): raise ValueError( @@ -68,6 +73,7 @@ def __init__( raise ValueError("The parameter probs has invalid values") self.probs = clamp_probs(self.probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits diff --git a/torch/distributions/dirichlet.py b/torch/distributions/dirichlet.py index f656a0582e89..414ad6efe47e 100644 --- a/torch/distributions/dirichlet.py +++ b/torch/distributions/dirichlet.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.autograd import Function @@ -54,7 +56,11 @@ class Dirichlet(ExponentialFamily): support = constraints.simplex has_rsample = True - def __init__(self, concentration, validate_args=None): + def __init__( + self, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if concentration.dim() < 1: raise ValueError( "`concentration` parameter must be at least one-dimensional." diff --git a/torch/distributions/distribution.py b/torch/distributions/distribution.py index 75ea50d24860..b2895cb3b0d7 100644 --- a/torch/distributions/distribution.py +++ b/torch/distributions/distribution.py @@ -44,7 +44,7 @@ def __init__( batch_shape: torch.Size = torch.Size(), event_shape: torch.Size = torch.Size(), validate_args: Optional[bool] = None, - ): + ) -> None: self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: diff --git a/torch/distributions/exponential.py b/torch/distributions/exponential.py index 8ca2636e1f52..d15cb1f7a258 100644 --- a/torch/distributions/exponential.py +++ b/torch/distributions/exponential.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,11 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return self.rate.pow(-2) - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size() super().__init__(batch_shape, validate_args=validate_args) diff --git a/torch/distributions/fishersnedecor.py b/torch/distributions/fishersnedecor.py index 053686c6de07..4755bd0d8bde 100644 --- a/torch/distributions/fishersnedecor.py +++ b/torch/distributions/fishersnedecor.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -31,7 +33,12 @@ class FisherSnedecor(Distribution): support = constraints.positive has_rsample = True - def __init__(self, df1, df2, validate_args=None): + def __init__( + self, + df1: Union[Tensor, float], + df2: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.df1, self.df2 = broadcast_all(df1, df2) self._gamma1 = Gamma(self.df1 * 0.5, self.df1) self._gamma2 = Gamma(self.df2 * 0.5, self.df2) diff --git a/torch/distributions/gamma.py b/torch/distributions/gamma.py index 5e0fe3fc7823..9df91ebee640 100644 --- a/torch/distributions/gamma.py +++ b/torch/distributions/gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -52,7 +54,12 @@ def mode(self) -> Tensor: def variance(self) -> Tensor: return self.concentration / self.rate.pow(2) - def __init__(self, concentration, rate, validate_args=None): + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration, self.rate = broadcast_all(concentration, rate) if isinstance(concentration, _Number) and isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/geometric.py b/torch/distributions/geometric.py index b8b05142db5b..b5ceac39e94e 100644 --- a/torch/distributions/geometric.py +++ b/torch/distributions/geometric.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -10,7 +12,7 @@ probs_to_logits, ) from torch.nn.functional import binary_cross_entropy_with_logits -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Geometric"] @@ -45,7 +47,12 @@ class Geometric(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.nonnegative_integer - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -53,11 +60,13 @@ def __init__(self, probs=None, logits=None, validate_args=None): if probs is not None: (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy (self.logits,) = broadcast_all(logits) probs_or_logits = probs if probs is not None else logits if isinstance(probs_or_logits, _Number): batch_shape = torch.Size() else: + assert probs_or_logits is not None # helps mypy batch_shape = probs_or_logits.size() super().__init__(batch_shape, validate_args=validate_args) if self._validate_args and probs is not None: diff --git a/torch/distributions/gumbel.py b/torch/distributions/gumbel.py index 623cc7edbda6..6d097c9324e2 100644 --- a/torch/distributions/gumbel.py +++ b/torch/distributions/gumbel.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -33,7 +34,12 @@ class Gumbel(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.real - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) finfo = torch.finfo(self.loc.dtype) if isinstance(loc, _Number) and isinstance(scale, _Number): diff --git a/torch/distributions/half_cauchy.py b/torch/distributions/half_cauchy.py index da17c40da2ed..572ae080ac3e 100644 --- a/torch/distributions/half_cauchy.py +++ b/torch/distributions/half_cauchy.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfCauchy(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Cauchy - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Cauchy(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/half_normal.py b/torch/distributions/half_normal.py index 5850f883e908..21e1b9d2c506 100644 --- a/torch/distributions/half_normal.py +++ b/torch/distributions/half_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, Tensor @@ -33,8 +34,13 @@ class HalfNormal(TransformedDistribution): arg_constraints = {"scale": constraints.positive} support = constraints.nonnegative has_rsample = True + base_dist: Normal - def __init__(self, scale, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(0, scale, validate_args=False) super().__init__(base_dist, AbsTransform(), validate_args=validate_args) diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 0442a4c1b483..b66406681bb8 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -1,7 +1,8 @@ # mypy: allow-untyped-defs +from typing import Generic, Optional, TypeVar import torch -from torch import Tensor +from torch import Size, Tensor from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import _sum_rightmost @@ -11,7 +12,10 @@ __all__ = ["Independent"] -class Independent(Distribution): +D = TypeVar("D", bound=Distribution) + + +class Independent(Distribution, Generic[D]): r""" Reinterprets some of the batch dims of a distribution as event dims. @@ -42,17 +46,21 @@ class Independent(Distribution): """ arg_constraints: dict[str, constraints.Constraint] = {} + base_dist: D def __init__( - self, base_distribution, reinterpreted_batch_ndims, validate_args=None - ): + self, + base_distribution: D, + reinterpreted_batch_ndims: int, + validate_args: Optional[bool] = None, + ) -> None: if reinterpreted_batch_ndims > len(base_distribution.batch_shape): raise ValueError( "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" ) - shape = base_distribution.batch_shape + base_distribution.event_shape - event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape) + shape: Size = base_distribution.batch_shape + base_distribution.event_shape + event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape) batch_shape = shape[: len(shape) - event_dim] event_shape = shape[len(shape) - event_dim :] self.base_dist = base_distribution diff --git a/torch/distributions/inverse_gamma.py b/torch/distributions/inverse_gamma.py index aaee976b7f17..de432a34434e 100644 --- a/torch/distributions/inverse_gamma.py +++ b/torch/distributions/inverse_gamma.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -38,8 +40,14 @@ class InverseGamma(TransformedDistribution): } support = constraints.positive has_rsample = True - - def __init__(self, concentration, rate, validate_args=None): + base_dist: Gamma + + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Gamma(concentration, rate, validate_args=validate_args) neg_one = -base_dist.rate.new_ones(()) super().__init__( diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index d38efb631e86..53c09ab9870d 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -45,7 +47,12 @@ class Kumaraswamy(TransformedDistribution): support = constraints.unit_interval has_rsample = True - def __init__(self, concentration1, concentration0, validate_args=None): + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0 ) diff --git a/torch/distributions/laplace.py b/torch/distributions/laplace.py index 39ef9b1efdb7..0d50712fb26f 100644 --- a/torch/distributions/laplace.py +++ b/torch/distributions/laplace.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -46,7 +48,12 @@ def variance(self) -> Tensor: def stddev(self) -> Tensor: return (2**0.5) * self.scale - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/lkj_cholesky.py b/torch/distributions/lkj_cholesky.py index a18f2ed9f52a..d2c29a9286de 100644 --- a/torch/distributions/lkj_cholesky.py +++ b/torch/distributions/lkj_cholesky.py @@ -9,8 +9,10 @@ """ import math +from typing import Optional, Union import torch +from torch import Tensor from torch.distributions import Beta, constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import broadcast_all @@ -61,7 +63,12 @@ class LKJCholesky(Distribution): arg_constraints = {"concentration": constraints.positive} support = constraints.corr_cholesky - def __init__(self, dim, concentration=1.0, validate_args=None): + def __init__( + self, + dim: int, + concentration: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: if dim < 2: raise ValueError( f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." diff --git a/torch/distributions/log_normal.py b/torch/distributions/log_normal.py index a048f94286c8..2c6dbc6bf55c 100644 --- a/torch/distributions/log_normal.py +++ b/torch/distributions/log_normal.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor from torch.distributions import constraints from torch.distributions.normal import Normal @@ -32,8 +34,14 @@ class LogNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.positive has_rsample = True - - def __init__(self, loc, scale, validate_args=None): + base_dist: Normal + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) super().__init__(base_dist, ExpTransform(), validate_args=validate_args) diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index a8f7c099d1e8..729e3a67419f 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + from torch import Tensor -from torch.distributions import constraints +from torch.distributions import constraints, Independent from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution from torch.distributions.transforms import StickBreakingTransform @@ -36,8 +38,14 @@ class LogisticNormal(TransformedDistribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} support = constraints.simplex has_rsample = True + base_dist: Independent[Normal] - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: base_dist = Normal(loc, scale, validate_args=validate_args) if not base_dist.batch_shape: base_dist = base_dist.expand([1]) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index c6f739a595a3..968e4634ba62 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -93,7 +94,13 @@ class LowRankMultivariateNormal(Distribution): support = constraints.real_vector has_rsample = True - def __init__(self, loc, cov_factor, cov_diag, validate_args=None): + def __init__( + self, + loc: Tensor, + cov_factor: Tensor, + cov_diag: Tensor, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") event_shape = loc.shape[-1:] diff --git a/torch/distributions/mixture_same_family.py b/torch/distributions/mixture_same_family.py index 1fc2c1052d03..79a7029e1d72 100644 --- a/torch/distributions/mixture_same_family.py +++ b/torch/distributions/mixture_same_family.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional import torch from torch import Tensor @@ -59,7 +60,7 @@ def __init__( self, mixture_distribution: Categorical, component_distribution: Distribution, - validate_args=None, + validate_args: Optional[bool] = None, ) -> None: self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution diff --git a/torch/distributions/multinomial.py b/torch/distributions/multinomial.py index 85a227f5c403..41d8ded53fd6 100644 --- a/torch/distributions/multinomial.py +++ b/torch/distributions/multinomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import inf, Tensor from torch.distributions import Categorical, constraints @@ -59,7 +61,13 @@ def mean(self) -> Tensor: def variance(self) -> Tensor: return self.total_count * self.probs * (1 - self.probs) - def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: int = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if not isinstance(total_count, int): raise NotImplementedError("inhomogeneous total_count is not supported") self.total_count = total_count diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 849ee4170015..c15a84815b06 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch from torch import Tensor @@ -133,12 +134,12 @@ class MultivariateNormal(Distribution): def __init__( self, - loc, - covariance_matrix=None, - precision_matrix=None, - scale_tril=None, - validate_args=None, - ): + loc: Tensor, + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if loc.dim() < 1: raise ValueError("loc must be at least one-dimensional.") if (covariance_matrix is not None) + (scale_tril is not None) + ( @@ -167,6 +168,7 @@ def __init__( ) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) else: + assert precision_matrix is not None # helps mypy if precision_matrix.dim() < 2: raise ValueError( "precision_matrix must be at least two-dimensional, " diff --git a/torch/distributions/negative_binomial.py b/torch/distributions/negative_binomial.py index e5b0e128efe6..f28222f92f78 100644 --- a/torch/distributions/negative_binomial.py +++ b/torch/distributions/negative_binomial.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch import torch.nn.functional as F from torch import Tensor @@ -38,7 +40,13 @@ class NegativeBinomial(Distribution): } support = constraints.nonnegative_integer - def __init__(self, total_count, probs=None, logits=None, validate_args=None): + def __init__( + self, + total_count: Union[Tensor, float], + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." @@ -50,6 +58,7 @@ def __init__(self, total_count, probs=None, logits=None, validate_args=None): ) = broadcast_all(total_count, probs) self.total_count = self.total_count.type_as(self.probs) else: + assert logits is not None # helps mypy ( self.total_count, self.logits, diff --git a/torch/distributions/normal.py b/torch/distributions/normal.py index 86e30ba450f5..626358d14795 100644 --- a/torch/distributions/normal.py +++ b/torch/distributions/normal.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import Tensor @@ -51,7 +52,12 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return self.stddev.pow(2) - def __init__(self, loc, scale, validate_args=None): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.scale = broadcast_all(loc, scale) if isinstance(loc, _Number) and isinstance(scale, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/one_hot_categorical.py b/torch/distributions/one_hot_categorical.py index 7e0bc03c5aba..8edb6da0b8dd 100644 --- a/torch/distributions/one_hot_categorical.py +++ b/torch/distributions/one_hot_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -44,7 +46,12 @@ class OneHotCategorical(Distribution): support = constraints.one_hot has_enumerate_support = True - def __init__(self, probs=None, logits=None, validate_args=None): + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) batch_shape = self._categorical.batch_shape event_shape = self._categorical.param_shape[-1:] diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 2cc1e298ba25..bbca7e0cba35 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union from torch import Tensor from torch.distributions import constraints @@ -31,7 +31,10 @@ class Pareto(TransformedDistribution): arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} def __init__( - self, scale: Tensor, alpha: Tensor, validate_args: Optional[bool] = None + self, + scale: Union[Tensor, float], + alpha: Union[Tensor, float], + validate_args: Optional[bool] = None, ) -> None: self.scale, self.alpha = broadcast_all(scale, alpha) base_dist = Exponential(self.alpha, validate_args=validate_args) diff --git a/torch/distributions/poisson.py b/torch/distributions/poisson.py index c3b4bacc54cb..d3fb4446baf4 100644 --- a/torch/distributions/poisson.py +++ b/torch/distributions/poisson.py @@ -1,10 +1,12 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints from torch.distributions.exp_family import ExponentialFamily from torch.distributions.utils import broadcast_all -from torch.types import _Number +from torch.types import _Number, Number __all__ = ["Poisson"] @@ -45,7 +47,11 @@ def mode(self) -> Tensor: def variance(self) -> Tensor: return self.rate - def __init__(self, rate, validate_args=None): + def __init__( + self, + rate: Union[Tensor, Number], + validate_args: Optional[bool] = None, + ) -> None: (self.rate,) = broadcast_all(rate) if isinstance(rate, _Number): batch_shape = torch.Size() diff --git a/torch/distributions/relaxed_bernoulli.py b/torch/distributions/relaxed_bernoulli.py index 4c1549660313..16ad4219627e 100644 --- a/torch/distributions/relaxed_bernoulli.py +++ b/torch/distributions/relaxed_bernoulli.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -12,7 +14,7 @@ logits_to_probs, probs_to_logits, ) -from torch.types import _Number, _size +from torch.types import _Number, _size, Number __all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] @@ -41,7 +43,13 @@ class LogitRelaxedBernoulli(Distribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.real - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: self.temperature = temperature if (probs is None) == (logits is None): raise ValueError( @@ -51,6 +59,7 @@ def __init__(self, temperature, probs=None, logits=None, validate_args=None): is_scalar = isinstance(probs, _Number) (self.probs,) = broadcast_all(probs) else: + assert logits is not None # helps mypy is_scalar = isinstance(logits, _Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits @@ -131,8 +140,15 @@ class RelaxedBernoulli(TransformedDistribution): arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval has_rsample = True - - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist: LogitRelaxedBernoulli + + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = LogitRelaxedBernoulli(temperature, probs, logits) super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 97ae3ed1857b..47314be9e44a 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional + import torch from torch import Tensor from torch.distributions import constraints @@ -42,7 +44,13 @@ class ExpRelaxedCategorical(Distribution): ) # The true support is actually a submanifold of this. has_rsample = True - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: self._categorical = Categorical(probs, logits) self.temperature = temperature batch_shape = self._categorical.batch_shape @@ -121,8 +129,15 @@ class RelaxedOneHotCategorical(TransformedDistribution): arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.simplex has_rsample = True - - def __init__(self, temperature, probs=None, logits=None, validate_args=None): + base_dist: ExpRelaxedCategorical + + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: base_dist = ExpRelaxedCategorical( temperature, probs, logits, validate_args=validate_args ) diff --git a/torch/distributions/studentT.py b/torch/distributions/studentT.py index e141939b2745..d98554f413c0 100644 --- a/torch/distributions/studentT.py +++ b/torch/distributions/studentT.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional, Union import torch from torch import inf, nan, Tensor @@ -60,7 +61,13 @@ def variance(self) -> Tensor: m[self.df <= 1] = nan return m - def __init__(self, df, loc=0.0, scale=1.0, validate_args=None): + def __init__( + self, + df: Union[Tensor, float], + loc: Union[Tensor, float] = 0.0, + scale: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: self.df, self.loc, self.scale = broadcast_all(df, loc, scale) self._chi2 = Chi2(self.df) batch_shape = self.df.size() diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index 02792ce9d309..d5fbff877413 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +from typing import Optional, Union import torch from torch import Tensor @@ -49,7 +50,12 @@ class TransformedDistribution(Distribution): arg_constraints: dict[str, constraints.Constraint] = {} - def __init__(self, base_distribution, transforms, validate_args=None): + def __init__( + self, + base_distribution: Distribution, + transforms: Union[Transform, list[Transform]], + validate_args: Optional[bool] = None, + ) -> None: if isinstance(transforms, Transform): self.transforms = [ transforms, diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 8958f1a63c87..a033ce14408b 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -3,11 +3,14 @@ import math import operator import weakref -from typing import Optional +from collections.abc import Sequence +from typing import Optional, Union import torch import torch.nn.functional as F +from torch import Tensor from torch.distributions import constraints +from torch.distributions.distribution import Distribution from torch.distributions.utils import ( _sum_rightmost, broadcast_all, @@ -92,7 +95,7 @@ class Transform: domain: constraints.Constraint codomain: constraints.Constraint - def __init__(self, cache_size=0): + def __init__(self, cache_size: int = 0) -> None: self._cache_size = cache_size self._inv: Optional[weakref.ReferenceType[Transform]] = None if cache_size == 0: @@ -218,7 +221,7 @@ class _InverseTransform(Transform): This class is private; please instead use the ``Transform.inv`` property. """ - def __init__(self, transform: Transform): + def __init__(self, transform: Transform) -> None: super().__init__(cache_size=transform._cache_size) self._inv: Transform = transform # type: ignore[assignment] @@ -285,7 +288,7 @@ class ComposeTransform(Transform): the latest single value is cached. Only 0 and 1 are supported. """ - def __init__(self, parts: list[Transform], cache_size=0): + def __init__(self, parts: list[Transform], cache_size: int = 0) -> None: if cache_size: parts = [part.with_cache(cache_size) for part in parts] super().__init__(cache_size=cache_size) @@ -413,7 +416,12 @@ class IndependentTransform(Transform): dimensions to treat as dependent. """ - def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + def __init__( + self, + base_transform: Transform, + reinterpreted_batch_ndims: int, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.base_transform = base_transform.with_cache(cache_size) self.reinterpreted_batch_ndims = reinterpreted_batch_ndims @@ -442,7 +450,7 @@ def bijective(self) -> bool: # type: ignore[override] return self.base_transform.bijective @property - def sign(self) -> int: # type: ignore[override] + def sign(self) -> int: return self.base_transform.sign def _call(self, x): @@ -486,7 +494,12 @@ class ReshapeTransform(Transform): bijective = True - def __init__(self, in_shape, out_shape, cache_size=0): + def __init__( + self, + in_shape: torch.Size, + out_shape: torch.Size, + cache_size: int = 0, + ) -> None: self.in_shape = torch.Size(in_shape) self.out_shape = torch.Size(out_shape) if self.in_shape.numel() != self.out_shape.numel(): @@ -571,7 +584,7 @@ class PowerTransform(Transform): codomain = constraints.positive bijective = True - def __init__(self, exponent, cache_size=0): + def __init__(self, exponent: Tensor, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) (self.exponent,) = broadcast_all(exponent) @@ -582,7 +595,7 @@ def with_cache(self, cache_size=1): @lazy_property def sign(self) -> int: # type: ignore[override] - return self.exponent.sign() + return self.exponent.sign() # type: ignore[return-value] def __eq__(self, other): if not isinstance(other, PowerTransform): @@ -734,7 +747,13 @@ class AffineTransform(Transform): bijective = True - def __init__(self, loc, scale, event_dim=0, cache_size=0): + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + event_dim: int = 0, + cache_size: int = 0, + ) -> None: super().__init__(cache_size=cache_size) self.loc = loc self.scale = scale @@ -771,20 +790,20 @@ def __eq__(self, other): if self.loc != other.loc: return False else: - if not (self.loc == other.loc).all().item(): + if not (self.loc == other.loc).all().item(): # type: ignore[union-attr] return False if isinstance(self.scale, _Number) and isinstance(other.scale, _Number): if self.scale != other.scale: return False else: - if not (self.scale == other.scale).all().item(): + if not (self.scale == other.scale).all().item(): # type: ignore[union-attr] return False return True @property - def sign(self) -> int: + def sign(self) -> Union[Tensor, int]: # type: ignore[override] if isinstance(self.scale, _Number): return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 return self.scale.sign() @@ -1022,7 +1041,7 @@ class PositiveDefiniteTransform(Transform): """ domain = constraints.independent(constraints.real, 2) - codomain = constraints.positive_definite # type: ignore[assignment] + codomain = constraints.positive_definite def __eq__(self, other): return isinstance(other, PositiveDefiniteTransform) @@ -1053,7 +1072,13 @@ class CatTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, lengths=None, cache_size=0): + def __init__( + self, + tseq: Sequence[Transform], + dim: int = 0, + lengths: Optional[Sequence[int]] = None, + cache_size: int = 0, + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1157,7 +1182,9 @@ class StackTransform(Transform): transforms: list[Transform] - def __init__(self, tseq, dim=0, cache_size=0): + def __init__( + self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0 + ) -> None: assert all(isinstance(t, Transform) for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] @@ -1237,12 +1264,12 @@ class CumulativeDistributionTransform(Transform): codomain = constraints.unit_interval sign = +1 - def __init__(self, distribution, cache_size=0): + def __init__(self, distribution: Distribution, cache_size: int = 0) -> None: super().__init__(cache_size=cache_size) self.distribution = distribution @property - def domain(self) -> constraints.Constraint: # type: ignore[override] + def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override] return self.distribution.support def _call(self, x): diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index 31007c924de0..37decbaadce5 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import nan, Tensor from torch.distributions import constraints @@ -50,7 +52,12 @@ def stddev(self) -> Tensor: def variance(self) -> Tensor: return (self.high - self.low).pow(2) / 12 - def __init__(self, low, high, validate_args=None): + def __init__( + self, + low: Union[Tensor, float], + high: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.low, self.high = broadcast_all(low, high) if isinstance(low, _Number) and isinstance(high, _Number): diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index f83d75c904ab..b53c4721ffc7 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torch import Tensor from torch.overrides import is_tensor_like -from torch.types import _Number +from torch.types import _Number, Number euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant @@ -23,7 +23,9 @@ ] -def broadcast_all(*values): +# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. +# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 +def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: r""" Given a list of values (possibly containing numbers), returns a list where each value is broadcasted based on the following rules: diff --git a/torch/distributions/von_mises.py b/torch/distributions/von_mises.py index 9a144fe10817..4f96a23cf55b 100644 --- a/torch/distributions/von_mises.py +++ b/torch/distributions/von_mises.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import math +from typing import Optional import torch import torch.jit @@ -126,7 +127,12 @@ class VonMises(Distribution): support = constraints.real has_rsample = False - def __init__(self, loc, concentration, validate_args=None): + def __init__( + self, + loc: Tensor, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: self.loc, self.concentration = broadcast_all(loc, concentration) batch_shape = self.loc.shape event_shape = torch.Size() diff --git a/torch/distributions/weibull.py b/torch/distributions/weibull.py index e7b3c5e0cebe..98132472b4ee 100644 --- a/torch/distributions/weibull.py +++ b/torch/distributions/weibull.py @@ -1,4 +1,6 @@ # mypy: allow-untyped-defs +from typing import Optional, Union + import torch from torch import Tensor from torch.distributions import constraints @@ -34,7 +36,12 @@ class Weibull(TransformedDistribution): } support = constraints.positive - def __init__(self, scale, concentration, validate_args=None): + def __init__( + self, + scale: Union[Tensor, float], + concentration: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: self.scale, self.concentration = broadcast_all(scale, concentration) self.concentration_reciprocal = self.concentration.reciprocal() base_dist = Exponential( diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index 225aeeb97430..1b5a51ea88f9 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -80,8 +80,8 @@ def __init__( covariance_matrix: Optional[Tensor] = None, precision_matrix: Optional[Tensor] = None, scale_tril: Optional[Tensor] = None, - validate_args=None, - ): + validate_args: Optional[bool] = None, + ) -> None: assert (covariance_matrix is not None) + (scale_tril is not None) + ( precision_matrix is not None ) == 1, ( From 6a8ab902a27fe5dbcd2521c1a6e8ebdfe454c58a Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Sun, 6 Apr 2025 07:57:40 -0700 Subject: [PATCH 219/332] [AOTI][dashboard] Fix mis-calculated memory compression ratio (#150695) Summary: https://github.com/pytorch/pytorch/pull/149817 introduced an extra warmup run to compute AOTI memory compression ratio, but since weights are only loaded once in the AOTI run, the peak memory seen in the extra warmup won't include the weight, which causes an aritifically high memory compression ratio. This PR removes that extra warmup run, and calls reset_peak_memory_stats in the proper place instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150695 Approved by: https://github.com/yushangdi --- benchmarks/dynamo/common.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 7c8a91de5202..45fe1fb9f7bf 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1395,6 +1395,8 @@ def load(cls, model, example_inputs): with torch.no_grad(): # copy.deepcopy is required to prevent any surprising side-effect, # see https://github.com/pytorch/pytorch/issues/113029 + # This will cause memory stats to be overshadowed by this eager run. + # To fix that, memory stats will be reset later. example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs) if pytree.is_namedtuple_instance(example_outputs): @@ -1411,6 +1413,14 @@ def load(cls, model, example_inputs): _produce_dynamic_shapes_for_export, combined_args ) + # delete example_outputs and reset memory stats here + del example_outputs + if current_device == "cuda": + torch.cuda.reset_peak_memory_stats() + empty_gpu_cache(current_device) + elif current_device == "hpu": + torch.hpu.reset_peak_memory_stats() + ep = torch.export.export( model, example_args, @@ -3735,10 +3745,6 @@ def run(runner, args, original_dir=None): # AOTInductor doesn't support control flow yet runner.skip_models.update(runner.skip_models_due_to_control_flow) runner.skip_models.update(runner.skip_models_due_to_export_not_supported) - - # For AOTI, we only measure the memory compression ratio at the run time - # instead of the compile time, so use a warmup run to trigger AOTI compilation. - args.use_warm_peak_memory = True elif args.backend == "torchao": assert "cuda" in args.devices, "Quantization requires CUDA device." assert args.bfloat16, "Quantization requires dtype bfloat16." From 8adfcd35c36c085e40170761daebf1c54b309254 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Sun, 6 Apr 2025 20:31:11 +0000 Subject: [PATCH 220/332] [cuDNN][SDPA] Loosen constraints for GQA for cuDNN Attention (#150337) cuDNN attention doesn't require key and value tensors to have the same number of heads Pull Request resolved: https://github.com/pytorch/pytorch/pull/150337 Approved by: https://github.com/drisspg --- aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 5 +++-- aten/src/ATen/native/transformers/sdp_utils_cpp.h | 12 +++++++----- test/test_transformers.py | 3 ++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index bb4c3d9cbc18..05acc275b468 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -553,9 +553,10 @@ bool check_for_nested_inputs(sdp_params const& params, bool debug) { TORCH_WARN("Experimental cuDNN SDPA nested tensor support is not enabled."); } return false; - } else if (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad()) { + } else if (has_for_nested_inputs(params) && (params.query.requires_grad() || params.key.requires_grad() || params.value.requires_grad())) { if (debug) { TORCH_WARN("Experimental cuDNN SDPA nested tensor support does not support backward."); + return false; } } @@ -645,7 +646,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { constexpr auto dense_constraints = c10::array_of( check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense + check_batch_size_and_num_heads_dense ); if (has_only_dense_inputs(params)) { diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 22afbac1d079..4591fa253824 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -333,13 +333,14 @@ inline bool check_safe_kv_broadcast(at::Tensor const& param, bool debug) { return true; } +template inline bool check_grouped_query_attention(sdp_params const& params, bool debug) { const auto q_num_heads = params.query.sym_size(-3); const auto k_num_heads = params.key.sym_size(-3); const auto v_num_heads = params.value.sym_size(-3); const bool same_kv_heads = k_num_heads == v_num_heads; - if (!(same_kv_heads)){ + if (requires_same_num_heads && !(same_kv_heads)){ if (debug) { TORCH_WARN( "Both fused kernels require key and value to have the same num_heads and batch_size but got: ", @@ -355,10 +356,10 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) } // Check if grouped query attention is supported and validate the number of // heads - if (q_num_heads % k_num_heads != 0) { + if (q_num_heads % k_num_heads != 0 || (!requires_same_num_heads && (q_num_heads % v_num_heads != 0))) { if (debug) { TORCH_WARN( - "FlashAttentionV2 only supports grouped query attention, where the number of heads in key/value must divide number of heads in query.", + "The number of heads in key/value must divide number of heads in query.", "Got input Key sizes(): ", params.key.sym_size(-3), ", Value sizes(): ", @@ -372,7 +373,7 @@ inline bool check_grouped_query_attention(sdp_params const& params, bool debug) return true; } -template +template inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool debug) { // This is expected to be called after check_tensor_shapes ensuring that the // size() calls won't error since the inputs are all 4 dimensional @@ -407,9 +408,10 @@ inline bool check_batch_size_and_num_heads_dense(sdp_params const& params, bool } if(params.enable_gqa && supports_gqa){ - return check_grouped_query_attention(params, debug); + return check_grouped_query_attention(params, debug); } + // same num heads condition for non-gqa case if (!same_num_heads){ if (debug) { TORCH_WARN( diff --git a/test/test_transformers.py b/test/test_transformers.py index 3a22d382d3c5..42950a84f154 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2479,7 +2479,8 @@ def test_cudnn_attention_gqa(self, device): # Sample call to SDPA - GQ query = torch.rand(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16) key = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) - value = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) + # cuDNN supports h_k != h_v + value = torch.rand(batch, 4, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) with sdpa_kernel([SDPBackend.MATH]): output_math = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) From 912102b4ecf776711436f95d2fe62d78e39ad880 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sun, 6 Apr 2025 17:28:12 +0000 Subject: [PATCH 221/332] Make at::vec::Vectorized ops work with scalars (#150380) I noticed that I couldn't use `vec::Vectorized` operations with scalars, even though there is an implicit conversion from `T` to `vec::Vectorized`, so I made it work. Test Plan: Added tests. Reverted vec_base.h, left the new tests in place, and confirmed that new tests don't compile in that state. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150380 Approved by: https://github.com/Skylion007 --- aten/src/ATen/cpu/vec/vec_base.h | 93 +++++++++++++++++++++++ aten/src/ATen/test/vec_test_all_types.cpp | 34 ++++----- aten/src/ATen/test/vec_test_all_types.h | 40 +++++++++- 3 files changed, 149 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 3e6124cbc500..02e3b65dbd46 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -737,6 +737,25 @@ struct Vectorized { } }; +// There is an implicit conversion that would make this work if +// these operators weren't template functions, but they are template +// functions (and can't be moved to be non-member friends defined in +// the class body as suggested in +// https://stackoverflow.com/questions/9787593/implicit-type-conversion-with-template/9788255#9788255 +// because we have a lot of disparate specializations of +// Vectorized). So, just explicitly make scalars work. +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(name) \ + template \ + Vectorized inline name(const Vectorized& a, T b) { \ + return name(a, Vectorized(b)); \ + } \ + template \ + Vectorized inline name(T a, const Vectorized& b) { \ + return name(Vectorized(a), b); \ + } +#define VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(op) \ + VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(operator op) + template Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { Vectorized c; @@ -746,6 +765,8 @@ Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(+) + template Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { Vectorized c; @@ -755,6 +776,8 @@ Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(-) + template Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { Vectorized c; @@ -764,6 +787,8 @@ Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(*) + template Vectorized inline operator/(const Vectorized& a, const Vectorized& b) __ubsan_ignore_float_divide_by_zero__ { @@ -774,12 +799,16 @@ Vectorized inline operator/(const Vectorized& a, const Vectorized& b) return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(/) + template , int> = 0> Vectorized inline operator%(const Vectorized& a, const Vectorized& b) __ubsan_ignore_float_divide_by_zero__ { return a - a / b * b; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(%) + template Vectorized inline operator||( const Vectorized& a, @@ -791,6 +820,8 @@ Vectorized inline operator||( return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(||) + // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template < @@ -827,6 +858,8 @@ Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(maximum) + // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template < @@ -863,6 +896,8 @@ Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(minimum) + template < class T, typename std::enable_if_t::value, int> = 0> @@ -877,6 +912,42 @@ Vectorized inline clamp( return c; } +#define VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(name) \ + template \ + Vectorized inline name( \ + const Vectorized& a, const Vectorized& b, T c) { \ + return name(a, b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + const Vectorized& a, T b, const Vectorized& c) { \ + return name(a, Vectorized(b), c); \ + } \ + \ + template \ + Vectorized inline name(const Vectorized& a, T b, T c) { \ + return name(a, Vectorized(b), Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name( \ + T a, const Vectorized& b, const Vectorized& c) { \ + return name(Vectorized(a), b, c); \ + } \ + \ + template \ + Vectorized inline name(T a, const Vectorized& b, T c) { \ + return name(Vectorized(a), b, Vectorized(c)); \ + } \ + \ + template \ + Vectorized inline name(T a, T b, const Vectorized& c) { \ + return name(Vectorized(a), Vectorized(b), c); \ + } + +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(clamp) + template < class T, typename std::enable_if_t::value, int> = 0> @@ -890,6 +961,8 @@ Vectorized inline clamp_max( return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_max) + template < class T, typename std::enable_if_t::value, int> = 0> @@ -903,6 +976,8 @@ Vectorized inline clamp_min( return c; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(clamp_min) + struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) @@ -1046,6 +1121,10 @@ inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(|) +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(^) + template < class T, typename std:: @@ -1139,6 +1218,8 @@ inline Vectorized fmadd( return a * b + c; } +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmadd) + template inline Vectorized fmsub( const Vectorized& a, @@ -1147,6 +1228,8 @@ inline Vectorized fmsub( return a * b - c; } +VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC(fmsub) + template Vectorized inline operator&&( const Vectorized& a, @@ -1158,6 +1241,8 @@ Vectorized inline operator&&( return ret; } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP(&&) + template std::enable_if_t< scale == 1 || scale == 2 || scale == 4 || scale == 8, @@ -1295,6 +1380,8 @@ deinterleave2(const Vectorized& a, const Vectorized& b) { Vectorized::loadu(static_cast(buffer2))); } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(deinterleave2) + // clang-format off // inverse operation of deinterleave2 // Example inputs for AVX512: @@ -1332,6 +1419,12 @@ interleave2(const Vectorized& a, const Vectorized& b) { Vectorized::loadu(static_cast(buffer2))); } +VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC(interleave2) + +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_FUNC +#undef VECTORIZED_SUPPORT_SCALARS_FOR_BINARY_OP +#undef VECTORIZED_SUPPORT_SCALARS_FOR_TERNARY_FUNC + template inline void convert(const src_T* src, dst_T* dst, int64_t n) { #ifndef _MSC_VER diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index beca3043ce71..ab9c8a2aea6a 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -329,7 +329,7 @@ namespace { test_binary( NAME_INFO(fmod), RESOLVE_OVERLOAD(std::fmod), - [](vec v0, vec v1) { return v0.fmod(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).fmod(v1); }, createDefaultBinaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_fmod)); } @@ -599,8 +599,8 @@ namespace { test_binary( NAME_INFO(atan2), RESOLVE_OVERLOAD(std::atan2), - [](vec v0, vec v1) { - return v0.atan2(v1); + [](const auto& v0, const auto& v1) { + return vec(v0).atan2(v1); }, createDefaultBinaryTestCase(TestSeed())); } @@ -609,7 +609,7 @@ namespace { test_binary( NAME_INFO(pow), RESOLVE_OVERLOAD(std::pow), - [](vec v0, vec v1) { return v0.pow(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).pow(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(RealTests, Hypot) { @@ -617,7 +617,7 @@ namespace { test_binary( NAME_INFO(hypot), RESOLVE_OVERLOAD(std::hypot), - [](vec v0, vec v1) { return v0.hypot(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).hypot(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(RealTests, NextAfter) { @@ -625,7 +625,7 @@ namespace { test_binary( NAME_INFO(nextafter), RESOLVE_OVERLOAD(std::nextafter), - [](vec v0, vec v1) { return v0.nextafter(v1); }, + [](const auto& v0, const auto& v1) { return vec(v0).nextafter(v1); }, createDefaultBinaryTestCase(TestSeed(), false, true)); } TYPED_TEST(Interleave, Interleave) { @@ -675,7 +675,7 @@ namespace { test_binary( NAME_INFO(plus), std::plus(), - [](const vec& v0, const vec& v1) -> vec { + [](const auto& v0, const auto& v1) -> vec { return v0 + v1; }, createDefaultBinaryTestCase(TestSeed()), @@ -687,7 +687,7 @@ namespace { test_binary( NAME_INFO(minus), std::minus(), - [](const vec& v0, const vec& v1) -> vec { + [](const auto& v0, const auto& v1) -> vec { return v0 - v1; }, createDefaultBinaryTestCase(TestSeed()), @@ -698,7 +698,7 @@ namespace { test_binary( NAME_INFO(mult), RESOLVE_OVERLOAD(local_multiply), - [](const vec& v0, const vec& v1) { return v0 * v1; }, + [](const auto& v0, const auto& v1) { return v0 * v1; }, createDefaultBinaryTestCase(TestSeed(), false, true), RESOLVE_OVERLOAD(filter_mult_overflow)); } @@ -708,7 +708,7 @@ namespace { test_binary( NAME_INFO(division), RESOLVE_OVERLOAD(local_division), - [](const vec& v0, const vec& v1) { return v0 / v1; }, + [](const auto& v0, const auto& v1) { return v0 / v1; }, createDefaultBinaryTestCase(seed), RESOLVE_OVERLOAD(filter_div_ub)); } @@ -717,7 +717,7 @@ namespace { test_binary( NAME_INFO(bit_and), RESOLVE_OVERLOAD(local_and), - [](const vec& v0, const vec& v1) { return v0 & v1; }, + [](const auto& v0, const auto& v1) { return v0 & v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Bitwise, BitOr) { @@ -725,7 +725,7 @@ namespace { test_binary( NAME_INFO(bit_or), RESOLVE_OVERLOAD(local_or), - [](const vec& v0, const vec& v1) { return v0 | v1; }, + [](const auto& v0, const auto& v1) { return v0 | v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Bitwise, BitXor) { @@ -733,7 +733,7 @@ namespace { test_binary( NAME_INFO(bit_xor), RESOLVE_OVERLOAD(local_xor), - [](const vec& v0, const vec& v1) { return v0 ^ v1; }, + [](const auto& v0, const auto& v1) { return v0 ^ v1; }, createDefaultBinaryTestCase(TestSeed(), true)); } TYPED_TEST(Comparison, Equal) { @@ -796,7 +796,7 @@ namespace { test_binary( NAME_INFO(minimum), minimum, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return minimum(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -807,7 +807,7 @@ namespace { test_binary( NAME_INFO(maximum), maximum, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return maximum(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -818,7 +818,7 @@ namespace { test_binary( NAME_INFO(clamp min), clamp_min, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return clamp_min(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); @@ -829,7 +829,7 @@ namespace { test_binary( NAME_INFO(clamp max), clamp_max, - [](const vec& v0, const vec& v1) { + [](const auto& v0, const auto& v1) { return clamp_max(v0, v1); }, createDefaultBinaryTestCase(TestSeed())); diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 6ff988709582..cb877a9f77eb 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -991,6 +991,10 @@ void test_binary( CACHE_ALIGN VT vals0[el_count]; CACHE_ALIGN VT vals1[el_count]; CACHE_ALIGN VT expected[el_count]; + [[maybe_unused]] CACHE_ALIGN VT expectedWithLeftScalar[el_count]; + [[maybe_unused]] CACHE_ALIGN VT expectedWithRightScalar[el_count]; + [[maybe_unused]] VT scalar0; + [[maybe_unused]] VT scalar1; bool bitwise = testCase.isBitwise(); UVT default_start = std::is_floating_point_v ? std::numeric_limits::lowest() : std::numeric_limits::min(); UVT default_end = std::numeric_limits::max(); @@ -1000,6 +1004,7 @@ void test_binary( int trialCount = getTrialCount(test_trials, domains_size); TestSeed seed = testCase.getTestSeed(); uint64_t changeSeedBy = 0; + constexpr bool kCanUseScalar = std::is_invocable_v && std::is_invocable_v; for (const CheckWithinDomains& dmn : testCase.getDomains()) { size_t dmn_argc = dmn.ArgsDomain.size(); UVT start0 = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; @@ -1012,9 +1017,23 @@ void test_binary( for (const auto k : c10::irange(el_count)) { vals0[k] = generator0.get(); vals1[k] = generator1.get(); + if (k == 0) { + scalar0 = vals0[0]; + scalar1 = vals1[0]; + } call_filter(filter, vals0[k], vals1[k]); + if constexpr (kCanUseScalar) { + call_filter(filter, vals0[k], scalar1); + call_filter(filter, scalar0, vals1[k]); + } + } + for (const auto k : c10::irange(el_count)) { // map operator expected[k] = expectedFunction(vals0[k], vals1[k]); + if constexpr (kCanUseScalar) { + expectedWithLeftScalar[k] = expectedFunction(scalar0, vals1[k]); + expectedWithRightScalar[k] = expectedFunction(vals0[k], scalar1); + } } // test auto input0 = vec_type::loadu(vals0); @@ -1024,8 +1043,27 @@ void test_binary( AssertVectorized vecAssert( testNameInfo, seed, vec_expected, actual, input0, input1); if (vecAssert.check( - bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { return; + } + if constexpr (kCanUseScalar) { + auto actualWithLeftScalar = actualFunction(scalar0, input1); + auto actualWithRightScalar = actualFunction(input0, scalar1); + auto vec_expectedWithLeftScalar = vec_type::loadu(expectedWithLeftScalar); + auto vec_expectedWithRightScalar = vec_type::loadu(expectedWithRightScalar); + AssertVectorized vecAssertWithLeftScalar( + testNameInfo, seed, vec_expectedWithLeftScalar, actualWithLeftScalar, scalar0, input1); + if (vecAssertWithLeftScalar.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { + return; + } + AssertVectorized vecAssertWithRightScalar( + testNameInfo, seed, vec_expectedWithRightScalar, actualWithRightScalar, input0, scalar1); + if (vecAssertWithRightScalar.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) { + return; + } + } } // trial changeSeedBy += 1; } From 0aaf35310afbe67580a0aa4ea53738ede73e44dd Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sun, 6 Apr 2025 17:28:12 +0000 Subject: [PATCH 222/332] Overload unary - operator on at::vec::Vectorized to call neg() (#150568) Makes Vectorized look even more like a scalar type, getting me closer to being able to use the same generic code with scalars and Vectorized (e.g., for sigmoid, which needs `exp(-x)`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/150568 Approved by: https://github.com/Skylion007 ghstack dependencies: #150380 --- aten/src/ATen/cpu/vec/vec_base.h | 5 +++++ aten/src/ATen/test/vec_test_all_types.cpp | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index 02e3b65dbd46..0f24ccf385df 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -737,6 +737,11 @@ struct Vectorized { } }; +template +Vectorized inline operator-(const Vectorized& a) { + return a.neg(); +} + // There is an implicit conversion that would make this work if // these operators weren't template functions, but they are template // functions (and can't be moved to be non-member friends defined in diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index ab9c8a2aea6a..db37925add67 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -192,6 +192,11 @@ namespace { [](vec v) { return v.neg(); }, createDefaultUnaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_int_minimum)); + test_unary( + NAME_INFO(negate), std::negate>(), + [](vec v) { return -v; }, + createDefaultUnaryTestCase(TestSeed()), + RESOLVE_OVERLOAD(filter_int_minimum)); } TYPED_TEST(SignManipulationHalfPrecision, AbsNegate) { typedef enum { From 47b494ef69c8e9d8133a1464d760860dcd235c95 Mon Sep 17 00:00:00 2001 From: Paul Ganssle Date: Sun, 6 Apr 2025 22:25:32 +0000 Subject: [PATCH 223/332] Add type hints to `_tensor_docs.add_docstr_all` (#150715) There is some sort of bug in `pytype` where if this function doesn't have type hints, `pytype` will spend 10 minutes inferring the types. Not that this matters much for a project not using `pytype`, but it led me to realize that this function could easily be type hinted and is not, so here is a PR adding some type hints. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150715 Approved by: https://github.com/Skylion007 --- torch/_tensor_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9e1956763242..076491993d46 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -6,7 +6,7 @@ from torch._torch_docs import parse_kwargs, reproducibility_notes -def add_docstr_all(method, docstr): +def add_docstr_all(method: str, docstr: str) -> None: add_docstr(getattr(torch._C.TensorBase, method), docstr) From 370ba6b96f8e492c91c2e05a3e7a4fc7199100d4 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Mon, 7 Apr 2025 01:45:03 +0000 Subject: [PATCH 224/332] [codemod] Fix `-Wambiguous-reversed-operator` in aten/src/ATen/cuda/tunable/Tunable.h (#150744) Summary: `-Wambiguous-reversed-operator` warns about ambiguous reversed operators, e.g. `a < b` and `b > a` are both valid. Such operators are disallowed in C++20. This codemod fixes the warnings. #buildsonlynotests - If this diff compiles, it works. - If you approve of this diff, please use the "Accept & Ship" button :-) Test Plan: Sandcastle Differential Revision: D72535527 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150744 Approved by: https://github.com/drisspg --- aten/src/ATen/cuda/tunable/Tunable.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index b8187b4254bf..3ea292582f60 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -41,8 +41,8 @@ class TORCH_CUDA_CPP_API ResultEntry { public: explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} explicit ResultEntry(std::string key, double time, const std::string& blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(blas_sig) {} - bool operator==(const ResultEntry& other) { return key_ == other.key_; } - bool operator!=(const ResultEntry& other) { return key_ != other.key_; } + bool operator==(const ResultEntry& other) const { return key_ == other.key_; } + bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } operator std::string () { return key_; } std::string GetKey() const { return key_; } double GetTime() const { return time_; } From d8d306cbc645466c96d5de690e3c77bc33f8d0c1 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Mon, 7 Apr 2025 01:47:32 +0000 Subject: [PATCH 225/332] Suppress `-Wunused-function` for DSA (#150735) Test Plan: Sandcastle Reviewed By: dtolnay Differential Revision: D72458590 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150735 Approved by: https://github.com/eqy, https://github.com/cyyever --- c10/cuda/CUDADeviceAssertion.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/c10/cuda/CUDADeviceAssertion.h b/c10/cuda/CUDADeviceAssertion.h index 063c7836932a..6b98e78aa469 100644 --- a/c10/cuda/CUDADeviceAssertion.h +++ b/c10/cuda/CUDADeviceAssertion.h @@ -6,6 +6,7 @@ namespace c10::cuda { #ifdef TORCH_USE_CUDA_DSA +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") // Copy string from `src` to `dst` static __device__ void dstrcpy(char* dst, const char* src) { int i = 0; @@ -64,6 +65,7 @@ static __device__ void dsa_add_new_assertion_failure( self.thread_id[1] = thread_id.y; self.thread_id[2] = thread_id.z; } +C10_CLANG_DIAGNOSTIC_POP() // Emulates a kernel assertion. The assertion won't stop the kernel's progress, // so you should assume everything the kernel produces is garbage if there's an From d98575806ba3f2b67439c241e980df8f98923f44 Mon Sep 17 00:00:00 2001 From: "Han, Chao1" Date: Mon, 7 Apr 2025 01:54:17 +0000 Subject: [PATCH 226/332] Generalize compile collective to avoid cuda-bias (#150405) Fixes https://github.com/intel/torch-xpu-ops/issues/1527 Let the combination of `compile` and `collective` to support more devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150405 Approved by: https://github.com/guangyey, https://github.com/jansel Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> --- torch/_dynamo/output_graph.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 92a6ea2f15c8..856ae4e32973 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -79,6 +79,7 @@ from .code_context import code_context from .codegen import PyCodegen from .current_scope_id import enter_new_scope +from .device_interface import get_interface_for_device from .exc import ( BackendCompilerFailed, exceptions_allowed_to_be_fallback, @@ -1347,8 +1348,14 @@ def run_compiler_collective(self, tx): }, payload_fn=lambda: ds.local_state.render(), ) + device_types = compile_pg._device_types + assert len(device_types) == 1, ( + "Expect only one device type but got {}".format("+".join(device_types)) + ) with ( - torch.cuda.device(compile_pg.rank() % torch.cuda.device_count()), + get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined] + compile_pg.rank() % torch.accelerator.device_count() + ), dynamo_timed("compiler_collective", log_pt2_compile_event=True), ): all_states = [None] * compile_pg.size() From d86c14156d875b782b82dda96842a1f77910f010 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Sun, 6 Apr 2025 09:08:18 +0000 Subject: [PATCH 227/332] Generalize poison fork logic for each device backend (#144664) # Motivation Generalize the posion_fork code to make it reusable across different devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144664 Approved by: https://github.com/EikanWang, https://github.com/albanD --- torch/csrc/cuda/Module.cpp | 36 ++++++------------------ torch/csrc/mps/Module.cpp | 30 +++----------------- torch/csrc/mtia/Module.cpp | 31 ++++----------------- torch/csrc/utils/device_lazy_init.cpp | 40 +++++++++++++++++++++++++++ torch/csrc/utils/device_lazy_init.h | 17 ++++++++++++ torch/csrc/xpu/Module.cpp | 34 ++++------------------- 6 files changed, 80 insertions(+), 108 deletions(-) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 1ff4079a56e5..f5365a674d29 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -51,32 +51,9 @@ #include #include #include -#ifndef WIN32 -#include -#endif using namespace torch; -static bool in_bad_fork = false; // True for children forked after cuda init - -#ifndef WIN32 -// Called in the forked child if cuda has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kCUDA, true); -} -#endif - -// Should be called before the first cuda call. -// Note: This is distinct from initExtension because a stub cuda implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - //////////////////////////////////////////////////////////////////////////////// // CUDA management methods //////////////////////////////////////////////////////////////////////////////// @@ -160,14 +137,17 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + // Note: This is distinct from initExtension because a stub cuda + // implementation has some working functions (e.g. device_count) but cannot + // fully initialize. + torch::utils::register_fork_handler_for_device_init(at::kCUDA); return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + torch::utils::register_fork_handler_for_device_init(at::kCUDA); #ifdef CUDA_ARCH_FLAGS static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS); return THPUtils_packString(flags); @@ -179,7 +159,7 @@ PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kCUDA)); END_HANDLE_TH_ERRORS } @@ -1513,8 +1493,8 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { "please rebuild pytorch without asan if you need to use this module"); #endif HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kCUDA)); + torch::utils::register_fork_handler_for_device_init(at::kCUDA); at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 3694cd194179..0ec9b8418c6e 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -6,16 +6,12 @@ #include #include #include +#include #include #include #include #include -// pthread.h is included for tracking bad forks -#ifndef WIN32 -#include -#endif - #ifdef USE_MPS #include #include @@ -23,27 +19,9 @@ namespace torch::mps { -namespace { -// True for children forked after mps init -static bool in_bad_fork = false; - -// Called in the forked child if mps has already been initialized -static void forked_mps_child() { - in_bad_fork = true; -} - -// Should be called before the first mps call. -static void track_bad_mps_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_mps_child); -#endif -} -} // namespace - static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kMPS)); END_HANDLE_TH_ERRORS } @@ -51,7 +29,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator( PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - track_bad_mps_fork(); + torch::utils::register_fork_handler_for_device_init(at::kMPS); return THPGenerator_initDefaultGenerator( at::detail::getMPSHooks().getDefaultGenerator()); END_HANDLE_TH_ERRORS @@ -59,8 +37,8 @@ static PyObject* MPSModule_getDefaultMPSGenerator( static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - track_bad_mps_fork(); if (at::detail::getMPSHooks().hasMPS()) { + torch::utils::register_fork_handler_for_device_init(at::kMPS); Py_RETURN_TRUE; } else { Py_RETURN_FALSE; diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 405b9d780023..ec6229967e0b 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -7,38 +7,15 @@ #include #include #include -#ifndef WIN32 -#include -#endif namespace torch::mtia { -static bool in_bad_fork = false; // True for children forked after mtia init - -#ifndef WIN32 -// Called in the forked child if mtia has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kMTIA, true); -} -#endif - -// Should be called before the first mtia call. -// Note: This is distinct from initExtension because a stub mtia implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - void initModule(PyObject* module) { auto m = py::handle(module).cast(); m.def("_mtia_init", []() { - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA)); + torch::utils::register_fork_handler_for_device_init(at::kMTIA); at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); @@ -47,7 +24,9 @@ void initModule(PyObject* module) { return at::detail::isMTIAHooksBuilt(); }); - m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); + m.def("_mtia_isInBadFork", []() { + return torch::utils::is_device_in_bad_fork(at::kMTIA); + }); m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { torch::utils::device_lazy_init(at::kMTIA); diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index 74adb6b5e6b0..c5a6512b363c 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -1,13 +1,23 @@ #include +#include #include #include #include #include + +#ifndef WIN32 +#include +#endif + namespace torch::utils { namespace { std::array is_initialized{}; +std::array is_in_bad_fork{}; +std::array + at_fork_once_flags{}; +std::optional at_fork_device_type{}; } // anonymous namespace @@ -58,4 +68,34 @@ void set_requires_device_init(at::DeviceType device_type, bool value) { is_initialized[static_cast(device_type)] = !value; } +bool is_device_in_bad_fork(at::DeviceType device_type) { + return is_in_bad_fork[static_cast(device_type)]; +} + +void set_device_in_bad_fork(at::DeviceType device_type, bool value) { + is_in_bad_fork[static_cast(device_type)] = value; +} + +// Should be called before the first device runtime call. +void register_fork_handler_for_device_init(at::DeviceType device_type) { +#ifndef WIN32 + auto& flag = at_fork_once_flags[static_cast(device_type)]; + c10::call_once(flag, [device_type]() { + TORCH_CHECK( + !at_fork_device_type, + "Only one device type can be registered. But now, we have two device types: ", + at_fork_device_type.value(), + " and ", + device_type); + at_fork_device_type = device_type; + pthread_atfork(nullptr, nullptr, []() { + set_device_in_bad_fork(at_fork_device_type.value(), true); + if (is_device_lazy_init_supported(at_fork_device_type.value())) { + set_requires_device_init(at_fork_device_type.value(), true); + } + }); + }); +#endif +} + } // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index e1f480a60f77..e65f16ace163 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -67,4 +67,21 @@ inline void maybe_initialize_device( bool is_device_initialized(at::DeviceType device_type); +TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type); + +TORCH_PYTHON_API void set_device_in_bad_fork( + at::DeviceType device_type, + bool value); + +TORCH_PYTHON_API void register_fork_handler_for_device_init( + at::DeviceType device_type); + +inline void maybe_register_fork_handler_for_device_init( + std::optional& device_type) { + if (!device_type.has_value()) { + return; + } + register_fork_handler_for_device_init(device_type.value()); +} + } // namespace torch::utils diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 43ad06365efc..8144dddd8298 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,32 +11,8 @@ #include #include -#ifndef WIN32 -#include -#endif - using namespace torch; -static bool in_bad_fork = false; // True for children forked after xpu init - -#ifndef WIN32 -// Called in the forked child if xpu has already been initialized -static void forked_child() { - in_bad_fork = true; - torch::utils::set_requires_device_init(at::kXPU, true); -} -#endif - -// Should be called before the first xpu call. It is mainly called in lazy_init. -// Note: This is distinct from initExtension because a stub xpu implementation -// has some working functions (e.g. device_count) but cannot fully initialize. -static void poison_fork() { -#ifndef WIN32 - static auto result [[maybe_unused]] = - pthread_atfork(nullptr, nullptr, forked_child); -#endif -} - // XPU management methods static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { @@ -52,7 +28,7 @@ static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(in_bad_fork); + return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kXPU)); END_HANDLE_TH_ERRORS } @@ -115,7 +91,9 @@ static PyObject* THXPModule_getDeviceCount_wrap( PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - poison_fork(); + // Note: This is distinct from initExtension because a stub xpu implementation + // has some working functions (e.g. device_count) but cannot fully initialize. + torch::utils::register_fork_handler_for_device_init(at::kXPU); return THPUtils_packUInt64(at::xpu::device_count()); END_HANDLE_TH_ERRORS } @@ -420,8 +398,8 @@ static void initXpuMethodBindings(PyObject* module) { // classes static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level - poison_fork(); + TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kXPU)); + torch::utils::register_fork_handler_for_device_init(at::kXPU); at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu")); From b6929aef08eff63e67094ccec2233b6bfdec931d Mon Sep 17 00:00:00 2001 From: eellison Date: Sun, 6 Apr 2025 15:30:22 -0700 Subject: [PATCH 228/332] Fix conv2d strided prologue (#150697) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150697 Approved by: https://github.com/drisspg --- test/inductor/test_max_autotune.py | 41 +++++++++++++++++++++++++++++ torch/_inductor/select_algorithm.py | 11 ++++---- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 499dcbf4ae47..ce1263d502a1 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1380,6 +1380,47 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): "del", num_deallocs, exactly=True ).run(code_str) + @parametrize("prologue", (False, True)) + @unittest.skipIf(TEST_WITH_ROCM, "ROCM Different layout decisions") + def test_conv1x1_cast(self, prologue): + with torch._inductor.config.patch( + prologue_fusion=prologue, force_layout_optimization=True + ): + conv1x1 = ( + torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1) + .to(memory_format=torch.channels_last) + .to(GPU_TYPE) + .to(dtype=torch.float16) + ) + input_tensor = ( + torch.randn(4, 3, 32, 32) + .contiguous(memory_format=torch.channels_last) + .to(GPU_TYPE) + ) + + def foo(mod, input): + return torch.nn.functional.conv2d( + input, + mod.weight.to(input.dtype), + None, + mod.stride, + mod.padding, + mod.dilation, + mod.groups, + ) + + with torch.no_grad(): + out_eager = foo(conv1x1, input_tensor) + foo_c = torch.compile(foo) + out, code = run_and_get_code(foo_c, conv1x1, input_tensor) + + FileCheck().check_not("extern_kernels.convolution").run(code[0]) + if prologue: + self.check_code( + code[0], num_kernels=1, num_allocs=1, num_deallocs=2 + ) + self.assertEqual(out_eager, out, atol=1e-2, rtol=0) + @parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250))) def test_upcast(self, sizes): M, K, N = sizes diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index fed0f9ebebd7..1fbb9aff8580 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -746,11 +746,10 @@ def load_input( indices, self.range_trees[0].construct_entries(lengths) ): range_tree_entry.set_name(name) - contiguous_index = sympy_dot( - ir.FlexibleLayout.contiguous_strides(lengths), index_symbols - ) - contiguous_index = self.rename_indexing(contiguous_index) - self.body.writeline("xindex = " + texpr(contiguous_index)) + + strided_index = sympy_dot(input_node.get_stride(), index_symbols) + strided_index = self.rename_indexing(strided_index) + self.body.writeline("xindex = " + texpr(strided_index)) xindex_range_root = self.range_trees[0].lookup( sympy.Integer(1), sympy_product(lengths) @@ -823,7 +822,7 @@ def store( output_index = self.rename_indexing(output_index) - if output_index == contiguous_index: + if output_index == strided_index: output_index_str = "xindex" else: out_indexing = self.indexing( From 24aadb40fb23a9de67c5c147d0679dfe6ab6fc95 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Mon, 7 Apr 2025 03:10:03 +0000 Subject: [PATCH 229/332] [precompile] Serialization for GlobalStateGuard (#150636) Summary: To preserve global state guards we need to make the C++ type serialzable. Using json because it's easier to do and we don't have a lot of data in global state. Test Plan: test_dynamo -k test_global_state_guard_serialization Differential Revision: D72410611 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150636 Approved by: https://github.com/williamwen42 --- c10/util/typeid.h | 2 +- test/dynamo/test_misc.py | 53 +++++++++++++++++++++++++ torch/csrc/dynamo/guards.cpp | 77 ++++++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 1 deletion(-) diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 20959f64180e..1140fc703b59 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -477,7 +477,7 @@ class C10_API TypeMeta final { /** * convert TypeMeta handles to ScalarType enum values */ - inline ScalarType toScalarType() { + inline ScalarType toScalarType() const { if (C10_LIKELY(isScalarType())) { return static_cast(index_); } diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b91129e6c1c4..53c3e8b624ca 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -11,6 +11,7 @@ import gc import importlib import itertools +import json import logging import math import operator @@ -3185,6 +3186,58 @@ def fn(m, x): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 4) + def test_global_state_guard_serialization(self): + GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard + guards = GlobalStateGuard() + serialized_guards = guards.dump() + json_guards = json.loads(serialized_guards) + + samples = [] + # Test on non autocast state and autocast cache states. + self.assertIn("autocast_state", json_guards) + for key, value in json_guards.items(): + if type(value) == int: + variant = value + 1 + elif type(value) == bool: + variant = not value + elif isinstance(value, dict) and key == "autocast_state": + variant = value.copy() + variant["cached_enabled"] = not variant["cached_enabled"] + continue + else: + self.fail(f"Unknown global state type {key}: {value}") + new_dict = json_guards.copy() + new_dict[key] = variant + samples.append(new_dict) + + for sample in samples: + guards.load(json.dumps(sample)) + self.assertFalse(guards.check()) + + guards.load(json.dumps(json_guards)) + self.assertTrue(guards.check()) + + # Test on autocast states. + def _test_autocast(dtype): + with torch.autocast("cpu", dtype): + guards = GlobalStateGuard() + serialized_guards = guards.dump() + json_guards = json.loads(serialized_guards) + + for i, enabled in enumerate(json_guards["autocast_state"]["enabled"]): + if enabled: + self.assertEqual( + type(json_guards["autocast_state"]["dtype"][i]), int + ) + json_guards["autocast_state"]["dtype"][i] += 1 + guards.load(json.dumps(json_guards)) + self.assertFalse(guards.check()) + + _test_autocast(torch.float16) + _test_autocast(torch.float32) + _test_autocast(torch.float64) + _test_autocast(torch.bfloat16) + def test_type_copy(self): def fn(seq): a, b = seq diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 9df910662742..bbf66afeb18d 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -20,6 +20,8 @@ #include +#include + #ifdef USE_CUDA #include #endif @@ -552,6 +554,20 @@ struct AutocastState { } return true; } + + template + friend void to_json(T& json_j, const AutocastState& json_t) { + json_j["enabled"] = json_t.enabled; + json_j["dtype"] = json_t.dtype; + json_j["cached_enabled"] = json_t.cache_enabled; + } + + template + friend void from_json(const T& json_j, AutocastState& json_t) { + json_t.enabled = json_j.at("enabled"); + json_t.dtype = json_j.at("dtype"); + json_t.cache_enabled = json_j.at("cached_enabled"); + } }; // TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is @@ -623,6 +639,40 @@ struct GlobalStateGuard { return os.str(); } + template + friend void to_json(T& json_j, const GlobalStateGuard& json_t) { + json_j["grad_mode"] = json_t._grad_mode; + json_j["autocast_state"] = json_t._autocast_state; + json_j["torch_function"] = json_t._torch_function; + json_j["torch_function_all_disabled"] = json_t._torch_function_all_disabled; + json_j["deterministic_algorithms"] = json_t._deterministic_algorithms; + json_j["deterministic_algorithms_warn_only"] = + json_t._deterministic_algorithms_warn_only; + json_j["allow_tf32"] = json_t._allow_tf32; + json_j["allow_fp16_reduce"] = json_t._allow_fp16_reduce; + json_j["allow_bf16_reduce"] = json_t._allow_bf16_reduce; + json_j["num_threads"] = json_t._num_threads; + json_j["default_dtype"] = json_t._default_dtype.toScalarType(); + } + + template + friend void from_json(const T& json_j, GlobalStateGuard& json_t) { + json_t._grad_mode = json_j.at("grad_mode"); + json_t._autocast_state = json_j.at("autocast_state"); + json_t._torch_function = json_j.at("torch_function"); + json_t._torch_function_all_disabled = + json_j.at("torch_function_all_disabled"); + json_t._deterministic_algorithms = json_j.at("deterministic_algorithms"); + json_t._deterministic_algorithms_warn_only = + json_j.at("deterministic_algorithms_warn_only"); + json_t._allow_tf32 = json_j.at("allow_tf32"); + json_t._allow_fp16_reduce = json_j.at("allow_fp16_reduce"); + json_t._allow_bf16_reduce = json_j.at("allow_bf16_reduce"); + json_t._num_threads = json_j.at("num_threads"); + json_t._default_dtype = + caffe2::TypeMeta::fromScalarType(json_j.at("default_dtype")); + } + bool _grad_mode; AutocastState _autocast_state; bool _torch_function; @@ -663,6 +713,25 @@ PyObject* GlobalStateGuard_reason( return PyUnicode_FromString(self->reason().c_str()); } +PyObject* GlobalStateGuard_dump( + GlobalStateGuard* self, + PyObject* args, + PyObject* kwargs) { + return PyUnicode_FromString(nlohmann::json(*self).dump().c_str()); +} + +PyObject* GlobalStateGuard_load( + GlobalStateGuard* self, + PyObject* args, + PyObject* kwargs) { + char* json; + if (!PyArg_ParseTuple(args, "s", &json)) { + throw std::runtime_error("Cannot parse as json string."); + } + nlohmann::json::parse(json).get_to(*self); + Py_RETURN_NONE; +} + // NOLINTNEXTLINE(*array*) static PyMethodDef GlobalStateGuard_methods[] = { {"check", @@ -673,6 +742,14 @@ static PyMethodDef GlobalStateGuard_methods[] = { (PyCFunction)(void*)GlobalStateGuard_reason, METH_NOARGS, "Return string reason for guard check failing"}, + {"dump", + (PyCFunction)(void*)GlobalStateGuard_dump, + METH_NOARGS, + "Return serialized json format"}, + {"load", + (PyCFunction)(void*)GlobalStateGuard_load, + METH_VARARGS, + "Parse serialized json format"}, {nullptr}}; static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0) }; From 164d2c887b45ec2673b9bbc08e457bfab0a1eb3c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Sat, 5 Apr 2025 00:46:49 +0000 Subject: [PATCH 230/332] Add check in `test_cow_input` to ensure COW data is never changed (#150723) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150723 Approved by: https://github.com/Skylion007 --- test/test_ops.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 871b643568eb..d15fa7c6659d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1825,6 +1825,7 @@ def check_ignore_materialize(idx_or_kw, allow_list): def check_cow_input( arg, arg_copy, + arg_raw, idx_or_kw, backward_or_forward="forward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward, @@ -1837,6 +1838,13 @@ def check_cow_input( ) + f" during {backward_or_forward} call" if is_strided_tensor(arg): + self.assertTrue( + torch._C._is_cow_tensor(arg_raw), + msg=( + f"{arg_name} raw input should remain COW, but it " + "unexpectedly materialized." + ), + ) is_cow = torch._C._is_cow_tensor(arg) if supports_cow_input_no_materialize and not check_ignore_materialize( @@ -1861,6 +1869,17 @@ def check_cow_input( "but the operation mutated its data." ), ) + else: + self.assertTrue( + torch.allclose( + arg_raw, arg_copy, rtol=0, atol=0, equal_nan=True + ), + msg=( + f"{arg_name} materialized, which is allowed in this " + "case, but the COW input data was mutated, which is " + "not allowed." + ), + ) for sample in samples: args_raw = [sample.input] + list(sample.args) @@ -1901,10 +1920,10 @@ def check_cow_input( # Check that COW inputs remain COW after the forward op is executed for idx, arg in enumerate(args): - check_cow_input(arg, args_copy[idx], idx) + check_cow_input(arg, args_copy[idx], args_raw[idx], idx) for kw, arg in kwargs.items(): - check_cow_input(arg, kwargs_copy[kw], kw) + check_cow_input(arg, kwargs_copy[kw], kwargs_raw[kw], kw) # Call backward op if it is supported. This part of the test is # based on `composite_compliance.check_backward_formula` @@ -1954,6 +1973,7 @@ def check_cow_input( check_cow_input( arg, args_copy[idx], + args_raw[idx], idx, backward_or_forward="backward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, @@ -1965,6 +1985,7 @@ def check_cow_input( check_cow_input( output_grad, output_grads_copy[idx], + output_grads_raw[idx], f"output grad {idx}", backward_or_forward="backward", supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward, From 25662d38d59bde6a10ef8cc32f65a510d30606cd Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 7 Apr 2025 11:35:56 +0000 Subject: [PATCH 231/332] [xla hash update] update the pinned xla hash (#132021) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned xla hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132021 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/xla.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 2925b494d999..96bf43f4c0e2 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -760675ad9aa8e7202d4f9f51fe862e8a9bedb713 +ac9a39f4b768cef09b9d2be8e074be496d7783b6 From cdf3b63e32dc757ac0a9ab3d7c52d5c95ae2353b Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Mon, 7 Apr 2025 11:49:55 +0000 Subject: [PATCH 232/332] Update slow tests (#150283) This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/weekly.yml). Update the list of slow tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150283 Approved by: https://github.com/pytorchbot --- test/slow_tests.json | 596 +++++++++++++++++++++---------------------- 1 file changed, 293 insertions(+), 303 deletions(-) diff --git a/test/slow_tests.json b/test/slow_tests.json index 7434d944c2d0..bbda0f96278f 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,305 +1,295 @@ { - "EndToEndLSTM (__main__.RNNTest)": 187.95632934570312, - "MultiheadAttention (__main__.ModulesTest)": 137.24066670735678, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 216.9356689453125, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 159.3027776082357, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 112.10600026448567, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 65.38766564263238, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 173.56966654459634, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 75.28399658203125, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 163.19466654459634, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 88.1193339029948, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 60.00295284816197, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 83.75133260091145, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 691.2717827690972, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 117.44299926757813, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 503.3826666937934, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 503.24066840277777, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 126.52850087483723, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 61.86766688028971, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 163.50066630045572, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 97.42933400472005, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 308.0576663547092, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 134.46916961669922, - "test_builtin_equivalent_funcs (__main__.TorchFunctionModeTests)": 81.6673030275287, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 355.91133287217883, - "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 73.32400004069011, - "test_comprehensive_constant_pad_nd_cpu_float32 (__main__.TestInductorOpInfoCPU)": 70.80933125813802, - "test_comprehensive_constant_pad_nd_cpu_float64 (__main__.TestInductorOpInfoCPU)": 70.98533376057942, - "test_comprehensive_constant_pad_nd_cpu_int32 (__main__.TestInductorOpInfoCPU)": 67.57033284505208, - "test_comprehensive_constant_pad_nd_cpu_int64 (__main__.TestInductorOpInfoCPU)": 70.75233205159505, - "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 102.2750015258789, - "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 103.07066599527995, - "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 105.27833557128906, - "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 100.10233561197917, - "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 102.20266977945964, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 93.59800084431966, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 93.51633326212566, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 62.04499944051107, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 63.05183347066244, - "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 86.4076639811198, - "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 81.19499969482422, - "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 86.38233439127605, - "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 124.90833536783855, - "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 123.35333251953125, - "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 121.35933430989583, - "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 123.5403340657552, - "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 120.98033396402995, - "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 124.76566823323567, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 71.77733357747395, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 83.0576655069987, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 83.4250005086263, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 353.14801025390625, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 79.26999918619792, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 329.7780049641927, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 80.16866556803386, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 273.2213312784831, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 249.29500325520834, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 988.9061686197916, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.60549990336101, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1203.5001627604167, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.39716657002766, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.78449948628743, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 61.64166704813639, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 89.50603711163556, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.10983276367188, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 61.83733304341634, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 203.87232971191406, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 203.09432983398438, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 199.30699666341147, - "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.42596266004774, - "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 67.53049977620442, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 423.8486633300781, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 425.5379943847656, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 403.22300211588544, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 409.60033162434894, - "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 93.37733459472656, - "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 99.49733225504558, - "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 94.82899983723958, - "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 89.32633209228516, - "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 90.41433207194011, - "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 95.9903335571289, - "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 95.3953348795573, - "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 93.07833607991536, - "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 89.55566660563152, - "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 86.22466786702473, - "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 94.80033111572266, - "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 93.42666625976562, - "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 93.45800018310547, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 466.69366455078125, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 464.84532674153644, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 468.4709981282552, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 125.94750086466472, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 120.40383402506511, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 133.67750295003256, - "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 90.84866841634114, - "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 96.20899963378906, - "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 90.58700052897136, - "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 99.9510014851888, - "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 94.47566731770833, - "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 89.86966705322266, - "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 89.43766530354817, - "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 97.86233266194661, - "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 87.95466613769531, - "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 90.6480000813802, - "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 91.357666015625, - "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 94.107666015625, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 92.32383346557617, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 107.00616836547852, - "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 71.70499928792317, - "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 72.04166666666667, - "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 74.28933461507161, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 87.73799896240234, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 81.04799906412761, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 242.09933217366537, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 256.25333404541016, - "test_comprehensive_nn_functional_interpolate_bicubic_cpu_uint8 (__main__.TestInductorOpInfoCPU)": 60.534000396728516, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 75.61316553751628, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 76.84416834513347, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 188.01399739583334, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 186.28333536783853, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 185.177001953125, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 974.4946695963541, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 874.4259847005209, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 882.1919962565104, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 836.5886433919271, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 833.1363525390625, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 848.4001770019531, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 855.3283386230469, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 863.9473368326823, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 198.97533671061197, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 199.50466918945312, - "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 204.54600524902344, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 72.47933260599773, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 126.71599833170573, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 128.18866729736328, - "test_comprehensive_nn_functional_max_unpool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 125.28499857584636, - "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 68.7433344523112, - "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 69.4153315226237, - "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 69.83100128173828, - "test_comprehensive_nn_functional_pad_constant_cpu_int32 (__main__.TestInductorOpInfoCPU)": 67.97833251953125, - "test_comprehensive_nn_functional_pad_constant_cpu_int64 (__main__.TestInductorOpInfoCPU)": 68.58200073242188, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 123.47900136311848, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.12900034586589, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 118.65166473388672, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 115.42100016276042, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 111.11299896240234, - "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 131.9026641845703, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 229.06666564941406, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 230.85599772135416, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 229.9073282877604, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 115.78150049845378, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 109.48800150553386, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 62.58650016784668, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 72.28583272298177, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.01150004069011, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.31016667683919, - "test_cond_autograd_nested (__main__.TestControlFlow)": 108.9411112467448, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 100.8696657816569, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 91.36616770426433, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 80.7226676940918, - "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 70.30566660563152, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 224.8618867662218, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 521.492443508572, - "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 73.4326680501302, - "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 64.06488927205403, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 95.46499888102214, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 83.36849975585938, - "test_count_nonzero_all (__main__.TestBool)": 621.7835659450955, - "test_cusparse_multiple_threads_same_device (__main__.TestCuda)": 88.83855459425185, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 782.4322068956163, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 83.058167775472, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 95.14833323160808, - "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 243.97850879033408, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 66.31777699788411, - "test_fail_creation_ops.py (__main__.TestTyping)": 73.5800605542732, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 80.3489990234375, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 84.08483378092448, - "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 76.93700154622395, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 492.0260009765625, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 327.5421651204427, - "test_forward_ad_svd_lowrank_cpu_float32 (__main__.TestCompositeComplianceCPU)": 68.51100158691406, - "test_fuse_large_params_cpu (__main__.CpuTests)": 78.46166653103299, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 160.96700032552084, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 166.3767784966363, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 85.631165822347, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 99.18250020345052, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 91.73133341471355, - "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 108.87999979654948, - "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 118.38499959309895, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 203.54966990152994, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 123.6168327331543, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 140.19833119710287, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 576.1204986572266, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 194.1616668701172, - "test_group_norm (__main__.TestQuantizedOps)": 240.9851115544637, - "test_indexing (__main__.TestAutogradWithCompiledAutograd)": 88.37566757202148, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 261.59466552734375, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 66.35699971516927, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 98.20366668701172, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 133.5326656765408, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.19766489664714, - "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 597.1708386739095, - "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 77.99583435058594, - "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 95.48333422342937, - "test_linear (__main__.TestStaticQuantizedModule)": 201.3015539381239, - "test_linear_relu (__main__.TestStaticQuantizedModule)": 198.11822424994574, - "test_lobpcg_ortho_cuda_float64 (__main__.TestLinalgCUDA)": 111.03733523686726, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 70.34333419799805, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 114.91833411322699, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 577.1563313802084, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 135.72266642252603, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 452.1196695963542, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 63.14066653781467, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 63.6434457567003, - "test_memory_format_operators_cpu (__main__.TestTorchDeviceTypeCPU)": 74.79505585485862, - "test_nccl_non_blocking_wait_with_barrier (__main__.NcclErrorHandlingTest)": 69.80233256022136, - "test_proper_exit (__main__.TestDataLoader)": 229.3759969075521, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 254.74083709716797, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 65.72250080108643, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 153.51377783881293, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 60.370178349812825, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 93.32955551147461, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 126.44500223795573, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 87.6626688639323, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 93.46333312988281, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 125.66800181070964, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.86966705322266, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 94.73033396402995, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 123.07366689046223, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.68800099690755, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 90.31833394368489, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 135.38099670410156, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 93.70433298746745, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.45233154296875, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 132.78799947102866, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 89.87099965413411, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.96466827392578, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 128.1060002644857, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 92.87266794840495, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 95.6653340657552, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 137.1143341064453, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.06833394368489, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 97.31199900309245, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 125.58800252278645, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 92.6046651204427, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.97366587320964, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 328.73166910807294, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 796.4181518554688, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 544.1849975585938, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1095.7953186035156, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 61.016167958577476, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 72.16200129191081, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 203.41483052571616, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 87.03033192952473, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 174.36033376057944, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 61.354000091552734, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 111.04199981689453, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 75.12533315022786, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 132.99366505940756, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 91.84250005086263, - "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 119.75666681925456, - "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 67.09033457438152, - "test_rosenbrock_sparse_with_lrsched_False_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 66.16733169555664, - "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 67.3231650988261, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 109.39099884033203, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 141.49566650390625, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 146.1365534464518, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.30600229899089, - "test_sparse_gradients (__main__.DistributedDataParallelTest)": 104.54216623306274, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 89.89100138346355, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 182.42533469200134, - "test_terminate_handler_on_crash (__main__.TestTorch)": 100.94433457321591, - "test_terminate_signal (__main__.ForkTest)": 137.2934450134635, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 137.37077751341792, - "test_terminate_signal (__main__.SpawnTest)": 139.78100167380438, - "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 95.95577812194824, - "test_transformer_backend_inductor_fullgraph_True (__main__.TestFullyShardCompile)": 95.92808405558269, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 70.52516492207845, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 67.16916783650716, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 77.5228328704834, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 118.96799850463867, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 117.68416659037273, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 100.82866477966309, - "test_unary_ops (__main__.TestTEFuserDynamic)": 174.203000386556, - "test_unary_ops (__main__.TestTEFuserStatic)": 162.15266492631702, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 96.66299947102864, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 79.73133341471355, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 75.25083287556966, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 86.45800018310547, - "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 63.62757146926153, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 75.84433301289876, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 71.28416697184245, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 71.99233182271321, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 73.53566614786784, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 91.30016708374023, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 64.31961922418503, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 72.40933481852214, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 80.20533307393391, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 76.08066749572754, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 64.3009055001395, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 80.77966817220052, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 88.6466687520345, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 85.9961675008138, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 139.46716435750326 + "EndToEndLSTM (__main__.RNNTest)": 181.61566162109375, + "MultiheadAttention (__main__.ModulesTest)": 136.4750010172526, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 151.13477834065756, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.89133326212565, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.0672378540039, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 83.74566650390625, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 151.28533426920572, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 76.96799977620442, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 86.31200154622395, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 574.2255004882812, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 112.03270034790039, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 501.27077229817706, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 485.51055908203125, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 120.31833012898763, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 132.48300425211588, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 94.59216562906902, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 301.2558898925781, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 133.76050313313803, + "test_builtin_equivalent_funcs (__main__.TorchFunctionModeTests)": 82.59212158665513, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 328.1419949001736, + "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 70.02100118001302, + "test_comprehensive_constant_pad_nd_cpu_float32 (__main__.TestInductorOpInfoCPU)": 74.87266540527344, + "test_comprehensive_constant_pad_nd_cpu_float64 (__main__.TestInductorOpInfoCPU)": 69.08433278401692, + "test_comprehensive_constant_pad_nd_cpu_int32 (__main__.TestInductorOpInfoCPU)": 72.38800303141277, + "test_comprehensive_constant_pad_nd_cpu_int64 (__main__.TestInductorOpInfoCPU)": 68.1750005086263, + "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 103.65033467610677, + "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 113.07499694824219, + "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 102.12066650390625, + "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 102.65233357747395, + "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 110.2530008951823, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 82.65933227539062, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 79.33149973551433, + "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.96466573079427, + "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 82.62900034586589, + "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 84.08733367919922, + "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 126.77433268229167, + "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 129.90166727701822, + "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 130.88333129882812, + "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 128.96799977620444, + "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 124.88400014241536, + "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 125.25633239746094, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 85.94333394368489, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 76.60233306884766, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 82.14366658528645, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 333.54833984375, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 83.21299997965495, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 348.1693420410156, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 86.17266591389973, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 233.37083180745444, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 240.9846674601237, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 922.0073445638021, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.75899823506673, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 957.1233317057291, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 65.89716720581055, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.13633410135905, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 68.52150026957194, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 69.30249913533528, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 61.91258398691813, + "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 211.29166666666666, + "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 201.2199961344401, + "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 201.05166625976562, + "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 448.63133748372394, + "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 435.1319986979167, + "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 414.0263366699219, + "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 428.4053446451823, + "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 97.50900014241536, + "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 96.23233286539714, + "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 98.6259994506836, + "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 99.11599985758464, + "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 89.52233632405598, + "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 100.05933125813802, + "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 92.08133188883464, + "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 102.49733479817708, + "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 93.6953353881836, + "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 94.12633260091145, + "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 90.63199869791667, + "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 93.61466471354167, + "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 95.45333353678386, + "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 471.6109924316406, + "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 478.7690022786458, + "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 483.9660135904948, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 104.6216672261556, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 125.5418332417806, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 102.33516438802083, + "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 91.59299977620442, + "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 96.90999857584636, + "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 95.03333282470703, + "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 96.96366628011067, + "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 93.97466532389323, + "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 91.50166829427083, + "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 96.39866892496745, + "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 93.34033457438152, + "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 97.53666687011719, + "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 87.80099995930989, + "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 98.83033243815105, + "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 98.4626693725586, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 83.55299886067708, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 82.62733205159505, + "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 74.2403335571289, + "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 73.23299916585286, + "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 74.39199829101562, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 88.33433532714844, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 88.76199849446614, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 212.46066538492838, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 215.0308354695638, + "test_comprehensive_nn_functional_interpolate_bicubic_cpu_uint8 (__main__.TestInductorOpInfoCPU)": 63.83266576131185, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 72.45750109354655, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 84.0174986521403, + "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 184.9403330485026, + "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 186.5510050455729, + "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 182.49533081054688, + "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 1921.0713297526042, + "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 1740.4580078125, + "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 1776.2012939453125, + "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 1599.7586263020833, + "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 1617.5953369140625, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1075.434326171875, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1099.4353332519531, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1125.7143249511719, + "test_comprehensive_nn_functional_max_pool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 950.7925381130642, + "test_comprehensive_nn_functional_max_pool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 812.9512176513672, + "test_comprehensive_nn_functional_max_pool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 829.8953365749783, + "test_comprehensive_nn_functional_max_pool3d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 889.6016608344185, + "test_comprehensive_nn_functional_max_pool3d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 899.8731655544705, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 464.61108271280926, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 459.0932896931966, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 197.40933227539062, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 200.65933227539062, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 195.45833333333334, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 131.92733510335287, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 133.2663319905599, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 124.51133473714192, + "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 70.27266693115234, + "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 68.51133219401042, + "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 72.49566650390625, + "test_comprehensive_nn_functional_pad_constant_cpu_int32 (__main__.TestInductorOpInfoCPU)": 70.40933481852214, + "test_comprehensive_nn_functional_pad_constant_cpu_int64 (__main__.TestInductorOpInfoCPU)": 69.66466522216797, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 119.83233388264973, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 114.48733266194661, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 120.08599853515625, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 127.59833017985027, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 120.84366353352864, + "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 123.23733266194661, + "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 231.90233357747397, + "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 225.70599873860678, + "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 237.0050048828125, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 102.30183537801106, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 105.99450047810872, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 64.52433395385742, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 67.21816571553548, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 63.552083015441895, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 68.47633298238118, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 63.37950070699056, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 86.54149881998698, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 82.44583511352539, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 70.82466634114583, + "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 61.23749987284342, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 174.3193367852105, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 325.76544019911023, + "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 79.62433369954427, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 75.77216720581055, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 71.49933369954427, + "test_count_nonzero_all (__main__.TestBool)": 607.5451117621528, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 795.3888888888889, + "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 185.69566524028778, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 81.4071667989095, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 211.3948280016581, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 67.51188871595595, + "test_fail_random.py (__main__.TestTyping)": 72.58257559574011, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 83.84800211588542, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 103.40933100382487, + "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 84.23000081380208, + "test_fn_gradgrad_map_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 61.798926176848234, + "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 506.09766642252606, + "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 345.16650390625, + "test_forward_ad_svd_lowrank_cpu_float32 (__main__.TestCompositeComplianceCPU)": 87.62466684977214, + "test_fuse_large_params_cpu (__main__.CpuTests)": 74.28099937438965, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 167.2507781982422, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 170.01244269476996, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 85.7643330891927, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 82.86033376057942, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 92.42216618855794, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 93.47633107503255, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 95.30599975585938, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 204.81483459472656, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 122.41116460164388, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 143.45366795857748, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 545.9098256429037, + "test_group_norm (__main__.TestQuantizedOps)": 106.1533326043023, + "test_indexing (__main__.TestAutogradWithCompiledAutograd)": 73.70611148410373, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 260.4870096842448, + "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 67.80922275119357, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 104.3433354695638, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 131.30877770317926, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 116.27200317382812, + "test_large_bmm_bfloat16 (__main__.TestMPS)": 1425.0152994791667, + "test_large_bmm_float16 (__main__.TestMPS)": 1253.7086181640625, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 626.6733360290527, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 83.12500190734863, + "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 95.26766586303711, + "test_linear (__main__.TestStaticQuantizedModule)": 111.07222196790907, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 186.7755593193902, + "test_low_memory_max_pool_dilation_1_dim_2_cpu_halide (__main__.HalideCpuTests)": 60.36066691080729, + "test_low_memory_max_pool_dilation_1_dim_3_cpu_halide (__main__.HalideCpuTests)": 654.1646728515625, + "test_low_memory_max_pool_dilation_2_dim_3_cpu_halide (__main__.HalideCpuTests)": 516.5246785481771, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 78.14400100708008, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 111.50900014241536, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 64.09444597032335, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 61.463999854193794, + "test_proper_exit (__main__.TestDataLoader)": 224.20867156982422, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 223.44366709391275, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 63.79383373260498, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 135.53544277615018, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 126.51955371432834, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 124.93733469645183, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 88.31800079345703, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 91.14400227864583, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 131.42633819580078, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 87.89099884033203, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.00499979654948, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 130.41100311279297, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 80.20366668701172, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 92.31099955240886, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 132.61800130208334, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 89.43633270263672, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 87.20300038655598, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 141.2316640218099, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 85.5403340657552, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.23933410644531, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 134.7986628214518, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 84.30833180745442, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 98.22066752115886, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 126.87933603922527, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 80.81599934895833, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 88.3933334350586, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 136.1326649983724, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 86.34866587320964, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 89.5760014851888, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 339.4856669108073, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 688.2038370768229, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 563.0550130208334, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1057.2596740722656, + "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 60.16333325703939, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 65.99533589680989, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 225.16200002034506, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 84.37333424886067, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 161.50900268554688, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 72.27966562906902, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 110.45833333333333, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 69.66266632080078, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 133.47400283813477, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 83.216002146403, + "test_register_spills_cuda (__main__.BenchmarkFusionCudaTest)": 113.03650029500325, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 66.53833262125652, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 118.47366658846538, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 142.11158307393393, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 106.017333984375, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 130.9388910929362, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 147.9407755533854, + "test_sum_all_cpu_float64 (__main__.TestReductionsCPU)": 252.02999792054848, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 232.58900745709738, + "test_terminate_handler_on_crash (__main__.TestTorch)": 100.51200015015073, + "test_terminate_signal (__main__.ForkTest)": 137.1949984199471, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 137.9567770593696, + "test_terminate_signal (__main__.SpawnTest)": 140.8028925259908, + "test_torchvision_smoke (__main__.TestTensorBoardPytorchGraph)": 76.12111282348633, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 124.51886669516874, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 79.41683387756348, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 78.52750078837077, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 83.28249867757161, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 117.90133094787598, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 118.63733228047688, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 100.5381685892741, + "test_unary_ops (__main__.TestTEFuserDynamic)": 173.42911020914713, + "test_unary_ops (__main__.TestTEFuserStatic)": 158.1659984588623, + "test_unwaited (__main__.CommTest)": 60.680667877197266, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 94.42666371663411, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 74.38800048828125, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 95.0030008951823, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cuda_float32 (__main__.TestOperatorsCUDA)": 89.27025032043457, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 67.01800028483073, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 68.22083346048991, + "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 63.439666112264, + "test_vmapjvpvjp_linalg_solve_triangular_cuda_float32 (__main__.TestOperatorsCUDA)": 63.80483341217041, + "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 65.49283345540364, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 81.31166585286458, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 67.66283289591472, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.59249941507976, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 69.74266688028972, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 70.75883356730144, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 74.43816630045573, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 72.2706667582194, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 129.9483324686686 } \ No newline at end of file From e209625334a6c0a02bf5ba477865481b66286e70 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Mon, 7 Apr 2025 13:00:52 +0000 Subject: [PATCH 233/332] [torchrec] update local_shards_wrapper to latest version (#150469) Summary: Adding new ops, support for empty shards, and fixed initializations for downstream checkpointing. Test Plan: buck2 run 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_shards_wrapper Differential Revision: D72271275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150469 Approved by: https://github.com/XilunWu --- torch/distributed/tensor/_shards_wrapper.py | 145 +++++++++++++------- 1 file changed, 94 insertions(+), 51 deletions(-) diff --git a/torch/distributed/tensor/_shards_wrapper.py b/torch/distributed/tensor/_shards_wrapper.py index 11bdb4ec2ef2..3102b84c11d1 100644 --- a/torch/distributed/tensor/_shards_wrapper.py +++ b/torch/distributed/tensor/_shards_wrapper.py @@ -21,12 +21,10 @@ ) -aten = ( - torch.ops.aten -) # pyre-ignore[5]: Globally accessible variable `aten` has no type specified. +aten = torch.ops.aten -class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__ +class LocalShardsWrapper(torch.Tensor): """ A wrapper class to hold local shards of a DTensor. This class is used largely for checkpointing purposes and implicity subtypes @@ -41,18 +39,39 @@ class LocalShardsWrapper(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new def __new__( cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]] ) -> "LocalShardsWrapper": - assert len(local_shards) > 0 - assert len(local_shards) == len(local_offsets) assert all( tensor.device == local_shards[0].device for tensor in local_shards[1:] ) + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + # we calculate the total tensor size by "concat" on second tensor dimension cat_tensor_shape = list(local_shards[0].size()) - if len(local_shards) > 1: # column-wise sharding + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding for shard in local_shards[1:]: cat_tensor_shape[1] += shard.size()[1] + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) wrapper_shape = torch.Size(cat_tensor_shape) chunks_meta = [ @@ -78,9 +97,7 @@ def __new__( # necessary for ops dispatching from this subclass to its local shards @classmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] kwargs = kwargs or {} dispatcher = { @@ -91,21 +108,18 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.equal.default: cls.handle_equal, aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, } if func in dispatcher: - return dispatcher[func]( - args, kwargs - ) # pyre-ignore [29] - `Variable[_VT]` is not a function. + return dispatcher[func](args, kwargs) else: raise NotImplementedError( f"{func} is not supported for LocalShardsWrapper!" ) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_all_gather_into_tensor(args, kwargs): + def handle_all_gather_into_tensor(args, kwargs) -> torch.Tensor: dim = args[0].local_sizes()[0][1] cat_tensor = torch.cat( [t.view(-1) for t in args[0].local_shards()], dim=0 @@ -115,15 +129,11 @@ def handle_all_gather_into_tensor(args, kwargs): ) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_wait_tensor(args, kwargs): + def handle_wait_tensor(args, kwargs) -> torch.Tensor: return torch.ops._c10d_functional.wait_tensor(args[0]) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_to_copy(args, kwargs): + def handle_to_copy(args, kwargs) -> torch.Tensor: res_shards_list = [ aten._to_copy.default(shard, *args[1:], **kwargs) for shard in args[0].local_shards() @@ -131,20 +141,41 @@ def handle_to_copy(args, kwargs): return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_view(args, kwargs): - # TODO, do we need to change the shape of associated offsets? - res_shards_list = [ - aten.view.default(shard, args[1], **kwargs) - for shard in args[0].local_shards() - ] + def handle_view(args, kwargs) -> "LocalShardsWrapper": + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_equal(args, kwargs): + def handle_equal(args, kwargs) -> bool: """ LocalShardsWrapper equal impl also checks for equality of storage metadata and the order of shards @@ -161,9 +192,7 @@ def handle_equal(args, kwargs): return True @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_detach(args, kwargs): + def handle_detach(args, kwargs) -> "LocalShardsWrapper": self_ls = args[0] deatched_local_shards = [ aten.detach.default(shard) for shard in self_ls.local_shards() @@ -173,9 +202,7 @@ def handle_detach(args, kwargs): return self_ls @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def handle_clone(args, kwargs): + def handle_clone(args, kwargs) -> "LocalShardsWrapper": self_ls = args[0] desired_memory_format = kwargs.get("memory_format", None) if desired_memory_format and desired_memory_format != torch.preserve_format: @@ -188,19 +215,27 @@ def handle_clone(args, kwargs): ] return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + @staticmethod + def handle_new_empty(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + @property def device(self) -> torch._C.device: # type: ignore[override] - return self._local_shards[0].device + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) @property def is_meta(self) -> bool: # type: ignore[override] - return self._local_shards[0].is_meta + return self._local_shards[0].is_meta if self._local_shards else True - # pyre-ignore[14] def is_pinned(self) -> bool: # type: ignore[override] return self._storage_meta.properties.pin_memory - # pyre-ignore[14] def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": self._storage_meta.properties.requires_grad = requires_grad [shard.requires_grad_(requires_grad) for shard in self._local_shards] @@ -233,7 +268,7 @@ def local_offsets(self) -> list[torch.Size]: @property def local_chunks(self) -> list[ChunkStorageMetadata]: """ - Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the + Returns a :class:`list[ChunkStorageMetadata]` object corresponding to the metadata for each tensor shard """ return self._storage_meta.chunks @@ -245,9 +280,14 @@ def storage_metadata(self) -> TensorStorageMetadata: """ return self._storage_meta - def __create_write_items__( - self, fqn: str, object: Any - ) -> list[WriteItem]: # pyre-ignore[2] + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__(self, fqn: str, object: Any) -> list[WriteItem]: """ For compatibility with DCP, we support creation of WriteItems such that they can be saved properly. @@ -293,6 +333,12 @@ def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: if chunk.offsets == index.offset: return shard + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + raise ValueError( f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" ) @@ -303,12 +349,9 @@ def _get_tensor_size_bytes(self) -> int: object_size += shard.nelement() * shard.element_size() return object_size - # pyre-fixme[3]: Return type must be annotated. - def __hash__(self): + def __hash__(self) -> int: return id(self) - # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently. - # pyre-fixme[3]: Return type must be annotated. def __repr__(self) -> str: # type: ignore[override] return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" From 99c9a31386c5da2fe47c3755f71ade5ebb9615eb Mon Sep 17 00:00:00 2001 From: Shivam Raikundalia Date: Mon, 7 Apr 2025 13:04:38 +0000 Subject: [PATCH 234/332] [submodule] [Snapshot/Profiler] Memory Snapshot On Demand (#150559) Summary: Profiler side of memory snapshot. 1. Add API to actually do snapshot when client interface is called 2. Add ifdefs to builds so that kineto hooks snapshot correctly. Design Philosophy: There is one interesting part of this implementation and it is during export. For export we are callign the python impl of the export rather than CPP even though we are already in CPP. This is because it is better to simply have one path of export rather than 2. Personally, I want there to be parity between auto-trace and on-demand so it if we can limit the side paths then we will have an easier time maintaining this relationship Test Plan: {F1976563426} Reviewed By: sanrise Differential Revision: D70733247 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150559 Approved by: https://github.com/sanrise --- buckbuild.bzl | 3 + cmake/Dependencies.cmake | 2 +- test/cpp/jit/CMakeLists.txt | 2 +- third_party/kineto | 2 +- torch/csrc/autograd/profiler_kineto.cpp | 19 ++++- torch/csrc/autograd/profiler_kineto.h | 4 + torch/csrc/autograd/profiler_python.cpp | 83 +++++++++++++++++++ .../csrc/profiler/kineto_client_interface.cpp | 14 ++++ .../profiler/orchestration/python_tracer.cpp | 21 +++++ .../profiler/orchestration/python_tracer.h | 16 ++++ 10 files changed, 160 insertions(+), 6 deletions(-) diff --git a/buckbuild.bzl b/buckbuild.bzl index b208a4d25c18..f7fac4bf49dd 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1725,6 +1725,7 @@ def define_buck_targets( compiler_flags = get_pt_compiler_flags() + ["-Wno-error"], exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DUSE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", # Need this otherwise USE_KINETO is undefed # for mobile "-DEDGE_PROFILER_USE_KINETO", @@ -1750,6 +1751,7 @@ def define_buck_targets( exported_preprocessor_flags = get_pt_preprocessor_flags() + [ "-DUSE_KINETO", "-DEDGE_PROFILER_USE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", ], # @lint-ignore BUCKLINT link_whole link_whole = True, @@ -1836,6 +1838,7 @@ def define_buck_targets( # Need this otherwise USE_KINETO is undefed # for mobile "-DEDGE_PROFILER_USE_KINETO", + "-DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND", ] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []), extra_flags = { "fbandroid_compiler_flags": ["-frtti"], diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1df6b350b9b1..7627c3d9c7bb 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1715,7 +1715,7 @@ if(USE_KINETO) set_property(TARGET kineto PROPERTY POSITION_INDEPENDENT_CODE ON) endif() list(APPEND Caffe2_DEPENDENCY_LIBS kineto) - string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO") + string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND") if(LIBKINETO_NOCUPTI) string(APPEND CMAKE_CXX_FLAGS " -DLIBKINETO_NOCUPTI") endif() diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index cd2eaf761dff..75bf60b0654e 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -27,7 +27,7 @@ add_library(backend_with_compiler SHARED ) if(USE_KINETO) set_target_properties(backend_with_compiler PROPERTIES COMPILE_FLAGS - "-DUSE_KINETO") + "-DUSE_KINETO -DTMP_IMPL_MEMORY_PROFILING_ON_DEMAND") endif() target_link_libraries(backend_with_compiler torch) diff --git a/third_party/kineto b/third_party/kineto index 2859721fd9e7..d6796921fdde 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 2859721fd9e73d3ca1c56f827dbc64e6d68f78a2 +Subproject commit d6796921fdde135cb94d2dd04fe2071a5424a321 diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 2b1e6f2e0104..447ca88f0e84 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -8,7 +8,6 @@ #include #include #include - #include #include #include @@ -21,8 +20,6 @@ #include #include -#include - #include #include @@ -860,6 +857,22 @@ std::unique_ptr disableProfiler() { return result; } +namespace tracer = torch::profiler::impl::python_tracer; +std::unique_ptr memory_tracer; +void startMemoryProfile() { + if (memory_tracer == nullptr) { + memory_tracer = tracer::PythonMemoryTracerBase::make(); + } + memory_tracer->start(); +} + +void stopMemoryProfile() { + memory_tracer->stop(); +} + +void exportMemoryProfile(const std::string& filename) { + memory_tracer->export_memory_history(filename); +} KinetoEvent::KinetoEvent( const std::shared_ptr& result, diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index cedf58123381..2e4b89da4b79 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -185,6 +185,10 @@ TORCH_API void toggleCollectionDynamic( const bool enable, const std::set& activities); +TORCH_API void startMemoryProfile(); +TORCH_API void stopMemoryProfile(); +TORCH_API void exportMemoryProfile(const std::string& path); + /** * When a C++ thread really has no control over how the profiler was enabled, * for example, by some unreachable Python code, it can call these functions diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 17da6cf3d70b..045c47902516 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -1144,6 +1145,81 @@ std::vector> PythonTracer::getEvents( return out; } +// ============================================================================ +// == Memory Tracer ====================================================== +// ============================================================================ + +// Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere +class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase { + public: + explicit PythonMemoryTracer(); + ~PythonMemoryTracer() override; + void start() override; + void stop() override; + void export_memory_history(const std::string path) override; +}; + +PythonMemoryTracer::PythonMemoryTracer() {} +PythonMemoryTracer::~PythonMemoryTracer() {} + +static void toggle_memory_tracing(bool enable) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + THPObjectPtr torch_cuda_memory_module( + PyImport_ImportModule("torch.cuda.memory")); + if (!torch_cuda_memory_module) { + return; + } + THPObjectPtr snapshot_func(PyObject_GetAttrString( + torch_cuda_memory_module.get(), "_record_memory_history_impl")); + if (!snapshot_func) { + return; + } + // Call the function with arguments + PyObject* args = PyTuple_New(6); + PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None); + PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context + PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks + PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries + PyTuple_SetItem(args, 4, Py_None); // device (None) + PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False) + PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); + Py_DECREF(args); + if (result == NULL) { + return; + } + PyGILState_Release(gil_state); +} + +void PythonMemoryTracer::start() { + toggle_memory_tracing(true); +} + +void PythonMemoryTracer::export_memory_history(const std::string path) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + THPObjectPtr torch_cuda_memory_module( + PyImport_ImportModule("torch.cuda.memory")); + if (!torch_cuda_memory_module) { + return; + } + THPObjectPtr snapshot_func( + PyObject_GetAttrString(torch_cuda_memory_module.get(), "_dump_snapshot")); + if (!snapshot_func) { + return; + } + PyObject* py_filename = PyUnicode_FromString(path.c_str()); + // Call the function with arguments (e.g., a file path) + PyObject* args = PyTuple_Pack(1, py_filename); + PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); + Py_DECREF(args); + if (result == NULL) { + return; + } + PyGILState_Release(gil_state); +} + +void PythonMemoryTracer::stop() { + toggle_memory_tracing(false); +} // ============================================================================ // == API ===================================================================== @@ -1181,6 +1257,11 @@ std::unique_ptr getTracer( torch::profiler::impl::RecordQueue* queue) { return std::make_unique(queue); } + +std::unique_ptr getMemoryTracer() { + return std::make_unique(); +} + } // namespace } // namespace torch::profiler::impl @@ -1191,5 +1272,7 @@ void init() { TORCH_CHECK(PyType_Ready(&torch::profiler::impl::TraceContextType) == 0); torch::profiler::impl::python_tracer::registerTracer( &torch::profiler::impl::getTracer); + torch::profiler::impl::python_tracer::registerMemoryTracer( + &torch::profiler::impl::getMemoryTracer); } } // namespace torch::autograd::profiler::python_tracer diff --git a/torch/csrc/profiler/kineto_client_interface.cpp b/torch/csrc/profiler/kineto_client_interface.cpp index fd145f4c4fa6..89c824cd578f 100644 --- a/torch/csrc/profiler/kineto_client_interface.cpp +++ b/torch/csrc/profiler/kineto_client_interface.cpp @@ -58,6 +58,20 @@ class LibKinetoClient : public libkineto::ClientInterface { (void)disableProfiler(); } + void start_memory_profile() override { + LOG(INFO) << "Starting on-demand memory profile"; + startMemoryProfile(); + } + + void stop_memory_profile() override { + LOG(INFO) << "Stopping on-demand memory profile"; + stopMemoryProfile(); + } + + void export_memory_profile(const std::string& path) override { + exportMemoryProfile(path); + } + private: // Temporarily disable shape collection until // we re-roll out the feature for on-demand cases diff --git a/torch/csrc/profiler/orchestration/python_tracer.cpp b/torch/csrc/profiler/orchestration/python_tracer.cpp index e570a69cb696..73bdf3ccb017 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.cpp +++ b/torch/csrc/profiler/orchestration/python_tracer.cpp @@ -3,6 +3,7 @@ namespace torch::profiler::impl::python_tracer { namespace { MakeFn make_fn; +MakeMemoryFn memory_make_fn; struct NoOpPythonTracer : public PythonTracerBase { NoOpPythonTracer() = default; @@ -17,6 +18,15 @@ struct NoOpPythonTracer : public PythonTracerBase { return {}; } }; + +struct NoOpMemoryPythonTracer : public PythonMemoryTracerBase { + NoOpMemoryPythonTracer() = default; + ~NoOpMemoryPythonTracer() override = default; + void start() override {} + void stop() override {} + void export_memory_history(const std::string path) override {} +}; + } // namespace void registerTracer(MakeFn make_tracer) { @@ -29,4 +39,15 @@ std::unique_ptr PythonTracerBase::make(RecordQueue* queue) { } return make_fn(queue); } + +void registerMemoryTracer(MakeMemoryFn make_memory_tracer) { + memory_make_fn = make_memory_tracer; +} + +std::unique_ptr PythonMemoryTracerBase::make() { + if (memory_make_fn == nullptr) { + return std::make_unique(); + } + return memory_make_fn(); +} } // namespace torch::profiler::impl::python_tracer diff --git a/torch/csrc/profiler/orchestration/python_tracer.h b/torch/csrc/profiler/orchestration/python_tracer.h index 580bf523e7f5..725c6d8a5c95 100644 --- a/torch/csrc/profiler/orchestration/python_tracer.h +++ b/torch/csrc/profiler/orchestration/python_tracer.h @@ -56,5 +56,21 @@ struct TORCH_API PythonTracerBase { using MakeFn = std::unique_ptr (*)(RecordQueue*); TORCH_API void registerTracer(MakeFn make_tracer); + +/** + * Memory Tracer Implementation + */ +struct TORCH_API PythonMemoryTracerBase { + static std::unique_ptr make(); + virtual ~PythonMemoryTracerBase() = default; + + virtual void start() = 0; + virtual void stop() = 0; + virtual void export_memory_history(const std::string path) = 0; +}; + +using MakeMemoryFn = std::unique_ptr (*)(); +TORCH_API void registerMemoryTracer(MakeMemoryFn make_memory_tracer); + } // namespace python_tracer } // namespace torch::profiler::impl From 5e3c8214b52a8b00e750b8997a58d2915db562ce Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 7 Apr 2025 10:12:34 +0000 Subject: [PATCH 235/332] cpp_wrapper: Re-enable code disabled for forward compatibility (#150671) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150671 Approved by: https://github.com/desertfire --- torch/_inductor/codegen/cpp_wrapper_cpu.py | 17 ++++++----------- .../codegen/cpp_wrapper_cpu_array_ref.py | 16 ++++++---------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 9f163256b311..3dfe5434722f 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1568,17 +1568,12 @@ def create_dtypeview_call(reinterpret_call: str) -> tuple[str, list[str]]: return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs def create_new_tensor_handle() -> tuple[str, list[str]]: - # TODO (benjaminglass1): uncomment this and remove the call to - # create_reinterpret_view after the AOTI forwards compatibility window has - # passed. - # - # tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" - # tmp_call_strs = [ - # f"AtenTensorHandle {tmp_AtenTensorHandle};", - # f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", - # ] - # return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs - return create_reinterpret_call(), [] + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs if ( size == data.layout.size diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index a3e472834518..67ea2e2166e8 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -835,16 +835,12 @@ def create_new_tensor_handle() -> tuple[str, list[str]]: if (name := data.get_name()) in self.stack_allocated_buffers: return name, [] - # TODO (benjaminglass1): uncomment this and remove create_reinterpret_view - # after the AOTI forwards compatibility window has passed. - # - # tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" - # tmp_call_strs = [ - # f"AtenTensorHandle {tmp_AtenTensorHandle};", - # f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", - # ] - # return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs - return create_reinterpret_call(), [] + tmp_AtenTensorHandle = f"tmp_{name}_{next(self.tmp_tensor_id)}" + tmp_call_strs = [ + f"AtenTensorHandle {tmp_AtenTensorHandle};", + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle({data.get_name()}, &{tmp_AtenTensorHandle}));", + ] + return f"RAIIAtenTensorHandle({tmp_AtenTensorHandle})", tmp_call_strs if ( size == data.layout.size From f0abbabac189c59fe4ed7f930872e245bb0cc096 Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 7 Apr 2025 10:12:34 +0000 Subject: [PATCH 236/332] AOTI fallback ops: sort alphabetically (#150672) This is just a housekeeping task that makes the listed fallback op order match what's in the generated C shim files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150672 Approved by: https://github.com/desertfire ghstack dependencies: #150671 --- torchgen/aoti/fallback_ops.py | 98 +++++++++++++++++------------------ 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index a2a6cf1b1afc..567ccdf1ee7a 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -10,16 +10,55 @@ inductor_fallback_ops = { "aten._adaptive_avg_pool2d_backward.default", "aten._adaptive_avg_pool2d.default", - "aten._adaptive_avg_pool3d.default", "aten._adaptive_avg_pool3d_backward.default", + "aten._adaptive_avg_pool3d.default", + "aten._addmm_activation.default", + "aten._cdist_backward.default", + "aten._cdist_forward.default", + "aten._cudnn_rnn.default", + "aten._dyn_quant_matmul_4bit.default", + "aten._dyn_quant_pack_4bit_weight.default", + "aten._efficient_attention_backward.default", + "aten._efficient_attention_forward.default", + "aten._efficientzerotensor.default", + "aten._embedding_bag_dense_backward.default", + "aten._embedding_bag_forward_only.default", + "aten._embedding_bag_per_sample_weights_backward.default", + "aten._embedding_bag.default", + "aten._fft_c2c.default", + "aten._fft_r2c.default", + "aten._flash_attention_backward.default", + "aten._flash_attention_forward.default", + "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten._fused_moving_avg_obs_fq_helper.default", + "aten._histogramdd_from_bin_cts.default", + "aten._int_mm.out", + "aten._pdist_backward.default", + "aten._pdist_forward.default", + "aten._scaled_dot_product_cudnn_attention_backward.default", + "aten._scaled_dot_product_cudnn_attention.default", + "aten._scaled_dot_product_efficient_attention_backward.default", + "aten._scaled_dot_product_efficient_attention.default", + "aten._scaled_dot_product_flash_attention_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", + "aten._scaled_dot_product_flash_attention_for_cpu.default", + "aten._scaled_dot_product_flash_attention.default", + "aten._scaled_dot_product_fused_attention_overrideable_backward.default", + "aten._scaled_dot_product_fused_attention_overrideable.default", + "aten._scaled_mm.default", + "aten._scaled_mm.out", + "aten._segment_reduce_backward.default", + "aten._thnn_fused_lstm_cell.default", + "aten._to_sparse.default", + "aten._trilinear.default", + "aten._weight_int8pack_mm.default", "aten.adaptive_max_pool2d_backward.default", "aten.adaptive_max_pool2d.default", - "aten.adaptive_max_pool3d.default", "aten.adaptive_max_pool3d_backward.default", + "aten.adaptive_max_pool3d.default", "aten.add.Scalar", "aten.add.Tensor", "aten.addbmm.default", - "aten._addmm_activation.default", "aten.addmm.out", "aten.addmv.default", "aten.angle.default", @@ -33,57 +72,37 @@ "aten.bmm.out", "aten.bucketize.Tensor", "aten.cat.default", - "aten._cdist_backward.default", - "aten._cdist_forward.default", "aten.cholesky_inverse.default", "aten.cholesky_solve.default", "aten.convolution_backward.default", - "aten._cudnn_rnn.default", "aten.convolution.default", "aten.cummax.default", "aten.cummin.default", "aten.cumprod.default", "aten.cumsum.default", - "aten._dyn_quant_matmul_4bit.default", - "aten._dyn_quant_pack_4bit_weight.default", - "aten._efficient_attention_backward.default", - "aten._efficient_attention_forward.default", - "aten._efficientzerotensor.default", - "aten._embedding_bag.default", - "aten._embedding_bag_dense_backward.default", - "aten._embedding_bag_forward_only.default", - "aten._embedding_bag_per_sample_weights_backward.default", "aten.exponential.default", - "aten._fft_c2c.default", - "aten._fft_r2c.default", - "aten._flash_attention_backward.default", - "aten._flash_attention_forward.default", "aten.fractional_max_pool2d_backward.default", "aten.fractional_max_pool2d.default", - "aten.fractional_max_pool3d.default", "aten.fractional_max_pool3d_backward.default", - "aten._fused_moving_avg_obs_fq_helper.default", - "aten._fused_moving_avg_obs_fq_helper_functional.default", + "aten.fractional_max_pool3d.default", "aten.gcd.default", "aten.geqrf.default", "aten.grid_sampler_2d_backward.default", "aten.histc.default", "aten.histogram.bin_ct", - "aten._histogramdd_from_bin_cts.default", "aten.index_put.default", "aten.index_reduce.default", "aten.index.Tensor", - "aten._int_mm.out", "aten.kthvalue.default", "aten.logcumsumexp.default", "aten.lu_unpack.default", - "aten.masked_select.default", - "aten.masked_scatter.default", "aten.masked_scatter_backward.default", + "aten.masked_scatter.default", + "aten.masked_select.default", "aten.max_pool2d_with_indices_backward.default", "aten.max_pool2d_with_indices.default", - "aten.max_pool3d_with_indices.default", "aten.max_pool3d_with_indices_backward.default", + "aten.max_pool3d_with_indices.default", "aten.max_unpool2d.default", "aten.max_unpool3d.default", "aten.median.default", @@ -93,11 +112,9 @@ "aten.mul.Tensor", "aten.nanmedian.default", "aten.native_dropout.default", - "aten.normal_functional.default", "aten.nonzero.default", + "aten.normal_functional.default", "aten.ormqr.default", - "aten._pdist_backward.default", - "aten._pdist_forward.default", "aten.polar.default", "aten.pow.Scalar", "aten.pow.Tensor_Scalar", @@ -106,8 +123,8 @@ "aten.rand.generator", "aten.randint.default", "aten.randint.generator", - "aten.randint.low", "aten.randint.low_out", + "aten.randint.low", "aten.randn.default", "aten.randn.generator", "aten.randperm.default", @@ -117,36 +134,20 @@ "aten.reshape.default", "aten.resize_.default", "aten.resize_as_.default", - "aten._scaled_dot_product_efficient_attention_backward.default", - "aten._scaled_dot_product_efficient_attention.default", - "aten._scaled_dot_product_flash_attention_backward.default", - "aten._scaled_dot_product_flash_attention.default", - "aten._scaled_dot_product_cudnn_attention_backward.default", - "aten._scaled_dot_product_cudnn_attention.default", - "aten._scaled_dot_product_flash_attention_for_cpu_backward.default", - "aten._scaled_dot_product_flash_attention_for_cpu.default", - "aten._scaled_dot_product_fused_attention_overrideable_backward.default", - "aten._scaled_dot_product_fused_attention_overrideable.default", - "aten._scaled_mm.default", - "aten._scaled_mm.out", "aten.scatter_reduce.two_out", "aten.scatter.src_out", "aten.scatter.value_out", "aten.searchsorted.Scalar", "aten.searchsorted.Tensor", - "aten._segment_reduce_backward.default", "aten.segment_reduce.default", "aten.set_.source_Tensor", "aten.slice.Tensor", "aten.soft_margin_loss_backward.default", "aten.sort.default", "aten.sort.stable", - "aten._thnn_fused_lstm_cell.default", - "aten.topk.default", - "aten._to_sparse.default", "aten.to_sparse.default", + "aten.topk.default", "aten.triangular_solve.default", - "aten._trilinear.default", "aten.uniform.default", "aten.upsample_bicubic2d_backward.default", "aten.upsample_linear1d_backward.default", @@ -154,5 +155,4 @@ "aten.view_as_complex.default", "aten.view_as_real.default", "aten.view.dtype", - "aten._weight_int8pack_mm.default", } From f813d64f54cf2e8b0a1f6e78589c78d761346e9f Mon Sep 17 00:00:00 2001 From: Benjamin Glass Date: Mon, 7 Apr 2025 10:12:35 +0000 Subject: [PATCH 237/332] cpp_wrapper: Fix even more tests (#147225) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147225 Approved by: https://github.com/desertfire ghstack dependencies: #150671, #150672 --- test/inductor/test_benchmark_fusion.py | 34 +++++++++-------- test/inductor/test_compiled_autograd.py | 17 +++++++-- test/inductor/test_max_autotune.py | 45 +++++++++++------------ test/inductor/test_torchinductor.py | 3 -- torch/testing/_internal/inductor_utils.py | 6 +++ 5 files changed, 60 insertions(+), 45 deletions(-) diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 2192e58f0f3f..ca542c81eea1 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -10,7 +10,7 @@ from torch._inductor.utils import fresh_inductor_cache, is_big_gpu, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_utils import slowTest -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import get_func_call, HAS_CPU, HAS_CUDA # Make the helper files in test/ importable @@ -24,6 +24,7 @@ check_model, check_model_cuda, copy_tests, + skip_if_cpp_wrapper, ) from torch._inductor import config from torch._inductor.scheduler import Scheduler @@ -126,7 +127,7 @@ def f(a, b): self.common(f, (a, b)) - @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") + @config.patch(max_autotune_gemm_backends="TRITON") def test_avoid_register_spilling(self): if self.device != "cuda": raise unittest.SkipTest("CUDA only") @@ -196,6 +197,7 @@ class BenchmarkingTest(TestCase): @unittest.skipIf( torch.cuda.device_count() < 2, "The test need at least 2 devices" ) + @skip_if_cpp_wrapper("This tests triton scheduling directly") def test_benchmark_on_non_zero_device(self): hit_count = 0 with torch.cuda.device("cuda:0"): @@ -265,9 +267,7 @@ def foo(m, inp): res, code = run_and_get_code(foo_c, m, inp) torch._dynamo.reset() - with unittest.mock.patch.object( - torch._inductor.config, "benchmark_epilogue_fusion", False - ): + with config.patch(benchmark_epilogue_fusion=False): foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) with torch.no_grad(): res2, code2 = run_and_get_code(foo_c, m, inp) @@ -276,32 +276,34 @@ def foo(m, inp): return code, code2 @fresh_inductor_cache() - @torch._inductor.config.patch(max_autotune_gemm_backends="TRITON") + @config.patch(max_autotune_gemm_backends="TRITON") def test_equivalent_template_code(self): code, code2 = self._equivalent_output_code_impl(256) for out_code in [code, code2]: - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("triton_tem_fused_addmm_relu_0.run").check_count( - "del", 3, exactly=True + FileCheck().check(get_func_call()).check_count( + "empty_strided", 1, exactly=True + ).check("triton_tem_fused_addmm_relu_0").check_count( + ".reset()" if config.cpp_wrapper else "del", 3, exactly=True ).check( - "return" + "" if config.cpp_wrapper else "return" ).run( out_code[0] ) @fresh_inductor_cache() - @torch._inductor.config.patch(max_autotune_gemm_backends="ATEN") + @config.patch(max_autotune_gemm_backends="ATEN") def test_equivalent_extern_code(self): torch._dynamo.reset() code, code2 = self._equivalent_output_code_impl(512, 1, False) for out_code in [code, code2]: - FileCheck().check("def call").check_count( - "empty_strided_cuda", 1, exactly=True - ).check("extern_kernels.").check_count("del", 3, exactly=True).check( - "return" + FileCheck().check(get_func_call()).check_count( + "empty_strided", 1, exactly=True + ).check("" if config.cpp_wrapper else "extern_kernels.").check_count( + ".reset()" if config.cpp_wrapper else "del", 3, exactly=True + ).check( + "" if config.cpp_wrapper else "return" ).run( out_code[0] ) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f10bf940e711..ec0ba10b9cb2 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -2801,7 +2801,12 @@ def test_cudagraphs_cpu_division(self): loss.backward() torch._inductor.config.triton.cudagraphs = False - self.assertFalse("skipping cudagraphs" in stderr_msgs.getvalue()) + if inductor_config.cpp_wrapper: + self.assertIn("skipping cudagraphs", stderr_msgs.getvalue()) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + else: + self.assertNotIn("skipping cudagraphs", stderr_msgs.getvalue()) + self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) def test_cudagraphs_cpu_graph(self): from torch._dynamo.testing import reduce_to_scalar_loss @@ -2834,7 +2839,10 @@ def test_cudagraphs_sdpa(self): opt_bwd() self.assertEqual(counters["compiled_autograd"]["captures"], 1) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 2 if inductor_config.cpp_wrapper else 0, + ) @unittest.skipIf(not HAS_CUDA, "requires cuda") def test_cudagraphs_cpu_scalar_used_in_python_custom_op(self): @@ -2927,7 +2935,10 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): # into it. We must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. # In the future, we can consider having a cpu scalar movement pass sometime after we trace # into the custom C++ autograd::Function (like in AOTDispatcher) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + self.assertEqual( + counters["inductor"]["cudagraph_skips"], + 2 if inductor_config.cpp_wrapper else 1, + ) def test_logs(self): logs, ctx = logs_to_string( diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index ce1263d502a1..3aa7ee276fc6 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -46,7 +46,14 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU +from torch.testing._internal.inductor_utils import ( + get_func_call, + get_kernel_launch, + GPU_TYPE, + HAS_CPU, + HAS_CUDA, + HAS_GPU, +) torch.set_float32_matmul_precision("high") @@ -54,14 +61,6 @@ torch.cuda.memory._set_allocator_settings("expandable_segments:False") -def _get_func_call() -> str: - return "void inductor_entry_impl(" if config.cpp_wrapper else "def call(" - - -def _get_kernel_launch() -> str: - return "call_triton_" if config.cpp_wrapper else ".run(" - - def benchmark_choice(choice, args, out, expected_out, timings): result = choice.benchmark(*args, out=out) if expected_out is not None: @@ -899,8 +898,8 @@ def f(x, y): # mm kernel, and cos kernel count = 2 if using_triton_mm else 1 - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), count, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), count, exactly=True ).run(code[0]) def f(x, y): @@ -912,8 +911,8 @@ def f(x, y): f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) _, code = run_and_get_code(f_c, inps[0], inps[1]) self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 2, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 2, exactly=True ).run(code[0]) def f(x, y): @@ -1362,21 +1361,21 @@ def setUpClass(cls): ) def check_code(self, code_str, num_kernels, num_allocs, num_deallocs): - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), num_kernels, exactly=True, ).run(code_str) if num_allocs is not None: - FileCheck().check(_get_func_call()).check_count( + FileCheck().check(get_func_call()).check_count( "empty_strided", num_allocs, exactly=True ).run(code_str) # skip the deallocation check when using cpp_wrapper; most deallocations happen # outside of our control via RAIIAtenTensorHandle if num_deallocs is not None and not config.cpp_wrapper: - FileCheck().check(_get_func_call()).check_count( + FileCheck().check(get_func_call()).check_count( "del", num_deallocs, exactly=True ).run(code_str) @@ -1557,8 +1556,8 @@ def multi_use(x, y): out, code = run_and_get_code(torch.compile(multi_use), x, y) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 2, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 2, exactly=True ).run(code[0]) self.assertEqual(out, multi_use(x, y), atol=0.05, rtol=0.05) @@ -1567,8 +1566,8 @@ def resolve_pending(x): x = torch.rand([128, 128], device=GPU_TYPE) out, code = run_and_get_code(torch.compile(resolve_pending), x) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 1, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 1, exactly=True ).run(code[0]) self.assertEqual(out, resolve_pending(x), atol=0.05, rtol=0.05) @@ -1591,8 +1590,8 @@ def test_multiple_fusions(x): x = torch.rand([128, 128], dtype=torch.float16, device=GPU_TYPE) out, code = run_and_get_code(torch.compile(test_multiple_fusions), x) - FileCheck().check(_get_func_call()).check_count( - _get_kernel_launch(), 1, exactly=True + FileCheck().check(get_func_call()).check_count( + get_kernel_launch(), 1, exactly=True ).run(code[0]) self.assertEqual(out, test_multiple_fusions(x), atol=0.05, rtol=0.05) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index fa83302732f1..9716486dd7ee 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10124,9 +10124,6 @@ def fn(x): for x in (torch.randn(2, 3), torch.randn(2, 2), torch.randn(3, 2)): self.common(fn, (x,)) - @skip_if_cpp_wrapper( - "cannot currently handle fallback ops with return types containing list[Tensor]" - ) def test_kwargs(self): if self.device == GPU_TYPE: raise unittest.SkipTest("histogramdd only supports cpu") diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 1501a3bfcb36..4461a62bbe57 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -210,6 +210,12 @@ def maybe_skip_size_asserts(op): else: return contextlib.nullcontext() +def get_func_call() -> str: + return "void inductor_entry_impl(" if torch._inductor.config.cpp_wrapper else "def call(" + +def get_kernel_launch() -> str: + return "call_triton_" if torch._inductor.config.cpp_wrapper else ".run(" + def clone_preserve_strides_offset(x, device=None): if not isinstance(x, torch.Tensor): return x From 0ad2c5d7e2e0ee6c847e167c2e2d65ec8d7321fa Mon Sep 17 00:00:00 2001 From: shiyang-weng Date: Mon, 7 Apr 2025 15:12:26 +0000 Subject: [PATCH 238/332] Add RECORD_FUNCTION for AOTI (#150150) Only add RECORD_FUNCTION for shim_fn now. Next step need to add RECORD_FUNCTION for all the aoti_torch_* functions. Fixes https://github.com/pytorch/pytorch/issues/148650 Some code gen by aoti ```c++ AtenTensorHandle buf1_handle; AtenTensorHandle buf2_handle; AtenTensorHandle buf3_handle; AtenTensorHandle buf4_handle; {RECORD_FUNCTION("aoti_torch_cpu__embedding_bag", c10::ArrayRef());AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cpu__embedding_bag(L__self___sparse_arch_embedding_bag_collection_embedding_bags_t_cat_0_weight, arg80_1, arg81_1, 0, 0L, 0, nullptr, 1, -1L, &buf1_handle, &buf2_handle, &buf3_handle, &buf4_handle));} RAIIAtenTensorHandle buf1(buf1_handle); RAIIAtenTensorHandle buf2(buf2_handle); RAIIAtenTensorHandle buf3(buf3_handle); RAIIAtenTensorHandle buf4(buf4_handle); arg80_1.reset(); arg81_1.reset(); ``` On trace ``` { "name": "aoti_torch_cpu__embedding_bag", "ph": "X", "ts": 68874.450000, "dur": 361.291000, "tid": 2, "pid": "CPU Functions", "args": {} }, ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150150 Approved by: https://github.com/desertfire, https://github.com/EikanWang --- test/inductor/test_aot_inductor.py | 41 ++++++++++++++++++++++ torch/_inductor/codegen/cpp_wrapper_cpu.py | 16 ++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 973e720c7eb9..d008dba80421 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3991,6 +3991,47 @@ def forward(self, a): FileCheck().check_not(f"before_launch - {kernel_name}").run(code) FileCheck().check_not(f"after_launch - {kernel_name}").run(code) + @common_utils.parametrize("enable_kernel_profile", (True, False)) + def test_aoti_profiler(self, enable_kernel_profile): + # basic addmm model + class Model(torch.nn.Module): + def __init__(self, n, k, device): + super().__init__() + self.weight = torch.randn(n, k, device=device) + self.bias = torch.randn(n, device=device) + + def forward(self, a): + return torch.nn.functional.linear(a, self.weight, self.bias) + + if sys.platform not in ["linux", "win32"]: + raise unittest.SkipTest( + "enable_kernel_profile only supported on linux and win32" + ) + + M = 8 + N = 6 + K = 16 + model = Model(N, K, self.device) + batch = 2 + a = torch.randn(batch, M, K, device=self.device) + example_inputs = (a,) + kernel_calls = ( + f"aoti_torch_{GPU_TYPE}_addmm_out" + if self.device == GPU_TYPE + else "aoti_torch_cpu_addmm_out" + ) + with config.patch({"cpp.enable_kernel_profile": enable_kernel_profile}): + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + shim_fn_codes = ( + f'RECORD_FUNCTION("{kernel_calls}", c10::ArrayRef());' + ) + if enable_kernel_profile: + FileCheck().check(shim_fn_codes).run(code) + else: + FileCheck().check_not(shim_fn_codes).run(code) + def test_aoti_debug_printer_user_defined_triton_kernel(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 3dfe5434722f..6a0a4b4ba888 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1117,11 +1117,25 @@ def generate_c_shim_extern_kernel_call( debug_printer_manager.set_printer_args( debug_args if debug_args is not None else args, kernel, None, None, "extern" ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ + "linux", + "win32", + ] with debug_printer_manager: shim_fn = self.get_c_shim_func_name(kernel, device) - self.writeline( + shim_fn_codes = ( f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));" ) + if enable_kernel_profile: + shim_fn_codes = textwrap.dedent( + f""" + {{ + RECORD_FUNCTION("{shim_fn}", c10::ArrayRef()); + {shim_fn_codes} + }} + """ + ) + self.writeline(shim_fn_codes) def generate_c_shim_extern_kernel_alloc( self, extern_kernel: ir.ExternKernelAlloc, args: list[str] From 56ab71de98ea65b3d001e7d50546d250326db9fc Mon Sep 17 00:00:00 2001 From: jpvillam Date: Mon, 7 Apr 2025 16:05:56 +0000 Subject: [PATCH 239/332] [ROCm] Expand workspace size for gfx95 (#150632) Use same workspace size for gfx95* as gfx94* Pull Request resolved: https://github.com/pytorch/pytorch/pull/150632 Approved by: https://github.com/jeffdaily Co-authored-by: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> --- aten/src/ATen/cuda/CublasHandlePool.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index e88c0bd5dab2..06fa4f91abff 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -123,9 +123,9 @@ size_t parseChosenWorkspaceSize() { // for extra convenience val = getenv("ROCBLAS_WORKSPACE_CONFIG"); } - /* 32MiB default, 128MiB for MI300 */ - const bool gfx94 = at::detail::getCUDAHooks().isGPUArch({"gfx94"}); - const size_t default_size = gfx94 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; + /* 32MiB default, 128MiB for gfx94x/gfx95x */ + const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"}); + const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); From 06e9deabb623e004eb6024e703a976c5748d51e6 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Fri, 4 Apr 2025 17:16:01 -0700 Subject: [PATCH 240/332] [c10d][fr] Improve FR dump robustness with all watchdog broadcast wait and more frequent store check (#150652) When debugging FR missing dump and missing dump logs, I have couple initial findings: 1. On the same rank, if a second watchdog timeout triggers on a different PG(or subPG), that watchdog thread will immediately throw exception instead of sleeping. We want to fix that by still making the watchdog thread to wait for 1 min. 2. The FR dump takes about 900ms to 1200ms so, we are not checking the store frequently enough. But instead of changing the frequency from 1sec to 300ms, we finally decided to just let all ranks just sleep for 1 min universally rather than using a promise. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150652 Approved by: https://github.com/kwen2501 --- .../distributed/c10d/ProcessGroupNCCL.cpp | 34 ++++++------------- .../distributed/c10d/ProcessGroupNCCL.hpp | 3 -- 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3dc7f23860f3..ba1516a45f65 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1815,8 +1815,6 @@ void ProcessGroupNCCL::heartbeatMonitor() { if (logger) { logger->log(debugLog); } - // Indicate to watchdog thread that we have finished dumping. - promiseFlightRecorderDump_.set_value(); } // GIL deadlock check. @@ -2141,27 +2139,10 @@ void ProcessGroupNCCL::broadcastDumpSignal() { // broadcast dump signal to all other global ranks. broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank()); // signal the local rank to start dumping - if (shouldDump_.load()) { - // already signaled dump, skipping signal again and wait for the dump - // future. - return; - } - LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; - // signal the monitor thread on PG0 to start dumping - shouldDump_.store(true); - // Give time for dumping before throwing exception - auto start = std::chrono::steady_clock::now(); - auto status = promiseFlightRecorderDump_.get_future().wait_for( - std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); - if (status == std::future_status::timeout) { - LOG(WARNING) << logPrefix() << "timed out after waiting for " - << waitTimeoutDumpInMilSec_ << "ms" - << " flight recorder dumps to finish."; - } else if (status == std::future_status::ready) { - auto end = std::chrono::steady_clock::now(); - LOG(INFO) << logPrefix() << "slept for " << computeDeltaMS(start, end) - << "ms" - << " giving time for flight recorder dumps to finish."; + if (!shouldDump_.load()) { + LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping."; + // signal the monitor thread on PG0 to start dumping + shouldDump_.store(true); } } @@ -2336,6 +2317,13 @@ void ProcessGroupNCCL::watchdogHandler() { // recorder behavior is independent of desync Debug. if (dumpOnTimeoutOrEx_) { broadcastDumpSignal(); + // Give time for dumping before throwing exception for all ranks. + // It is hard to presume or control what the pattern of watchdog might + // look like, so it is better to let all ranks universally sleep for a + // short period of time, in this case, 60 seconds, which is also the + // maximum time we leave for FR dump. + std::this_thread::sleep_for( + std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); } if (SHOULD_CLEAN_UP(asyncErrorHandling_)) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index aa0021d7608f..0896dd0de290 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1173,9 +1173,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout for the dump to finish. int waitTimeoutDumpInMilSec_; - // promise to coordinate flight recorder dump. - std::promise promiseFlightRecorderDump_; - // Interval of check coordinated signals in ProcessGroupNCCL from other ranks // e.g., trigger the dump of the debugging info for timeout when notified. int coordCheckIntervalMilSec_; From 957faaadca78ec453d60f2fe986c1191e2e7c5b6 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Sat, 5 Apr 2025 00:44:57 +0000 Subject: [PATCH 241/332] Avoid overflow in vector_norm for scalar input (#144073) Fixes https://github.com/pytorch/pytorch/issues/143960 where torch.dist gave different results from eager due to vector_norm overflowing and eager mode avoids the overflow for single element reductions by not computing the power and then the root. Pull Request resolved: https://github.com/pytorch/pytorch/pull/144073 Approved by: https://github.com/eellison, https://github.com/laithsakka --- test/inductor/test_torchinductor.py | 11 +++++++++++ torch/_refs/linalg/__init__.py | 16 ++++++++++++++++ .../_internal/opinfo/definitions/linalg.py | 7 ------- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 9716486dd7ee..bbec7ab9bbe4 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -2642,6 +2642,17 @@ def fn(a, b): self.common(fn, (torch.randn(4, 4), torch.randn(4, 4))) + @skip_if_halide # different pow accuracies + @xfail_if_triton_cpu + def test_norm_constant_overflow(self): + def fn(a): + return ( + torch.norm(a, p=-41.0, dim=1), + torch.norm(a, p=-41.0, dim=0), + ) + + self.common(fn, (torch.randn(4, 1, 4),)) + @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 def test_dist_bf16(self): diff --git a/torch/_refs/linalg/__init__.py b/torch/_refs/linalg/__init__.py index 00d95445c6f3..c85962f22842 100644 --- a/torch/_refs/linalg/__init__.py +++ b/torch/_refs/linalg/__init__.py @@ -151,6 +151,22 @@ def vector_norm( reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim) is_ord_even = ord % 2 == 0 if isinstance(ord, IntLike) else ord % 2.0 == 0.0 + if (dim is None and x.numel() == 1) or ( + dim is not None and (x.ndim > 0 and all(x.shape[d] == 1 for d in dim)) + ): + if x.ndim > 64: + raise RuntimeError( + f"Received a tensor with {x.ndim} dimensions, but only tensors with up to 64 dims are supported!" + ) + x = torch.abs(x) + if keepdim or x.ndim == 0: + return to_result_dtype(x).contiguous() + elif dim is None: + return x.flatten()[0] + else: + new_shape = [s for d, s in enumerate(x.shape) if d not in dim] + return to_result_dtype(x.view(new_shape)).contiguous() + if not (is_ord_even and utils.is_float_dtype(x.dtype)): x = torch.abs(x) return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value] diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index 26be0b5255ef..822e664270db 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -2327,13 +2327,6 @@ def make_input(): torch_opinfo_name="linalg.vector_norm", supports_out=True, op_db=op_db, - skips=( - # FIXME: sum reduces all dimensions when dim=[] - DecorateInfo(unittest.expectedFailure, "TestReductions", "test_dim_empty"), - DecorateInfo( - unittest.expectedFailure, "TestReductions", "test_dim_empty_keepdim" - ), - ), ), PythonRefInfo( "_refs.linalg.matrix_norm", From 7d2411d30ed7fd91d7b262a189f20286a6f28f1e Mon Sep 17 00:00:00 2001 From: Saurabh Mishra Date: Mon, 7 Apr 2025 17:33:04 +0000 Subject: [PATCH 242/332] [DCP][OSS] Introduce barrier util in the DistWrapper for rank local checkpointing (#150748) Summary: Introduce barrier util in the DistWrapper for rank local checkpointing. This barrier will be used at the end of the rank local checkpointing to ensure all ranks synchronize. Test Plan: UTs Differential Revision: D72541431 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150748 Approved by: https://github.com/MeetVadakkanchery --- test/distributed/checkpoint/test_utils.py | 13 +++++++++++++ torch/distributed/checkpoint/utils.py | 10 ++++++++++ 2 files changed, 23 insertions(+) diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py index d3b3441039d4..9dc730379ecf 100644 --- a/test/distributed/checkpoint/test_utils.py +++ b/test/distributed/checkpoint/test_utils.py @@ -242,6 +242,19 @@ def test_scatter_object(self): expected_objects = rank assert scattered_objects == expected_objects + @with_comms + @skip_if_lt_x_gpu(2) + def test_barrier(self): + mesh_2d = dist.init_device_mesh(self.device_type, (2, self.world_size // 2)) + torch.random.manual_seed(dist.get_rank()) + + dist_wrapper = _DistWrapper( + mesh_2d.get_group(1), use_dist=True, coordinator_rank=0 + ) + + # No exception should be raised. + dist_wrapper.barrier() + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index 0615721228b0..dd9c27f6542c 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -307,6 +307,16 @@ def broadcast( raise final_result return cast(T, final_result) + def barrier(self) -> None: + """ + Add a synchronization point across all processes when using distributed. + If torch.distributed is initialized, this function will invoke a barrier across the global process group. + If torch.distributed is not initialized, this function is a no-op. + """ + if not self.use_dist: + return + dist.barrier(group=self.group) + def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: if index.offset is None: From 6fcffd8cd1d9d8634410a79cc299c34784457bb7 Mon Sep 17 00:00:00 2001 From: Annop Wongwathanarat Date: Mon, 7 Apr 2025 18:01:54 +0000 Subject: [PATCH 243/332] Optimize SVE embedding performance (#150176) Change loop unrolling strategy. Previously, the script only unrolls the inner loop over block_size when block size is multiple of vector length. This version instead unrolls the outer loop which reduces the number of load/store for accumulation into the output array and improves performance for cases when block size is not multiple of vector length. Benchmarking script: ```python # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliate # SPDX-License-Identifier: BSD-3-Clause import torch import torch.nn as nn import numpy as np import time import sys np.random.seed(0) torch.manual_seed(0) num_embeddings = 400000 embedding_dim = int(sys.argv[1]) multi_hot = 100 batch_size = 400 nrun = 1000 class SimpleEmbeddingBagModel(nn.Module): def __init__(self, num_embeddings, embedding_dim): super(SimpleEmbeddingBagModel, self).__init__() weights = torch.from_numpy((np.random.random_sample((num_embeddings, embedding_dim)) + 1).astype(np.float32)).to(torch.float16) # Defining the EmbeddingBag layer self.embedding_bag = torch.nn.EmbeddingBag(num_embeddings, embedding_dim, _weight=weights, mode='sum', include_last_offset=True, dtype=torch.float32) def forward(self, input, offsets): # Forward pass through the EmbeddingBag layer result32 = self.embedding_bag(input, offsets, per_sample_weights=None) return result32 # Instantiate the model model = SimpleEmbeddingBagModel(num_embeddings=num_embeddings, embedding_dim=embedding_dim) model.eval() # Example input input_tensor = torch.randint(0, num_embeddings, (batch_size * multi_hot,), dtype=torch.long) offsets = torch.tensor(range(0, batch_size * multi_hot + 1, multi_hot)) with torch.no_grad(): # warm up output32 = model(input_tensor, offsets) ti = time.time_ns() for i in range(nrun): _ = model(input_tensor, offsets) tf = time.time_ns() print("{:3d} {:.3E}".format(embedding_dim, (tf-ti)/nrun/1.e6)) ``` Speedup on NEOVERSEV1 with 1 thread ![embedding](https://github.com/user-attachments/assets/16e567ed-b9a5-4db3-90b8-dec66d5414a7) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150176 Approved by: https://github.com/digantdesai, https://github.com/malfet --- .../perfkernels/embedding_lookup_idx_sve.cc | 9932 +++++++---------- caffe2/perfkernels/sve_emblookup_codegen.py | 379 +- 2 files changed, 3984 insertions(+), 6327 deletions(-) diff --git a/caffe2/perfkernels/embedding_lookup_idx_sve.cc b/caffe2/perfkernels/embedding_lookup_idx_sve.cc index 873823536b55..3e211a5ba1f5 100644 --- a/caffe2/perfkernels/embedding_lookup_idx_sve.cc +++ b/caffe2/perfkernels/embedding_lookup_idx_sve.cc @@ -28,517 +28,406 @@ static bool EmbeddingLookupIdx_int32_t_float_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + const float* const ip8 = &input[idx8 * block_size]; + const float* const ip9 = &input[idx9 * block_size]; + const float* const ip10 = &input[idx10 * block_size]; + const float* const ip11 = &input[idx11 * block_size]; + const float* const ip12 = &input[idx12 * block_size]; + const float* const ip13 = &input[idx13 * block_size]; + const float* const ip14 = &input[idx14 * block_size]; + const float* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const float* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -611,517 +500,406 @@ static bool EmbeddingLookupIdx_int64_t_float_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - vsum16 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); - vsum17 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); - vsum18 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); - vsum19 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); - vsum20 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); - vsum21 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); - vsum22 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); - vsum23 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); - vsum24 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); - vsum25 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); - vsum26 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); - vsum27 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); - vsum28 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); - vsum29 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); - vsum30 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); - vsum31 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - vsum8 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); - vsum9 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); - vsum10 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); - vsum11 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); - vsum12 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); - vsum13 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); - vsum14 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); - vsum15 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + const float* const ip8 = &input[idx8 * block_size]; + const float* const ip9 = &input[idx9 * block_size]; + const float* const ip10 = &input[idx10 * block_size]; + const float* const ip11 = &input[idx11 * block_size]; + const float* const ip12 = &input[idx12 * block_size]; + const float* const ip13 = &input[idx13 * block_size]; + const float* const ip14 = &input[idx14 * block_size]; + const float* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(svAll, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(svAll, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(svAll, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(svAll, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(svAll, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(svAll, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(svAll, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(svAll, output, svld1(svAll, &ip15[k]), wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + output = svmla_x(pg, output, svld1(svAll, &ip8[k]), wgt8); + output = svmla_x(pg, output, svld1(svAll, &ip9[k]), wgt9); + output = svmla_x(pg, output, svld1(svAll, &ip10[k]), wgt10); + output = svmla_x(pg, output, svld1(svAll, &ip11[k]), wgt11); + output = svmla_x(pg, output, svld1(svAll, &ip12[k]), wgt12); + output = svmla_x(pg, output, svld1(svAll, &ip13[k]), wgt13); + output = svmla_x(pg, output, svld1(svAll, &ip14[k]), wgt14); + output = svmla_x(pg, output, svld1(svAll, &ip15[k]), wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - vsum4 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); - vsum5 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); - vsum6 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); - vsum7 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + const float* const ip4 = &input[idx4 * block_size]; + const float* const ip5 = &input[idx5 * block_size]; + const float* const ip6 = &input[idx6 * block_size]; + const float* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(svAll, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(svAll, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(svAll, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(svAll, output, svld1(svAll, &ip7[k]), wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + output = svmla_x(pg, output, svld1(svAll, &ip4[k]), wgt4); + output = svmla_x(pg, output, svld1(svAll, &ip5[k]), wgt5); + output = svmla_x(pg, output, svld1(svAll, &ip6[k]), wgt6); + output = svmla_x(pg, output, svld1(svAll, &ip7[k]), wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - vsum2 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); - vsum3 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + const float* const ip2 = &input[idx2 * block_size]; + const float* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(svAll, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(svAll, output, svld1(svAll, &ip3[k]), wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + output = svmla_x(pg, output, svld1(svAll, &ip2[k]), wgt2); + output = svmla_x(pg, output, svld1(svAll, &ip3[k]), wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); - vsum1 = - svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const float* const ip0 = &input[idx0 * block_size]; + const float* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(svAll, output, svld1(svAll, &ip1[k]), wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + output = svmla_x(pg, output, svld1(svAll, &ip1[k]), wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const float* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const float* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svmla_x(svAll, output, svld1(svAll, &ip0[k]), wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svmla_x(pg, output, svld1(svAll, &ip0[k]), wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -1194,895 +972,530 @@ static bool EmbeddingLookupIdx_int32_t_half_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + const at::Half* const ip8 = &input[idx8 * block_size]; + const at::Half* const ip9 = &input[idx9 * block_size]; + const at::Half* const ip10 = &input[idx10 * block_size]; + const at::Half* const ip11 = &input[idx11 * block_size]; + const at::Half* const ip12 = &input[idx12 * block_size]; + const at::Half* const ip13 = &input[idx13 * block_size]; + const at::Half* const ip14 = &input[idx14 * block_size]; + const at::Half* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip15[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -2155,895 +1568,530 @@ static bool EmbeddingLookupIdx_int64_t_half_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])))), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])))), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])))), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])))), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])))), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])))), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])))), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])))), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])))), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])))), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])))), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])))), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])))), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])))), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])))), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])))), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])))), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])))), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])))), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])))), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])))), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])))), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])))), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])))), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + const at::Half* const ip8 = &input[idx8 * block_size]; + const at::Half* const ip9 = &input[idx9 * block_size]; + const at::Half* const ip10 = &input[idx10 * block_size]; + const at::Half* const ip11 = &input[idx11 * block_size]; + const at::Half* const ip12 = &input[idx12 * block_size]; + const at::Half* const ip13 = &input[idx13 * block_size]; + const at::Half* const ip14 = &input[idx14 * block_size]; + const at::Half* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + auto input8 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip8[k])))); + auto input9 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip9[k])))); + auto input10 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip10[k])))); + auto input11 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip11[k])))); + auto input12 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip12[k])))); + auto input13 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip13[k])))); + auto input14 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip14[k])))); + auto input15 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip15[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])))), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])))), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])))), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])))), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + const at::Half* const ip4 = &input[idx4 * block_size]; + const at::Half* const ip5 = &input[idx5 * block_size]; + const at::Half* const ip6 = &input[idx6 * block_size]; + const at::Half* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + auto input4 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip4[k])))); + auto input5 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip5[k])))); + auto input6 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip6[k])))); + auto input7 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip7[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])))), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])))), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + const at::Half* const ip2 = &input[idx2 * block_size]; + const at::Half* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + auto input2 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip2[k])))); + auto input3 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip3[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])))), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_f16_x( - svAll, - svreinterpret_f16_u32(svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])))), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + const at::Half* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])))); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + auto input1 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip1[k])))); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::Half* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_f16_x( - pg, - svreinterpret_f16_u32(svld1uh_u32( - pg, reinterpret_cast(&ip[k])))), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::Half* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svcvt_f32_x(svAll, svreinterpret_f16( + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])))); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svcvt_f32_x(pg, svreinterpret_f16( + svld1uh_u32(pg, reinterpret_cast(&ip0[k])))); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -3116,958 +2164,530 @@ static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + const at::BFloat16* const ip8 = &input[idx8 * block_size]; + const at::BFloat16* const ip9 = &input[idx9 * block_size]; + const at::BFloat16* const ip10 = &input[idx10 * block_size]; + const at::BFloat16* const ip11 = &input[idx11 * block_size]; + const at::BFloat16* const ip12 = &input[idx12 * block_size]; + const at::BFloat16* const ip13 = &input[idx13 * block_size]; + const at::BFloat16* const ip14 = &input[idx14 * block_size]; + const at::BFloat16* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -4140,958 +2760,530 @@ static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[16 * vLen])), - 16)), - vsum16); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[17 * vLen])), - 16)), - vsum17); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[18 * vLen])), - 16)), - vsum18); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[19 * vLen])), - 16)), - vsum19); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[20 * vLen])), - 16)), - vsum20); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[21 * vLen])), - 16)), - vsum21); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[22 * vLen])), - 16)), - vsum22); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[23 * vLen])), - 16)), - vsum23); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[24 * vLen])), - 16)), - vsum24); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[25 * vLen])), - 16)), - vsum25); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[26 * vLen])), - 16)), - vsum26); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[27 * vLen])), - 16)), - vsum27); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[28 * vLen])), - 16)), - vsum28); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[29 * vLen])), - 16)), - vsum29); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[30 * vLen])), - 16)), - vsum30); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[31 * vLen])), - 16)), - vsum31); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[8 * vLen])), - 16)), - vsum8); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[9 * vLen])), - 16)), - vsum9); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[10 * vLen])), - 16)), - vsum10); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[11 * vLen])), - 16)), - vsum11); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[12 * vLen])), - 16)), - vsum12); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[13 * vLen])), - 16)), - vsum13); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[14 * vLen])), - 16)), - vsum14); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[15 * vLen])), - 16)), - vsum15); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + const at::BFloat16* const ip8 = &input[idx8 * block_size]; + const at::BFloat16* const ip9 = &input[idx9 * block_size]; + const at::BFloat16* const ip10 = &input[idx10 * block_size]; + const at::BFloat16* const ip11 = &input[idx11 * block_size]; + const at::BFloat16* const ip12 = &input[idx12 * block_size]; + const at::BFloat16* const ip13 = &input[idx13 * block_size]; + const at::BFloat16* const ip14 = &input[idx14 * block_size]; + const at::BFloat16* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + auto input8 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip8[k])), 16)); + auto input9 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip9[k])), 16)); + auto input10 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip10[k])), 16)); + auto input11 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip11[k])), 16)); + auto input12 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip12[k])), 16)); + auto input13 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip13[k])), 16)); + auto input14 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip14[k])), 16)); + auto input15 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip15[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[4 * vLen])), - 16)), - vsum4); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[5 * vLen])), - 16)), - vsum5); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[6 * vLen])), - 16)), - vsum6); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[7 * vLen])), - 16)), - vsum7); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + const at::BFloat16* const ip4 = &input[idx4 * block_size]; + const at::BFloat16* const ip5 = &input[idx5 * block_size]; + const at::BFloat16* const ip6 = &input[idx6 * block_size]; + const at::BFloat16* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + auto input4 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip4[k])), 16)); + auto input5 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip5[k])), 16)); + auto input6 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip6[k])), 16)); + auto input7 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip7[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[2 * vLen])), - 16)), - vsum2); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[3 * vLen])), - 16)), - vsum3); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + const at::BFloat16* const ip2 = &input[idx2 * block_size]; + const at::BFloat16* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + auto input2 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip2[k])), 16)); + auto input3 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip3[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[0 * vLen])), - 16)), - vsum0); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - svAll, - svld1uh_u32( - svAll, reinterpret_cast(&ip[1 * vLen])), - 16)), - vsum1); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + const at::BFloat16* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + auto input1 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip1[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - const svfloat32_t vwgt = svdup_n_f32(wgt); - const at::BFloat16* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svreinterpret_f32_u32(svlsl_n_u32_x( - pg, - svld1uh_u32( - pg, reinterpret_cast(&ip[k])), - 16)), - svld1_f32(pg, &op[k]))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + const at::BFloat16* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(svAll, + svld1uh_u32(svAll, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + auto input0 = svreinterpret_f32(svlsl_x(pg, + svld1uh_u32(pg, reinterpret_cast(&ip0[k])), 16)); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -5164,743 +3356,555 @@ static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + bio += wgt8 * scale_bias[2 * idx8 + 1]; + wgt8 = wgt8 * scale_bias[2 * idx8]; + bio += wgt9 * scale_bias[2 * idx9 + 1]; + wgt9 = wgt9 * scale_bias[2 * idx9]; + bio += wgt10 * scale_bias[2 * idx10 + 1]; + wgt10 = wgt10 * scale_bias[2 * idx10]; + bio += wgt11 * scale_bias[2 * idx11 + 1]; + wgt11 = wgt11 * scale_bias[2 * idx11]; + bio += wgt12 * scale_bias[2 * idx12 + 1]; + wgt12 = wgt12 * scale_bias[2 * idx12]; + bio += wgt13 * scale_bias[2 * idx13 + 1]; + wgt13 = wgt13 * scale_bias[2 * idx13]; + bio += wgt14 * scale_bias[2 * idx14 + 1]; + wgt14 = wgt14 * scale_bias[2 * idx14]; + bio += wgt15 * scale_bias[2 * idx15 + 1]; + wgt15 = wgt15 * scale_bias[2 * idx15]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + const uint8_t* const ip8 = &input[idx8 * block_size]; + const uint8_t* const ip9 = &input[idx9 * block_size]; + const uint8_t* const ip10 = &input[idx10 * block_size]; + const uint8_t* const ip11 = &input[idx11 * block_size]; + const uint8_t* const ip12 = &input[idx12 * block_size]; + const uint8_t* const ip13 = &input[idx13 * block_size]; + const uint8_t* const ip14 = &input[idx14 * block_size]; + const uint8_t* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k])); + auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k])); + auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k])); + auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k])); + auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k])); + auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k])); + auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k])); + auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k])); + auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k])); + auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k])); + auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k])); + auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k])); + auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k])); + auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k])); + auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } @@ -5973,743 +3977,555 @@ static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve( const svbool_t svAll = svptrue_b32(); const auto vLen = static_cast(svcntw()); int64_t pos = 0; - if (block_size == 32 * vLen) { - // unrolling 32 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - svfloat32_t vsum16 = svdup_n_f32(0); - svfloat32_t vsum17 = svdup_n_f32(0); - svfloat32_t vsum18 = svdup_n_f32(0); - svfloat32_t vsum19 = svdup_n_f32(0); - svfloat32_t vsum20 = svdup_n_f32(0); - svfloat32_t vsum21 = svdup_n_f32(0); - svfloat32_t vsum22 = svdup_n_f32(0); - svfloat32_t vsum23 = svdup_n_f32(0); - svfloat32_t vsum24 = svdup_n_f32(0); - svfloat32_t vsum25 = svdup_n_f32(0); - svfloat32_t vsum26 = svdup_n_f32(0); - svfloat32_t vsum27 = svdup_n_f32(0); - svfloat32_t vsum28 = svdup_n_f32(0); - svfloat32_t vsum29 = svdup_n_f32(0); - svfloat32_t vsum30 = svdup_n_f32(0); - svfloat32_t vsum31 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - vsum16 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), - svadd_f32_x(svAll, vsum16, vbio)); - vsum17 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), - svadd_f32_x(svAll, vsum17, vbio)); - vsum18 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), - svadd_f32_x(svAll, vsum18, vbio)); - vsum19 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), - svadd_f32_x(svAll, vsum19, vbio)); - vsum20 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), - svadd_f32_x(svAll, vsum20, vbio)); - vsum21 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), - svadd_f32_x(svAll, vsum21, vbio)); - vsum22 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), - svadd_f32_x(svAll, vsum22, vbio)); - vsum23 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), - svadd_f32_x(svAll, vsum23, vbio)); - vsum24 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), - svadd_f32_x(svAll, vsum24, vbio)); - vsum25 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), - svadd_f32_x(svAll, vsum25, vbio)); - vsum26 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), - svadd_f32_x(svAll, vsum26, vbio)); - vsum27 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), - svadd_f32_x(svAll, vsum27, vbio)); - vsum28 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), - svadd_f32_x(svAll, vsum28, vbio)); - vsum29 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), - svadd_f32_x(svAll, vsum29, vbio)); - vsum30 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), - svadd_f32_x(svAll, vsum30, vbio)); - vsum31 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), - svadd_f32_x(svAll, vsum31, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); - svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); - svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); - svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); - svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); - svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); - svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); - svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); - svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); - svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); - svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); - svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); - svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); - svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); - svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); - svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); - svst1_f32(svAll, &op[16 * vLen], vsum16); - svst1_f32(svAll, &op[17 * vLen], vsum17); - svst1_f32(svAll, &op[18 * vLen], vsum18); - svst1_f32(svAll, &op[19 * vLen], vsum19); - svst1_f32(svAll, &op[20 * vLen], vsum20); - svst1_f32(svAll, &op[21 * vLen], vsum21); - svst1_f32(svAll, &op[22 * vLen], vsum22); - svst1_f32(svAll, &op[23 * vLen], vsum23); - svst1_f32(svAll, &op[24 * vLen], vsum24); - svst1_f32(svAll, &op[25 * vLen], vsum25); - svst1_f32(svAll, &op[26 * vLen], vsum26); - svst1_f32(svAll, &op[27 * vLen], vsum27); - svst1_f32(svAll, &op[28 * vLen], vsum28); - svst1_f32(svAll, &op[29 * vLen], vsum29); - svst1_f32(svAll, &op[30 * vLen], vsum30); - svst1_f32(svAll, &op[31 * vLen], vsum31); - } + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; } - } else if (block_size == 16 * vLen) { + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + int64_t j = start_offset; // unrolling 16 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - svfloat32_t vsum8 = svdup_n_f32(0); - svfloat32_t vsum9 = svdup_n_f32(0); - svfloat32_t vsum10 = svdup_n_f32(0); - svfloat32_t vsum11 = svdup_n_f32(0); - svfloat32_t vsum12 = svdup_n_f32(0); - svfloat32_t vsum13 = svdup_n_f32(0); - svfloat32_t vsum14 = svdup_n_f32(0); - svfloat32_t vsum15 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - vsum8 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), - svadd_f32_x(svAll, vsum8, vbio)); - vsum9 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), - svadd_f32_x(svAll, vsum9, vbio)); - vsum10 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), - svadd_f32_x(svAll, vsum10, vbio)); - vsum11 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), - svadd_f32_x(svAll, vsum11, vbio)); - vsum12 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), - svadd_f32_x(svAll, vsum12, vbio)); - vsum13 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), - svadd_f32_x(svAll, vsum13, vbio)); - vsum14 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), - svadd_f32_x(svAll, vsum14, vbio)); - vsum15 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), - svadd_f32_x(svAll, vsum15, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); - svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); - svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); - svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); - svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); - svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); - svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); - svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); - svst1_f32(svAll, &op[8 * vLen], vsum8); - svst1_f32(svAll, &op[9 * vLen], vsum9); - svst1_f32(svAll, &op[10 * vLen], vsum10); - svst1_f32(svAll, &op[11 * vLen], vsum11); - svst1_f32(svAll, &op[12 * vLen], vsum12); - svst1_f32(svAll, &op[13 * vLen], vsum13); - svst1_f32(svAll, &op[14 * vLen], vsum14); - svst1_f32(svAll, &op[15 * vLen], vsum15); + while (j + 15 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + const auto idx8 = indices[pos + 8]; + const auto idx9 = indices[pos + 9]; + const auto idx10 = indices[pos + 10]; + const auto idx11 = indices[pos + 11]; + const auto idx12 = indices[pos + 12]; + const auto idx13 = indices[pos + 13]; + const auto idx14 = indices[pos + 14]; + const auto idx15 = indices[pos + 15]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; + } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + if (idx8 < 0 || idx8 >= data_size) { + return false; + } + if (idx9 < 0 || idx9 >= data_size) { + return false; + } + if (idx10 < 0 || idx10 >= data_size) { + return false; + } + if (idx11 < 0 || idx11 >= data_size) { + return false; + } + if (idx12 < 0 || idx12 >= data_size) { + return false; + } + if (idx13 < 0 || idx13 >= data_size) { + return false; + } + if (idx14 < 0 || idx14 >= data_size) { + return false; } + if (idx15 < 0 || idx15 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float wgt8 = 1.f; + float wgt9 = 1.f; + float wgt10 = 1.f; + float wgt11 = 1.f; + float wgt12 = 1.f; + float wgt13 = 1.f; + float wgt14 = 1.f; + float wgt15 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + wgt8 = weights[IS_WEIGHT_POSITIONAL ? (j + 8 - start_offset) : pos + 8]; + wgt9 = weights[IS_WEIGHT_POSITIONAL ? (j + 9 - start_offset) : pos + 9]; + wgt10 = weights[IS_WEIGHT_POSITIONAL ? (j + 10 - start_offset) : pos + 10]; + wgt11 = weights[IS_WEIGHT_POSITIONAL ? (j + 11 - start_offset) : pos + 11]; + wgt12 = weights[IS_WEIGHT_POSITIONAL ? (j + 12 - start_offset) : pos + 12]; + wgt13 = weights[IS_WEIGHT_POSITIONAL ? (j + 13 - start_offset) : pos + 13]; + wgt14 = weights[IS_WEIGHT_POSITIONAL ? (j + 14 - start_offset) : pos + 14]; + wgt15 = weights[IS_WEIGHT_POSITIONAL ? (j + 15 - start_offset) : pos + 15]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + bio += wgt8 * scale_bias[2 * idx8 + 1]; + wgt8 = wgt8 * scale_bias[2 * idx8]; + bio += wgt9 * scale_bias[2 * idx9 + 1]; + wgt9 = wgt9 * scale_bias[2 * idx9]; + bio += wgt10 * scale_bias[2 * idx10 + 1]; + wgt10 = wgt10 * scale_bias[2 * idx10]; + bio += wgt11 * scale_bias[2 * idx11 + 1]; + wgt11 = wgt11 * scale_bias[2 * idx11]; + bio += wgt12 * scale_bias[2 * idx12 + 1]; + wgt12 = wgt12 * scale_bias[2 * idx12]; + bio += wgt13 * scale_bias[2 * idx13 + 1]; + wgt13 = wgt13 * scale_bias[2 * idx13]; + bio += wgt14 * scale_bias[2 * idx14 + 1]; + wgt14 = wgt14 * scale_bias[2 * idx14]; + bio += wgt15 * scale_bias[2 * idx15 + 1]; + wgt15 = wgt15 * scale_bias[2 * idx15]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + const uint8_t* const ip8 = &input[idx8 * block_size]; + const uint8_t* const ip9 = &input[idx9 * block_size]; + const uint8_t* const ip10 = &input[idx10 * block_size]; + const uint8_t* const ip11 = &input[idx11 * block_size]; + const uint8_t* const ip12 = &input[idx12 * block_size]; + const uint8_t* const ip13 = &input[idx13 * block_size]; + const uint8_t* const ip14 = &input[idx14 * block_size]; + const uint8_t* const ip15 = &input[idx15 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + auto input8 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip8[k])); + auto input9 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip9[k])); + auto input10 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip10[k])); + auto input11 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip11[k])); + auto input12 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip12[k])); + auto input13 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip13[k])); + auto input14 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip14[k])); + auto input15 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip15[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + output = svmla_x(svAll, output, input8, wgt8); + output = svmla_x(svAll, output, input9, wgt9); + output = svmla_x(svAll, output, input10, wgt10); + output = svmla_x(svAll, output, input11, wgt11); + output = svmla_x(svAll, output, input12, wgt12); + output = svmla_x(svAll, output, input13, wgt13); + output = svmla_x(svAll, output, input14, wgt14); + output = svmla_x(svAll, output, input15, wgt15); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + auto input8 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip8[k])); + auto input9 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip9[k])); + auto input10 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip10[k])); + auto input11 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip11[k])); + auto input12 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip12[k])); + auto input13 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip13[k])); + auto input14 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip14[k])); + auto input15 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip15[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + output = svmla_x(pg, output, input8, wgt8); + output = svmla_x(pg, output, input9, wgt9); + output = svmla_x(pg, output, input10, wgt10); + output = svmla_x(pg, output, input11, wgt11); + output = svmla_x(pg, output, input12, wgt12); + output = svmla_x(pg, output, input13, wgt13); + output = svmla_x(pg, output, input14, wgt14); + output = svmla_x(pg, output, input15, wgt15); + svst1(pg, &op[k], output); + k += vLen; + } + j += 16; + pos += 16; } - } else if (block_size == 8 * vLen) { // unrolling 8 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - svfloat32_t vsum4 = svdup_n_f32(0); - svfloat32_t vsum5 = svdup_n_f32(0); - svfloat32_t vsum6 = svdup_n_f32(0); - svfloat32_t vsum7 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - vsum4 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), - svadd_f32_x(svAll, vsum4, vbio)); - vsum5 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), - svadd_f32_x(svAll, vsum5, vbio)); - vsum6 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), - svadd_f32_x(svAll, vsum6, vbio)); - vsum7 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), - svadd_f32_x(svAll, vsum7, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); - svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); - svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); - svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); - svst1_f32(svAll, &op[4 * vLen], vsum4); - svst1_f32(svAll, &op[5 * vLen], vsum5); - svst1_f32(svAll, &op[6 * vLen], vsum6); - svst1_f32(svAll, &op[7 * vLen], vsum7); + while (j + 7 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + const auto idx4 = indices[pos + 4]; + const auto idx5 = indices[pos + 5]; + const auto idx6 = indices[pos + 6]; + const auto idx7 = indices[pos + 7]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; + } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + if (idx4 < 0 || idx4 >= data_size) { + return false; + } + if (idx5 < 0 || idx5 >= data_size) { + return false; + } + if (idx6 < 0 || idx6 >= data_size) { + return false; } + if (idx7 < 0 || idx7 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float wgt4 = 1.f; + float wgt5 = 1.f; + float wgt6 = 1.f; + float wgt7 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + wgt4 = weights[IS_WEIGHT_POSITIONAL ? (j + 4 - start_offset) : pos + 4]; + wgt5 = weights[IS_WEIGHT_POSITIONAL ? (j + 5 - start_offset) : pos + 5]; + wgt6 = weights[IS_WEIGHT_POSITIONAL ? (j + 6 - start_offset) : pos + 6]; + wgt7 = weights[IS_WEIGHT_POSITIONAL ? (j + 7 - start_offset) : pos + 7]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + bio += wgt4 * scale_bias[2 * idx4 + 1]; + wgt4 = wgt4 * scale_bias[2 * idx4]; + bio += wgt5 * scale_bias[2 * idx5 + 1]; + wgt5 = wgt5 * scale_bias[2 * idx5]; + bio += wgt6 * scale_bias[2 * idx6 + 1]; + wgt6 = wgt6 * scale_bias[2 * idx6]; + bio += wgt7 * scale_bias[2 * idx7 + 1]; + wgt7 = wgt7 * scale_bias[2 * idx7]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + const uint8_t* const ip4 = &input[idx4 * block_size]; + const uint8_t* const ip5 = &input[idx5 * block_size]; + const uint8_t* const ip6 = &input[idx6 * block_size]; + const uint8_t* const ip7 = &input[idx7 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + auto input4 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip4[k])); + auto input5 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip5[k])); + auto input6 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip6[k])); + auto input7 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip7[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + output = svmla_x(svAll, output, input4, wgt4); + output = svmla_x(svAll, output, input5, wgt5); + output = svmla_x(svAll, output, input6, wgt6); + output = svmla_x(svAll, output, input7, wgt7); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + auto input4 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip4[k])); + auto input5 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip5[k])); + auto input6 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip6[k])); + auto input7 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip7[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + output = svmla_x(pg, output, input4, wgt4); + output = svmla_x(pg, output, input5, wgt5); + output = svmla_x(pg, output, input6, wgt6); + output = svmla_x(pg, output, input7, wgt7); + svst1(pg, &op[k], output); + k += vLen; + } + j += 8; + pos += 8; } - } else if (block_size == 4 * vLen) { // unrolling 4 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - svfloat32_t vsum2 = svdup_n_f32(0); - svfloat32_t vsum3 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - vsum2 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), - svadd_f32_x(svAll, vsum2, vbio)); - vsum3 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), - svadd_f32_x(svAll, vsum3, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); - svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); - svst1_f32(svAll, &op[2 * vLen], vsum2); - svst1_f32(svAll, &op[3 * vLen], vsum3); + while (j + 3 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + const auto idx2 = indices[pos + 2]; + const auto idx3 = indices[pos + 3]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; + } + if (idx2 < 0 || idx2 >= data_size) { + return false; } + if (idx3 < 0 || idx3 >= data_size) { + return false; + } + float wgt0 = 1.f; + float wgt1 = 1.f; + float wgt2 = 1.f; + float wgt3 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + wgt2 = weights[IS_WEIGHT_POSITIONAL ? (j + 2 - start_offset) : pos + 2]; + wgt3 = weights[IS_WEIGHT_POSITIONAL ? (j + 3 - start_offset) : pos + 3]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + bio += wgt2 * scale_bias[2 * idx2 + 1]; + wgt2 = wgt2 * scale_bias[2 * idx2]; + bio += wgt3 * scale_bias[2 * idx3 + 1]; + wgt3 = wgt3 * scale_bias[2 * idx3]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + const uint8_t* const ip2 = &input[idx2 * block_size]; + const uint8_t* const ip3 = &input[idx3 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + auto input2 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip2[k])); + auto input3 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip3[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + output = svmla_x(svAll, output, input2, wgt2); + output = svmla_x(svAll, output, input3, wgt3); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + auto input2 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip2[k])); + auto input3 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip3[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + output = svmla_x(pg, output, input2, wgt2); + output = svmla_x(pg, output, input3, wgt3); + svst1(pg, &op[k], output); + k += vLen; + } + j += 4; + pos += 4; } - } else if (block_size == 2 * vLen) { // unrolling 2 times - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - if (pos != offsets[i] - offsets[0]) { - return false; - } - svfloat32_t vsum0 = svdup_n_f32(0); - svfloat32_t vsum1 = svdup_n_f32(0); - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* const ip = &input[idx * block_size]; - // weight * input + out - vsum0 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), - svadd_f32_x(svAll, vsum0, vbio)); - vsum1 = svmad_f32_x( - svAll, - vwgt, - svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), - svadd_f32_x(svAll, vsum1, vbio)); - ++pos; - } - // Normalisation - const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - const svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); - svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); - } else { - svst1_f32(svAll, &op[0 * vLen], vsum0); - svst1_f32(svAll, &op[1 * vLen], vsum1); + while (j + 1 < end_offset) { + const auto idx0 = indices[pos + 0]; + const auto idx1 = indices[pos + 1]; + if (idx0 < 0 || idx0 >= data_size) { + return false; + } + if (idx1 < 0 || idx1 >= data_size) { + return false; } + float wgt0 = 1.f; + float wgt1 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + wgt1 = weights[IS_WEIGHT_POSITIONAL ? (j + 1 - start_offset) : pos + 1]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + bio += wgt1 * scale_bias[2 * idx1 + 1]; + wgt1 = wgt1 * scale_bias[2 * idx1]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + const uint8_t* const ip1 = &input[idx1 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + auto input1 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip1[k])); + output = svmla_x(svAll, output, input0, wgt0); + output = svmla_x(svAll, output, input1, wgt1); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + auto input1 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip1[k])); + output = svmla_x(pg, output, input0, wgt0); + output = svmla_x(pg, output, input1, wgt1); + svst1(pg, &op[k], output); + k += vLen; + } + j += 2; + pos += 2; } - } else { - // generic code: - for (int64_t i = 0; i < output_size; ++i) { - float* const op = &out[i * block_size]; - memset(op, 0, sizeof(float) * block_size); - if (pos != offsets[i] - offsets[0]) { - return false; - } - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1]; - for (int64_t j = start_offset; j < end_offset; ++j) { - const auto idx = indices[pos]; - if (idx < 0 || idx >= data_size) { - return false; - } - // unimplemented - float wgt = 1.f; - float bio{}; - if (weights) { - wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; - } - if (scale_bias) { - bio = wgt * scale_bias[2 * idx + 1]; - wgt = wgt * scale_bias[2 * idx]; - } - svfloat32_t vbio = svdup_n_f32(bio); - const svfloat32_t vwgt = svdup_n_f32(wgt); - const uint8_t* ip = &input[idx * block_size]; - svbool_t pg; - for (int64_t k = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); - k += vLen) { - svst1_f32( - pg, - &op[k], - svmad_f32_x( - pg, - vwgt, - svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), - svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); - } - - ++pos; + // tail loop + if (j < end_offset) { + const auto idx0 = indices[pos + 0]; + if (idx0 < 0 || idx0 >= data_size) { + return false; } - const int64_t length = end_offset - start_offset; + float wgt0 = 1.f; + float bio = 0.f; + if (weights) { + wgt0 = weights[IS_WEIGHT_POSITIONAL ? (j + 0 - start_offset) : pos + 0]; + } + if (scale_bias) { + bio += wgt0 * scale_bias[2 * idx0 + 1]; + wgt0 = wgt0 * scale_bias[2 * idx0]; + } + const uint8_t* const ip0 = &input[idx0 * block_size]; + svbool_t pg; + int64_t k = 0; + while (k + vLen - 1 < block_size) { + auto output = svld1(svAll, &op[k]); + output = svadd_x(svAll, output, bio); + auto input0 = svcvt_f32_x(svAll, svld1ub_u32(svAll, &ip0[k])); + output = svmla_x(svAll, output, input0, wgt0); + svst1(svAll, &op[k], output); + k += vLen; + } + if (k < block_size) { + pg = svwhilelt_b32_s64(k, block_size); + auto output = svld1(pg, &op[k]); + output = svadd_x(pg, output, bio); + auto input0 = svcvt_f32_x(pg, svld1ub_u32(pg, &ip0[k])); + output = svmla_x(pg, output, input0, wgt0); + svst1(pg, &op[k], output); + k += vLen; + } + pos ++; + } + const int64_t length = end_offset - start_offset; - if (normalize_by_lengths && length != 0) { - const float len_inv = 1.0f / length; - svfloat32_t vlen_inv = svdup_n_f32(len_inv); - svbool_t pg; - for (int64_t j = 0; - svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); - j += vLen) { - svst1_f32( - pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); - } + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svbool_t pg; + int64_t j = 0; + while (j + vLen - 1 < block_size) { + svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv)); + j += vLen; + } + if (j < block_size) { + pg = svwhilelt_b32_s64(j, block_size); + svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv)); } } } diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py index 643b614c9081..4c5ad01bdc10 100644 --- a/caffe2/perfkernels/sve_emblookup_codegen.py +++ b/caffe2/perfkernels/sve_emblookup_codegen.py @@ -4,289 +4,105 @@ # Unroll loops when block_size is a multiple of vector length. -def unroll(num_unrolls, IndexType, InType, OutType, use_weights): - def compute(regid, InType, use_weights): +def unroll(num_unrolls, IndexType, InType, OutType): + def compute_output(num_unrolls, InType, is_main): code = [] + pred = "svAll" if is_main else "pg" if InType == "float": - code.append( - f" vsum{regid} =\n" - " svmad_f32_x(" - f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen])," - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, svld1(svAll, &ip{i}[k]), wgt{i});") elif InType == "at::Half": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " svAll,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])))),\n" - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" auto input{i} = svcvt_f32_x({pred}, svreinterpret_f16(\n" + f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k]))));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "at::BFloat16": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " svAll,\n" - " svld1uh_u32(\n" - " svAll, reinterpret_cast(" - f"&ip[{regid} * vLen])),\n" - " 16)),\n" - f" vsum{regid});" - ) + for i in range(num_unrolls): + code.append(f" auto input{i} = svreinterpret_f32(svlsl_x({pred},\n" + f" svld1uh_u32({pred}, reinterpret_cast(&ip{i}[k])), 16));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") elif InType == "uint8_t": - code.append( - f" vsum{regid} = svmad_f32_x(\n" - " svAll,\n" - " vwgt,\n" - " svcvt_f32_u32_x(svAll," - f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" - f" svadd_f32_x(svAll, vsum{regid}, vbio));" - ) + code.append(f" output = svadd_x({pred}, output, bio);") + for i in range(num_unrolls): + code.append(f" auto input{i} = svcvt_f32_x({pred}, svld1ub_u32({pred}, &ip{i}[k]));") + for i in range(num_unrolls): + code.append(f" output = svmla_x({pred}, output, input{i}, wgt{i});") else: raise ValueError(f'Unknown datatype "{InType}"') return code code = [] - code.append(f" // unrolling {num_unrolls} times") - code.append(" for (int64_t i = 0; i < output_size; ++i) {") - - code.append(" " + OutType + "* const op = &out[i * block_size];") - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - - # Initialise vector sum registers + if num_unrolls == 1: + code.append(f" // tail loop") + code.append(" if (j < end_offset) {") + else: + code.append(f" // unrolling {num_unrolls} times") + code.append(f" while (j + {num_unrolls - 1} < end_offset) {{") for i in range(num_unrolls): - code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);") - - # inner loop - code.append("""\ - int64_t start_offset = offsets[i]; - int64_t end_offset = offsets[i + 1];""") - code.append( - " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) + code.append(f" const auto idx{i} = indices[pos + {i}];") - if InType == "uint8_t": - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") + # check indices + for i in range(num_unrolls): code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" + f" if (idx{i} < 0 || idx{i} >= data_size) {{\n" + + " return false;\n" + + " }" ) - code.append(" }") - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* const ip = &input[idx * block_size];") - code.append(" // weight * input + out") + if InType == "uint8_t": + for i in range(num_unrolls): + code.append(f" {OutType} wgt{i} = 1.f;") + code.append(f" {OutType} bio = 0.f;") + else: + for i in range(num_unrolls): + code.append(f" {OutType} wgt{i} = 1.f;") + code.append(" if (weights) {") for i in range(num_unrolls): - code.extend(compute(i, InType, use_weights)) - - code.append(" ++pos;") + code.append(f" wgt{i} = weights[IS_WEIGHT_POSITIONAL ? (j + {i} - start_offset) : pos + {i}];") code.append(" }") + if InType == "uint8_t": + code.append(" if (scale_bias) {") + for i in range(num_unrolls): + code.append(f" bio += wgt{i} * scale_bias[2 * idx{i} + 1];") + code.append(f" wgt{i} = wgt{i} * scale_bias[2 * idx{i}];") + code.append(" }") - code.append(" // Normalisation") - code.append(" const int64_t length = end_offset - start_offset;") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - - for i in range(num_unrolls): - code.append( - f" svst1_f32(svAll, &op[{i} * vLen]," - + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));" - ) - - code.append(" } else {") - # inv of length for i in range(num_unrolls): - code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});") - + code.append(f" const {InType}* const ip{i} = &input[idx{i} * block_size];") + + # compute and store + code.append(" svbool_t pg;") + code.append(" int64_t k = 0;") + # main loop + code.append(" while (k + vLen - 1 < block_size) {") + code.append(" auto output = svld1(svAll, &op[k]);") + code.extend(compute_output(num_unrolls, InType, True)) + code.append(" svst1(svAll, &op[k], output);") + code.append(" k += vLen;") code.append(" }") - code.append(" }") - return code - - -# Handle the case where block_size is not a multiple of vector length. -def generic(IndexType, InType, OutType, use_weights): - def compute(InType, use_weights): - code = [] - if InType == "float": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg, vwgt, svld1_f32(pg, &ip[k])," - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::Half": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_f16_x(\n" - " pg,\n" - " svreinterpret_f16_u32(svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])))),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "at::BFloat16": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svreinterpret_f32_u32(svlsl_n_u32_x(\n" - " pg,\n" - " svld1uh_u32(\n" - " pg," - " reinterpret_cast(&ip[k])),\n" - " 16)),\n" - " svld1_f32(pg, &op[k])));" - ) - elif InType == "uint8_t": - code.append( - " svst1_f32(\n" - " pg,\n" - " &op[k],\n" - " svmad_f32_x(\n" - " pg,\n" - " vwgt,\n" - " svcvt_f32_u32_x(pg," - " svld1ub_u32(pg, &ip[k])),\n" - " svadd_f32_x(pg," - " svld1_f32(pg, &op[k]), vbio)));" - ) - else: - raise ValueError(f'Unknown datatype "{InType}"') - - return code - - code = [] - - code.append(" for (int64_t i = 0; i < output_size; ++i) {") - - code.append(" " + OutType + "* const op = &out[i * block_size];") - - # initialize to 0 - code.append(" memset(op, 0, sizeof(float) * block_size);") - - # inner loop - code.append( - " if (pos != offsets[i] - offsets[0]) {\n" - + " return false;\n" - + " }" - ) - code.append( - " int64_t start_offset = offsets[i];\n" - + " int64_t end_offset = offsets[i + 1];" - ) - code.append( - " for (" + "int64_t" + " j = start_offset; j < end_offset; ++j) {" - ) - - code.append(" const auto idx = indices[pos];") - code.append( - " if (idx < 0 || idx >= data_size) {\n" - + " return false;\n" - + " }" - ) - - if InType == "uint8_t": - code.append(" // unimplemented") - code.append(" " + OutType + " wgt = 1.f;") - code.append(" " + OutType + " bio{};") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - code.append(" if (scale_bias) {") - code.append(" bio = wgt * scale_bias[2 * idx + 1];") - code.append(" wgt = wgt * scale_bias[2 * idx];") - code.append(" }") - code.append(" svfloat32_t vbio = svdup_n_f32(bio);") - else: - code.append(" " + OutType + " wgt = 1.f;") - code.append(" if (weights) {") - code.append( - " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" - ) - code.append(" }") - - code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") - code.append(f" const {InType}* ip = &input[idx * block_size];") - - # compute and store main loop - code.append(" svbool_t pg;") - code.append(" for (int64_t k = 0;") - code.append( - " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "k, block_size));" - ) - code.append(" k += vLen) {") - code.extend(compute(InType, use_weights)) - code.append(" }\n") - code.append(" ++pos;") + # tail loop + code.append(" if (k < block_size) {") + code.append(" pg = svwhilelt_b32_s64(k, block_size);") + code.append(" auto output = svld1(pg, &op[k]);") + code.extend(compute_output(num_unrolls, InType, False)) + code.append(" svst1(pg, &op[k], output);") + code.append(" k += vLen;") code.append(" }") + if num_unrolls == 1: + code.append(" pos ++;") + else: + code.append(f" j += {num_unrolls};") + code.append(f" pos += {num_unrolls};") - code.append(" const int64_t length = end_offset - start_offset;\n") - code.append(" if (normalize_by_lengths && length != 0) {") - code.append(" const float len_inv = 1.0f / length;") - code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") - code.append(" svbool_t pg;") - code.append( - " for (int64_t j = 0;\n" - " svptest_first(svAll, pg = svwhilelt_b32_s64(" - "j, block_size));" - ) - code.append(" j += vLen) {") - code.append( - " svst1_f32(\n" - " pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));" - ) - code.append(" }") - code.append(" }") code.append(" }") - return code + return code def main(): parser = argparse.ArgumentParser() @@ -352,22 +168,47 @@ def main(): code.append(" const auto vLen = static_cast(svcntw());") code.append(" int64_t pos = 0;") - code.append(" if (block_size == 32 * vLen) {") - code += unroll(32, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 16 * vLen) {") - code += unroll(16, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 8 * vLen) {") - code += unroll(8, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 4 * vLen) {") - code += unroll(4, IndexType, InType, OutType, True) - code.append(" } else if (block_size == 2 * vLen) {") - code += unroll(2, IndexType, InType, OutType, True) - code.append(" } else {") - code.append(" // generic code:") - code += generic(IndexType, InType, OutType, True) + code.append(" for (int64_t i = 0; i < output_size; ++i) {") + code.append(" " + OutType + "* const op = &out[i * block_size];") + + # initialize to 0 + code.append(" memset(op, 0, sizeof(float) * block_size);") + + # inner loop + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + code.append( + " int64_t start_offset = offsets[i];\n" + + " int64_t end_offset = offsets[i + 1];" + ) + code.append(" int64_t j = start_offset;") + + code += unroll(16, IndexType, InType, OutType) + code += unroll(8, IndexType, InType, OutType) + code += unroll(4, IndexType, InType, OutType) + code += unroll(2, IndexType, InType, OutType) + code += unroll(1, IndexType, InType, OutType) + + code.append(" const int64_t length = end_offset - start_offset;\n") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" svbool_t pg;") + code.append(" int64_t j = 0;") + code.append(" while (j + vLen - 1 < block_size) {") + code.append(" svst1(svAll, &op[j], svmul_x(svAll, svld1(svAll, &op[j]), len_inv));") + code.append(" j += vLen;") + code.append(" }") + code.append(" if (j < block_size) {") + code.append(" pg = svwhilelt_b32_s64(j, block_size);") + code.append(" svst1(pg, &op[j], svmul_x(pg, svld1(pg, &op[j]), len_inv));") + code.append(" }") + code.append(" }") + code.append(" }") code.append(" return pos == index_size;") - code.append("}") for is_weight_positional in ["false", "true"]: From 2a1e2b88ed7bf7d7436b741ee0c3a2297d7d7bc2 Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 4 Apr 2025 11:25:29 -0700 Subject: [PATCH 244/332] [logging] Add pgo remote get/put timings to dynamo_compile (#150322) Test Plan: https://fburl.com/scuba/dynamo_compile/sandbox/xf950tw8 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150322 Approved by: https://github.com/ppanchalia --- test/dynamo/test_utils.py | 4 ++++ torch/_dynamo/pgo.py | 10 ++++++++-- torch/_dynamo/utils.py | 2 ++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index dd5a5c4593eb..9f51c11a87c6 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -337,6 +337,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'non_compliant_ops': set(), 'num_graph_breaks': 0, 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, 'post_grad_pass_time_us': 0, 'pre_grad_pass_time_us': 0, 'python_version': None, @@ -424,6 +426,8 @@ def test_dynamo_timed(self, mock_time, mock_time_ns): 'non_compliant_ops': None, 'num_graph_breaks': 0, 'num_triton_bundles': None, + 'pgo_get_remote_code_state_time_us': None, + 'pgo_put_remote_code_state_time_us': None, 'post_grad_pass_time_us': 0, 'pre_grad_pass_time_us': None, 'python_version': None, diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index 96ace1da75b4..8db484cab727 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -605,7 +605,9 @@ def hit(ty: str) -> defaultdict[CodeId, CodeState]: remote_cache = get_remote_cache() if remote_cache is not None: with dynamo_timed( - name := "pgo.get_remote_code_state", log_pt2_compile_event=True + name := "pgo.get_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_get_remote_code_state_time_us", ): CompileEventLogger.pt2_compile(name, cache_key=cache_key) # TODO: I don't really understand why there's a JSON container format @@ -716,7 +718,11 @@ def put_local_code_state(cache_key: str) -> None: def put_remote_code_state(cache_key: str) -> None: - with dynamo_timed(name := "pgo.put_remote_code_state", log_pt2_compile_event=True): + with dynamo_timed( + name := "pgo.put_remote_code_state", + log_pt2_compile_event=True, + dynamo_compile_column_us="pgo_put_remote_code_state_time_us", + ): CompileEventLogger.pt2_compile(name, cache_key=cache_key) assert _CODE_STATE is not None diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 13ee160a93d5..04ee7aa86d69 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1267,6 +1267,8 @@ class CompilationMetrics: ir_count: Optional[int] = None cudagraph_skip_reason: Optional[str] = None python_version: Optional[str] = None + pgo_put_remote_code_state_time_us: Optional[int] = None + pgo_get_remote_code_state_time_us: Optional[int] = None @classmethod def create(cls, metrics: dict[str, Any]): From f8b53f4a759b6b77ef6894194cb36dffb7312049 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Mon, 7 Apr 2025 18:58:37 +0000 Subject: [PATCH 245/332] [export] raise when Dim.DYNAMIC 0/1 specializes (#150716) Previously we didn't catch this, mark_dynamic() just doesn't allocate a symbol for it Differential Revision: D72486930 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150716 Approved by: https://github.com/angelayi --- test/export/test_export.py | 51 ++++++++++++++++++++++++++----- torch/_export/non_strict_utils.py | 12 +++++++- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 988e2fae81c6..343118c715c6 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2383,6 +2383,49 @@ def forward(self, x, y, z): ): export(Foo(), inputs, dynamic_shapes=shapes) + def test_dim_dynamic_specialization(self): + class Foo(torch.nn.Module): + def forward(self, x): + return x + 2 + + # 0/1 specialization + with self.assertRaisesRegex( + ValueError, + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 0 for dimension " + r"inputs\['x'\]\.shape\[0\](.*\n)*.*" + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 1 for dimension " + r"inputs\['x'\]\.shape\[1\].*", + ): + export( + Foo(), + (torch.randn(0, 1),), + dynamic_shapes={ + "x": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC}, + }, + ) + + class Bar(torch.nn.Module): + def forward(self, x): + assert x.shape[0] <= 32 + return x + 2 + + # static specialization + with self.assertRaisesRegex( + ValueError, + r"Received user-specified dim hint Dim.DYNAMIC.*" + r"but tracing inferred a static shape of 32 for dimension " + r"inputs\['x'\]\.shape\[0\](.*\n)*.*", + ): + export( + Bar(), + (torch.randn(32),), + dynamic_shapes={ + "x": {0: Dim.DYNAMIC(min=32)}, + }, + ) + def test_dim_hint_ranges(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -7494,14 +7537,6 @@ def check(inputs, epm): # output shape is (3, 2), with n_row 3 and n_sample 2 <= dist_size 2 check(inputs, epm) - inputs = ( - torch.tensor([[4, 5], [6, 7], [8, 9], [10, 11]], dtype=torch.float32), - torch.ones(1, dtype=torch.int64), - ) - epm = exported_module(inputs) - # output shape is (4, 1), with n_row 4 and n_sample 1 <= dist_size 2 - check(inputs, epm) - inputs = ( torch.tensor([[4, 5], [6, 7], [8, 9]], dtype=torch.float32), torch.ones(3, dtype=torch.int64), diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 6e65141acfac..3db84d43a484 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -27,6 +27,7 @@ _check_dynamic_shapes, _combine_args, _DimHint, + _DimHintType, _process_dynamic_shapes, _RelaxedConstraint, _tree_map_with_path, @@ -435,10 +436,19 @@ def is_int(x: object) -> bool: upper=int_oo if dim.max is None else dim.max, ) if is_int(d): - trace_vr & user_vr + out_vr = trace_vr & user_vr else: range_constraints[d.node.expr] &= user_vr shape_env.var_to_range[d.node._expr] &= user_vr + out_vr = range_constraints[d.node.expr] + # check for specializations + if dim.type == _DimHintType.DYNAMIC and out_vr.is_singleton(): + msg = ( + f"- Received user-specified dim hint Dim.DYNAMIC(min={dim.min}, max={dim.max}), " + f"but tracing inferred a static shape of {out_vr.lower} for dimension " + f"inputs{pytree.keystr(flat_paths[input_index])}.shape[{i}]." + ) + range_violations.append(msg) except torch.utils._sympy.value_ranges.ValueRangeError: msg = ( f"- Received user-specified min/max range of [{dim.min}, {dim.max}], " From bf1132c1967bcef44977887970101ec787d42a90 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Mon, 7 Apr 2025 20:09:53 +0000 Subject: [PATCH 246/332] Revert "Generalize poison fork logic for each device backend (#144664)" This reverts commit d86c14156d875b782b82dda96842a1f77910f010. Reverted https://github.com/pytorch/pytorch/pull/144664 on behalf of https://github.com/atalman due to failing periodic test: python test/test_cpp_extensions_mtia_backend.py TestCppExtensionMTIABackend.test_device_context ([comment](https://github.com/pytorch/pytorch/pull/144664#issuecomment-2784506104)) --- torch/csrc/cuda/Module.cpp | 36 ++++++++++++++++++------ torch/csrc/mps/Module.cpp | 30 +++++++++++++++++--- torch/csrc/mtia/Module.cpp | 31 +++++++++++++++++---- torch/csrc/utils/device_lazy_init.cpp | 40 --------------------------- torch/csrc/utils/device_lazy_init.h | 17 ------------ torch/csrc/xpu/Module.cpp | 34 +++++++++++++++++++---- 6 files changed, 108 insertions(+), 80 deletions(-) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index f5365a674d29..1ff4079a56e5 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -51,9 +51,32 @@ #include #include #include +#ifndef WIN32 +#include +#endif using namespace torch; +static bool in_bad_fork = false; // True for children forked after cuda init + +#ifndef WIN32 +// Called in the forked child if cuda has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kCUDA, true); +} +#endif + +// Should be called before the first cuda call. +// Note: This is distinct from initExtension because a stub cuda implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + //////////////////////////////////////////////////////////////////////////////// // CUDA management methods //////////////////////////////////////////////////////////////////////////////// @@ -137,17 +160,14 @@ PyObject* THCPModule_canDeviceAccessPeer_wrap(PyObject* self, PyObject* args) { PyObject* THCPModule_getDeviceCount_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // Note: This is distinct from initExtension because a stub cuda - // implementation has some working functions (e.g. device_count) but cannot - // fully initialize. - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + poison_fork(); return THPUtils_packUInt64(at::cuda::device_count()); END_HANDLE_TH_ERRORS } PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + poison_fork(); #ifdef CUDA_ARCH_FLAGS static const char* flags = C10_STRINGIZE(CUDA_ARCH_FLAGS); return THPUtils_packString(flags); @@ -159,7 +179,7 @@ PyObject* THCPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THCPModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kCUDA)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -1493,8 +1513,8 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { "please rebuild pytorch without asan if you need to use this module"); #endif HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kCUDA)); - torch::utils::register_fork_handler_for_device_init(at::kCUDA); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 0ec9b8418c6e..3694cd194179 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -6,12 +6,16 @@ #include #include #include -#include #include #include #include #include +// pthread.h is included for tracking bad forks +#ifndef WIN32 +#include +#endif + #ifdef USE_MPS #include #include @@ -19,9 +23,27 @@ namespace torch::mps { +namespace { +// True for children forked after mps init +static bool in_bad_fork = false; + +// Called in the forked child if mps has already been initialized +static void forked_mps_child() { + in_bad_fork = true; +} + +// Should be called before the first mps call. +static void track_bad_mps_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_mps_child); +#endif +} +} // namespace + static PyObject* MPSModule_isInBadFork(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kMPS)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -29,7 +51,7 @@ static PyObject* MPSModule_getDefaultMPSGenerator( PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS - torch::utils::register_fork_handler_for_device_init(at::kMPS); + track_bad_mps_fork(); return THPGenerator_initDefaultGenerator( at::detail::getMPSHooks().getDefaultGenerator()); END_HANDLE_TH_ERRORS @@ -37,8 +59,8 @@ static PyObject* MPSModule_getDefaultMPSGenerator( static PyObject* MPSModule_isAvailable(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS + track_bad_mps_fork(); if (at::detail::getMPSHooks().hasMPS()) { - torch::utils::register_fork_handler_for_device_init(at::kMPS); Py_RETURN_TRUE; } else { Py_RETURN_FALSE; diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index ec6229967e0b..405b9d780023 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -7,15 +7,38 @@ #include #include #include +#ifndef WIN32 +#include +#endif namespace torch::mtia { +static bool in_bad_fork = false; // True for children forked after mtia init + +#ifndef WIN32 +// Called in the forked child if mtia has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kMTIA, true); +} +#endif + +// Should be called before the first mtia call. +// Note: This is distinct from initExtension because a stub mtia implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + void initModule(PyObject* module) { auto m = py::handle(module).cast(); m.def("_mtia_init", []() { - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA)); - torch::utils::register_fork_handler_for_device_init(at::kMTIA); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); @@ -24,9 +47,7 @@ void initModule(PyObject* module) { return at::detail::isMTIAHooksBuilt(); }); - m.def("_mtia_isInBadFork", []() { - return torch::utils::is_device_in_bad_fork(at::kMTIA); - }); + m.def("_mtia_isInBadFork", []() { return in_bad_fork; }); m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) { torch::utils::device_lazy_init(at::kMTIA); diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index c5a6512b363c..74adb6b5e6b0 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -1,23 +1,13 @@ #include -#include #include #include #include #include - -#ifndef WIN32 -#include -#endif - namespace torch::utils { namespace { std::array is_initialized{}; -std::array is_in_bad_fork{}; -std::array - at_fork_once_flags{}; -std::optional at_fork_device_type{}; } // anonymous namespace @@ -68,34 +58,4 @@ void set_requires_device_init(at::DeviceType device_type, bool value) { is_initialized[static_cast(device_type)] = !value; } -bool is_device_in_bad_fork(at::DeviceType device_type) { - return is_in_bad_fork[static_cast(device_type)]; -} - -void set_device_in_bad_fork(at::DeviceType device_type, bool value) { - is_in_bad_fork[static_cast(device_type)] = value; -} - -// Should be called before the first device runtime call. -void register_fork_handler_for_device_init(at::DeviceType device_type) { -#ifndef WIN32 - auto& flag = at_fork_once_flags[static_cast(device_type)]; - c10::call_once(flag, [device_type]() { - TORCH_CHECK( - !at_fork_device_type, - "Only one device type can be registered. But now, we have two device types: ", - at_fork_device_type.value(), - " and ", - device_type); - at_fork_device_type = device_type; - pthread_atfork(nullptr, nullptr, []() { - set_device_in_bad_fork(at_fork_device_type.value(), true); - if (is_device_lazy_init_supported(at_fork_device_type.value())) { - set_requires_device_init(at_fork_device_type.value(), true); - } - }); - }); -#endif -} - } // namespace torch::utils diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index e65f16ace163..e1f480a60f77 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -67,21 +67,4 @@ inline void maybe_initialize_device( bool is_device_initialized(at::DeviceType device_type); -TORCH_PYTHON_API bool is_device_in_bad_fork(at::DeviceType device_type); - -TORCH_PYTHON_API void set_device_in_bad_fork( - at::DeviceType device_type, - bool value); - -TORCH_PYTHON_API void register_fork_handler_for_device_init( - at::DeviceType device_type); - -inline void maybe_register_fork_handler_for_device_init( - std::optional& device_type) { - if (!device_type.has_value()) { - return; - } - register_fork_handler_for_device_init(device_type.value()); -} - } // namespace torch::utils diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 8144dddd8298..43ad06365efc 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -11,8 +11,32 @@ #include #include +#ifndef WIN32 +#include +#endif + using namespace torch; +static bool in_bad_fork = false; // True for children forked after xpu init + +#ifndef WIN32 +// Called in the forked child if xpu has already been initialized +static void forked_child() { + in_bad_fork = true; + torch::utils::set_requires_device_init(at::kXPU, true); +} +#endif + +// Should be called before the first xpu call. It is mainly called in lazy_init. +// Note: This is distinct from initExtension because a stub xpu implementation +// has some working functions (e.g. device_count) but cannot fully initialize. +static void poison_fork() { +#ifndef WIN32 + static auto result [[maybe_unused]] = + pthread_atfork(nullptr, nullptr, forked_child); +#endif +} + // XPU management methods static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { @@ -28,7 +52,7 @@ static PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - return PyBool_FromLong(torch::utils::is_device_in_bad_fork(at::kXPU)); + return PyBool_FromLong(in_bad_fork); END_HANDLE_TH_ERRORS } @@ -91,9 +115,7 @@ static PyObject* THXPModule_getDeviceCount_wrap( PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // Note: This is distinct from initExtension because a stub xpu implementation - // has some working functions (e.g. device_count) but cannot fully initialize. - torch::utils::register_fork_handler_for_device_init(at::kXPU); + poison_fork(); return THPUtils_packUInt64(at::xpu::device_count()); END_HANDLE_TH_ERRORS } @@ -398,8 +420,8 @@ static void initXpuMethodBindings(PyObject* module) { // classes static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kXPU)); - torch::utils::register_fork_handler_for_device_init(at::kXPU); + TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level + poison_fork(); at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu")); From ed0dea3e24a2ba4d01043c4cfd27e90655692adc Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 10:33:09 -0700 Subject: [PATCH 247/332] [AO] update port_metadata_pass to support quant_affine ops (#150642) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150642 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/pt2e/port_metadata_pass.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index b0946d0075c9..0c96f915306d 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -27,17 +27,20 @@ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.tensor, torch.ops.quantized_decomposed.quantize_per_channel.default, + torch.ops.pt2e_quant.quantize_affine, ] _DEQUANTIZE_OPS = [ torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.dequantize_affine, ] _CHOOSE_QPARAMS_OPS = [ torch.ops.quantized_decomposed.choose_qparams.tensor, torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, + torch.ops.pt2e_quant.choose_qparams_affine, ] From 5653fb352505710a201106c1d6aa8eb57aaca0dd Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 10:33:10 -0700 Subject: [PATCH 248/332] [AO] Add Moving Average Affine Observer (#150643) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150643 Approved by: https://github.com/jerryzh168 ghstack dependencies: #150642 --- test/quantization/pt2e/test_quantize_pt2e.py | 92 ++++++++++ .../quantization/pt2e/_affine_quantization.py | 168 ++++++++++++++++++ 2 files changed, 260 insertions(+) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 08ffecc3aabd..b37904e57799 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2554,5 +2554,97 @@ def forward(self, x): is_debug_mode=True, ) + def test_dynamic_affine_act_per_channel_weights(self): + import operator + + from torch.ao.quantization.observer import ( + MappingType, + PerChannelMinMaxObserver, + PerToken, + ) + from torch.ao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMovingAverageMinMaxObserver, + ) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + activation_dtype = torch.int8 + act_qspec = QuantizationSpec( + dtype=activation_dtype, + quant_min=-128, + quant_max=127, + qscheme=None, + is_dynamic=True, + observer_or_fake_quant_ctr=AffineQuantizedMovingAverageMinMaxObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=activation_dtype, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + averaging_constant=1, + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.pt2e_quant.choose_qparams_affine: 1, + operator.getitem: 2, + torch.ops.pt2e_quant.quantize_affine: 1, + torch.ops.pt2e_quant.dequantize_affine: 1, + torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, + } + node_list = [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.pt2e_quant.choose_qparams_affine, + operator.getitem, + torch.ops.pt2e_quant.quantize_affine, + torch.ops.pt2e_quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py index 70ad5c0cde89..011f019524ff 100644 --- a/torch/ao/quantization/pt2e/_affine_quantization.py +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -2,6 +2,7 @@ # and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py # PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC import logging +import operator from abc import ABCMeta from typing import Any, Optional, Union @@ -9,6 +10,7 @@ from torch.ao.quantization.observer import ( AffineQuantizedObserverBase, get_block_size, + Granularity, MappingType, TorchAODType, ZeroPointDomain, @@ -773,3 +775,169 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): ) observer_node.replace_all_uses_with(dq_node) model.graph.erase_node(observer_node) + + +class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + averaging_constant=0.01, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + self.averaging_constant = averaging_constant + if is_dynamic and self.averaging_constant != 1: + raise NotImplementedError( + "MovingAverageMinMaxObserver doesn't support dynamic quantization for " + f"averaging constant of {self.averaging_constant}" + ) + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input: torch.Tensor): + if input.numel() == 0: + return input + + input_detached = input.detach() + self.original_dtype = input_detached.dtype + assert self.granularity is not None, "granularity is None" + self.block_size = get_block_size(input_detached.shape, self.granularity) + + shape_for_reduction, reduction_dims = _get_reduction_params( + self.block_size, input_detached.size() + ) + input_detached = input_detached.view(shape_for_reduction) + min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + assert ( + self.min_val.shape == min_val.shape + ), f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + assert ( + self.max_val.shape == max_val.shape + ), f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" + min_val = self.min_val + self.averaging_constant * (min_val - self.min_val) + max_val = self.max_val + self.averaging_constant * (max_val - self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + # returning original input + return input + + def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr( + self, "max_val" + ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + [], # BlockSize is not needed because the min/max are already reduced + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + print("calling convert") + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert ( + self.original_dtype is not None + ), "Expecting original_dtype to be populated" + if hasattr(self, "is_dynamic") and self.is_dynamic: + print("is dynamic") + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) From eba05e2d3ef9de7618dd0bedb049104fd6f66a89 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 10:33:12 -0700 Subject: [PATCH 249/332] [AO] Refactor convert and add QuantAffinePlaceholderObserver (#150644) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150644 Approved by: https://github.com/jerryzh168 ghstack dependencies: #150642, #150643 --- test/quantization/pt2e/test_quantize_pt2e.py | 91 ++++++++++ torch/ao/quantization/observer.py | 80 +++++++++ .../quantization/pt2e/_affine_quantization.py | 159 +++++------------- 3 files changed, 212 insertions(+), 118 deletions(-) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index b37904e57799..87ac89fe852c 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2646,5 +2646,96 @@ def forward(self, x): is_debug_mode=True, ) + def test_dynamic_per_tok_act_per_group_weights(self): + import operator + + from torch.ao.quantization.observer import MappingType, PerGroup, PerToken + from torch.ao.quantization.pt2e._affine_quantization import ( + AffineQuantizedMinMaxObserver, + AffineQuantizedPlaceholderObserver, + ) + + class BackendAQuantizer(Quantizer): + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + activation_dtype = torch.int8 + act_qspec = QuantizationSpec( + dtype=activation_dtype, + quant_min=-128, + quant_max=127, + qscheme=None, + is_dynamic=True, + observer_or_fake_quant_ctr=AffineQuantizedPlaceholderObserver.with_args( + # TODO: maybe align the arg name here + target_dtype=activation_dtype, + mapping_type=MappingType.SYMMETRIC, + granularity=PerToken(), + ), + ) + + weight_qspec = QuantizationSpec( + dtype=torch.int8, + quant_min=-127, + quant_max=127, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + is_dynamic=False, + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( + target_dtype=torch.int8, + mapping_type=MappingType.SYMMETRIC, + granularity=PerGroup(group_size=128), + ), + ) + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act: act_qspec, + weight: weight_qspec, + }, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 20) + + def forward(self, x): + return self.linear(x) + + node_occurrence = { + torch.ops.pt2e_quant.choose_qparams_affine: 1, + operator.getitem: 2, + torch.ops.pt2e_quant.quantize_affine: 1, + torch.ops.pt2e_quant.dequantize_affine: 2, + } + node_list = [ + torch.ops.pt2e_quant.dequantize_affine, + torch.ops.pt2e_quant.choose_qparams_affine, + operator.getitem, + torch.ops.pt2e_quant.quantize_affine, + torch.ops.pt2e_quant.dequantize_affine, + ] + example_inputs = (torch.randn(5, 128),) + self._test_quantizer( + M().eval(), + example_inputs, + BackendAQuantizer(), + node_occurrence, + node_list, + is_debug_mode=True, + ) + instantiate_parametrized_tests(TestQuantizePT2E) diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index a3672b5cb01d..6a39bdc0fc39 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -8,6 +8,7 @@ the values observed during calibration (PTQ) or training (QAT). """ +import operator import re import warnings from abc import ABCMeta, abstractmethod @@ -24,6 +25,7 @@ is_per_tensor, validate_qmin_qmax, ) +from torch.fx import Node __all__ = [ @@ -1850,6 +1852,84 @@ def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: and returns a tuple of scale and zero_point Tensor """ + def convert(self, model: torch.fx.GraphModule, observer_node: Node): + """ + Converts the observer node in the graph into its quantized representation + + Args: + model: graph module to conver the observer node in + observer_node: the observer node to convert + """ + from torch.ao.quantization.fx.utils import create_getattr_from_value + + with model.graph.inserting_before(observer_node): + assert self.block_size is not None, "Expecting block_size to be populated" + assert ( + self.original_dtype is not None + ), "Expecting original_dtype to be populated" + if hasattr(self, "is_dynamic") and self.is_dynamic: + choose_qparams_affine = model.graph.call_function( + torch.ops.pt2e_quant.choose_qparams_affine, + ( + observer_node.args[0], + self.mapping_type.name, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain.name, + ), + ) + scale_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 0) + ) + zero_point_node = model.graph.call_function( + operator.getitem, (choose_qparams_affine, 1) + ) + else: + scale, zero_point = self.calculate_qparams() + scale_node = create_getattr_from_value( + model, model.graph, "_scale", scale + ) + zero_point_node = create_getattr_from_value( + model, model.graph, "_zero_point", zero_point + ) + + q_node = model.graph.call_function( + torch.ops.pt2e_quant.quantize_affine, + ( + observer_node.args[0], + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {}, + ) + dq_node = model.graph.call_function( + torch.ops.pt2e_quant.dequantize_affine, + ( + q_node, + self.block_size, + scale_node, + zero_point_node, + self.target_dtype, + self.quant_min, + self.quant_max, + self.zero_point_domain.name, + ), + {"output_dtype": self.original_dtype}, + ) + observer_node.replace_all_uses_with(dq_node) + model.graph.erase_node(observer_node) + def _is_observer_script_module(mod, obs_type_name): """Returns true if given mod is an instance of Observer script module.""" diff --git a/torch/ao/quantization/pt2e/_affine_quantization.py b/torch/ao/quantization/pt2e/_affine_quantization.py index 011f019524ff..32b4a773f28f 100644 --- a/torch/ao/quantization/pt2e/_affine_quantization.py +++ b/torch/ao/quantization/pt2e/_affine_quantization.py @@ -2,7 +2,6 @@ # and https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py # PLESE DON'T MODIFY THIS FILE SO THAT WE DON'T GET OUT OF SYNC import logging -import operator from abc import ABCMeta from typing import Any, Optional, Union @@ -15,7 +14,6 @@ TorchAODType, ZeroPointDomain, ) -from torch.fx import Node ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: @@ -731,51 +729,6 @@ def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: self.zero_point_domain, ) - def convert(self, model: torch.fx.GraphModule, observer_node: Node): - print("calling convert") - from torch.ao.quantization.fx.utils import create_getattr_from_value - - scale, zero_point = self.calculate_qparams() - with model.graph.inserting_before(observer_node): - assert self.block_size is not None, "Expecting block_size to be populated" - assert ( - self.original_dtype is not None - ), "Expecting original_dtype to be populated" - scale_node = create_getattr_from_value(model, model.graph, "_scale", scale) - zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point - ) - q_node = model.graph.call_function( - torch.ops.pt2e_quant.quantize_affine, - ( - observer_node.args[0], - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {}, - ) - dq_node = model.graph.call_function( - torch.ops.pt2e_quant.dequantize_affine, - ( - q_node, - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {"output_dtype": self.original_dtype}, - ) - observer_node.replace_all_uses_with(dq_node) - model.graph.erase_node(observer_node) - class AffineQuantizedMovingAverageMinMaxObserver(AffineQuantizedObserverBase): def __init__( @@ -869,75 +822,45 @@ def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor]: self.zero_point_domain, ) - def convert(self, model: torch.fx.GraphModule, observer_node: Node): - print("calling convert") - from torch.ao.quantization.fx.utils import create_getattr_from_value - with model.graph.inserting_before(observer_node): - assert self.block_size is not None, "Expecting block_size to be populated" - assert ( - self.original_dtype is not None - ), "Expecting original_dtype to be populated" - if hasattr(self, "is_dynamic") and self.is_dynamic: - print("is dynamic") - choose_qparams_affine = model.graph.call_function( - torch.ops.pt2e_quant.choose_qparams_affine, - ( - observer_node.args[0], - self.mapping_type.name, - self.block_size, - self.target_dtype, - self.quant_min, - self.quant_max, - self.eps, - self.scale_dtype, - self.zero_point_dtype, - self.preserve_zero, - self.zero_point_domain.name, - ), - ) - scale_node = model.graph.call_function( - operator.getitem, (choose_qparams_affine, 0) - ) - zero_point_node = model.graph.call_function( - operator.getitem, (choose_qparams_affine, 1) - ) - else: - scale, zero_point = self.calculate_qparams() - scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale - ) - zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point - ) - - q_node = model.graph.call_function( - torch.ops.pt2e_quant.quantize_affine, - ( - observer_node.args[0], - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {}, - ) - dq_node = model.graph.call_function( - torch.ops.pt2e_quant.dequantize_affine, - ( - q_node, - self.block_size, - scale_node, - zero_point_node, - self.target_dtype, - self.quant_min, - self.quant_max, - self.zero_point_domain.name, - ), - {"output_dtype": self.original_dtype}, - ) - observer_node.replace_all_uses_with(dq_node) - model.graph.erase_node(observer_node) +class AffineQuantizedPlaceholderObserver(AffineQuantizedObserverBase): + def __init__( + self, + mapping_type: MappingType, + target_dtype: torch.dtype, + granularity: Granularity, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + is_dynamic=False, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + # there could be some extra args that's ignored + **kwargs, + ): + self.is_dynamic = is_dynamic + + super().__init__( + mapping_type=mapping_type, + target_dtype=target_dtype, + granularity=granularity, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + ) + + def forward(self, input): + self.block_size = get_block_size(input.shape, self.granularity) + self.original_dtype = input.dtype + return input + + def calculate_qparams(self): + raise Exception( # noqa: TRY002 + "calculate_qparams should not be called for PlaceholderObserver" + ) From fbccbfedaf4e12012dcb24d599f009dc717d3697 Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Mon, 7 Apr 2025 22:05:00 +0000 Subject: [PATCH 250/332] [BE] Fix Amp.metal compilation warning (#150783) Deleting unused `uint tid` fixes ``` [114/1416] Compiling /Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Amp.metal to Amp_30.air /Users/nshulga/git/pytorch/pytorch/aten/src/ATen/native/mps/kernels/Amp.metal:70:10: warning: unused parameter 'tid' [-Wunused-parameter] uint tid [[thread_position_in_grid]]) { ^ 1 warning generated. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150783 Approved by: https://github.com/wdvr, https://github.com/atalman --- aten/src/ATen/native/mps/kernels/Amp.metal | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/Amp.metal b/aten/src/ATen/native/mps/kernels/Amp.metal index f32621320ab4..abe852798f44 100644 --- a/aten/src/ATen/native/mps/kernels/Amp.metal +++ b/aten/src/ATen/native/mps/kernels/Amp.metal @@ -66,8 +66,7 @@ kernel void ampUpdateScale( device float& foundInf [[buffer(2)]], constant T& scaleGrowthFactor [[buffer(3)]], constant T& scaleBackoffFactor [[buffer(4)]], - constant int& growthInterval [[buffer(5)]], - uint tid [[thread_position_in_grid]]) { + constant int& growthInterval [[buffer(5)]]) { if (foundInf != 0.0f) { scale *= scaleBackoffFactor; growth_tracker = 0; @@ -110,8 +109,7 @@ kernel void ampUpdateScale( device float& foundInf [[buffer(2)]], \ constant DTYPE& scaleGrowthFactor [[buffer(3)]], \ constant DTYPE& scaleBackoffFactor [[buffer(4)]], \ - constant int& growthInterval [[buffer(5)]], \ - uint tid [[thread_position_in_grid]]) + constant int& growthInterval [[buffer(5)]]) INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half); @@ -129,4 +127,4 @@ INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float); INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half); #if __METAL_VERSION__ >= 310 INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat); -#endif \ No newline at end of file +#endif From 78fe079c97fd35a2072717cbdcf98bf3a88e61be Mon Sep 17 00:00:00 2001 From: Ankita George Date: Mon, 7 Apr 2025 22:10:35 +0000 Subject: [PATCH 251/332] Support having no metadata file for HuggingFaceStorageReader (#150701) Summary: If there is only one safetensors file, we don't need users to have a metadata file and we can just construct it from the keys of that file. This is a use-case for some HuggingFace models, so adding support for it Test Plan: ensure existing tests pass tested e2e in a notebook Differential Revision: D72472490 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150701 Approved by: https://github.com/joecummings --- .../distributed/checkpoint/test_hf_storage.py | 31 +++++++++++++++ .../checkpoint/_fsspec_filesystem.py | 3 ++ torch/distributed/checkpoint/_hf_storage.py | 38 ++++++++++++++++--- torch/distributed/checkpoint/filesystem.py | 8 ++++ 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/test/distributed/checkpoint/test_hf_storage.py b/test/distributed/checkpoint/test_hf_storage.py index 9f099bbd825a..dfa485090c2d 100644 --- a/test/distributed/checkpoint/test_hf_storage.py +++ b/test/distributed/checkpoint/test_hf_storage.py @@ -190,6 +190,37 @@ def test_metadata_hf(self) -> None: metadata = reader.read_metadata() self.assertEqual(metadata.storage_data, expected_metadata["weight_map"]) + def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: + mock_module = MagicMock() + sys.modules["safetensors.torch"] = mock_module + sys.modules["huggingface_hub"] = mock_module + with tempfile.TemporaryDirectory() as path: + reader = _HuggingFaceStorageReader(path=path) + reader.fs = FileSystem() + # there is one safetensor file, but no metadata file, + # so we create metadata from the safetensor file + file_name = "test.safetensors" + open(os.path.join(path, file_name), "w").close() + + keys = ["tensor_0", "tensor_1"] + mock_module.safe_open.return_value.__enter__.return_value.keys.return_value = ( + keys + ) + + metadata = reader.read_metadata() + + self.assertEqual( + metadata.state_dict_metadata, + { + keys[0]: BytesStorageMetadata(), + keys[1]: BytesStorageMetadata(), + }, + ) + self.assertEqual( + metadata.storage_data, + {keys[0]: file_name, keys[1]: file_name}, + ) + if __name__ == "__main__": run_tests() diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index 8363fcf207a3..3bd508f4c2c9 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -91,6 +91,9 @@ def exists(self, path: Union[str, os.PathLike]) -> bool: def rm_file(self, path: Union[str, os.PathLike]) -> None: self.fs.rm(path) + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + return self.fs.ls(path) + # TODO: add the dcp.async_save mixin class FsspecWriter(FileSystemWriter): diff --git a/torch/distributed/checkpoint/_hf_storage.py b/torch/distributed/checkpoint/_hf_storage.py index 6927aed7e570..ef5f5bacb95b 100644 --- a/torch/distributed/checkpoint/_hf_storage.py +++ b/torch/distributed/checkpoint/_hf_storage.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import dataclasses import json +import os import queue from typing import Optional @@ -206,15 +207,40 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: return fut def read_metadata(self) -> Metadata: - path = self.fs.concat_path(self.path, _metadata_fn) - with self.fs.create_stream(path, "r") as metadata_file: - metadata = json.load(metadata_file) + metadata_path = self.fs.concat_path(self.path, _metadata_fn) state_dict_metadata: dict[str, STORAGE_TYPES] = {} - for key in metadata["weight_map"].keys(): - state_dict_metadata[key] = BytesStorageMetadata() + storage_data: dict[str, str] = {} + + if not self.fs.exists(metadata_path): + # if metadata file doesn't exist, create it from the safetensors file + from safetensors.torch import safe_open # type: ignore[import-not-found] + + safetensors_files = [] + for file in self.fs.ls(self.path): + if file.endswith(SUFFIX): + safetensors_files.append(os.path.basename(file)) + + if len(safetensors_files) != 1: + raise ValueError( + f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files" + ) + storage_data = {} + with safe_open(safetensors_files[0], framework="pt") as f: + for k in f.keys(): + state_dict_metadata[k] = BytesStorageMetadata() + storage_data[k] = safetensors_files[0] + else: + with self.fs.create_stream(metadata_path, "r") as metadata_file: + metadata = json.load(metadata_file) + + for key in metadata["weight_map"].keys(): + state_dict_metadata[key] = BytesStorageMetadata() + storage_data = metadata["weight_map"] + metadata = Metadata( - state_dict_metadata=state_dict_metadata, storage_data=metadata["weight_map"] + state_dict_metadata=state_dict_metadata, + storage_data=storage_data, ) if getattr(metadata, "storage_meta", None) is None: diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index 76954da21eb0..0c3db0416e90 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -479,6 +479,9 @@ def exists(self, path: Union[str, os.PathLike]) -> bool: ... @abstractmethod def rm_file(self, path: Union[str, os.PathLike]) -> None: ... + @abstractmethod + def ls(self, path: Union[str, os.PathLike]) -> list[str]: ... + class FileSystem(FileSystemBase): @contextmanager @@ -539,6 +542,11 @@ def rm_file(self, path: Union[str, os.PathLike]) -> None: path = Path(path) path.unlink() + def ls(self, path: Union[str, os.PathLike]) -> list[str]: + if not isinstance(path, Path): + path = Path(path) + return [str(p) for p in path.iterdir()] + class _FileSystemWriter(StorageWriter): """ From 6ea5514e0460604e4b0325a7218a7a8ca2e61819 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 7 Apr 2025 12:35:57 -0700 Subject: [PATCH 252/332] [invoke_subgraph] Lazy backward (#150666) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150666 Approved by: https://github.com/zou3519, https://github.com/bdhirsh --- torch/_guards.py | 13 ++ torch/_higher_order_ops/base_hop.py | 8 +- torch/_higher_order_ops/invoke_subgraph.py | 198 ++++++++++++++++----- 3 files changed, 170 insertions(+), 49 deletions(-) diff --git a/torch/_guards.py b/torch/_guards.py index b6b36f637101..c85c7b0d7325 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -672,12 +672,19 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... @abstractmethod def get_proxy_dispatch_entry(self, identifier: str): ... + @abstractmethod + def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): ... + + @abstractmethod + def get_lazy_bwd_entry(self, identifier: str): ... + class InvokeSubgraphCache(HopSubgraphCache): def __init__(self) -> None: self.autograd_cache: dict[str, Callable] = {} self.proxy_dispatch_cache: dict[str, Callable] = {} self.dynamo_identifiers: dict[str, str] = {} + self.lazy_bwd_cache: dict[str, torch.fx.GraphModule] = {} def add_dynamo_identifier(self, cache_key: str, identifier: str): self.dynamo_identifiers[cache_key] = identifier @@ -697,6 +704,12 @@ def add_proxy_dispatch_entry(self, identifier: str, key: Callable): def get_proxy_dispatch_entry(self, identifier: str): return self.proxy_dispatch_cache.get(identifier, None) + def add_lazy_bwd_entry(self, identifier: str, gmod: torch.fx.GraphModule): + self.lazy_bwd_cache[identifier] = gmod + + def get_lazy_bwd_entry(self, identifier: str): + return self.lazy_bwd_cache.get(identifier, None) + class HopDispatchSetCache: def __init__(self) -> None: diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index 5f634f0c6436..a8fc106214b7 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -151,9 +151,11 @@ def backward(ctx, *grad_outputs): from .utils import _from_fun fw_inputs = pytree.tree_map(_from_fun, operands) - _, joint_graph, _, _ = create_fw_bw_graph( - subgraph, fw_inputs, grad_outputs - ) + ( + _, + joint_graph, + _, + ) = create_fw_bw_graph(subgraph, fw_inputs, grad_outputs) # The joint graph returns (*grad_inputs, *fwd_outputs). # We only need the grad_inputs. diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index d508e8ffc0da..833b04e78e43 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs +import contextlib from contextlib import nullcontext from dataclasses import dataclass, field from typing import Optional, Union @@ -42,7 +43,8 @@ # used to filter out grad_outs/tangents in the `backward` method of # InvokeSubgraphAutogradOp. @dataclass -class FilterTangentInfo: +class OutputMetadata: + num_fw_outs: Optional[int] = None indexes_with_none: set[int] = field(default_factory=set) indexes_with_no_grad: set[int] = field(default_factory=set) @@ -144,6 +146,7 @@ def get_invoke_subgraph_cache(): return cache +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra def trace_joint_graph(fn, fw_inputs, fw_outputs): """ Naively trace out a joint graph. This simplifies the reconstruction of joint @@ -184,6 +187,7 @@ def joint_fn(*primals_and_tangents): return _maybe_reenter_make_fx(joint_fn)(*joint_operands) +# TODO (@anijain2305) - Delete this function when base_hop uses invoke_subgraph infra def create_fw_bw_graph(subgraph, operands, grad_outputs=None): with suspend_functionalization(), disable_functional_mode(): with disable_proxy_modes_tracing(): @@ -209,13 +213,14 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): # performed in the autograd.Function - InvokeSubgraphAutogradOp. # Also collect the indexes of no_grad in the output to filter out # the grad_outs in the `backward` method. - filter_tangent_info = FilterTangentInfo() + output_metadata = OutputMetadata() + output_metadata.num_fw_outs = num_fw_outs for idx, fw_out in enumerate(fw_outs): if fw_out is None: - filter_tangent_info.indexes_with_none.add(idx) + output_metadata.indexes_with_none.add(idx) elif not fw_out.requires_grad: - filter_tangent_info.indexes_with_no_grad.add(idx) + output_metadata.indexes_with_no_grad.add(idx) if grad_outputs is None: # Infer grad_outputs to be the same properties as the fw_outputs @@ -253,88 +258,182 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None): fw_inputs, grad_outputs, ) - return fw_graph, bw_graph, num_fw_outs, filter_tangent_info + return fw_graph, bw_graph, output_metadata + + +def get_output_metadata(subgraph, operands): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # args are functional tensors, generate some example tensors + fw_inputs = pytree.tree_map(_from_fun, operands) + + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(fw_inputs) + context = ( + nullcontext() + if fake_mode is None or fake_mode.shape_env is None + else fake_mode.shape_env.ignore_fresh_unbacked_symbols() + ) + + with context: + fw_outs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + + num_fw_outs = len(fw_outs) + + # Collect the indexes of none in the output to check that the grad + # is None at the corresponding index in the backward. This check is + # performed in the autograd.Function - InvokeSubgraphAutogradOp. + # Also collect the indexes of no_grad in the output to filter out + # the grad_outs in the `backward` method. + output_metadata = OutputMetadata() + + output_metadata.num_fw_outs = num_fw_outs + for idx, fw_out in enumerate(fw_outs): + if fw_out is None: + output_metadata.indexes_with_none.add(idx) + elif not fw_out.requires_grad: + output_metadata.indexes_with_no_grad.add(idx) + return output_metadata + + +def trace_joint_graph_as_bwd( + fn, num_primals, joint_operands, include_key_set, exclude_key_set +): + """ + Naively trace out a joint graph. This simplifies the reconstruction of joint + graph in the min-cut partitioner later on. + """ + from torch._functorch.aot_autograd import create_joint + + dummy_aot_config = get_dummy_aot_autograd_config() + + # This joint_fn is inserted as the backward graph as is. This simplifies the + # min-cut partitioner work later on. + # Input signature - (*primals, *tangents) + # Output signature - (*grads, *fw_outs) + # The output signature is deliberately kept grads first and fw_outs second. + # Having grads first makes the min-cut partitioner HOP graph stitching + # easier. + def joint_fn(*primals_and_tangents): + primals = primals_and_tangents[:num_primals] + tangents = primals_and_tangents[num_primals:] + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fn), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + + # return signature is deliberately kept (*grads, *fw_outs). This + # simplifies partitioning work later on. + return pytree.tree_map(maybe_clone, tuple(grads + list(fw_outs))) + + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + joint_operands = [_from_fun(arg) for arg in joint_operands] + with contextlib.ExitStack() as stack: + stack.enter_context( + torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set), + ) + with torch.enable_grad(): + return _maybe_reenter_make_fx(joint_fn)(*joint_operands) class InvokeSubgraphAutogradOp(torch.autograd.Function): """ - This autograd function op is to stash the backward graph in the ctx while - running forward. + Saves the subgraph, i.e. original callable, in the forward method. And then + traces out a joint graph in the backward. This delaying of tracing in + backward, also called as lazy backward, ensures that the assumptions about + the grad_out strides and tensor-subclass-ness are already accounted for. """ @staticmethod def forward( ctx, - fw_graph, - bw_graph, + subgraph, identifier, - num_fw_outs, - filter_tangent_info, + output_metadata, *operands, ): - ctx._fw_graph = fw_graph - ctx._bw_graph = bw_graph + # We want to delay the backward graph construction until the backward. + # So in forward, we just run the fw callable as is. And save all the + # information necessary to construct the backward graph in the ctx. + ctx._subgraph = subgraph ctx._identifier = identifier - ctx._num_fw_outs = num_fw_outs - ctx._filter_tangent_info = filter_tangent_info + ctx._output_metadata = output_metadata + # We snapshot the dispatch keys in forward for materializing the + # the bw_graph in backward. + ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set() + ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set() + + save_tensors_and_symints_for_backward(ctx, operands) with torch._C._AutoDispatchBelowAutograd(): out = invoke_subgraph( - fw_graph, + subgraph, f"___forward_{identifier}", operands, ) - save_tensors_and_symints_for_backward(ctx, operands) - # Check that None is at expected indexes. for idx, o in enumerate(out): if o is None: - assert idx in filter_tangent_info.indexes_with_none + assert idx in output_metadata.indexes_with_none return out @staticmethod - def backward(ctx, *grad_outs): - bw_graph = ctx._bw_graph + def backward( + ctx, + *grad_outs, + ): + subgraph = ctx._subgraph identifier = ctx._identifier + output_metadata = ctx._output_metadata primals = saved_tensors_and_symints(ctx) - num_fw_outs = ctx._num_fw_outs - filter_tangent_info = ctx._filter_tangent_info - # While tracing we made the assumption that tangents are contiguous. So, - # force the grad_outs to be contiguous. - # Also filter out grads that are None or do not require_grad. This was + # Filter out grads that are None or do not require_grad. This was # the assumption we made during the tracing of joint_graph. - contiguous_grad_outs = [] + filtered_grad_outs = [] for idx, o in enumerate(grad_outs): if o is None: - assert idx in filter_tangent_info.indexes_with_none - elif idx in filter_tangent_info.indexes_with_no_grad: + assert idx in output_metadata.indexes_with_none + elif idx in output_metadata.indexes_with_no_grad: # Deliberately skip over the grad_outs which we know should be # None because the corresponding fwd_out does not require_grad. pass else: - contiguous_grad_outs.append(o.contiguous()) - contiguous_grad_outs = tuple(contiguous_grad_outs) + filtered_grad_outs.append(o) + filtered_grad_outs = tuple(filtered_grad_outs) # bw_graph is a joint graph with signature (*primals_and_tangents) and # returns (*grads_and_fw_outs). To get the grads, we use the num_fw_outs # to extract the grads. - primals_and_tangents = primals + contiguous_grad_outs - grads = invoke_subgraph( - bw_graph, f"___backward_{identifier}", primals_and_tangents - )[:-num_fw_outs] - return None, None, None, None, None, *grads + primals_and_tangents = primals + filtered_grad_outs + # Check if we have already traced the bwd subgraph. + bw_graph = None + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + bw_graph = invoke_subgraph_cache.get_lazy_bwd_entry(identifier) -@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) -def _(subgraph, identifier, operands): - from torch.utils._python_dispatch import _get_current_dispatch_mode + if bw_graph is None: + bw_graph = trace_joint_graph_as_bwd( + subgraph, + len(primals), + primals_and_tangents, + ctx._fw_include_key_set, + ctx._fw_exclude_key_set, + ) - mode = _get_current_dispatch_mode() - assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return subgraph(*operands) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_lazy_bwd_entry(identifier, bw_graph) + + grads = invoke_subgraph( + bw_graph, f"___backward_{identifier}", primals_and_tangents + )[: -output_metadata.num_fw_outs] + return None, None, None, *grads @invoke_subgraph.py_impl(DispatchKey.Autograd) @@ -361,13 +460,11 @@ def _(subgraph, identifier, operands): ): return saved_autograd_fn(*operands) - fw_graph, bw_graph, num_fw_outs, filter_tangent_info = create_fw_bw_graph( - subgraph, operands - ) + output_metadata = get_output_metadata(subgraph, operands) def autograd_fn_callable(*args): return InvokeSubgraphAutogradOp.apply( - fw_graph, bw_graph, identifier, num_fw_outs, filter_tangent_info, *args + subgraph, identifier, output_metadata, *args ) # Save the autograd_fn_callable in the dispatch set cache. @@ -377,6 +474,15 @@ def autograd_fn_callable(*args): return autograd_fn_callable(*operands) +@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) +def _(subgraph, identifier, operands): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return subgraph(*operands) + + @invoke_subgraph.py_functionalize_impl def _(ctx, subgraph, identifier, operands): unwrapped_operands = ctx.unwrap_tensors(operands) From 91173ff89aab5f632d483c736d11d5dcf60decac Mon Sep 17 00:00:00 2001 From: Hexin Wang Date: Mon, 7 Apr 2025 23:20:49 +0000 Subject: [PATCH 253/332] Fixing NCCL abort hang issue when a ProcessGroupNCCL manages multiple ncclComms (#150690) Detail of the issue: If PyTorch issues send/recv to each 2 rank comm, and these comms are managed by a single ProcessGroupNCCL instance, then comms need to abort either in sequence or in group. I.e. the following sequential abort will cause hang in NCCL. recv(..., comm0, stream); send(..., comm1, stream); abort(comm1); abort(comm0); Fixes #119797 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150690 Approved by: https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index ba1516a45f65..9f3e66a5f549 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1353,6 +1353,9 @@ void ProcessGroupNCCL::abortCommsFromMap( const std::optional& abortReason) { // The process may control multiple devices, loop through the communicators on // each device + // NCCL expects Group abort when there are multiple communicators created in a + // device. + groupStart(); for (auto& it : ncclCommsMap) { auto& devName = it.first; auto& ncclComm = it.second; @@ -1373,6 +1376,7 @@ void ProcessGroupNCCL::abortCommsFromMap( VLOG(2) << logPrefix() << "ProcessGroupNCCL destroyed " << " communicator on CUDA device: " << devName; } + groupEnd(); } // Abort all communicators on this rank From e9e5682a4a687719a2928856ce54d7a0c5c95d47 Mon Sep 17 00:00:00 2001 From: Akash Verma Date: Mon, 7 Apr 2025 23:31:29 +0000 Subject: [PATCH 254/332] [ROCm] Build Pytorch extensions with amdclang++ (#150451) Here are the following modifications made to cpp_extension.py- 1) Changed compiler flag to use --version. 2) Added a feature to convert alpha-numeric string to numeric string for the version string returned by compiler. This was the source of error as the parser was failing on parsing alpha-numeric version string. Build with following pytorch extensions- Apex, TorchVision, TorchAudio & DeepSpeed. Unit tested with following pytorch extensions- Apex, TorchVision. (cherry picked from commit c873aeac35851a7d5000eb7f24561d3f56c2ffbd) Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/150451 Approved by: https://github.com/jeffdaily --- torch/utils/cpp_extension.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 23f1ba6ed559..1ba4891ebb1c 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -382,7 +382,10 @@ def check_compiler_ok_for_platform(compiler: str) -> bool: # If compiler wrapper is used try to infer the actual compiler by invoking it with -v flag env = os.environ.copy() env['LC_ALL'] = 'C' # Don't localize output - version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + try: + version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) + except subprocess.CalledProcessError: + version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) if IS_LINUX: # Check for 'gcc' or 'g++' for sccache wrapper pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) @@ -445,13 +448,17 @@ def get_compiler_abi_compatibility_and_version(compiler) -> tuple[bool, TorchVer warnings.warn(f'Error checking compiler version for {compiler}: {error}') return (False, TorchVersion('0.0.0')) - if tuple(map(int, version)) >= minimum_required_version: - return (True, TorchVersion('.'.join(version))) + # convert alpha-numeric string to numeric string + # amdclang++ returns str like 0.0.0git, others return 0.0.0 + numeric_version = [re.sub(r'\D', '', v) for v in version] + + if tuple(map(int, numeric_version)) >= minimum_required_version: + return (True, TorchVersion('.'.join(numeric_version))) - compiler = f'{compiler} {".".join(version)}' + compiler = f'{compiler} {".".join(numeric_version)}' warnings.warn(ABI_INCOMPATIBILITY_WARNING.format(compiler)) - return (False, TorchVersion('.'.join(version))) + return (False, TorchVersion('.'.join(numeric_version))) def _check_cuda_version(compiler_name: str, compiler_version: TorchVersion) -> None: From 5228986c395dc79f90d2a2b991deea1eef188260 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 8 Apr 2025 00:46:13 +0000 Subject: [PATCH 255/332] [CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705) By addressing a feedback requested at https://github.com/pytorch/pytorch/pull/145746 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150705 Approved by: https://github.com/atalman --- aten/src/ATen/native/cuda/CUDALoops.cuh | 6 ++++-- aten/src/ATen/native/cuda/MemoryAccess.cuh | 4 +++- aten/src/ATen/native/cuda/thread_constants.h | 5 ++++- aten/src/ATen/test/cuda_vectorized_test.cu | 19 ++++++++++++------- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 82d0defd972b..fb71dc5488f5 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -78,7 +78,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence) { } } -#ifdef USE_ROCM +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080) template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { @@ -219,7 +219,7 @@ static inline void launch_vectorized_kernel( constexpr auto io_size = calc_io_size(); int64_t grid = (N + io_block_work_size() - 1) / io_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); -#ifdef USE_ROCM +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080) int vec_size = memory::can_vectorize_up_to(data); #else using cpp_type = typename function_traits::result_type; @@ -241,11 +241,13 @@ static inline void launch_vectorized_kernel( C10_CUDA_KERNEL_LAUNCH_CHECK(); break; #endif +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) case 8: vectorized_elementwise_kernel<8, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; +#endif case 4: vectorized_elementwise_kernel<4, func_t, array_t> <<>>(N, f, data); diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index fd88df3f8b17..3e46f873c61d 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -486,7 +486,9 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec2_alignment = std::alignment_of_v>; constexpr int vec4_alignment = std::alignment_of_v>; +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) constexpr int vec8_alignment = std::alignment_of_v>; +#endif #ifdef USE_ROCM constexpr int vec16_alignment = std::alignment_of_v>; constexpr int type_size = sizeof(scalar_t); @@ -495,7 +497,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { } else if (type_size <= 2 && (address % vec8_alignment == 0)) { return 8; } else -#else +#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080 if (address % vec8_alignment == 0) { return 8; } else diff --git a/aten/src/ATen/native/cuda/thread_constants.h b/aten/src/ATen/native/cuda/thread_constants.h index bcc797a26e1c..9299b79916cf 100644 --- a/aten/src/ATen/native/cuda/thread_constants.h +++ b/aten/src/ATen/native/cuda/thread_constants.h @@ -18,8 +18,11 @@ constexpr int thread_work_size() { return 4; } constexpr uint32_t num_threads() { return C10_WARP_SIZE * 4; } - +#if defined(CUDA_VERSION) && CUDA_VERSION < 12080 +constexpr int thread_work_size() { return 4; } +#else constexpr int thread_work_size() { return 8; } #endif +#endif constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu index 6b120f7eb304..4e0c14b17337 100644 --- a/aten/src/ATen/test/cuda_vectorized_test.cu +++ b/aten/src/ATen/test/cuda_vectorized_test.cu @@ -46,12 +46,17 @@ TEST(TestLoops, HasSameArgTypes) { TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { char *ptr = reinterpret_cast(buffer1); +#if defined(CUDA_VERSION) && CUDA_VERSION < 12080 + constexpr auto vectorize_limit = 4; +#else + constexpr auto vectorize_limit= 8; +#endif - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); @@ -65,8 +70,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), vectorize_limit); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 1); From d7f3cd0ac36bd5d2a33ddbe49846ce2f3b4ac83c Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 8 Apr 2025 01:12:29 +0000 Subject: [PATCH 256/332] Add Half support for weight_norm on CPU (#148878) Fixes #148867. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148878 Approved by: https://github.com/leslie-fang-intel, https://github.com/cyyever, https://github.com/albanD --- aten/src/ATen/native/WeightNorm.cpp | 9 +-- aten/src/ATen/native/cpu/WeightNormKernel.cpp | 72 +++++++++++-------- test/test_nn.py | 2 +- 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/aten/src/ATen/native/WeightNorm.cpp b/aten/src/ATen/native/WeightNorm.cpp index 428669e6466b..bbd39809085a 100644 --- a/aten/src/ATen/native/WeightNorm.cpp +++ b/aten/src/ATen/native/WeightNorm.cpp @@ -53,8 +53,8 @@ std::tuple weight_norm_cpu( int64_t dim) { auto w = at::empty_like(v, at::MemoryFormat::Contiguous); - // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16' - const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ? + // align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'/'Half' + const auto dtype = (g.scalar_type() == at::ScalarType::BFloat16 || g.scalar_type() == at::ScalarType::Half) ? at::ScalarType::Float : g.scalar_type(); auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype)); weight_norm_stub(kCPU, w, norm, v, g, dim); @@ -93,10 +93,7 @@ Tensor _weight_norm auto v = v_in.contiguous(); auto g = g_in.contiguous(); - auto has_half_dtype = v.scalar_type() == at::ScalarType::Half - || g.scalar_type() == at::ScalarType::Half; - - bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1)); + bool can_use_fused = (dim == 0) || (dim == v.dim() - 1); if (can_use_fused) { // weight_norm does not have a derivative defined for it, so this will route back through diff --git a/aten/src/ATen/native/cpu/WeightNormKernel.cpp b/aten/src/ATen/native/cpu/WeightNormKernel.cpp index 9ee5c97be8bc..5e866d538768 100644 --- a/aten/src/ATen/native/cpu/WeightNormKernel.cpp +++ b/aten/src/ATen/native/cpu/WeightNormKernel.cpp @@ -48,7 +48,8 @@ void weight_norm_first_dim_kernel( } template -inline void sum_norm_per_row( +inline std::enable_if_t, void> +sum_norm_per_row( scalar_t* out_ptr, const scalar_t* v_ptr, int64_t size) { @@ -61,16 +62,18 @@ inline void sum_norm_per_row( size); } -inline void sum_norm_per_row( +template +inline std::enable_if_t, void> +sum_norm_per_row( float* out_ptr, - const BFloat16* v_ptr, + const scalar_t* v_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0; fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1; @@ -84,7 +87,8 @@ inline void sum_norm_per_row( } template -inline void apply_norm_per_row( +inline std::enable_if_t, void> +apply_norm_per_row( scalar_t* w_ptr, const scalar_t* v_ptr, const scalar_t* a_ptr, @@ -98,21 +102,23 @@ inline void apply_norm_per_row( size); } -inline void apply_norm_per_row( - BFloat16* w_ptr, - const BFloat16* v_ptr, +template +inline std::enable_if_t, void> +apply_norm_per_row( + scalar_t* w_ptr, + const scalar_t* v_ptr, const float* a_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0; fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1; - bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1); + bVec w_bvec = vec::convert_from_float(w_fvec0, w_fvec1); w_bvec.store(w_ptr + d); } for(; d < size; ++d) { @@ -222,7 +228,8 @@ void weight_norm_backward_first_dim_kernel( } template -inline void sum_product_per_row( +inline std::enable_if_t, void> +sum_product_per_row( scalar_t* out_ptr, const scalar_t* grad_w_ptr, const scalar_t* v_ptr, @@ -237,19 +244,21 @@ inline void sum_product_per_row( size); } -inline void sum_product_per_row( +template +inline std::enable_if_t, void> +sum_product_per_row( float* out_ptr, - const BFloat16* grad_w_ptr, - const BFloat16* v_ptr, + const scalar_t* grad_w_ptr, + const scalar_t* v_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d); - auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec); + auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float(grad_w_bvec); bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0; fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1; @@ -264,7 +273,8 @@ inline void sum_product_per_row( } template -inline void apply_per_row_backward( +inline std::enable_if_t, void> +apply_per_row_backward( scalar_t* grad_v_ptr, const scalar_t* grad_w_ptr, const scalar_t* v_ptr, @@ -282,26 +292,28 @@ inline void apply_per_row_backward( size); } -inline void apply_per_row_backward( - BFloat16* grad_v_ptr, - const BFloat16* grad_w_ptr, - const BFloat16* v_ptr, +template +inline std::enable_if_t, void> +apply_per_row_backward( + scalar_t* grad_v_ptr, + const scalar_t* grad_w_ptr, + const scalar_t* v_ptr, const float* a_ptr, const float* b_ptr, int64_t size) { - using bVec = vec::Vectorized; + using bVec = vec::Vectorized; using fVec = vec::Vectorized; int64_t d = 0; for (; d < size - (size % bVec::size()); d += bVec::size()) { bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d); - auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec); + auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float(grad_w_bvec); bVec v_bvec = bVec::loadu(v_ptr + d); - auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec); + auto [v_fvec0, v_fvec1] = vec::convert_to_float(v_bvec); fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0; fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1 - fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1; - bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1); + bVec grad_v_bvec = vec::convert_from_float(grad_v_fvec0, grad_v_fvec1); grad_v_bvec.store(grad_v_ptr + d); } for(; d < size; ++d) { @@ -395,7 +407,7 @@ void weight_norm_kernel( int64_t dim) { TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1, "fused kernels can only be applied for first or last dim"); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(), + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, v.scalar_type(), "weight_norm_kernel", [&]() { using accscalar_t = at::opmath_type; if (dim == 0) { @@ -420,7 +432,7 @@ void weight_norm_backward_kernel( int64_t dim) { TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1, "fused kernels can only be applied for first or last dim"); - AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(), + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, saved_v.scalar_type(), "weight_norm_backward_kernel", [&]() { using accscalar_t = at::opmath_type; if (dim == 0) { diff --git a/test/test_nn.py b/test/test_nn.py index 30fe71b4162e..ff3950ec32e4 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1812,7 +1812,7 @@ def check_weight_norm(l, name, num_params): def test_weight_norm(self): - for dtype in [torch.float, torch.bfloat16]: + for dtype in [torch.float, torch.bfloat16, torch.float16]: input = torch.randn(3, 4, dtype=dtype) m = nn.Linear(4, 5).to(dtype=dtype) expected_output = m(input) From c0991b03163e048ea5abb20412d486b06389b65d Mon Sep 17 00:00:00 2001 From: morotti Date: Tue, 8 Apr 2025 02:10:31 +0000 Subject: [PATCH 257/332] README: anaconda license violation / no longer recommend anaconda since it's no longer free to use (#150619) hello, I was going over the documentation to build pytorch from source. Unfortunately, the first thing that come up is that you strongly recommend to use anaconda, which shouldn't be used because it's no longer free to use. Could you please remove that from the doc? I don't know if you are aware but anaconda is no longer free. They changed their terms of service in 2020 to restrict commercial usage. They changed their terms of service in 2024 to forbid downloading anaconda and forbid education and non-profit usage too. The download is open and doesn't require any registration, but if you download anaconda they will sue you ^^ They started raining lawsuits against users since last year. You may have heard about anaconda vs intel in the news. They started another 5 or so in the last few months. https://www.reuters.com/legal/litigation/intel-sued-copyright-infringement-over-ai-software-2024-08-09/ You may need to adjust more doc and adjust your build system. The free to use alternatives are miniforge with the conda-forge channel. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150619 Approved by: https://github.com/seemethere --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index fcdf761295ae..00c515408528 100644 --- a/README.md +++ b/README.md @@ -169,8 +169,6 @@ Professional, or Community Editions. You can also install the build tools from https://visualstudio.microsoft.com/visual-cpp-build-tools/. The build tools *do not* come with Visual Studio Code by default. -\* We highly recommend installing an [Anaconda](https://www.anaconda.com/download) environment. You will get a high-quality BLAS library (MKL) and you get controlled dependency versions regardless of your Linux distro. - An example of environment setup is shown below: * Linux: From 73b4938f7c825a5614108e2c508ae1b58676e6ab Mon Sep 17 00:00:00 2001 From: Ahmad Sharif Date: Tue, 8 Apr 2025 02:39:41 +0000 Subject: [PATCH 258/332] [cuda] Add new faster gammabeta backward kernel (#148605) (Reapply with launch bounds) (#150625) # Changes over the previous PR This reverts commit 61a1f09 and adds `__launch_bounds__` to the kernel. Previously I merged 114d404 that did not work on Blackwell because it consumed too many registers. It got reverted in 61a1f09. For more context see: https://github.com/pytorch/pytorch/issues/150266. This PR reverts the revert (i.e. reapplies the original diff), with one additional line with `__launch_bounds__` added: ``` git diff HEAD^ diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 0d63a2f979c..3ce2c24c18e 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -657,6 +657,7 @@ bool aligned_grid > __global__ void +__launch_bounds__(block_dim_x * block_dim_y) GammaBetaBackwardCUDAKernelTemplate( int64_t M, int64_t N, ``` I managed to get a Blackwell machine and verified that the fix works. The fix was verified using this repro that I got from @drisspg
Repro script that fails on Blackwell ``` import torch from torch.nn import init # from transformer_nuggets import init_logging # from transformer_nuggets.utils.benchmark import profiler # from pathlib import Path # init_logging() class PermuteModule(torch.nn.Module): def __init__(self, permutation): super(PermuteModule, self).__init__() self.permutation = permutation def forward(self, x:torch.Tensor) -> torch.Tensor: assert len(x.shape) == len(self.permutation), f"Dimension mismatch! Unable to permute {len(x.shape)} dim input with a {len(self.permutation)} dim permutation!" return x.permute(*self.permutation) def test(n_layers:int, conv_stride:int): _sequence = [] for _ in range(n_layers): # Conv1d inputs are (N x C x L), LayerNorm expects (* x C). Dims must be permuted between modules. _sequence += [ PermuteModule((0,2,1)), torch.nn.Conv1d(in_channels=512, out_channels=512, groups=1, kernel_size=9, dilation=1, stride=conv_stride, padding=0, bias=False), PermuteModule((0,2,1)), torch.nn.LayerNorm(512), torch.nn.ReLU() ] model = torch.nn.Sequential(*_sequence).to(device="cuda") data = torch.randn((100,2048,512), device="cuda") out = model(data) loss = torch.nn.functional.mse_loss(out, torch.rand_like(out)) loss.backward() torch.autograd.set_detect_anomaly(True) print(f"Torch version: {torch.__version__}") # with profiler(Path("conv")): # # print(f"layers=1, stride=1") # # test(n_layers=1, conv_stride=1) # # print(f"layers=2, stride=1") # # test(n_layers=2, conv_stride=1) # # print(f"layers=1, stride=2") # # test(n_layers=1, conv_stride=2) # print(f"layers=2, stride=2") # test(n_layers=2, conv_stride=2) print(f"layers=2, stride=2") test(n_layers=2, conv_stride=2) # we will not reach this print statement. print("DONE.") ```
I also re-ran my performance benchmark and found no regressions over the previous PR. # Full description of the old PR Original PR: https://github.com/pytorch/pytorch/pull/148605 This PR adds a new kernel for producing gamma and beta values for the backward pass in a performant way. To test the performance against the baseline, I measured the backward pass of layernorm while sweeping over the following variables: 1. dtype in {half, float} 2. M in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 3. N in `2**k, 2**k - 1, 2**k + 1 for k in range(...)` 4. Whether we flush the L2 cache before running the backward pass Summary: The new code performs better than the old code, especially for powers of 2. For M >> N case, it performs very well (kernel itself can be 30x faster and the overall backward pass can be 5-10x faster). In order to visualize results of the kernel when choosing different values of M, N and dtype, I wrote some code to generate a heatmap. The heatmap has N on the x-axis, M on the y-axis and color-coded points where green shows performance improvement and red shows regressions. For example, `m=32 n=2048 1.42x` in the heatmap would indicate the normalized shape had 32 elements. The leading dimensions' product was 2048 elements and the new kernel resulted in the *backward pass* being 1.42x faster than the old *backward pass*. Important note: This heatmap shows the total backward pass time as seen by the user. The kernel time difference can be sometimes very large while the total backward pass time is not that high. For example, for dtype=torch.half, M=32 N=2048, flush_l2_cache=True case, the heatmap shows a speedup of 1.42x, while ncu tells me the new kernel is 2.5x faster than the old: M=32 N=2048 dtype=half flush_l2=True Old Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.35 Elapsed Cycles cycle 27,526 Memory Throughput % 2.21 DRAM Throughput % 0.54 Duration us 20.42 L1/TEX Cache Throughput % 4.31 L2 Cache Throughput % 2.62 SM Active Cycles cycle 1,475.02 Compute (SM) Throughput % 0.29 ----------------------- ----------- ------------ ``` M=32 N=2048 dtype=half flush_l2=True New Kernel NCU summary: ``` ----------------------- ----------- ------------ Metric Name Metric Unit Metric Value ----------------------- ----------- ------------ DRAM Frequency Ghz 1.59 SM Frequency Ghz 1.34 Elapsed Cycles cycle 10,920 Memory Throughput % 5.64 DRAM Throughput % 1.35 Duration us 8.13 L1/TEX Cache Throughput % 1.92 L2 Cache Throughput % 6.89 SM Active Cycles cycle 3,554.41 Compute (SM) Throughput % 0.67 ----------------------- ----------- ------------ ``` Let's look at some rows from the heatmap. For dtype=float16 flush_l2_cache=True and when input shapes are powers of 2, we get the following: image There are 3 columns -- the first shows all data points, the second shows speedups only and the 3rd column shows regressions only. We can see that there are dramatic speedups for M >> N cases and the regressions are not that high (less than 1%, which could just be measurement noise). Here is a small guide I made: ![image](https://github.com/user-attachments/assets/90c26f7c-e3ad-46d2-a6ce-fe4b5fb3d738) For dtype=float32, we get a similar chart: image The new code performs especially well for m >> n cases, and also where m and n are small. The m >> n case is special because we run 2 reduction kernels back to back and parallelize in the "M" dimension (the older kernel only parallelized in the "N" dimension). The new code can sometimes have regressions for non-powers of 2. That is because the old code was using block sizes of {16, 32} while we have `threads.x = 32`. For example when N=33, the old code would have 3 blocks and we will have 2 blocks. I wrote some code to specialize for this case, but I think it will add complexity and @ngimel mentioned that non-powers of 2 are rare enough. I am including the regressions here for completeness' sake: image To see this better: 1. Click the image 2. Right click the expanded image and open in a new tab 3. Go to that tab and left click once to zoom in If you want to see the full data, here it is: ![image](https://github.com/user-attachments/assets/54fb60c9-8c0c-4530-a1dd-79ecda1a69a1) I also measured binary size and compile time since those are important for developers: Binary size comparison ![image](https://github.com/user-attachments/assets/ceef5073-1036-47f6-b9dc-cea088beda51) ``` # Original -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so # This PR -rwxr-xr-x 1 ahmads users 307193112 Mar 6 08:46 ./torch/lib/libtorch_cuda.so ``` The diff in bytes is 302kB which is about a 0.1% increase. Compile time difference: ``` # Original real 0m10.931s user 0m9.676s sys 0m1.004s # this PR real 0m16.720s user 0m15.514s sys 0m1.066s # Command I ran time /usr/local/cuda/bin/nvcc -forward-unknown-to-host-compiler -DAT_PER_OPERATOR_HEADERS -DFLASHATTENTION_DISABLE_ALIBI -DFLASHATTENTION_DISABLE_SOFTCAP -DFLASH_NAMESPACE=pytorch_flash -DFMT_HEADER_ONLY=1 -DHAVE_MALLOC_USABLE_SIZE=1 -DHAVE_MMAP=1 -DHAVE_SHM_OPEN=1 -DHAVE_SHM_UNLINK=1 -DMINIZ_DISABLE_ZIP_READER_CRC32_CHECKS -DONNXIFI_ENABLE_EXT=1 -DONNX_ML=1 -DONNX_NAMESPACE=onnx_torch -DTORCH_CUDA_BUILD_MAIN_LIB -DTORCH_CUDA_USE_NVTX3 -DUNFUSE_FMA -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_CUDA -DUSE_CUFILE -DUSE_DISTRIBUTED -DUSE_EXTERNAL_MZCRC -DUSE_FLASH_ATTENTION -DUSE_MEM_EFF_ATTENTION -DUSE_NCCL -DUSE_RPC -DUSE_TENSORPIPE -D_FILE_OFFSET_BITS=64 -Dtorch_cuda_EXPORTS -I/home/ahmads/personal/pytorch/build/aten/src -I/home/ahmads/personal/pytorch/aten/src -I/home/ahmads/personal/pytorch/build -I/home/ahmads/personal/pytorch -I/home/ahmads/personal/pytorch/cmake/../third_party/benchmark/include -I/home/ahmads/personal/pytorch/third_party/onnx -I/home/ahmads/personal/pytorch/build/third_party/onnx -I/home/ahmads/personal/pytorch/nlohmann -I/home/ahmads/personal/pytorch/third_party/flash-attention/csrc/flash_attn/src -I/home/ahmads/personal/pytorch/aten/src/THC -I/home/ahmads/personal/pytorch/aten/src/ATen/cuda -I/home/ahmads/personal/pytorch/third_party/fmt/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/include -I/home/ahmads/personal/pytorch/aten/src/ATen/../../../third_party/cutlass/tools/util/include -I/home/ahmads/personal/pytorch/build/caffe2/aten/src -I/home/ahmads/personal/pytorch/aten/src/ATen/.. -I/home/ahmads/personal/pytorch/build/nccl/include -I/home/ahmads/personal/pytorch/c10/cuda/../.. -I/home/ahmads/personal/pytorch/c10/.. -I/home/ahmads/personal/pytorch/third_party/tensorpipe -I/home/ahmads/personal/pytorch/build/third_party/tensorpipe -I/home/ahmads/personal/pytorch/third_party/tensorpipe/third_party/libnop/include -I/home/ahmads/personal/pytorch/torch/csrc/api -I/home/ahmads/personal/pytorch/torch/csrc/api/include -isystem /home/ahmads/personal/pytorch/build/third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/gloo -isystem /home/ahmads/personal/pytorch/cmake/../third_party/tensorpipe/third_party/libuv/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googlemock/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/googletest/googletest/include -isystem /home/ahmads/personal/pytorch/third_party/protobuf/src -isystem /home/ahmads/personal/pytorch/third_party/XNNPACK/include -isystem /home/ahmads/personal/pytorch/third_party/ittapi/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/eigen -isystem /usr/local/cuda/include -isystem /home/ahmads/personal/pytorch/third_party/ideep/mkl-dnn/include/oneapi/dnnl -isystem /home/ahmads/personal/pytorch/third_party/ideep/include -isystem /home/ahmads/personal/pytorch/INTERFACE -isystem /home/ahmads/personal/pytorch/third_party/nlohmann/include -isystem /home/ahmads/personal/pytorch/third_party/NVTX/c/include -isystem /home/ahmads/personal/pytorch/cmake/../third_party/cudnn_frontend/include -DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS -D_GLIBCXX_USE_CXX11_ABI=1 -Xfatbin -compress-all -DONNX_NAMESPACE=onnx_torch -gencode arch=compute_90,code=sm_90 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -Wno-deprecated-gpu-targets --expt-extended-lambda -DCUB_WRAPPED_NAMESPACE=at_cuda_detail -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -O3 -DNDEBUG -std=c++17 -Xcompiler=-fPIC -DTORCH_USE_LIBUV -DCAFFE2_USE_GLOO -Xcompiler -Wall -Wextra -Wdeprecated -Wno-unused-parameter -Wno-missing-field-initializers -Wno-array-bounds -Wno-unknown-pragmas -Wno-strict-overflow -Wno-strict-aliasing -Wunused-function -Wunused-variable -Wunused-but-set-variable -Wno-maybe-uninitialized -MD -MT caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o -MF caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o.d -x cu -c /home/ahmads/personal/pytorch/aten/src/ATen/native/cuda/layer_norm_kernel.cu -o caffe2/CMakeFiles/torch_cuda.dir/__/aten/src/ATen/native/cuda/layer_norm_kernel.cu.o ``` So the new PR is 6 seconds longer compile time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150625 Approved by: https://github.com/ngimel, https://github.com/atalman --- .../src/ATen/native/cuda/layer_norm_kernel.cu | 526 +++++++++++------- test/test_nn.py | 20 + 2 files changed, 352 insertions(+), 194 deletions(-) diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 9feb30c21941..ee573e2e566f 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -540,191 +540,365 @@ __global__ void GammaBetaBackwardSimpleCUDAKernel( } } -// This implementation gets called if M and N divide with 32. This case should -// be the most common. We can then make better use of warp level intrinsics -// to improve performance. +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsHelper( + int64_t M_start, + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; + + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; + T_ACC warp_mean = 0, warp_rstd = 0; + if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { + warp_mean = mean[mean_index + lane_id]; + warp_rstd = rstd[mean_index + lane_id]; + } + // We do a WARP_SYNC() here because we use WARP_SHFL below to access + // warp_mean and warp_rstd. + WARP_SYNC(); + + T_ACC dY_regs[rows_per_thread_y] = {0}; + T_ACC X_regs[rows_per_thread_y] = {0}; + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + int64_t current_y = M_start + threadIdx.y * rows_per_thread_y + i; + bool active = true; + if (check_x && thread_x >= N) { + active = false; + } + if (check_y && current_y >= M) { + active = false; + } + if (active) { + dY_regs[i] = dY[current_y * N + thread_x]; + X_regs[i] = X[current_y * N + thread_x]; + } + } -template -__global__ void GammaBetaBackwardCUDAKernel_32x32( + #pragma unroll + for (int i = 0; i < rows_per_thread_y; ++i) { + T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); + dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; + db_sum += dY_regs[i]; + } +} + +template +__device__ +__forceinline__ +void +blockReduceGammaBetaBackwardsWithChecks( int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db, + T_ACC &dg_sum, + T_ACC &db_sum +) { + for (int64_t M_start = blockIdx.y * rows_per_block_y; + M_start < M; + M_start += rows_per_block_y * gridDim.y) { + int64_t M_end = M_start + rows_per_block_y - 1; + if (!check_y || M_end < M) { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsHelper + (M_start, M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } +} + +// block_dim_x is the number of threads in the x dimension per block. +// block_dim_y is the number of threads in the y dimension per block. +// rows_per_block_y is the size of the tile (number of data elements) +// in the y dimension per block. +// partial_reduction indicates whether we need to reduce across threads +// or not. If set to true, we will not reduce across threads. This can +// be faster in the M >> N case but requires another kernel to do a full +// final reduction. +// aligned_grid means the data size is a multiple of tile size. In that +// case we don't need to check for boundary conditions which can provide +// a further speedup by not needing instructions to check for edge cases +// and not needing predicate registers. +template +__global__ +void +__launch_bounds__(block_dim_x * block_dim_y) + GammaBetaBackwardCUDAKernelTemplate( + int64_t M, + int64_t N, + const T* __restrict__ dY, + const T* __restrict__ X, + const T_ACC* __restrict__ mean, + const T_ACC* __restrict__ rstd, + T* __restrict__ dg, + T* __restrict__ db) { + // This assert is a compile-time check only. + constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; + static_assert(rows_per_thread_y <= kWarpSize); T_ACC dg_sum = 0; T_ACC db_sum = 0; - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; + if (aligned_grid) { + // When N and M align perfectly with block_dim_x and block_dim_y, we + // can skip boundary condition checks that waste instruction issue slots. + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + // In the general case we need to check boundary conditions in the M + // dimension. However, we can still avoid boundary checks in the N dimension + // for the inner blocks. So try to avoid those checks when possible. + if (blockIdx.x * block_dim_x + block_dim_x - 1 < N) { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } else { + blockReduceGammaBetaBackwardsWithChecks + + (M, N, dY, X, mean, rstd, dg, db, dg_sum, db_sum); + } + } - if (j < N) { - constexpr int unroll_factor = 8; - int laneId = threadIdx.x & (C10_WARP_SIZE - 1); - - T_ACC mean_reg, mean_reg_tmp; - T_ACC rstd_reg, rstd_reg_tmp; - T dY_reg; - T X_reg; - - // Main loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); - bcounter++) { - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - - if (laneId < unroll_factor) { - mean_reg_tmp = mean[offset + laneId]; - rstd_reg_tmp = rstd[offset + laneId]; - } - WARP_SYNC(); + int64_t thread_x = ((int64_t)blockIdx.x) * block_dim_x + threadIdx.x; - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = WARP_SHFL(mean_reg_tmp, ii, kWarpSize); - rstd_reg = WARP_SHFL(rstd_reg_tmp, ii, kWarpSize); - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; + // When partial_reduction is requested, we don't reduce within a block. + // We also don't reduce if we are only a single block in the y dimension. + if (partial_reduction || (blockDim.y == 1 && gridDim.y == 1)) { + if (aligned_grid || thread_x < N) { + int64_t thread_y = ((int64_t)blockIdx.y) * blockDim.y + threadIdx.y; + if (dg) { + dg[thread_y * N + thread_x] = dg_sum; } - } - - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++) { - if ((offset + ii) < M) { - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; + if (db) { + db[thread_y * N + thread_x] = db_sum; } } - - // This kernel uses a block of (C10_WARP_SIZE x C10_WARP_SIZE) and - // gets called when M; N divide by 32. We can use warp shuffles - // for the final reduction step. This removes 4 shmem loads and - // stores with their corresponding __syncthreads() - - // This greatly reduces bank conflicts at the expense of a little - // extra shared memory. It does not impact occupancy - int padded_bx = (1 + blockDim.x); - + } else { + // The caller requested a full reduction so we must reduce across + // warps using shared memory and warp shuffles. + static_assert(rows_per_thread_y <= C10_WARP_SIZE); + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + T_ACC* s_dg; + T_ACC* s_db; + int padded_bx = (block_dim_x + 1); + // Transpose dg and db. s_dg = s_data_typed; - s_db = s_data_typed + (padded_bx * blockDim.y); + s_db = s_data_typed + (padded_bx * block_dim_y); s_dg[threadIdx.y * padded_bx + threadIdx.x] = dg_sum; s_db[threadIdx.y * padded_bx + threadIdx.x] = db_sum; __syncthreads(); // Load transposed so that a warp holds an entire column - T_ACC reg_dg = s_dg[threadIdx.x * padded_bx + threadIdx.y]; - T_ACC reg_db = s_db[threadIdx.x * padded_bx + threadIdx.y]; - for (unsigned delta = C10_WARP_SIZE >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); - } - - if (threadIdx.x == 0) { - const int64_t j = blockIdx.x * blockDim.x + threadIdx.y; - if (dg) { - dg[j] = reg_dg; + // Because block_dim_x != block_dim_y in the general case, we need + // some code to handle the general case. + static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); + constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; + int thread_id = threadIdx.y * block_dim_x + threadIdx.x; + int warp_id = thread_id / C10_WARP_SIZE; + int lane_id = thread_id & (C10_WARP_SIZE - 1); + #pragma unroll + for (int i = warp_id; i < block_dim_x; i += warps_available_to_reduce) { + T_ACC reg_db, reg_dg; + if (lane_id < block_dim_y) { + reg_dg = s_dg[lane_id * padded_bx + i]; + reg_db = s_db[lane_id * padded_bx + i]; } - if (db) { - db[j] = reg_db; + #pragma unroll + for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { + reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); + reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + } + // Reduce is done. Now write it out to global memory. + int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; + if (threadIdx.x == 0 && (aligned_grid || out_index < N)) { + if (dg) { + dg[out_index] = reg_dg; + } + if (db) { + db[out_index] = reg_db; + } } } } } -template -__global__ void GammaBetaBackwardCUDAKernel( +template +void LaunchAndCheckGammaBetaBackwardKernel( + bool aligned_grid, + dim3 blocks, + dim3 threads, + size_t shmem_sz, + cudaStream_t cuda_stream, + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + T* dgamma_data, + T* dbeta_data) { +if (aligned_grid) { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } else { + GammaBetaBackwardCUDAKernelTemplate + <<>>( + M, + N, + dY_data, + X_data, + mean_data, + rstd_data, + dgamma_data, + dbeta_data); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void ConfigureAndLaunchGammaBetaBackwardKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, int64_t M, int64_t N, - const T* dY, - const T* X, - const T_ACC* mean, - const T_ACC* rstd, - T* dg, - T* db) { - alignas(sizeof(double)) extern __shared__ char s_data1[]; - T_ACC* s_data_typed = reinterpret_cast(&s_data1); - T_ACC* s_dg; - T_ACC* s_db; - - const int64_t j = ((int64_t) blockIdx.x) * blockDim.x + threadIdx.x; - - T_ACC dg_sum = 0; - T_ACC db_sum = 0; - - if (j < N) { - constexpr int unroll_factor = 8; - - T_ACC mean_reg; - T_ACC rstd_reg; - T dY_reg; - T X_reg; - - // Main Loop - int bcounter; - for (bcounter = 0; bcounter < M / (blockDim.y * unroll_factor); bcounter++){ - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + T* dgamma_data = + dgamma->defined() ? dgamma->template data_ptr() : nullptr; + T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + blocks.y = 1; + size_t shmem_sz = (block_dim_x + 1) * block_dim_y * sizeof(T_ACC) * 2; + if (blocks.y == 1 && threads.y == 1) { + // Optimization: since there is just one thread doing all the summation, we don't need a reduction + // across threads. So we set partial_reduction to true. + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } else { + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, shmem_sz, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_data, dbeta_data); + } - #pragma unroll - for (int ii = 0; ii < unroll_factor; ++ii) { - dY_reg = dY[(offset + ii) * N + j]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } - } +} - // Remainder loop - int offset = (bcounter * blockDim.y + threadIdx.y) * unroll_factor; - for (int ii = 0; ii < unroll_factor; ii++ ){ - if ((offset + ii) < M) { - dY_reg = dY[(offset + ii) * N + j ]; - X_reg = X[(offset + ii) * N + j]; - mean_reg = mean[offset + ii]; - rstd_reg = rstd[offset + ii]; - dg_sum += dY_reg * (X_reg - mean_reg) * rstd_reg; - db_sum += dY_reg; - } +template +void LaunchGammaBetaBackwardCUDAKernel( + const T* dY_data, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + int64_t M, + int64_t N, + Tensor* dgamma, + Tensor* dbeta, + cudaStream_t cuda_stream) { + constexpr int block_dim_x = 32; + const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { + // We have a situation where M >> N and N is small. + // In this case we can speed up the computation by parallelizing in the M dimension. + // We launch multiple blocks in the y-dimension, and compute partial sums for the + // gradient in the first pass. Then we do a .sum(0) to do a final reduction. + // Although we launch 2 kernels, we can get up to a 10x speedup for large M. + constexpr int block_dim_y = 1; + constexpr int rows_per_block_y = 32; + bool aligned_grid = (M % rows_per_block_y == 0) && (N % block_dim_x == 0); + dim3 threads{block_dim_x, block_dim_y}; + dim3 blocks; + blocks.x = (N + block_dim_x - 1) / block_dim_x; + // int rows_per_block = my_gamma_beta_unroll_factor * + blocks.y = (M + rows_per_block_y - 1) / rows_per_block_y; + constexpr int max_grid_size = 64 * 1024 / 2; + blocks.y = std::min(max_grid_size / blocks.x, blocks.y); + Tensor dgamma_blocks; + Tensor dbeta_blocks; + T * dgamma_blocks_ptr = nullptr; + T * dbeta_blocks_ptr = nullptr; + if (dgamma->defined()) { + auto options = dgamma->options(); + dgamma_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dgamma_blocks_ptr = dgamma_blocks.data_ptr(); } - - // Do the final reduction in shared memory - s_dg = s_data_typed; - s_db = s_data_typed + blockDim.x * blockDim.y; - s_dg[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; - s_db[threadIdx.y * blockDim.x + threadIdx.x] = db_sum; - __syncthreads(); - - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - if (threadIdx.y < offset) { - s_dg[threadIdx.y * blockDim.x + threadIdx.x] += - s_dg[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - s_db[threadIdx.y * blockDim.x + threadIdx.x] += - s_db[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; - } - __syncthreads(); + if (dbeta->defined()) { + auto options = dbeta->options(); + dbeta_blocks = at::empty({blocks.y * threads.y, dgamma->size(-1)}, options); + dbeta_blocks_ptr = dbeta_blocks.data_ptr(); } + LaunchAndCheckGammaBetaBackwardKernel( + aligned_grid, blocks, threads, 0, cuda_stream, dY_data, X_data, mean_data, rstd_data, M, N, dgamma_blocks_ptr, dbeta_blocks_ptr); - if (threadIdx.y == 0) { - if (dg) { - dg[j] = s_dg[threadIdx.x]; - } - if (db) { - db[j] = s_db[threadIdx.x]; - } + *dgamma = dgamma_blocks.sum(0); + *dbeta = dbeta_blocks.sum(0); + } else { + // We are in the normal case where M is not that large. + // We can change the tile shape (which is the last template parameter) in accordance with M. + // For small M it is faster to have a smaller tile, otherwise we could have idle threads. + // For larger M we use a bigger tile size. + if (M < 64) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 128) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else if (M < 256) { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); + } else { + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } } } @@ -1250,6 +1424,7 @@ void LayerNormBackwardKernelImplInternal( dgamma->defined() ? dgamma->template data_ptr() : nullptr; T* dbeta_data = dbeta->defined() ? dbeta->template data_ptr() : nullptr; +#if defined(USE_ROCM) if (M < 128) { // For small batch size, do colwise reduce directly. const int64_t B = (N + kCUDANumThreads - 1) / kCUDANumThreads; @@ -1265,7 +1440,6 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { -#if defined(USE_ROCM) // For small batch size, do colwise reduce directly. const int part_size = warp_size; const dim3 threads2(warp_size, 4, 1); @@ -1300,47 +1474,11 @@ void LayerNormBackwardKernelImplInternal( dgamma_data, dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); + } #else - if ((M % kWarpSize == 0) && (N % kWarpSize == 0)) { - // This implementation relies on warp primitives and requires that M and N divide - // exactly to warp size. - dim3 threads{kWarpSize, kWarpSize}; - int blocks = (N + threads.x - 1) / threads.x; - - // If M and N divide by warp_size, we can use warp shuffles for the final reduction. - // That requires transposing values in shared memory, so we apply a padding to - // reduce bank conflicts. - - size_t shmem_sz = 2 * sizeof(T_ACC) * (threads.x + 1) * threads.y; - GammaBetaBackwardCUDAKernel_32x32 - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } else { - dim3 threads{16, 32}; - int blocks = (N + threads.x - 1) / threads.x; - size_t shmem_sz = 2 * sizeof(T_ACC) * threads.x * threads.y; - GammaBetaBackwardCUDAKernel - <<>>( - M, - N, - dY_data, - X_data, - mean_data, - rstd_data, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - } + LaunchGammaBetaBackwardCUDAKernel( + dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); #endif - } } } diff --git a/test/test_nn.py b/test/test_nn.py index ff3950ec32e4..32b0efd40aff 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -7195,6 +7195,26 @@ def test_layer_norm_eps(self): ln = torch.nn.LayerNorm(2, eps=1e-6, elementwise_affine=False) self.assertEqual(ln.forward(x), torch.zeros_like(x)) + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_layer_norm_backwards_eps(self): + dtype = torch.float + m_x_n_list = [(3, 3), (5, 5), (11, 11), (55, 55), + (32, 32), (1024, 32), (1024, 1024), + (33, 33), (1025, 33), (1025, 1025)] + for m, n in m_x_n_list: + x = torch.randn((m, n), dtype=dtype, requires_grad=True) + grad_output = torch.rand_like(x) + x_cuda = x.clone().detach().to("cuda").requires_grad_() + grad_output_cuda = grad_output.clone().detach().to("cuda") + ln = nn.LayerNorm(n, dtype=dtype) + ln_cuda = nn.LayerNorm(n, device="cuda", dtype=dtype) + ln_out = ln(x) + ln_out_cuda = ln_cuda(x_cuda) + ln_out.backward(grad_output) + ln_out_cuda.backward(grad_output_cuda) + self.assertEqual(ln.weight.grad, ln_cuda.weight.grad, f"weight grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + self.assertEqual(ln.bias.grad, ln_cuda.bias.grad, f"bias grad failed: {m=} {n=}", rtol=1e-5, atol=1e-4) + @largeTensorTest("40GB", device="cuda") def test_layer_norm_large_tensor(self): # test for https://github.com/pytorch/pytorch/issues/136291 From 836955bdbdeb299e6937065299564fb44ec422c2 Mon Sep 17 00:00:00 2001 From: atalman Date: Tue, 8 Apr 2025 02:58:28 +0000 Subject: [PATCH 259/332] [Manylinux 2.28] Correct Linux aarch64 cuda binaries wheel name (#150786) Related to: https://github.com/pytorch/pytorch/issues/149044#issuecomment-2784044555 For CPU binaries we run auditwheel however for cuda binaries auditwheel produces invalid results . Hence we need to rename the file. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150786 Approved by: https://github.com/malfet --- .ci/aarch64_linux/aarch64_wheel_ci_build.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.ci/aarch64_linux/aarch64_wheel_ci_build.py b/.ci/aarch64_linux/aarch64_wheel_ci_build.py index 92dabf0fee48..1cce2836974d 100755 --- a/.ci/aarch64_linux/aarch64_wheel_ci_build.py +++ b/.ci/aarch64_linux/aarch64_wheel_ci_build.py @@ -136,6 +136,9 @@ def complete_wheel(folder: str) -> str: """ wheel_name = list_dir(f"/{folder}/dist")[0] + # Please note for cuda we don't run auditwheel since we use custom script to package + # the cuda dependencies to the wheel file using update_wheel() method. + # However we need to make sure filename reflects the correct Manylinux platform. if "pytorch" in folder and not enable_cuda: print("Repairing Wheel with AuditWheel") check_call(["auditwheel", "repair", f"dist/{wheel_name}"], cwd=folder) @@ -147,7 +150,14 @@ def complete_wheel(folder: str) -> str: f"/{folder}/dist/{repaired_wheel_name}", ) else: - repaired_wheel_name = wheel_name + repaired_wheel_name = wheel_name.replace( + "linux_aarch64", "manylinux_2_28_aarch64" + ) + print(f"Renaming {wheel_name} wheel to {repaired_wheel_name}") + os.rename( + f"/{folder}/dist/{wheel_name}", + f"/{folder}/dist/{repaired_wheel_name}", + ) print(f"Copying {repaired_wheel_name} to artifacts") shutil.copy2( From 7e11089fe5c603f694aabccfff267fec7c122e35 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 8 Apr 2025 03:52:18 +0000 Subject: [PATCH 260/332] Optimize dataloader Self typing (#146816) Optimize `dataloader.py` method return type with Self typing Pull Request resolved: https://github.com/pytorch/pytorch/pull/146816 Approved by: https://github.com/albanD --- torch/utils/data/dataloader.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 66a371085b39..15a71c7d7f94 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -5,6 +5,7 @@ functions to be run in multiprocessing. E.g., the data loading worker loop is in `./_utils/worker.py`. """ +from __future__ import annotations import functools import itertools @@ -14,8 +15,8 @@ import queue import threading import warnings -from collections.abc import Iterable -from typing import Any, Callable, Generic, Optional, TypeVar, Union +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import Self import torch import torch.distributed as dist @@ -37,6 +38,9 @@ ) +if TYPE_CHECKING: + from collections.abc import Iterable + __all__ = [ "DataLoader", "get_worker_info", @@ -233,7 +237,7 @@ class DataLoader(Generic[_T_co]): sampler: Union[Sampler, Iterable] pin_memory_device: str prefetch_factor: Optional[int] - _iterator: Optional["_BaseDataLoaderIter"] + _iterator: Optional[_BaseDataLoaderIter] __initialized = False def __init__( @@ -256,7 +260,7 @@ def __init__( persistent_workers: bool = False, pin_memory_device: str = "", in_order: bool = True, - ): + ) -> None: torch._C._log_api_usage_once("python.data_loader") if num_workers < 0: @@ -416,7 +420,7 @@ def __init__( torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] - def _get_iterator(self) -> "_BaseDataLoaderIter": + def _get_iterator(self) -> _BaseDataLoaderIter: if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) else: @@ -475,9 +479,7 @@ def __setattr__(self, attr, val): super().__setattr__(attr, val) - # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up - # since '_BaseDataLoaderIter' references 'DataLoader'. - def __iter__(self) -> "_BaseDataLoaderIter": + def __iter__(self) -> _BaseDataLoaderIter: # When using a single worker the returned iterator should be # created everytime to avoid resetting its state # However, in the case of a multiple workers iterator @@ -704,7 +706,7 @@ def __init__(self, loader: DataLoader) -> None: self._num_yielded = 0 self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__" - def __iter__(self) -> "_BaseDataLoaderIter": + def __iter__(self) -> Self: return self def _reset(self, loader, first_iter=False): From c9c0f8eae333179de3f642b325254426de7f83d5 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Tue, 8 Apr 2025 03:55:33 +0000 Subject: [PATCH 261/332] Add plot for `torch.nn.Threshold` and `torch.nn.GLU` (#150171) Fixes #150170 ## Changes - Add plot for `torch.nn.Threshold` and `torch.nn.GLU` - Add example output make them easier get result by users ## Test Result ![image](https://github.com/user-attachments/assets/f6c5bc46-f9b7-4db7-9797-e08d8423d1b3) ![image](https://github.com/user-attachments/assets/ad4e6c84-7b29-44f1-b7bd-9c81e4a92ef8) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150171 Approved by: https://github.com/albanD --- docs/source/scripts/build_activation_images.py | 12 ++++++++++-- torch/nn/modules/activation.py | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docs/source/scripts/build_activation_images.py b/docs/source/scripts/build_activation_images.py index 27e35f22a810..f7d64b5e4f7d 100644 --- a/docs/source/scripts/build_activation_images.py +++ b/docs/source/scripts/build_activation_images.py @@ -46,6 +46,8 @@ torch.nn.Softsign(), torch.nn.Tanh(), torch.nn.Tanhshrink(), + torch.nn.Threshold(0, 0.5), + torch.nn.GLU(), ] @@ -54,8 +56,14 @@ def plot_function(function, **args): Plot a function on the current plot. The additional arguments may be used to specify color, alpha, etc. """ - xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6 - plt.plot(xrange.numpy(), function(xrange).detach().numpy(), **args) + if isinstance(function, torch.nn.GLU): + xrange = torch.arange(-7.0, 7.0, 0.01).unsqueeze(1).repeat(1, 2) + x = xrange.numpy()[:, 0] + else: + xrange = torch.arange(-7.0, 7.0, 0.01) # We need to go beyond 6 for ReLU6 + x = xrange.numpy() + y = function(xrange).detach().numpy() + plt.plot(x, y, **args) # Step through all the functions diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 564a516a2477..54a2dec94e18 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -66,10 +66,12 @@ class Threshold(Module): - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - Output: :math:`(*)`, same shape as the input. + .. image:: ../scripts/activation_images/Threshold.png + Examples:: - >>> m = nn.Threshold(0.1, 20) - >>> input = torch.randn(2) + >>> m = nn.Threshold(0, 0.5) + >>> input = torch.arange(-3, 3) >>> output = m(input) """ @@ -674,6 +676,8 @@ class GLU(Module): dimensions - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2` + .. image:: ../scripts/activation_images/GLU.png + Examples:: >>> m = nn.GLU() From f8aa6404ac916253b853d57120a25ec383f5e96b Mon Sep 17 00:00:00 2001 From: FFFrog Date: Mon, 7 Apr 2025 19:55:10 +0800 Subject: [PATCH 262/332] Refactor: add initialization of math.lcm into torch_c_binding_in_graph_functions (#150766) As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150766 Approved by: https://github.com/aorenste, https://github.com/jansel --- torch/_dynamo/trace_rules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 42bbf9a0623f..9c9a142bda2c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -364,6 +364,7 @@ "math.isinf", "math.isnan", "math.isqrt", + "math.lcm", "math.ldexp", "math.lgamma", "math.log", @@ -2235,7 +2236,6 @@ ) -torch_c_binding_in_graph_functions["math.lcm"] = TorchInGraphFunctionVariable if sys.version_info >= (3, 11): torch_c_binding_in_graph_functions["math.exp2"] = TorchInGraphFunctionVariable torch_c_binding_in_graph_functions["math.cbrt"] = TorchInGraphFunctionVariable From 58ede0cca3e73a05659bd8de8f131f2e83601b5c Mon Sep 17 00:00:00 2001 From: "xinan.lin" Date: Sun, 6 Apr 2025 23:50:14 -0700 Subject: [PATCH 263/332] [Inductor XPU] Refine `test_mkldnn_pattern_matcher.py` to be reusable for XPU. (#150286) This PR extracts some test cases from TestPatternMatcher into a newly created TestPatternMatcherGeneric, and uses instantiate_device_type_tests to make them reusable across multiple devices. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150286 Approved by: https://github.com/jansel --- test/inductor/test_cpu_cpp_wrapper.py | 10 +- test/inductor/test_mkldnn_pattern_matcher.py | 191 ++++++++++++------- 2 files changed, 133 insertions(+), 68 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 5b10044fb648..8bd687c42a5f 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -189,7 +189,7 @@ class BaseTest(NamedTuple): BaseTest( "test_conv2d_unary", "cpu", - test_mkldnn_pattern_matcher.TestPatternMatcher(), + test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(), condition=torch.backends.mkldnn.is_available(), slow=True, ), @@ -220,9 +220,9 @@ class BaseTest(NamedTuple): ], BaseTest("test_polar"), BaseTest( - "test_linear_binary", + "test_linear_binary_cpu", "", - test_mkldnn_pattern_matcher.TestPatternMatcher(), + test_mkldnn_pattern_matcher.TestPatternMatcherGenericCPU(), torch.backends.mkldnn.is_available() and torch.ops.mkldnn._is_mkldnn_bf16_supported(), ), @@ -359,7 +359,9 @@ class BaseTest(NamedTuple): BaseTest("test_view_as_complex"), BaseTest("test_view_as_real"), BaseTest( - "test_woq_int4", "cpu", test_mkldnn_pattern_matcher.TestPatternMatcher() + "test_woq_int4", + "cpu", + test_mkldnn_pattern_matcher.TestPatternMatcher(), ), ]: make_test_case( diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 52a705911166..4b184aee4aba 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -13,6 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.nn import functional as F +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_quantization import ( _generate_qdq_quantized_model, skipIfNoDynamoSupport, @@ -33,7 +34,11 @@ TEST_MKL, xfailIfACL, ) -from torch.testing._internal.inductor_utils import _check_has_dynamic_shape, HAS_CPU +from torch.testing._internal.inductor_utils import ( + _check_has_dynamic_shape, + clone_preserve_strides_offset, + HAS_CPU, +) # The dict value is match_nodes(computation_op+unary_op) @@ -91,7 +96,7 @@ def get_default_quantizer(is_qat, is_dynamic): return quantizer -def cal_conv_generated_kernel_number(mod, input, dtype, dim=4): +def cal_conv_generated_kernel_number(mod, input, dtype, dim=4, device="cpu"): # this function is to decide how many kernels are generated # while testing conv2d/3d/deconv2d # the assumption is: @@ -103,11 +108,14 @@ def cal_conv_generated_kernel_number(mod, input, dtype, dim=4): # and force the output to have same stride with eager. # So there will be a to_contiguous for output if eager output is contiguouse mod = copy.deepcopy(mod) + mod = mod.to(device=device) input = input.clone() + input = input.to(device) + if dtype == torch.float32: maybe_autocast = contextlib.nullcontext() else: - maybe_autocast = torch.amp.autocast("cpu", dtype=dtype) + maybe_autocast = torch.amp.autocast(device_type=device, dtype=dtype) with torch.no_grad(), maybe_autocast: output = mod(input) input_kernel, output_kernel = 0, 0 @@ -155,26 +163,33 @@ def _test_common( quantizer=None, compile_options={}, # noqa: B006 ): + if not hasattr(self, "device"): + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) + device = "xpu" if has_xpu else "cpu" + else: + device = self.device + + mod = mod.to(device=device) + if device != "cpu": + inputs = tuple( + clone_preserve_strides_offset(x, device=device) for x in inputs + ) counters.clear() torch._dynamo.reset() - has_xpu = any( - isinstance(input, torch.Tensor) and input.device.type == "xpu" - for input in inputs - ) - device_type = "xpu" if has_xpu else "cpu" if check_autocast == torch.bfloat16 and ( - torch.ops.mkldnn._is_mkldnn_bf16_supported() or has_xpu + torch.ops.mkldnn._is_mkldnn_bf16_supported() or device == "xpu" ): maybe_autocast = torch.amp.autocast( - device_type=device_type, dtype=torch.bfloat16 + device_type=device, dtype=torch.bfloat16 ) atol, rtol = 1e-2, 1e-2 elif check_autocast == torch.float16 and ( - torch.ops.mkldnn._is_mkldnn_fp16_supported() or has_xpu + torch.ops.mkldnn._is_mkldnn_fp16_supported() or device == "xpu" ): - maybe_autocast = torch.amp.autocast( - device_type=device_type, dtype=torch.float16 - ) + maybe_autocast = torch.amp.autocast(device_type=device, dtype=torch.float16) atol, rtol = 1e-2, 1e-2 else: assert check_autocast == torch.float32 @@ -233,8 +248,8 @@ def _test_code_common( torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) -class TestPatternMatcher(TestPatternMatcherBase): - def _test_conv_unary_cpu_base(self, dim=4): +class TestPatternMatcherGeneric(TestPatternMatcherBase): + def _test_conv_unary_base(self, dim=4): assert dim == 4 or dim == 5 class M(torch.nn.Module): @@ -304,23 +319,27 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv2d_unary_cpu(self): - self._test_conv_unary_cpu_base(dim=4) + def test_conv2d_unary(self, device): + self.device = device + self._test_conv_unary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv3d_unary_cpu(self): - self._test_conv_unary_cpu_base(dim=5) + def test_conv3d_unary(self, device): + self.device = device + self._test_conv_unary_base(dim=5) + + def test_linear_unary(self, device): + self.device = device - def test_linear_unary(self): class M(torch.nn.Module): def __init__( self, @@ -374,7 +393,9 @@ def matcher_check_fn(): self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1) @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_fp32(self): + def test_linear_fp32(self, device): + self.device = device + class M(torch.nn.Module): def __init__(self, bias): super().__init__() @@ -396,7 +417,9 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn) @unittest.skipIf(not TEST_MKL, "Test requires MKL") - def test_linear_input_non_contiguous_3D_wo_bias(self): + def test_linear_input_non_contiguous_3D_wo_bias(self, device): + self.device = device + # Activation is 3D, non-contiguous and without Bias class M(torch.nn.Module): def __init__(self): @@ -438,17 +461,19 @@ def forward(self, x): ) torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) - def test_linear_add_bias(self): + def test_linear_add_bias(self, device): + self.device = device + class M(torch.nn.Module): - def __init__(self, dtype, unary_fn, cast_bias): + def __init__(self, device, dtype, unary_fn, cast_bias): super().__init__() self.linear1 = torch.nn.Linear(10, 64, bias=False) - self.bias1 = torch.randn(64) + self.bias1 = torch.randn(64, device=device) self.linear2 = torch.nn.Linear(10, 64, bias=False) - self.bias2 = torch.randn(64) + self.bias2 = torch.randn(64, device=device) if cast_bias: - self.bias1 = self.bias1.to(dtype=dtype) - self.bias2 = self.bias2.to(dtype=dtype) + self.bias1 = self.bias1.to(dtype=dtype, device=device) + self.bias2 = self.bias2.to(dtype=dtype, device=device) self.unary_fn = unary_fn def forward(self, x): @@ -464,7 +489,7 @@ def forward(self, x): options = itertools.product(unary_list, dtypes) for unary_fn, dtype in options: metrics.reset() - fold_mod = M(dtype, unary_fn, cast_bias=True).eval() + fold_mod = M(self.device, dtype, unary_fn, cast_bias=True).eval() v = torch.randn(2, 10) def folder_matcher_check_fn(): @@ -495,7 +520,7 @@ def folder_matcher_check_fn(): # we won't fold the bias if bias is not same dtype with weight # https://github.com/pytorch/pytorch/pull/129138 metrics.reset() - mod = M(dtype, unary_fn, cast_bias=False).eval() + mod = M(self.device, dtype, unary_fn, cast_bias=False).eval() def matcher_check_fn(): self.assertEqual( @@ -575,20 +600,22 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv_transpose2d_unary_cpu(self): + def test_conv_transpose2d_unary(self, device): + self.device = device self._test_conv_transpose_unary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv_transpose3d_unary_cpu(self): + def test_conv_transpose3d_unary(self, device): + self.device = device self._test_conv_transpose_unary_base(dim=5) def _test_conv_binary_base(self, dim=4): @@ -669,20 +696,22 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype) generated_kernel_count = cal_conv_generated_kernel_number( - mod, v, dtype, dim + mod, v, dtype, dim, self.device ) self.assertEqual(metrics.generated_kernel_count, generated_kernel_count) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv2d_binary(self): + def test_conv2d_binary(self, device): + self.device = device self._test_conv_binary_base(dim=4) @skipIfNoDynamoSupport @skipIfNoONEDNN @skipIfRocm - def test_conv3d_binary(self): + def test_conv3d_binary(self, device): + self.device = device self._test_conv_binary_base(dim=5) def _test_conv_binary_broadcast_shapes_base(self, dim=4): @@ -788,7 +817,9 @@ def test_conv2d_binary_broadcast_shapes_cpu(self): def test_conv3d_binary_broadcast_shapes_cpu(self): self._test_conv_binary_broadcast_shapes_base(dim=5) - def test_linear_binary(self): + def test_linear_binary(self, device): + self.device = device + class M(torch.nn.Module): def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs): super().__init__() @@ -939,7 +970,9 @@ def matcher_check_fn(): self._test_common(mod, (x1, x2), matcher_check_fn) - def test_multi_linear_share_same_input(self): + def test_multi_linear_share_same_input(self, device): + self.device = device + # llama pattern. class M(torch.nn.Module): def __init__( @@ -979,6 +1012,8 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) + +class TestPatternMatcher(TestPatternMatcherBase): def _qconv2d_test_helper(self, device="cpu", int8_mixed_bf16=False): class M(torch.nn.Module): def __init__( @@ -4119,30 +4154,42 @@ def matcher_check_fn(): self.assertEqual(counters["inductor"]["qlinear_binary_matcher_count"], 1) -# When testing kernel counts, unspecializing float causes wobbling of our tests because -# we end up reusing the same compiled region across tests. Thus we purposely specialize floats -# here since we primarily care about number of kernels generated in the absence of compile -# caching. -@dynamo_config.patch( - { - "dynamic_shapes": True, - "assume_static_by_default": False, - "specialize_float": True, - } -) -class TestDynamicPatternMatcher(TestPatternMatcherBase): - _test_conv_unary_cpu_base = TestPatternMatcher._test_conv_unary_cpu_base - test_conv2d_unary_dynamic_shapes = TestPatternMatcher.test_conv2d_unary_cpu - test_conv3d_unary_dynamic_shapes = TestPatternMatcher.test_conv3d_unary_cpu - _test_conv_binary_base = TestPatternMatcher._test_conv_binary_base - test_conv2d_binary_dynamic_shapes = TestPatternMatcher.test_conv2d_binary - test_conv3d_binary_dynamic_shapes = TestPatternMatcher.test_conv3d_binary - test_linear_unary_dynamic_shapes = TestPatternMatcher.test_linear_unary +class TestDynamicPatternMatcherGeneric(TestPatternMatcherBase): + def setUp(self): + TestCase.setUp(self) + self.ctx_stack = contextlib.ExitStack() + self.ctx_stack.enter_context( + # When testing kernel counts, unspecializing float causes wobbling of our tests because + # we end up reusing the same compiled region across tests. Thus we purposely specialize floats + # here since we primarily care about number of kernels generated in the absence of compile + # caching. + dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } + ) + ) + + def tearDown(self): + TestCase.tearDown(self) + self.ctx_stack.close() + + _test_conv_unary_base = TestPatternMatcherGeneric._test_conv_unary_base + test_conv2d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_unary + test_conv3d_unary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_unary + _test_conv_binary_base = TestPatternMatcherGeneric._test_conv_binary_base + test_conv2d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv2d_binary + test_conv3d_binary_dynamic_shapes = TestPatternMatcherGeneric.test_conv3d_binary + test_linear_unary_dynamic_shapes = TestPatternMatcherGeneric.test_linear_unary test_linear_input_non_contiguous_3D_wo_bias_dynamic_shapes = ( - TestPatternMatcher.test_linear_input_non_contiguous_3D_wo_bias + TestPatternMatcherGeneric.test_linear_input_non_contiguous_3D_wo_bias ) - def test_conv_transpose2d_dynamic_shapes(self): + def test_conv_transpose2d_dynamic_shapes(self, device): + self.device = device + # We don't support conv_transpose2d for now. class M(torch.nn.Module): def __init__(self) -> None: @@ -4163,7 +4210,9 @@ def matcher_check_fn(): self._test_common(mod, (v,), matcher_check_fn) - def test_multi_linear_share_same_input_dynamic(self): + def test_multi_linear_share_same_input_dynamic(self, device): + self.device = device + # llama pattern. class M(torch.nn.Module): def __init__( @@ -4206,6 +4255,15 @@ def matcher_check_fn(): v = torch.randn(2, 4, 16).to(dtype) self._test_common(mod, (v,), matcher_check_fn, rtol=1e-2, atol=1e-2) + +@dynamo_config.patch( + { + "dynamic_shapes": True, + "assume_static_by_default": False, + "specialize_float": True, + } +) +class TestDynamicPatternMatcher(TestPatternMatcherBase): @xfailIfACL def test_qconv2d_maxpool2d_linear_dynamic_cpu(self, include_ops=None): r""" @@ -4367,8 +4425,13 @@ def matcher_check_fn(): ) +instantiate_device_type_tests( + TestPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") +) +instantiate_device_type_tests( + TestDynamicPatternMatcherGeneric, globals(), allow_xpu=True, only_for=("cpu") +) instantiate_parametrized_tests(TestPatternMatcher) - if __name__ == "__main__": - if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available(): + if IS_LINUX and (HAS_CPU) and torch.backends.mkldnn.is_available(): run_tests() From a106842ea8be6eb17b368de16d9c107c12b809bc Mon Sep 17 00:00:00 2001 From: "fengqing.lu" Date: Tue, 8 Apr 2025 07:02:40 +0000 Subject: [PATCH 264/332] [XPU] Fix XPU unit test on Windows (#150520) This PR is to resolve issue reported in https://github.com/intel/torch-xpu-ops/issues/1478 There are two cases failing in our Windows CI enabling. - **test_xpu.py::TestXpuXPU::test_lazy_init_xpu** Needs to add `if __name__ == '__main__':` for Windows when using multiprocess. Refer to https://stackoverflow.com/a/18205006 ``` RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase. This probably means that you are not using fork to start your child processes and you have forgotten to use the proper idiom in the main module: if __name__ == '__main__': freeze_support() ... The "freeze_support()" line can be omitted if the program is not going to be frozen to produce an executable. Traceback (most recent call last): File "C:\Users\sdp\lufengqing\torch-xpu-ops\test\xpu\xpu_test_utils.py", line 24, in test_multi_process(model, input) File "C:\Users\sdp\lufengqing\torch-xpu-ops\test\xpu\xpu_test_utils.py", line 16, in test_multi_process assert p.exitcode == 0 AssertionError ``` - **test_xpu.py::TestXpuXPU::test_wrong_xpu_fork_xpu** is a linux only test case, we should skip it on Windows. Refer to https://github.com/pytorch/pytorch/blob/248487f455e943cbba368404119ca9bcb14c0499/test/test_multiprocessing.py#L609 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150520 Approved by: https://github.com/guangyey, https://github.com/EikanWang --- test/test_xpu.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_xpu.py b/test/test_xpu.py index 4208bf6daa5e..1647ad24a75a 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -136,6 +136,7 @@ def test_get_device_properties(self): device_capability["architecture"], ) + @unittest.skipIf(IS_WINDOWS, "not applicable to Windows (only fails with fork)") def test_wrong_xpu_fork(self): stderr = TestCase.runWithPytorchAPIUsageStderr( """\ @@ -192,9 +193,11 @@ def test_multi_process(model, input): torch.nn.ReLU(), torch.nn.MaxPool2d(2, 2), ) -test_multi_process(model, input) -test_multi_process(model, input) -print(torch.xpu.device_count()) + +if __name__ == "__main__": + test_multi_process(model, input) + test_multi_process(model, input) + print(torch.xpu.device_count()) """ rc = check_output(test_script) self.assertEqual(rc, str(torch.xpu.device_count())) From 881d99495ddd1376b000f7c03b050c22646012cf Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 10:59:26 +0800 Subject: [PATCH 265/332] Add more check for torch.ormqr (#150759) As the title statd. Please refer to https://github.com/pytorch/pytorch/issues/150674 for more info. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150759 Approved by: https://github.com/lezcano --- aten/src/ATen/native/BatchLinearAlgebra.cpp | 13 +++++-------- test/test_linalg.py | 19 ++++++++++--------- .../_internal/common_methods_invocations.py | 2 +- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 897e83890c79..d1947435d2bc 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -2693,12 +2693,6 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions."); int64_t left_size_condition = left ? -2 : -1; - TORCH_CHECK( - other.size(left_size_condition) >= tau.size(-1), - "torch.ormqr: other.shape[", - left_size_condition, - "] must be greater than or equal to tau.shape[-1]"); - TORCH_CHECK( other.size(left_size_condition) == input.size(-2), "torch.ormqr: other.shape[", @@ -2706,8 +2700,10 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b "] must be equal to input.shape[-2]"); TORCH_CHECK( - tau.size(-1) <= input.size(-1), - "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]"); + std::min(other.size(left_size_condition), input.size(-1)) == tau.size(-1), + "torch.ormqr: tau.shape[-1] must be equal to min(other.shape[", + left_size_condition, + "], input.shape[-1])"); TORCH_CHECK( input.dim() - tau.dim() == 1, @@ -2716,6 +2712,7 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b tau.dim(), " and input.ndim is equal to ", input.dim()); + TORCH_CHECK( input.dim() == other.dim(), "torch.ormqr: ", diff --git a/test/test_linalg.py b/test/test_linalg.py index 97c56796bbb9..649c46b5404c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5845,20 +5845,21 @@ def run_test(batch, m, n, fortran_contiguous): @dtypes(*floating_and_complex_types()) def test_ormqr_errors_and_warnings(self, device, dtype): test_cases = [ - # input1 size, input2 size, input3 size, error regex - ((10,), (2,), (2,), r"input must have at least 2 dimensions"), - ((2, 2), (2,), (2,), r"other must have at least 2 dimensions"), - ((10, 6), (20,), (10, 6), r"other.shape\[-2\] must be greater than or equal to tau.shape\[-1\]"), - ((6, 6), (5,), (5, 5), r"other.shape\[-2\] must be equal to input.shape\[-2\]"), - ((1, 2, 2), (2, 2), (1, 2, 2), r"batch dimensions of tau to be equal to input.shape\[:-2\]"), - ((1, 2, 2), (1, 2), (2, 2, 2), r"batch dimensions of other to be equal to input.shape\[:-2\]"), + # input1 size, input2 size, input3 size, left, error regex + ((10,), (2,), (2,), True, r"input must have at least 2 dimensions"), + ((2, 2), (2,), (2,), True, r"other must have at least 2 dimensions"), + ((6, 6), (5,), (5, 5), True, r"other.shape\[-2\] must be equal to input.shape\[-2\]"), + ((1, 2, 2), (2, 2), (1, 2, 2), True, r"batch dimensions of tau to be equal to input.shape\[:-2\]"), + ((1, 2, 2), (1, 2), (2, 2, 2), True, r"batch dimensions of other to be equal to input.shape\[:-2\]"), + ((2, 4, 3), (2, 2), (2, 3, 10), True, r"torch.ormqr: other.shape\[-2\] must be equal to input.shape\[-2\]"), + ((2, 4, 3), (2, 2), (2, 3, 10), False, r"torch.ormqr: other.shape\[-1\] must be equal to input.shape\[-2\]") ] - for a_size, tau_size, c_size, error_regex in test_cases: + for a_size, tau_size, c_size, left, error_regex in test_cases: a = make_tensor(a_size, dtype=dtype, device=device) tau = make_tensor(tau_size, dtype=dtype, device=device) c = make_tensor(c_size, dtype=dtype, device=device) with self.assertRaisesRegex(RuntimeError, error_regex): - torch.ormqr(a, tau, c) + torch.ormqr(a, tau, c, left) def test_blas_empty(self, device): def fn(torchfn, *args, test_out=False, **kwargs): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24f651020d75..17391695cdc3 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2775,7 +2775,7 @@ def error_inputs_ormqr(op_info, device, **kwargs): bool_3 = True bool_4 = True yield ErrorInput(SampleInput(tensor_0, args=(tensor_1, tensor_2, bool_3, bool_4)), error_type=RuntimeError, - error_regex=r"tau.shape\[-1\] must be less than or equal to input.shape\[-1\]") + error_regex=r"tau.shape\[-1\] must be equal to min\(other.shape\[-2\], input.shape\[-1\]\)") def error_inputs_diag(op_info, device, **kwargs): From 3da14d38bd396f5bbe8494872d1509efa1a6f048 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 14:53:05 +0800 Subject: [PATCH 266/332] Fix the Problems About Defining Static Variable in Inline Function (#147095) Refer to https://github.com/pytorch/pytorch/issues/125465 for more informations - Remove unused header files - Move the inline function that defines the static variable to .cc Pull Request resolved: https://github.com/pytorch/pytorch/pull/147095 Approved by: https://github.com/cyyever, https://github.com/albanD --- ...cpp_extensions_open_device_registration.py | 52 +++++++++---------- torch/csrc/api/src/serialize.cpp | 1 - .../csrc/distributed/rpc/python_remote_call.h | 1 - torch/csrc/distributed/rpc/rref_proto.h | 1 - torch/csrc/distributed/rpc/script_call.h | 1 - .../csrc/distributed/rpc/script_remote_call.h | 1 - torch/csrc/distributed/rpc/script_resp.h | 1 - torch/csrc/jit/serialization/export.cpp | 1 + torch/csrc/jit/serialization/export.h | 1 - torch/csrc/jit/serialization/pickler.cpp | 20 +++++++ torch/csrc/jit/serialization/pickler.h | 19 ++----- torch/csrc/jit/serialization/unpickler.cpp | 1 - 12 files changed, 50 insertions(+), 50 deletions(-) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 5d1f0c34ee2e..21394218c65b 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -4,7 +4,6 @@ import io import os import sys -import tempfile import unittest from typing import Union from unittest.mock import patch @@ -346,23 +345,22 @@ def test_open_device_storage_pin_memory(self): cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg") self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg")) - @unittest.skip( - "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" - ) def test_open_device_serialization(self): self.module.set_custom_device_index(-1) storage = torch.UntypedStorage(4, device=torch.device("openreg")) - self.assertEqual(torch.serialization.location_tag(storage), "openreg") + self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") self.module.set_custom_device_index(0) storage = torch.UntypedStorage(4, device=torch.device("openreg")) self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - cpu_storage = torch.empty(4, 4).storage() - openreg_storage = torch.serialization.default_restore_location( - cpu_storage, "openreg:0" - ) - self.assertTrue(openreg_storage.is_openreg) + # TODO(FFFrog): Comment this because openreg.device is missing + # Uncomment this after improving openreg + # cpu_storage = torch.empty(4, 4).storage() + # openreg_storage = torch.serialization.default_restore_location( + # cpu_storage, "openreg:0" + # ) + # self.assertTrue(openreg_storage.is_openreg) # test tensor MetaData serialization x = torch.empty(4, 4).long() @@ -371,22 +369,24 @@ def test_open_device_serialization(self): self.module.custom_set_backend_meta(y) self.assertTrue(self.module.check_backend_meta(y)) - self.module.custom_serialization_registry() - with tempfile.TemporaryDirectory() as tmpdir: - path = os.path.join(tmpdir, "data.pt") - torch.save(y, path) - z1 = torch.load(path) - # loads correctly onto the openreg backend device - self.assertTrue(z1.is_openreg) - # loads BackendMeta data correctly - self.assertTrue(self.module.check_backend_meta(z1)) - - # cross-backend - z2 = torch.load(path, map_location="cpu") - # loads correctly onto the cpu backend device - self.assertFalse(z2.is_openreg) - # loads BackendMeta data correctly - self.assertFalse(self.module.check_backend_meta(z2)) + # TODO(FFFrog): Comment this because openreg.device is missing + # Uncomment this after improving openreg + # self.module.custom_serialization_registry() + # with tempfile.TemporaryDirectory() as tmpdir: + # path = os.path.join(tmpdir, "data.pt") + # torch.save(y, path) + # z1 = torch.load(path) + # loads correctly onto the openreg backend device + # self.assertTrue(z1.is_openreg) + # loads BackendMeta data correctly + # self.assertTrue(self.module.check_backend_meta(z1)) + + # cross-backend + # z2 = torch.load(path, map_location="cpu") + # loads correctly onto the cpu backend device + # self.assertFalse(z2.is_openreg) + # loads BackendMeta data correctly + # self.assertFalse(self.module.check_backend_meta(z2)) def test_open_device_storage_resize(self): cpu_tensor = torch.randn([8]) diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index e8497a7f22b5..fae54d124847 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -1,5 +1,4 @@ #include -#include #include #include diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h index 0a3054b594d2..09d4ba36dc62 100644 --- a/torch/csrc/distributed/rpc/python_remote_call.h +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -3,7 +3,6 @@ #include #include #include -#include namespace torch::distributed::rpc { class TORCH_API PythonRemoteCall : public RpcCommandBase { diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index e6bffd1870b3..a1482b46939b 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -4,7 +4,6 @@ #include #include #include -#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 19e1871ead87..476ee118fe7f 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index 534ac0044599..e18edab64821 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -3,7 +3,6 @@ #include #include #include -#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_resp.h b/torch/csrc/distributed/rpc/script_resp.h index fd8cd4b845d1..53841e3d705c 100644 --- a/torch/csrc/distributed/rpc/script_resp.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -2,7 +2,6 @@ #include #include -#include namespace torch::distributed::rpc { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index ac20016c7bbb..9c10e94141a2 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 8b2d6d84716a..6f8e69bf0ca6 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 6ce524293a70..8038aa8ca658 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -807,4 +807,24 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { return true; } +std::unordered_set& GetBackendMetaAllowlist() { + static std::unordered_set DeviceTypeAllowlist{ + c10::DeviceType::PrivateUse1}; + return DeviceTypeAllowlist; +} + +std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES>& +GetBackendMetaSerialization() { + // The array to save function pointer for BackendMeta serialization. + // key is the DeviceType, value is std::pair obj. + // value.first represent get function and value.seconde represent set function + static std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES> + BackendMetaSerialization; + return BackendMetaSerialization; +} + } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 8accfa229b84..828f2b3b0521 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -299,27 +299,14 @@ using BackendMetaPtr = std::function< void(const at::Tensor&, std::unordered_map&)>; // A allowlist of device type, currently available is PrivateUse1 -inline std::unordered_set& GetBackendMetaAllowlist() { - static std::unordered_set DeviceTypeAllowlist{ - c10::DeviceType::PrivateUse1}; - return DeviceTypeAllowlist; -} +TORCH_API std::unordered_set& GetBackendMetaAllowlist(); // Dynamically obtain serialization function pairs // that require the corresponding backend. -inline std::array< +TORCH_API std::array< std::optional>, at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization() { - // The array to save function pointer for BackendMeta serialization. - // key is the DeviceType, value is std::pair obj. - // value.first represent get function and value.seconde represent set function - static std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES> - BackendMetaSerialization; - return BackendMetaSerialization; -} +GetBackendMetaSerialization(); // Register function pointer of Tensor BackendMetadata for serialization. TORCH_API inline void TensorBackendMetaRegistry( diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 0cbb710f5513..cdd58b8cef3d 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,7 +5,6 @@ #endif #include #include -#include #include #include #include From 3649e2e7bde8ff06f6fe6ac4168e879e9e4f5c0a Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Mon, 7 Apr 2025 12:34:01 +0000 Subject: [PATCH 267/332] Safer bookkeeping of NCCL communicators (#150681) This consists mainly in two changes: - ensure we can reliably obtain the device from a `NCCLComm` object (there was one constructor which didn't set the device) - use a RAII pattern for acquiring the lock to the global dictionary of `NCCLComms` (which ensures the lock is released in case of exceptions) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150681 Approved by: https://github.com/kwen2501 --- torch/csrc/distributed/c10d/NCCLUtils.cpp | 7 +++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 2 + .../distributed/c10d/ProcessGroupNCCL.cpp | 56 +++++++++---------- 3 files changed, 37 insertions(+), 28 deletions(-) diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index dff8a5f78775..faec5bc449ac 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -92,7 +92,9 @@ std::shared_ptr NCCLComm::create_scalable( int numRanks, int rank, std::vector& commIds, + at::DeviceIndex deviceIndex, ncclConfig_t& config) { + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex); auto comm = std::make_shared(); comm->nonBlocking_ = config.blocking == 0; LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " @@ -112,6 +114,7 @@ std::shared_ptr NCCLComm::create_scalable( // in the log file and in the replay tool. comm->ncclId_ = commIds[0]; comm->rank_ = rank; + comm->deviceIndex_ = deviceIndex; comm->initialized_ = !comm->nonBlocking_; return comm; } @@ -150,6 +153,10 @@ ncclComm_t NCCLComm::getNcclComm() { return ncclComm_; } +at::DeviceIndex NCCLComm::getDeviceIndex() { + return deviceIndex_; +} + // Wait for the communicator to be ready. This is a blocking function. // Arguments: // longInterval: if true, wait with sleep of an interval; otherwise, wait diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index c7cd0a30924e..89bf15fc6479 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -221,6 +221,7 @@ class NCCLComm { int numRanks, int rank, std::vector& commIds, + at::DeviceIndex deviceIndex, ncclConfig_t& config); #endif // NCCL_HAS_INIT_RANK_SCALABLE #endif // NCCL_HAS_CONFIG @@ -239,6 +240,7 @@ class NCCLComm { #endif ncclUniqueId getNcclId(); + at::DeviceIndex getDeviceIndex(); // Must not be copyable NCCLComm(const NCCLComm&) = delete; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 9f3e66a5f549..1cd794f684a0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -302,10 +302,8 @@ static void cacheAllocatorRegisterHook( } std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + if (te.device_ == ncclComm->getDeviceIndex()) { // NOLINTNEXTLINE(performance-no-int-to-ptr) ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); } @@ -321,10 +319,8 @@ static void cacheAllocatorDeregisterHook( } std::lock_guard lock(ncclCommDevIdxMapMutex); - for (auto& it : ncclCommDevIdxMap) { - auto& ncclComm = it.first; - auto& devIdx = it.second; - if (te.device_ == devIdx) { + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + if (te.device_ == ncclComm->getDeviceIndex()) { // NOLINTNEXTLINE(performance-no-int-to-ptr) ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); } @@ -345,11 +341,12 @@ static std:: std::vector> allNCCLComms; // within the critical section, we don't want to dump while holding the lock // as dump might hang - ncclCommDevIdxMapMutex.lock(); - for (auto& [ncclComm, _] : ncclCommDevIdxMap) { - allNCCLComms.push_back(ncclComm); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& [ncclComm, _] : ncclCommDevIdxMap) { + allNCCLComms.push_back(ncclComm); + } } - ncclCommDevIdxMapMutex.unlock(); for (auto& ncclComm : allNCCLComms) { std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId()); ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump(); @@ -824,9 +821,10 @@ void ProcessGroupNCCL::WorkNCCL::abort() { // Abort all communicators of this work ncclComm_->abort(); - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm_); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.erase(ncclComm_); + } } ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; @@ -1390,12 +1388,12 @@ bool ProcessGroupNCCL::abortComms( // communicators. Note that ncclCommDevIdxMap is a global container which may // contain other PG's communicators, thus we need to only erase communicators // for the current PG. - ncclCommDevIdxMapMutex.lock(); - for (auto& it : devNCCLCommMap_) { - auto& ncclComm = it.second; - ncclCommDevIdxMap.erase(ncclComm); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + for (auto& [_, ncclComm] : devNCCLCommMap_) { + ncclCommDevIdxMap.erase(ncclComm); + } } - ncclCommDevIdxMapMutex.unlock(); std::lock_guard lock(mutex_); abortCommsFromMap(devNCCLCommMap_, abortReason); @@ -2705,9 +2703,10 @@ void ProcessGroupNCCL::destroyNCCLComms(const std::string& devNCCLCommMapKey) { // Clear used device indices. usedDeviceIdxs_.clear(); - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.erase(ncclComm); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.erase(ncclComm); + } } std::shared_ptr ProcessGroupNCCL::initNCCLComm( @@ -2874,8 +2873,8 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( << "ProcessGroupNCCL all-gather unique IDs through store took " << timerDeltaMs << " ms"; #if defined(NCCL_HAS_INIT_RANK_SCALABLE) && defined(NCCL_HAS_CONFIG) - ncclComm = - NCCLComm::create_scalable(numRanks, rank, ncclIDs, options_->config); + ncclComm = NCCLComm::create_scalable( + numRanks, rank, ncclIDs, deviceIndex, options_->config); #else C10_THROW_ERROR( DistBackendError, @@ -2985,9 +2984,10 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( // on the same device. // NOTE: we need remove the communicator from this map when it is // destroyed, otherwise may register onto an invalid communicator. - ncclCommDevIdxMapMutex.lock(); - ncclCommDevIdxMap.emplace(ncclComm, device.index()); - ncclCommDevIdxMapMutex.unlock(); + { + std::lock_guard lock(ncclCommDevIdxMapMutex); + ncclCommDevIdxMap.emplace(ncclComm, device.index()); + } } it = devNCCLCommMap_.find(deviceKey); From 1791b4150b1a71d46b84db9008ac4b737bd75088 Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Mon, 7 Apr 2025 12:34:01 +0000 Subject: [PATCH 268/332] Clarify behavior of TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK (#150682) I still don't really understand the original purpose of that env var, but it appears that its usage is completely disconnected from MemPools and from `ncclMemAlloc`/`Free`. In fact, when that env var is set, we invoke `ncclCommRegister` for _all_ NCCL communicators for _all_ the memory segments managed by the allocator (both the global ones, allocated with `cudaMalloc`, and the ones in private MemPools), and we do that both for the segments that already exist when the PG is initialized and for all segments that will be allocated later. I'm reworking the code a bit, by using a few helper functions, whose name should make this behavior clearer. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150682 Approved by: https://github.com/kwen2501 ghstack dependencies: #150681 --- .../distributed/c10d/ProcessGroupNCCL.cpp | 78 ++++++++++++------- .../distributed/c10d/ProcessGroupNCCL.hpp | 4 - 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 1cd794f684a0..d6f3e0d42e1e 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -275,6 +275,28 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { } } +// When TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK is set, all tensors (no +// matter how they have been allocated) are registered with all NCCL comms. +bool shouldAllCommunicatorsRegisterAllTensors() { +#ifdef NCCL_HAS_COMM_REGISTER + static const bool flag = [] { + const bool flag = + getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); + if (flag && + c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + expandable_segments()) { + LOG(INFO) + << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; + return false; + } + return flag; + }(); + return flag; +#else + return false; +#endif // NCCL_HAS_COMM_REGISTER +} + } // namespace // Map from each communicator to its device index. @@ -289,7 +311,6 @@ inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // communicators in all PGs. static std::unordered_map, int> ncclCommDevIdxMap; static std::mutex ncclCommDevIdxMapMutex; -static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); @@ -304,8 +325,10 @@ static void cacheAllocatorRegisterHook( std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& [ncclComm, _] : ncclCommDevIdxMap) { if (te.device_ == ncclComm->getDeviceIndex()) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + if (shouldAllCommunicatorsRegisterAllTensors()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->registerSegment(reinterpret_cast(te.addr_), te.size_); + } } } } @@ -321,12 +344,28 @@ static void cacheAllocatorDeregisterHook( std::lock_guard lock(ncclCommDevIdxMapMutex); for (auto& [ncclComm, _] : ncclCommDevIdxMap) { if (te.device_ == ncclComm->getDeviceIndex()) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + if (shouldAllCommunicatorsRegisterAllTensors()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + ncclComm->deregisterSegment(reinterpret_cast(te.addr_)); + } } } } +static void attachAllocatorHooks() { + static c10::once_flag flag; + c10::call_once(flag, [] { + // Attaching hooks fails if CUDACachingAllocator is not initialized, so + // Init for CUDA is called (and is a no-op if CUDA is already + // initialized). + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorRegisterHook); + c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( + &cacheAllocatorDeregisterHook); + }); +} + static std:: unordered_map> getNCCLCommDumpMap() { @@ -957,17 +996,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( TORCH_WARN_ONCE( "TORCH_NCCL_AVOID_RECORD_STREAMS is the default now, this environment variable is thus deprecated."); } -#ifdef NCCL_HAS_COMM_REGISTER - useTensorRegisterAllocatorHook_ = - getCvarBool(TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK, false); - if (c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: - expandable_segments()) { - useTensorRegisterAllocatorHook_ = false; - LOG(INFO) - << logPrefix() - << "disables TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK because it is not compatible with CUDA allocator expandable segments mode."; - } -#endif // NCCL_HAS_COMM_REGISTER if (blockingWait_) { LOG(INFO) @@ -1020,7 +1048,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( << ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug #ifdef NCCL_HAS_COMM_REGISTER << ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: " - << useTensorRegisterAllocatorHook_ + << shouldAllCommunicatorsRegisterAllTensors() #endif // NCCL_HAS_COMM_REGISTER << ", TORCH_NCCL_ENABLE_MONITORING: " << monitorThreadEnabled_.load() @@ -1041,17 +1069,9 @@ ProcessGroupNCCL::ProcessGroupNCCL( // action is called. In the following hooks, we register a newly allocated // segment when SEGMENT_ALLOC action occurs, and deregister a segment when // SEGMENT_FREE action occurs. - // We attach hooks only once at the first PG creation. - // Attaching hooks fails if CUDACachingAllocator is not initialized, so - // Init for CUDA is called (and is a no-op if CUDA is already - // initialized). - if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { - at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorRegisterHook); - c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( - &cacheAllocatorDeregisterHook); - allocatorHooksAttached = true; + if (shouldAllCommunicatorsRegisterAllTensors()) { + // This call is idempotent. + attachAllocatorHooks(); } // Enable Desync Debugger per user setting @@ -2966,7 +2986,7 @@ std::shared_ptr ProcessGroupNCCL::initNCCLComm( // Now ncclComms are fully initialized. // Register all active CUDA memory segments in cache allocator to // the new NCCL communicators - if (useTensorRegisterAllocatorHook_) { + if (shouldAllCommunicatorsRegisterAllTensors()) { auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); // Register the segment to a new NCCL communicator if on the same device for (const auto& segmentInfo : snapshot.segments) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 0896dd0de290..82961db0ec17 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1286,10 +1286,6 @@ class TORCH_API ProcessGroupNCCL : public Backend { // for the operation to complete. bool blockingWait_ = false; - // Whether or not to hook the cache allocator to register all allocated - // tensors - bool useTensorRegisterAllocatorHook_ = false; - // Whether or not the workCleanupThread is used to perform async error // handling. ErrorHandlingMode asyncErrorHandling_ = NoHandling; From f3b2fb6c66efbdc13ab9c99c6b2190ed10a1c770 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 7 Apr 2025 19:59:29 +0000 Subject: [PATCH 269/332] Allow trace through unittest (#146500) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146500 Approved by: https://github.com/anijain2305 --- test/dynamo/test_ctx_manager.py | 10 +- test/dynamo/test_error_messages.py | 2 +- test/dynamo/test_exceptions.py | 12 +- test/dynamo/test_generator_stop.py | 23 +- test/dynamo/test_raise.py | 95 +-- test/dynamo/test_sys.py | 7 +- test/dynamo/test_unittest.py | 619 ++++++++++++++++++ .../TestScript.test_python_frontend | 0 .../TestScript.test_python_frontend_py3 | 0 torch/_dynamo/config.py | 3 + torch/_dynamo/test_case.py | 19 + torch/_dynamo/trace_rules.py | 7 +- torch/_dynamo/variables/builtin.py | 30 +- torch/_dynamo/variables/dicts.py | 20 +- torch/_dynamo/variables/functions.py | 5 + 15 files changed, 739 insertions(+), 113 deletions(-) create mode 100644 test/dynamo/test_unittest.py create mode 100644 test/dynamo_expected_failures/TestScript.test_python_frontend create mode 100644 test/dynamo_expected_failures/TestScript.test_python_frontend_py3 diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 4e4af9341e75..0ae3ff452967 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1744,10 +1744,13 @@ def fn(x): class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): self._prev = torch._dynamo.config.enable_trace_contextlib + self._u_prev = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_contextlib = True + torch._dynamo.config.enable_trace_unittest = True def tearDown(self): torch._dynamo.config.enable_trace_contextlib = self._prev + torch._dynamo.config.enable_trace_unittest = self._u_prev def test_ctx_basic0(self): @contextlib.contextmanager @@ -2691,7 +2694,7 @@ def fn(t): self.assertEqual(y, t.sin()) -class CPythonContextManagerTestCase(torch._dynamo.test_case.TestCase): +class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py # https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70 @@ -2721,7 +2724,6 @@ def fn(t): self.assertEqual(state, [1, 42, 999]) self.assertEqual(y, t.sum() + 42) - @unittest.expectedFailure def test_contextmanager_finally(self): state = [] @@ -2831,7 +2833,6 @@ def fn(t): self.assertEqual(frames[0].name, "test_contextmanager_traceback") self.assertEqual(frames[0].line, "raise stop_exc") - @unittest.expectedFailure def test_contextmanager_no_reraise(self): @contextmanager def whee(): @@ -2847,7 +2848,6 @@ def fn(t): fn(torch.randn(2, 3)) - @unittest.expectedFailure def test_contextmanager_trap_yield_after_throw(self): @contextmanager def whoo(): @@ -2866,7 +2866,6 @@ def fn(t): fn(torch.randn(2, 3)) - @unittest.expectedFailure def test_contextmanager_trap_no_yield(self): @contextmanager def whoo(): @@ -2882,7 +2881,6 @@ def fn(t): fn(torch.randn(2, 3)) - @unittest.expectedFailure def test_contextmanager_trap_second_yield(self): @contextmanager def whoo(): diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index eef2512bcfe5..71c4c921ae4b 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -307,7 +307,7 @@ def post_munge(s): Hint: Remove the function `case.py` from torch/_dynamo/trace_rules.py. More graph breaks may occur as a result of attempting to trace into the function. Hint: Please file an issue to PyTorch. - Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup SKIP_DIRS + Developer debug context: qualname: skip, name: skip, filename: `case.py`, skip reason: skipped according trace_rules.lookup unittest from user code: diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 6c82593e6ec3..bfd1f5352645 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -177,7 +177,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_raise_match(self): a = AttributeError @@ -259,7 +258,6 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") opt_fn(x) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_exception_with_ctx_manager(self): def fn(x): x = torch.cos(x) @@ -853,7 +851,6 @@ def fn(t): t = torch.randn(2) fn(t) - @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_user_defined_exception_with_args(self): @torch.compile(backend="eager", fullgraph=True) def fn(t): @@ -889,6 +886,12 @@ def test_raise_set___context__(self): class CPythonExceptionTests(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_exceptions.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_exceptions.py + def setUp(self): + self._u_prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._u_prev @make_dynamo_test def testChainingAttrs(self): @@ -976,7 +979,6 @@ def test_context_of_exception_in_else_and_finally(self): assert exc is oe assert exc.__context__ is ve - @unittest.expectedFailure @make_dynamo_test def test_raise_does_not_create_context_chain_cycle(self): A = AssertionError @@ -1015,7 +1017,6 @@ def test_raise_does_not_create_context_chain_cycle(self): self.assertIs(c.__context__, b) self.assertIsNone(b.__context__) - @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle1(self): # See issue 25782. Cycle in context chain. @@ -1071,7 +1072,6 @@ def test_no_hang_on_context_chain_cycle2(self): self.assertIs(b.__context__, a) self.assertIs(a.__context__, c) - @unittest.expectedFailure @make_dynamo_test def test_no_hang_on_context_chain_cycle3(self): # See issue 25782. Longer context chain with cycle. diff --git a/test/dynamo/test_generator_stop.py b/test/dynamo/test_generator_stop.py index fe6c9961ddf9..7091d3d37137 100644 --- a/test/dynamo/test_generator_stop.py +++ b/test/dynamo/test_generator_stop.py @@ -8,19 +8,9 @@ from torch.testing._internal.common_utils import make_dynamo_test -class TestPEP479(torch._dynamo.test_case.TestCase): +class TestPEP479(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_generator_stop.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_generator_stop.py - - def assertTrue(self, expr, msg=None): - assert bool(expr) is True, msg - - def assertIs(self, expr1, expr2, msg=None): - assert expr1 is expr2, msg - - def assertEqual(self, x, y): - assert x == y - @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") @make_dynamo_test def test_stopiteration_wrapping(self): @@ -30,16 +20,9 @@ def f(): def g(): yield f() - try: + with self.assertRaises(RuntimeError) as cm: next(g()) - except RuntimeError as cm: - self.assertEqual("generator raised StopIteration", cm.args[0]) - except Exception: - self.fail("Error!") - - # with self.assertRaises(RuntimeError) as cm: - # next(g()) - # self.assertEqual("generator raised StopIteration", str(cm.exception)) + self.assertEqual("generator raised StopIteration", str(cm.exception)) @unittest.skipIf(sys.version_info < (3, 12), "Test does not work in Python < 3.12") @make_dynamo_test diff --git a/test/dynamo/test_raise.py b/test/dynamo/test_raise.py index 133ebc142fe4..9a95d23226c0 100644 --- a/test/dynamo/test_raise.py +++ b/test/dynamo/test_raise.py @@ -44,32 +44,9 @@ def __exit__(self, t, v, tb): raise NameError -class TestRaise(torch._dynamo.test_case.TestCase): +class TestRaise(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_raise.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py - - def assertIn(self, member, container, msg=None): - assert member in container, msg - - def assertIs(self, expr1, expr2, msg=None): - assert expr1 is expr2, msg - - def assertRaises(self, expected_exception, *args, **kwargs): - z = 0 - try: - yield - except expected_exception: - z = 1 - except Exception: - z = 2 - assert z == 1 - - def assertIsInstance(self, obj, cls, msg=None): - assert isinstance(obj, cls), msg - - def assertIsNone(self, obj, msg=None): - assert obj is None, msg - @make_dynamo_test def test_invalid_reraise(self): try: @@ -213,34 +190,12 @@ def test_assert_with_tuple_arg(self): class TestCause(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_raise.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True - def assertIn(self, member, container, msg=None): - assert member in container, msg - - def assertIs(self, expr1, expr2, msg=None): - assert expr1 is expr2, msg - - def assertRaises(self, expected_exception, *args, **kwargs): - z = 0 - try: - yield - except expected_exception: - z = 1 - except Exception: - z = 2 - assert z == 1 - - def assertIsInstance(self, obj, cls, msg=None): - assert isinstance(obj, cls), msg - - def assertIsNone(self, obj, msg=None): - assert obj is None, msg - - def assertTrue(self, expr, msg=None): - assert bool(expr) is True, msg - - def assertFalse(self, expr, msg=None): - assert bool(expr) is False, msg + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev @make_dynamo_test def testCauseSyntax(self): @@ -303,6 +258,12 @@ def test_erroneous_cause(self): class TestTraceback(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_raise.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev @unittest.expectedFailure # Dynamo doesn't track traceback @make_dynamo_test @@ -330,6 +291,12 @@ def test_accepts_traceback(self): class TestTracebackType(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_raise.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev def raiser(self): raise ValueError @@ -402,28 +369,12 @@ def test_constructor(self): class TestContext(torch._dynamo.test_case.TestCase): # Tests taken from CPython source code in cpython/Lib/test/test_raise.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_raise.py + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True - def assertIn(self, member, container, msg=None): - assert member in container, msg - - def assertIs(self, expr1, expr2, msg=None): - assert expr1 is expr2, msg - - def assertRaises(self, expected_exception, *args, **kwargs): - z = 0 - try: - yield - except expected_exception: - z = 1 - except Exception: - z = 2 - assert z == 1 - - def assertIsInstance(self, obj, cls, msg=None): - assert isinstance(obj, cls), msg - - def assertIsNone(self, obj, msg=None): - assert obj is None, msg + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev @unittest.expectedFailure # missing Exception.__eq__ @make_dynamo_test diff --git a/test/dynamo/test_sys.py b/test/dynamo/test_sys.py index 2f7bd7178695..3b72ecb36d99 100644 --- a/test/dynamo/test_sys.py +++ b/test/dynamo/test_sys.py @@ -25,9 +25,10 @@ def fn(t): self.assertEqual(y, t.sin()) -class CPythonActiveExceptionTests(torch._dynamo.test_case.TestCase): +class CPythonActiveExceptionTests(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_sys.py # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_sys.py + @make_dynamo_test def test_exc_info_no_exception(self): self.assertEqual(sys.exc_info(), (None, None, None)) @@ -37,7 +38,6 @@ def test_exc_info_no_exception(self): def test_sys_exception_no_exception(self): self.assertEqual(sys.exception(), None) - @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_instance(self): def f(): @@ -54,7 +54,6 @@ def f(): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) - @unittest.expectedFailure @make_dynamo_test def test_exc_info_with_exception_type(self): def f(): @@ -71,7 +70,6 @@ def f(): self.assertIs(exc_info[1], e) self.assertIs(exc_info[2], e.__traceback__) - @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_instance(self): @@ -87,7 +85,6 @@ def f(): self.assertIsInstance(e, ValueError) self.assertIs(exc, e) - @unittest.expectedFailure @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") @make_dynamo_test def test_sys_exception_with_exception_type(self): diff --git a/test/dynamo/test_unittest.py b/test/dynamo/test_unittest.py new file mode 100644 index 000000000000..244785e01bcd --- /dev/null +++ b/test/dynamo/test_unittest.py @@ -0,0 +1,619 @@ +# Owner(s): ["module: dynamo"] +import sys +import unittest +import warnings +from itertools import product + +import torch +import torch._dynamo.test_case +from torch.testing._internal.common_utils import make_dynamo_test + + +class TestUnittest(torch._dynamo.test_case.TestCase): + def setUp(self): + self._prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self._prev + + @make_dynamo_test + def test_SkipTest(self): + z = 0 + SkipTest = unittest.SkipTest + try: + raise SkipTest("abcd") + except Exception: + z = 1 + self.assertEqual(z, 1) + + +class CPythonTest_Assertions(torch._dynamo.test_case.CPythonTestCase): + # Tests taken from CPython source code in cpython/Lib/test/test_unittest/test_assertions.py + # https://github.com/python/cpython/blob/3.13/Lib/test/test_unittest/test_assertions.py + + @make_dynamo_test + def test_AlmostEqual(self): + self.assertAlmostEqual(1.00000001, 1.0) + self.assertNotAlmostEqual(1.0000001, 1.0) + self.assertRaises(self.failureException, self.assertAlmostEqual, 1.0000001, 1.0) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.00000001, 1.0 + ) + + self.assertAlmostEqual(1.1, 1.0, places=0) + self.assertRaises( + self.failureException, self.assertAlmostEqual, 1.1, 1.0, places=1 + ) + + self.assertAlmostEqual(0, 0.1 + 0.1j, places=0) + self.assertNotAlmostEqual(0, 0.1 + 0.1j, places=1) + self.assertRaises( + self.failureException, self.assertAlmostEqual, 0, 0.1 + 0.1j, places=1 + ) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 0, 0.1 + 0.1j, places=0 + ) + + self.assertAlmostEqual(float("inf"), float("inf")) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, float("inf"), float("inf") + ) + + @make_dynamo_test + def test_AmostEqualWithDelta(self): + self.assertAlmostEqual(1.1, 1.0, delta=0.5) + self.assertAlmostEqual(1.0, 1.1, delta=0.5) + self.assertNotAlmostEqual(1.1, 1.0, delta=0.05) + self.assertNotAlmostEqual(1.0, 1.1, delta=0.05) + + self.assertAlmostEqual(1.0, 1.0, delta=0.5) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.0, 1.0, delta=0.5 + ) + + self.assertRaises( + self.failureException, self.assertAlmostEqual, 1.1, 1.0, delta=0.05 + ) + self.assertRaises( + self.failureException, self.assertNotAlmostEqual, 1.1, 1.0, delta=0.5 + ) + + self.assertRaises( + TypeError, self.assertAlmostEqual, 1.1, 1.0, places=2, delta=2 + ) + self.assertRaises( + TypeError, self.assertNotAlmostEqual, 1.1, 1.0, places=2, delta=2 + ) + + @make_dynamo_test + def test_assertRaises(self): + def _raise(e): + raise e + + self.assertRaises(KeyError, _raise, KeyError) + self.assertRaises(KeyError, _raise, KeyError("key")) + try: + self.assertRaises(KeyError, lambda: None) + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + self.assertRaises(KeyError, _raise, ValueError) + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + with self.assertRaises(KeyError) as cm: + try: + raise KeyError + except Exception as e: + exc = e + raise + self.assertIs(cm.exception, exc) + + with self.assertRaises(KeyError): + raise KeyError("key") + try: + with self.assertRaises(KeyError): + pass + except self.failureException as e: + self.assertIn("KeyError not raised", str(e)) + else: + self.fail("assertRaises() didn't fail") + try: + with self.assertRaises(KeyError): + raise ValueError + except ValueError: + pass + else: + self.fail("assertRaises() didn't let exception pass through") + + @make_dynamo_test + def testAssertNotRegex(self): + self.assertNotRegex("Ala ma kota", r"r+") + try: + self.assertNotRegex("Ala ma kota", r"k.t", "Message") + except self.failureException as e: + self.assertIn("Message", e.args[0]) + else: + self.fail("assertNotRegex should have failed.") + + +class CPythonTestLongMessage(torch._dynamo.test_case.CPythonTestCase): + """Test that the individual asserts honour longMessage. + This actually tests all the message behaviour for + asserts that use longMessage.""" + + def setUp(self): + super().setUp() + + class TestableTestFalse(unittest.TestCase): + longMessage = False + failureException = self.failureException + + def testTest(self): + pass + + class TestableTestTrue(unittest.TestCase): + longMessage = True + failureException = self.failureException + + def testTest(self): + pass + + self.testableTrue = TestableTestTrue("testTest") + self.testableFalse = TestableTestFalse("testTest") + + def testDefault(self): + self.assertTrue(unittest.TestCase.longMessage) + + def test_formatMsg(self): + self.assertEqual(self.testableFalse._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableFalse._formatMessage("foo", "bar"), "foo") + + self.assertEqual(self.testableTrue._formatMessage(None, "foo"), "foo") + self.assertEqual(self.testableTrue._formatMessage("foo", "bar"), "bar : foo") + + # This blows up if _formatMessage uses string concatenation + self.testableTrue._formatMessage(object(), "foo") + + def test_formatMessage_unicode_error(self): + one = "".join(chr(i) for i in range(255)) + # this used to cause a UnicodeDecodeError constructing msg + self.testableTrue._formatMessage(one, "\uFFFD") + + def assertMessages(self, methodName, args, errors): + """ + Check that methodName(*args) raises the correct error messages. + errors should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + + def getMethod(i): + useTestableFalse = i < 2 + if useTestableFalse: + test = self.testableFalse + else: + test = self.testableTrue + return getattr(test, methodName) + + for i, expected_regex in enumerate(errors): + testMethod = getMethod(i) + kwargs = {} + withMsg = i % 2 + if withMsg: + kwargs = {"msg": "oops"} + + # with self.assertRaisesRegex( + # self.failureException, expected_regex=expected_regex + # ): + # testMethod(*args, **kwargs) + with self.assertRaises(self.failureException) as cm: + testMethod(*args, **kwargs) + self.assertRegex(str(cm.exception), expected_regex) + + @make_dynamo_test + def testAssertTrue(self): + self.assertMessages( + "assertTrue", + (False,), + [ + "False is not true", + "oops", + "False is not true", + "False is not true : oops", + ], + ) + + @make_dynamo_test + def testAssertFalse(self): + self.assertMessages( + "assertFalse", + (True,), + [ + "True is not false", + "oops", + "True is not false", + "True is not false : oops", + ], + ) + + @make_dynamo_test + def testNotEqual(self): + self.assertMessages( + "assertNotEqual", (1, 1), ["1 == 1", "oops", "1 == 1", "1 == 1 : oops"] + ) + + @make_dynamo_test + def testAlmostEqual(self): + self.assertMessages( + "assertAlmostEqual", + (1, 2), + [ + r"^1 != 2 within 7 places \(1 difference\)$", + "^oops$", + r"^1 != 2 within 7 places \(1 difference\)$", + r"^1 != 2 within 7 places \(1 difference\) : oops$", + ], + ) + + @make_dynamo_test + def testNotAlmostEqual(self): + self.assertMessages( + "assertNotAlmostEqual", + (1, 1), + [ + "^1 == 1 within 7 places$", + "^oops$", + "^1 == 1 within 7 places$", + "^1 == 1 within 7 places : oops$", + ], + ) + + @make_dynamo_test + def test_baseAssertEqual(self): + self.assertMessages( + "_baseAssertEqual", + (1, 2), + ["^1 != 2$", "^oops$", "^1 != 2$", "^1 != 2 : oops$"], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertSequenceEqual(self): + # Error messages are multiline so not testing on full message + # assertTupleEqual and assertListEqual delegate to this method + self.assertMessages( + "assertSequenceEqual", + ([], [None]), + [r"\+ \[None\]$", "^oops$", r"\+ \[None\]$", r"\+ \[None\] : oops$"], + ) + + @make_dynamo_test + def testAssertSetEqual(self): + self.assertMessages( + "assertSetEqual", + (set(), set([None])), # noqa: C405 + ["None$", "^oops$", "None$", "None : oops$"], + ) + + @make_dynamo_test + def testAssertIn(self): + self.assertMessages( + "assertIn", + (None, []), + [ + r"^None not found in \[\]$", + "^oops$", + r"^None not found in \[\]$", + r"^None not found in \[\] : oops$", + ], + ) + + @make_dynamo_test + def testAssertNotIn(self): + self.assertMessages( + "assertNotIn", + (None, [None]), + [ + r"^None unexpectedly found in \[None\]$", + "^oops$", + r"^None unexpectedly found in \[None\]$", + r"^None unexpectedly found in \[None\] : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertDictEqual(self): + self.assertMessages( + "assertDictEqual", + ({}, {"key": "value"}), + [ + r"\+ \{'key': 'value'\}$", + "^oops$", + r"\+ \{'key': 'value'\}$", + r"\+ \{'key': 'value'\} : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertMultiLineEqual(self): + self.assertMessages( + "assertMultiLineEqual", + ("", "foo"), + [r"\+ foo\n$", "^oops$", r"\+ foo\n$", r"\+ foo\n : oops$"], + ) + + @make_dynamo_test + def testAssertLess(self): + self.assertMessages( + "assertLess", + (2, 1), + [ + "^2 not less than 1$", + "^oops$", + "^2 not less than 1$", + "^2 not less than 1 : oops$", + ], + ) + + @make_dynamo_test + def testAssertLessEqual(self): + self.assertMessages( + "assertLessEqual", + (2, 1), + [ + "^2 not less than or equal to 1$", + "^oops$", + "^2 not less than or equal to 1$", + "^2 not less than or equal to 1 : oops$", + ], + ) + + @make_dynamo_test + def testAssertGreater(self): + self.assertMessages( + "assertGreater", + (1, 2), + [ + "^1 not greater than 2$", + "^oops$", + "^1 not greater than 2$", + "^1 not greater than 2 : oops$", + ], + ) + + @make_dynamo_test + def testAssertGreaterEqual(self): + self.assertMessages( + "assertGreaterEqual", + (1, 2), + [ + "^1 not greater than or equal to 2$", + "^oops$", + "^1 not greater than or equal to 2$", + "^1 not greater than or equal to 2 : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNone(self): + self.assertMessages( + "assertIsNone", + ("not None",), + [ + "^'not None' is not None$", + "^oops$", + "^'not None' is not None$", + "^'not None' is not None : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNotNone(self): + self.assertMessages( + "assertIsNotNone", + (None,), + [ + "^unexpectedly None$", + "^oops$", + "^unexpectedly None$", + "^unexpectedly None : oops$", + ], + ) + + @make_dynamo_test + def testAssertIs(self): + self.assertMessages( + "assertIs", + (None, "foo"), + [ + "^None is not 'foo'$", + "^oops$", + "^None is not 'foo'$", + "^None is not 'foo' : oops$", + ], + ) + + @make_dynamo_test + def testAssertIsNot(self): + self.assertMessages( + "assertIsNot", + (None, None), + [ + "^unexpectedly identical: None$", + "^oops$", + "^unexpectedly identical: None$", + "^unexpectedly identical: None : oops$", + ], + ) + + @make_dynamo_test + def testAssertRegex(self): + self.assertMessages( + "assertRegex", + ("foo", "bar"), + [ + "^Regex didn't match:", + "^oops$", + "^Regex didn't match:", + "^Regex didn't match: (.*) : oops$", + ], + ) + + @make_dynamo_test + def testAssertNotRegex(self): + self.assertMessages( + "assertNotRegex", + ("foo", "foo"), + [ + "^Regex matched:", + "^oops$", + "^Regex matched:", + "^Regex matched: (.*) : oops$", + ], + ) + + def assertMessagesCM(self, methodName, args, func, errors): + """ + Check that the correct error messages are raised while executing: + with method(*args): + func() + *errors* should be a list of 4 regex that match the error when: + 1) longMessage = False and no msg passed; + 2) longMessage = False and msg passed; + 3) longMessage = True and no msg passed; + 4) longMessage = True and msg passed; + """ + p = product((self.testableFalse, self.testableTrue), ({}, {"msg": "oops"})) + for (cls, kwargs), err in zip(p, errors): + method = getattr(cls, methodName) + # with self.assertRaisesRegex(cls.failureException, err): + with self.assertRaises(cls.failureException) as c: + with method(*args, **kwargs) as cm: # noqa: F841 + func() + self.assertRegex(str(c.exception), err) + + @make_dynamo_test + def testAssertRaises(self): + self.assertMessagesCM( + "assertRaises", + (TypeError,), + lambda: None, + [ + "^TypeError not raised$", + "^oops$", + "^TypeError not raised$", + "^TypeError not raised : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertRaisesRegex(self): + self.assertMessagesCM( + "assertRaisesRegex", + (TypeError, "unused regex"), + lambda: None, + [ + "^TypeError not raised$", + "^oops$", + "^TypeError not raised$", + "^TypeError not raised : oops$", + ], + ) + + # test error raised but with wrong message + def raise_wrong_message(): + raise TypeError("foo") + + self.assertMessagesCM( + "assertRaisesRegex", + (TypeError, "regex"), + raise_wrong_message, + [ + '^"regex" does not match "foo"$', + "^oops$", + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$', + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertWarns(self): + self.assertMessagesCM( + "assertWarns", + (UserWarning,), + lambda: None, + [ + "^UserWarning not triggered$", + "^oops$", + "^UserWarning not triggered$", + "^UserWarning not triggered : oops$", + ], + ) + + @unittest.expectedFailure + @unittest.skipIf(sys.version_info < (3, 13), "feature landed in 3.13") + @make_dynamo_test + def test_assertNotWarns(self): + def warn_future(): + warnings.warn("xyz", FutureWarning, stacklevel=2) + + self.assertMessagesCM( + "_assertNotWarns", + (FutureWarning,), + warn_future, + [ + "^FutureWarning triggered$", + "^oops$", + "^FutureWarning triggered$", + "^FutureWarning triggered : oops$", + ], + ) + + @unittest.expectedFailure + @make_dynamo_test + def testAssertWarnsRegex(self): + # test error not raised + self.assertMessagesCM( + "assertWarnsRegex", + (UserWarning, "unused regex"), + lambda: None, + [ + "^UserWarning not triggered$", + "^oops$", + "^UserWarning not triggered$", + "^UserWarning not triggered : oops$", + ], + ) + + # test warning raised but with wrong message + def raise_wrong_message(): + warnings.warn("foo") + + self.assertMessagesCM( + "assertWarnsRegex", + (UserWarning, "regex"), + raise_wrong_message, + [ + '^"regex" does not match "foo"$', + "^oops$", + '^"regex" does not match "foo"$', + '^"regex" does not match "foo" : oops$', + ], + ) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend b/test/dynamo_expected_failures/TestScript.test_python_frontend new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 b/test/dynamo_expected_failures/TestScript.test_python_frontend_py3 new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index b59e1c49e607..870291a43785 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -401,6 +401,9 @@ # Enable tracing through contextlib.contextmanager enable_trace_contextlib = True +# Enable tracing through unittest +enable_trace_unittest = False + # Enable tracing generator functions lazily. If False, Dynamo will exhaust # generators upon first execution. And if True, the generator will be accessed lazily enable_faithful_generator_behavior = True diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index e927fc4a1eaf..ac505c0de02a 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -95,3 +95,22 @@ def tearDown(self) -> None: if self._prior_is_grad_enabled is not torch.is_grad_enabled(): log.warning("Running test changed grad mode") torch.set_grad_enabled(self._prior_is_grad_enabled) + + +class CPythonTestCase(TestCase): + _stack: contextlib.ExitStack + + @classmethod + def tearDownClass(cls) -> None: + cls._stack.close() + super().tearDownClass() + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls._stack = contextlib.ExitStack() # type: ignore[attr-defined] + cls._stack.enter_context( # type: ignore[attr-defined] + config.patch( + enable_trace_unittest=True, + ), + ) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 9c9a142bda2c..22fa9344b61f 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -3173,7 +3173,6 @@ def is_numpy_type_info(obj) -> bool: random, traceback, linecache, - unittest, ) # third party libraries skiplist is defined by str, because users may not use these libraries. @@ -3580,6 +3579,12 @@ def check_file(filename, is_inlined_call=False): ): return SkipResult(True, "FBCODE_SKIP_TORCHREC_DIRS") + if ( + filename.startswith(_module_dir(unittest)) + and not torch._dynamo.config.enable_trace_unittest + ): + return SkipResult(True, "unittest") + if bool(SKIP_DIRS_RE.match(filename)): return SkipResult(True, "SKIP_DIRS") diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 2a7d031b7b87..5a19e7076899 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -10,6 +10,7 @@ import sys import types import typing +import unittest from collections import defaultdict, OrderedDict from collections.abc import KeysView, Sequence from typing import Callable, TYPE_CHECKING, Union @@ -1657,7 +1658,10 @@ def call_zip(self, tx: "InstructionTranslator", *args, **kwargs): ) def call_len(self, tx: "InstructionTranslator", *args, **kwargs): - return args[0].call_method(tx, "__len__", args[1:], kwargs) + try: + return args[0].call_method(tx, "__len__", args[1:], kwargs) + except AttributeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) def call_getitem(self, tx: "InstructionTranslator", *args, **kwargs): return args[0].call_method(tx, "__getitem__", args[1:], kwargs) @@ -1871,6 +1875,30 @@ def call_getattr( variables.UserDefinedObjectVariable, ), ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertDictEqual", + "assertSequenceEqual", + "assertWarns", + ) + ): + unimplemented_v2( + gb_type="Failed to trace builtin operator", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace builtin operator `{name}` ", + hints=[ + f"Avoid calling builtin `{name}`. " + "Please report an issue to PyTorch.", + ], + ) + try: return obj.var_getattr(tx, name) except NotImplementedError: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 7c38539bd217..2703d8c4eb7d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -22,7 +22,9 @@ import collections import functools +import inspect import types +from collections.abc import Hashable as py_Hashable from typing import Optional, TYPE_CHECKING from torch._subclasses.fake_tensor import is_fake @@ -53,6 +55,10 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl +def was_instancecheck_override(obj): + return type(obj).__dict__.get("__instancecheck__", False) + + def is_hashable(x): # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at @@ -72,6 +78,13 @@ def is_hashable(x): return x.as_proxy().node.meta.get("example_value") is not None elif isinstance(x, variables.TupleVariable): return all(is_hashable(e) for e in x.items) + elif ( + isinstance(x, variables.UserDefinedObjectVariable) + and not was_instancecheck_override(x.value) + and inspect.getattr_static(x.value, "__hash__") is int.__hash__ + and isinstance(x.value, int) + ): + return isinstance(x.value, py_Hashable) else: return isinstance( x, @@ -80,7 +93,7 @@ def is_hashable(x): variables.SymNodeVariable, variables.ConstantVariable, variables.EnumVariable, - variables.user_defined.UserDefinedClassVariable, + variables.UserDefinedClassVariable, variables.UserFunctionVariable, variables.SkipFunctionVariable, variables.misc.NumpyVariable, @@ -140,6 +153,11 @@ def underlying_value(self): # Access the underlying value inside the referent_vt for the key representation Hashable = ConstDictVariable._HashableTracker return Hashable(self.vt.referent_vt).underlying_value + elif isinstance(self.vt, variables.UserDefinedObjectVariable): + # The re module in Python 3.13+ has a dictionary (_cache2) with + # an object as key (`class _ZeroSentinel(int): ...`): + # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual + return self.vt.value else: x = self.vt.as_python_constant() return x diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index d8beec6aaeb2..701b067710de 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -1080,6 +1080,11 @@ def call_method(self, tx, name, args, kwargs): def has_closure(self): return self.closure is not None + def const_getattr(self, tx, name): + if name == "__name__": + return self.fn_name.as_python_constant() + return super().const_getattr(tx, name) + def has_self(self): return False From ad516180e08818d69294fa80e456373cb1dbe057 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 7 Apr 2025 19:59:29 +0000 Subject: [PATCH 270/332] Update CPython tests for ctx manager to use unittest (#146501) Pull Request resolved: https://github.com/pytorch/pytorch/pull/146501 Approved by: https://github.com/zou3519 ghstack dependencies: #146500 --- test/dynamo/test_ctx_manager.py | 411 ++++++++++++++++---------------- 1 file changed, 204 insertions(+), 207 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index 0ae3ff452967..74ff84dbb9e3 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -9,12 +9,18 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.exc import InternalTorchDynamoError -from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same +from torch._dynamo.testing import ( + EagerAndRecordGraphs, + normalize_gm, + same, + skipIfNotPy311, +) from torch._dynamo.utils import counters from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + make_dynamo_test, parametrize, ) @@ -2696,8 +2702,9 @@ def fn(t): class CPythonContextManagerTestCase(torch._dynamo.test_case.CPythonTestCase): # Tests taken from CPython source code in cpython/Lib/test/test_contextlib.py - # https://github.com/python/cpython/blob/d48cc82ed25e26b02eb97c6263d95dcaa1e9111b/Lib/test/test_contextlib.py#L70 + # https://github.com/python/cpython/blob/v3.13.1/Lib/test/test_contextlib.py + @make_dynamo_test def test_contextmanager_plain(self): state = [] @@ -2707,23 +2714,14 @@ def woohoo(): yield 42 state.append(999) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - y = t.sum() - with woohoo() as x: - assert state == [1] - assert x == 42 - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - y += x - return y - - t = torch.randn(2, 3) - y = fn(t) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) self.assertEqual(state, [1, 42, 999]) - self.assertEqual(y, t.sum() + 42) + @skipIfNotPy311 + @make_dynamo_test def test_contextmanager_finally(self): state = [] @@ -2735,166 +2733,66 @@ def woohoo(): finally: state.append(999) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - _y = t.sum() - with self.assertRaises(ZeroDivisionError): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError - - fn(torch.randn(2, 3)) + with self.assertRaises(ZeroDivisionError): + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_traceback(self): @contextmanager def f(): yield - frames = [] + try: + with f(): + 1 / 0 + except ZeroDivisionError as e: + frames = traceback.extract_tb(e.__traceback__) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - 1 / 0 - except ZeroDivisionError as e: - frames = traceback.extract_tb(e.__traceback__) - - fn(torch.randn(2, 3)) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "1 / 0") - - @unittest.expectedFailure - def test_contextmanager_traceback2(self): - @contextmanager - def f(): - yield + self.assertEqual(frames[0].line, "1/0") # Repeat with RuntimeError (which goes through a different code path) - class RuntimeErrorSubclass(RuntimeError): - pass - - frames = [] + try: + with f(): + raise NotImplementedError(42) + except NotImplementedError as e: + frames = traceback.extract_tb(e.__traceback__) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - raise RuntimeErrorSubclass(42) - except RuntimeErrorSubclass as e: - frames = traceback.extract_tb(e.__traceback__) - - fn(torch.randn(2, 3)) self.assertEqual(len(frames), 1) self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise RuntimeErrorSubclass(42)") - - @unittest.expectedFailure - def test_contextmanager_traceback3(self): - @contextmanager - def f(): - yield - - frames = [] - - class StopIterationSubclass(StopIteration): - pass - - for stop_exc in ( - StopIteration("spam"), - StopIterationSubclass("spam"), - ): - with self.subTest(type=type(stop_exc)): - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - nonlocal frames - _y = t.sum() - try: - with f(): - raise stop_exc - except type(stop_exc) as e: - self.assertIs(e, stop_exc) - frames = traceback.extract_tb(e.__traceback__) - else: - self.fail(f"{stop_exc} was suppressed") - - fn(torch.randn(2, 3)) - self.assertEqual(len(frames), 1) - self.assertEqual(frames[0].name, "test_contextmanager_traceback") - self.assertEqual(frames[0].line, "raise stop_exc") + self.assertEqual(frames[0].line, "raise NotImplementedError(42)") + @make_dynamo_test def test_contextmanager_no_reraise(self): @contextmanager def whee(): yield - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whee() - ctx.__enter__() - # Calling __exit__ should not result in an exception - self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) - return t.sum() - - fn(torch.randn(2, 3)) + ctx = whee() + ctx.__enter__() + # Calling __exit__ should not result in an exception + self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) + @make_dynamo_test def test_contextmanager_trap_yield_after_throw(self): @contextmanager def whoo(): try: yield - except Exception: + except Exception: # noqa: E722 yield - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(TypeError, TypeError("foo"), None) - return t.sum() - - fn(torch.randn(2, 3)) - - def test_contextmanager_trap_no_yield(self): - @contextmanager - def whoo(): - if False: - yield - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - ctx = whoo() - with self.assertRaises(RuntimeError): - ctx.__enter__() - return t.sum() - - fn(torch.randn(2, 3)) - - def test_contextmanager_trap_second_yield(self): - @contextmanager - def whoo(): - yield - yield - - @torch.compile(backend="eager", fullgraph=True) - def f(t): - ctx = whoo() - ctx.__enter__() - with self.assertRaises(RuntimeError): - ctx.__exit__(None, None, None) - - f(torch.randn(2)) + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(TypeError, TypeError("foo"), None) @unittest.skipIf(sys.version_info < (3, 11), "Python 3.11+") def test_contextmanager_except(self): @@ -2909,18 +2807,58 @@ def woohoo(): state.append(e.args[0]) self.assertEqual(state, [1, 42, 999]) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with woohoo() as x: - self.assertEqual(state, [1]) - self.assertEqual(x, 42) - state.append(x) - raise ZeroDivisionError(999) - - fn(torch.randn(2, 3)) + with woohoo() as x: + self.assertEqual(state, [1]) + self.assertEqual(x, 42) + state.append(x) + raise ZeroDivisionError(999) self.assertEqual(state, [1, 42, 999]) @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_except_stopiter(self): + @contextmanager + def woohoo(): + yield + + class StopIterationSubclass(StopIteration): + pass + + for stop_exc in (StopIteration("spam"), StopIterationSubclass("spam")): + with self.subTest(type=type(stop_exc)): + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail(f"{stop_exc} was suppressed") + + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_except_pep479(self): + code = """\ +from __future__ import generator_stop +from contextlib import contextmanager +@contextmanager +def woohoo(): + yield +""" + locals = {} + exec(code, locals, locals) + woohoo = locals["woohoo"] + + stop_exc = StopIteration("spam") + try: + with woohoo(): + raise stop_exc + except Exception as ex: + self.assertIs(ex, stop_exc) + else: + self.fail("StopIteration was suppressed") + + @unittest.expectedFailure + @make_dynamo_test def test_contextmanager_do_not_unchain_non_stopiteration_exceptions(self): @contextmanager def test_issue29692(): @@ -2929,71 +2867,77 @@ def test_issue29692(): except Exception as exc: raise RuntimeError("issue29692:Chained") from exc - @torch.compile(backend="eager", fullgraph=True) - def f(t): - try: - with test_issue29692(): - raise ZeroDivisionError - except Exception as ex: - self.assertIs(type(ex), RuntimeError) - self.assertEqual(ex.args[0], "issue29692:Chained") - self.assertIsInstance(ex.__cause__, ZeroDivisionError) + try: + with test_issue29692(): + raise ZeroDivisionError + except Exception as ex: + self.assertIs(type(ex), RuntimeError) + self.assertEqual(ex.args[0], "issue29692:Chained") + self.assertIsInstance(ex.__cause__, ZeroDivisionError) + + try: + with test_issue29692(): + raise StopIteration("issue29692:Unchained") + except Exception as ex: + self.assertIs(type(ex), StopIteration) + self.assertEqual(ex.args[0], "issue29692:Unchained") + self.assertIsNone(ex.__cause__) - try: - with test_issue29692(): - raise StopIteration("issue29692:Unchained") - except Exception as ex: - self.assertIs(type(ex), StopIteration) - self.assertEqual(ex.args[0], "issue29692:Unchained") - self.assertIsNone(ex.__cause__) + @unittest.expectedFailure + @make_dynamo_test + def _create_contextmanager_attribs(self): + def attribs(**kw): + def decorate(func): + for k, v in kw.items(): + setattr(func, k, v) + return func - f(torch.randn(2)) + return decorate - @unittest.expectedFailure - def test_contextmanager_wrap_runtimeerror(self): @contextmanager - def woohoo(): - try: - yield - except Exception as exc: - raise RuntimeError(f"caught {exc}") from exc - - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with self.assertRaises(RuntimeError): - with woohoo(): - 1 / 0 + @attribs(foo="bar") + def baz(spam): + """Whee!""" - fn(torch.randn(2, 3)) + return baz - # If the context manager wrapped StopIteration in a RuntimeError, - # we also unwrap it, because we can't tell whether the wrapping was - # done by the generator machinery or by the generator itself. - with self.assertRaises(StopIteration): - with woohoo(): - raise StopIteration + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_attribs(self): + baz = self._create_contextmanager_attribs() + self.assertEqual(baz.__name__, "baz") + self.assertEqual(baz.foo, "bar") + @make_dynamo_test def test_keywords(self): # Ensure no keyword arguments are inhibited @contextmanager def woohoo(self, func, args, kwds): yield (self, func, args, kwds) - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - with woohoo(self=11, func=22, args=33, kwds=44) as target: - self.assertEqual(target, (11, 22, 33, 44)) + with woohoo(self=11, func=22, args=33, kwds=44) as target: + self.assertEqual(target, (11, 22, 33, 44)) + + @unittest.expectedFailure + @make_dynamo_test + def test_param_errors(self): + @contextmanager + def woohoo(a, *, b): + yield - fn(torch.randn(2, 3)) + with self.assertRaises(TypeError): + woohoo() + with self.assertRaises(TypeError): + woohoo(3, 5) + with self.assertRaises(TypeError): + woohoo(b=3) + @make_dynamo_test def test_recursive(self): depth = 0 - ncols = 0 @contextmanager def woohoo(): - nonlocal ncols - ncols += 1 nonlocal depth before = depth depth += 1 @@ -3006,14 +2950,67 @@ def recursive(): if depth < 10: recursive() - @torch.compile(backend="eager", fullgraph=True) - def fn(t): - recursive() + recursive() + self.assertEqual(depth, 0) - fn(torch.randn(2, 3)) + @skipIfNotPy311 + @make_dynamo_test + def test_contextmanager_trap_no_yield(self): + @contextmanager + def whoo(): + if False: + yield - self.assertEqual(ncols, 10) - self.assertEqual(depth, 0) + ctx = whoo() + with self.assertRaises(RuntimeError): + ctx.__enter__() + + @make_dynamo_test + def test_contextmanager_trap_second_yield(self): + @contextmanager + def whoo(): + yield + yield + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(RuntimeError): + ctx.__exit__(None, None, None) + + @unittest.expectedFailure + @make_dynamo_test + def test_contextmanager_wrap_runtimeerror(self): + @contextmanager + def woohoo(): + try: + yield + except Exception as exc: + raise RuntimeError(f"caught {exc}") from exc + + with self.assertRaises(RuntimeError): + with woohoo(): + 1 / 0 + + # If the context manager wrapped StopIteration in a RuntimeError, + # we also unwrap it, because we can't tell whether the wrapping was + # done by the generator machinery or by the generator itself. + with self.assertRaises(StopIteration): + with woohoo(): + raise StopIteration + + @make_dynamo_test + def test_contextmanager_non_normalised(self): + @contextmanager + def whoo(): + try: + yield + except RuntimeError: + raise SyntaxError # noqa: B904 + + ctx = whoo() + ctx.__enter__() + with self.assertRaises(SyntaxError): + ctx.__exit__(RuntimeError, None, None) instantiate_parametrized_tests(CtxManagerTests) From a402c2f203ba129dc67b7878a6372b2fdd93cbb0 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 11:53:31 +0000 Subject: [PATCH 271/332] Remove redundant code in cuda/__init__.py (#150529) As the title stated. Follow: https://github.com/pytorch/pytorch/pull/147078 Fix issue: https://github.com/pytorch/pytorch/issues/150519 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150529 Approved by: https://github.com/eqy --- torch/cuda/__init__.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 7e1c6c15b175..cb5c4d0919d2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1210,8 +1210,7 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) + handle = _get_amdsmi_handler(device) # amdsmi_get_gpu_vram_usage returns mem usage in megabytes mem_mega_bytes = amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] mem_bytes = mem_mega_bytes * 1024 * 1024 @@ -1219,16 +1218,12 @@ def _get_amdsmi_device_memory_used(device: Optional[Union[Device, int]] = None) def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) - handle = amdsmi.amdsmi_get_processor_handles()[device] + handle = _get_amdsmi_handler(device) return amdsmi.amdsmi_get_gpu_activity(handle)["umc_activity"] def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: - handle = _get_amdsmi_handler() - device = _get_amdsmi_device_index(device) - handle = amdsmi.amdsmi_get_processor_handles()[device] + handle = _get_amdsmi_handler(device) return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] From 05365e380d01683d3f415cd7997c7425a80d0427 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 11:53:20 +0000 Subject: [PATCH 272/332] Remove torch functions that do not support device arguments from _device_constructor (#150290) As the title stated In Addition: - I have checked all the functions in _device_constructor and found ``torch.vander`` also don`t support device arguments - Remove the duplicated function such as torch.ones and torch.asarray Related issue:https://github.com/pytorch/pytorch/issues/150284 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150290 Approved by: https://github.com/albanD --- torch/utils/_device.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch/utils/_device.py b/torch/utils/_device.py index d7903fc3b465..e16505791b9d 100644 --- a/torch/utils/_device.py +++ b/torch/utils/_device.py @@ -24,7 +24,6 @@ def _device_constructors(): torch.fft.fftfreq, torch.fft.rfftfreq, torch.full, - torch.fill, torch.hamming_window, torch.hann_window, torch.kaiser_window, @@ -33,7 +32,6 @@ def _device_constructors(): torch.nested.nested_tensor, # This function doesn't actually take a device argument # torch.normal, - torch.ones, torch.rand, torch.randn, torch.randint, @@ -47,14 +45,12 @@ def _device_constructors(): torch.sparse_bsc_tensor, torch.tril_indices, torch.triu_indices, - torch.vander, torch.zeros, torch.asarray, # weird ones torch.tensor, torch.as_tensor, - torch.scalar_tensor, - torch.asarray, + torch.scalar_tensor } # NB: This is directly called from C++ in torch/csrc/Device.cpp From da7322548be76ce5303cde8ff84b2ec2c3871992 Mon Sep 17 00:00:00 2001 From: Yan Zhiwei Date: Tue, 8 Apr 2025 02:21:34 +0000 Subject: [PATCH 273/332] [Intel GPU] int4 WOQ gemm XPU Support (#137566) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137566 Approved by: https://github.com/liangan1, https://github.com/guangyey, https://github.com/EikanWang Co-authored-by: xiaolil1 --- aten/src/ATen/native/mkldnn/xpu/Blas.cpp | 49 +++++ .../native/mkldnn/xpu/detail/WoQMatmul.cpp | 179 ++++++++++++++++++ .../ATen/native/mkldnn/xpu/detail/oneDNN.h | 8 + 3 files changed, 236 insertions(+) create mode 100644 aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index cc3d4ec9555d..d2abeda0e6ff 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -418,4 +418,53 @@ TORCH_IMPL_FUNC(addmv_out_xpu) xpu::addmv_out(self, mat, vec, beta, alpha, const_cast(result)); } +Tensor _weight_int4pack_mm_xpu( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScale, + const Tensor& qZeros) { + auto M = A.size(0); // M + auto N = B.size(0); // N1=LCM(N, K) + TORCH_CHECK( + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + __func__, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + + TORCH_CHECK(B.dtype() == kInt, __func__, " : expect B to be int32 tensor."); + TORCH_CHECK( + qZeros.dtype() == kChar, + __func__, + " : expect qZeros to be int8 tensor currently."); + TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor."); + + TORCH_CHECK( + qGroupSize > 1 && qGroupSize % 32 == 0, + __func__, + " : expect qGroupSize to be multiple of 32 and greater than 1, got ", + qGroupSize); + + TORCH_CHECK( + qScale.dim() == 2 && qScale.size(1) == N, + __func__, + ": expect qScale to be 2d tensor with sizes [:, ", + N, + "]"); + TORCH_CHECK( + qZeros.dim() == 2 && qZeros.size(1) == N, + __func__, + ": expect qZeros to be 2d tensor with sizes [:, ", + N, + "]"); + + auto C = at::empty({M, N}, A.options()); + + // qscale:[K/qGroupSize, N] + // qzp:[K/qGroupSize, N] + at::native::onednn::woq_matmul_int4(C, A, B, qScale, qZeros, qGroupSize); + + return C; +} } // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp new file mode 100644 index 000000000000..66d4ffaa9b8a --- /dev/null +++ b/aten/src/ATen/native/mkldnn/xpu/detail/WoQMatmul.cpp @@ -0,0 +1,179 @@ +#include + +#include +#include + +#include +#include + +namespace at::native::onednn { + +void woq_matmul_int4( + Tensor& result, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat1_, // torchao: [M, K], dtype: fp16,bf16 + const Tensor& mat2_, // torchao quantized weight, [K/8, N], dtype: uint4x8 + const Tensor& scale, // torchao: [K/group_size, N], dtype: fp16,bf16 + const Tensor& zp, // torchao: [K/group_size, N], dtype: int8 + int64_t group_size) { + size_t dims = result.dim(); + TORCH_CHECK( + dims == 2, "INT4 matmul at XPU only works with 2D input, got ", dims); + TORCH_CHECK(result.defined(), "oneDNN matmul result should be defined"); + + at::Device cur_device = at::Device(at::kXPU, at::xpu::current_device()); + TORCH_CHECK( + cur_device == mat1_.device(), + "_weight_int4pack_mm_with_scales_and_zeros input should be on current device."); + auto& engine = GpuEngineManager::Instance().get_engine(); + auto& stream = GpuStreamManager::Instance().get_stream(); + + Tensor m1 = mat1_; + Tensor m2 = mat2_; + Tensor scale_ = scale; + Tensor zp_ = zp; + Tensor dst = result; + + int m = m1.size(-2); // M + int n = dst.size(-1); // N + int k = m1.size(-1); // K + + // Construct usr md from input + // xxx_usr_md would describe the real layout of inputs + auto m1_usr_dt = get_onednn_dtype(m1); // e.g., half <==> f16 + auto m2_usr_dt = get_onednn_dtype(m2); // int32 tensor, pack 8 int4 + auto scale_usr_dt = get_onednn_dtype(scale_); // bf16 + auto zp_usr_dt = get_onednn_dtype(zp_); // s8 expected currently + auto dst_usr_dt = get_onednn_dtype(dst); // bf16 + + dnnl::memory::dims m1_usr_dims, m2_usr_dims, scale_usr_dims, zp_usr_dims, + dst_usr_dims; + dnnl::memory::dims m1_usr_strides, m2_usr_strides, scale_usr_strides, + zp_usr_strides, dst_usr_strides; + int compressed_k = (int)(k / 8); + int num_groups = (int)(k / group_size); + m1_usr_dims = {m, k}; + m1_usr_strides = {m1.stride(0), m1.stride(1)}; + m2_usr_dims = {compressed_k, n}; + m2_usr_strides = {1, compressed_k}; // k dim contiguous, 4bit pack into s32 + + scale_usr_dims = {num_groups, n}; + scale_usr_strides = {n, 1}; + zp_usr_dims = {num_groups, n}; + zp_usr_strides = {n, 1}; + dst_usr_dims = {m, n}; + dst_usr_strides = {dst.stride(0), dst.stride(1)}; + + dnnl::memory::desc m1_usr_md, m2_usr_md, scale_usr_md, zp_usr_md, dst_usr_md; + + m1_usr_md = dnnl::memory::desc(m1_usr_dims, m1_usr_dt, m1_usr_strides); + m2_usr_md = dnnl::memory::desc(m2_usr_dims, m2_usr_dt, m2_usr_strides); + scale_usr_md = + dnnl::memory::desc(scale_usr_dims, scale_usr_dt, scale_usr_strides); + zp_usr_md = dnnl::memory::desc(zp_usr_dims, zp_usr_dt, zp_usr_strides); + dst_usr_md = dnnl::memory::desc(dst_usr_dims, dst_usr_dt, dst_usr_strides); + + // create usr memory + auto dst_usr_m = make_onednn_memory(dst_usr_md, engine, dst.data_ptr()); + auto scale_usr_m = make_onednn_memory(scale_usr_md, engine, scale.data_ptr()); + auto zp_usr_m = make_onednn_memory(zp_usr_md, engine, zp.data_ptr()); + + // Construct md for primitive creation + // The xxx_md describes what kinds of matmul the oneDNN does. + // The problem for this op is [m, k] x [k, n] => [m, n] matmul. + auto m1_dt = m1_usr_dt; // bf16 + // Tell oneDNN the weight dtype we want manipulate is u4, + // library needs infer how to unpack u4 data based on the m2_usr_md (s32). + auto m2_dt = dnnl::memory::data_type::u4; + auto scale_dt = scale_usr_dt; // bf16 + // Tell oneDNN the zp dtype we want manipulate is s8 + // library needs infer how to unpack s8 data based on the m2_usr_md. + auto zp_dt = zp_usr_dt; // should be s8, currently + auto dst_dt = dst_usr_dt; + + dnnl::memory::desc m1_md, m2_md, scale_md, zp_md, dst_md; + dnnl::memory::dims m1_dims, m2_dims, scale_dims, zp_dims, dst_dims; + dnnl::memory::dims m1_strides, m2_strides, scale_strides, zp_strides, + dst_strides; + + m1_dims = m1_usr_dims; // {m, k} + m1_strides = m1_usr_strides; // {k, 1} + m2_dims = {k, n}; + m2_strides = {n, 1}; + scale_dims = scale_usr_dims; // {k//group_size, n} + scale_strides = scale_usr_strides; + zp_dims = zp_usr_dims; + zp_strides = zp_usr_strides; + dst_dims = dst_usr_dims; + dst_strides = dst_usr_strides; + + m1_md = dnnl::memory::desc(m1_dims, m1_dt, m1_strides); + m2_md = dnnl::memory::desc(m2_dims, m2_dt, m2_strides); + scale_md = dnnl::memory::desc(scale_dims, scale_dt, scale_strides); + zp_md = dnnl::memory::desc(zp_dims, zp_dt, zp_strides); + dst_md = dnnl::memory::desc(dst_dims, dst_dt, dst_strides); + + std::unordered_map args; + + dnnl::matmul matmul_p; + dnnl::matmul::primitive_desc matmul_pd; + + auto m1_usr_m = make_onednn_memory(m1_usr_md, engine, m1.data_ptr()); + auto m2_usr_m = make_onednn_memory(m2_usr_md, engine, m2.data_ptr()); + + void* handle_b = m2_usr_m.get_data_handle(); + // reinterpret m2_usr_memory as u4 + dnnl::memory m2_u4_m( + {{k, n}, dnnl::memory::data_type::u4, dnnl::memory::format_tag::ba}, + engine, + handle_b); + + dnnl::primitive_attr pattr; + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); +#if ONEDNN_SUPPORT_DETERMINISTIC + if (at::globalContext().deterministicAlgorithms() || + at::globalContext().deterministicMkldnn()) { + pattr.set_deterministic(true); + } +#endif + + // Set scales with multiple scales along K dimension and with groups along K. + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), + {group_size, 1}, + scale_dt); + // Set a single zero point with s8 data type. + pattr.set_zero_points( + DNNL_ARG_WEIGHTS, + (1 << 0) + (1 << 1), + {group_size, 1}, + dnnl::memory::data_type::s8); + + if (m1_dt == dnnl::memory::data_type::f16) + pattr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); + else if (m1_dt == dnnl::memory::data_type::bf16) + pattr.set_fpmath_mode(dnnl::fpmath_mode::bf16, true); + + matmul_pd = dnnl::matmul::primitive_desc( + engine, m1_md, m2_u4_m.get_desc(), dst_md, pattr); + matmul_p = dnnl::matmul(matmul_pd); + + dnnl::memory m1_m = m1_usr_m, m2_m = m2_u4_m, dst_m = dst_usr_m; + dnnl::memory scale_m = scale_usr_m; // zp_m = zp_u4_m; + Tensor m1_, m2_, zp_new, dst_; + + int scratchpad_size = matmul_pd.scratchpad_desc().get_size(); + Tensor scratchpad_tensor = + at::empty({scratchpad_size}, m1.options().dtype(at::kByte), c10::nullopt); + auto scratchpad_memory = make_onednn_memory( + matmul_pd.scratchpad_desc(), engine, scratchpad_tensor.data_ptr()); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad_memory}); + + args.insert({DNNL_ARG_SRC, m1_m}); + args.insert({DNNL_ARG_WEIGHTS, m2_u4_m}); + args.insert({DNNL_ARG_DST, dst_m}); + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_m}); + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp_usr_m}); + dnnl::sycl_interop::execute(matmul_p, stream, args); +} +} // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index a4f993eebcd6..9d8e9fe50df5 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -89,6 +89,14 @@ TORCH_API sycl::event deconvolution_backward_weights( int64_t groups, const std::vector& deps = {}); +TORCH_API void woq_matmul_int4( + at::Tensor& result, // dst, [M, N] + const at::Tensor& mat1_, // src, [M, K] + const at::Tensor& mat2_, // quantized weight, [K/8, N] + const at::Tensor& scale, // [K/group_size, N] + const at::Tensor& zp, // [k/group_size, N] + int64_t group_size); + dnnl::memory::dims conv_dst_size( int64_t ndim, IntArrayRef src_tz, From 52d172eafd3198be899c2e86b89f20636ead71f3 Mon Sep 17 00:00:00 2001 From: ZhiweiYan-96 Date: Tue, 8 Apr 2025 02:21:35 +0000 Subject: [PATCH 274/332] Facilitate at::_weight_int4pack_mm_with_scale_and_zeros related registration (#147962) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147962 Approved by: https://github.com/jerryzh168, https://github.com/guangyey, https://github.com/EikanWang ghstack dependencies: #137566 Co-authored-by: xiaolil1 --- aten/src/ATen/native/native_functions.yaml | 4 + ...asDecompTest.test_has_decomposition.expect | 1 + test/xpu/test_gemm.py | 85 ++++++++++++++++++- torch/_dynamo/trace_rules.py | 1 + torch/_meta_registrations.py | 15 ++++ 5 files changed, 105 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a29fee8c7066..48be73ac5eea 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4165,6 +4165,10 @@ MPS: _weight_int4pack_mm_mps CUDA: _weight_int4pack_mm_cuda +- func: _weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor + dispatch: + XPU: _weight_int4pack_mm_xpu + # Split int4 pack weight between cpu and other devices due to # https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756. - func: _convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 9eb7c572228f..7cb72bda99ae 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -651,6 +651,7 @@ aten::_values_copy aten::_values_copy.out aten::_weight_int4pack_mm aten::_weight_int4pack_mm_for_cpu +aten::_weight_int4pack_mm_with_scales_and_zeros aten::_weight_int8pack_mm aten::_weight_norm_interface_backward aten::_weight_norm_interface_backward.out diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index cf3d68add29e..138729261652 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -15,7 +15,12 @@ instantiate_device_type_tests, precisionOverride, ) -from torch.testing._internal.common_utils import iter_indices, run_tests, TestCase +from torch.testing._internal.common_utils import ( + iter_indices, + parametrize, + run_tests, + TestCase, +) class TestBasicGEMM(TestCase): @@ -1119,6 +1124,84 @@ def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): with torch.no_grad(): torch.matmul(a, b, out=c) + def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16): + # w [k, n] = [32, 48] + assert w.dim() == 2 + # w [n, k] = [48, 32] + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + # to_quant: [n * k / group_size, group_size] + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2**n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + + zeros = min_int - min_val.div(scales).round() + zeros = torch.clamp(zeros, min_int, max_int) + zeros = zeros.to(torch.int8) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.div(scales).add(zeros).round().clamp_(min_int, max_int) + assert torch.isnan(out).sum() == 0 + + # [n, k] + out = out.to(dtype=torch.int32).reshape(w.shape) + if out.device != torch.device("cpu"): + out = (out[::, 1::2] << 4 | out[::, 0::2]).to(torch.uint8) + + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1).transpose(0, 1).contiguous() + zeros = zeros.view(w.shape[0], -1).transpose(0, 1).contiguous() + + return out, scales, zeros + + @parametrize("m", [128]) + @parametrize("k", [512, 1024]) + @parametrize("n", [512, 1024]) + def test__int4_mm(self, device, m, k, n): + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=torch.float32, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.float32, device=device) + + def convert_weight_to_int4pack(b): + # b_uint8 [n, k //2] + b_uint8, scales, zeros = self._group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + # b_int4pack [k//8, n] + b_int4pack = torch._convert_weight_to_int4pack(b_uint8, inner_k_tiles) + + return b_int4pack, scales, zeros + + def weight_int4pack_mm(a, b_int4pack, qscale, qzeros): + return torch._weight_int4pack_mm_with_scales_and_zeros( + a, b_int4pack, q_group, qscale, qzeros + ) + + b_int4pack, b_scales, zeros_int8 = convert_weight_to_int4pack(b_bf16) + + for dtype in [torch.bfloat16, torch.float16]: + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + b_scales = b_scales.to(dtype=dtype) + ref = torch.mm(a, b) + + res = weight_int4pack_mm(a, b_int4pack, b_scales, zeros_int8) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + instantiate_device_type_tests(TestBasicGEMM, globals(), only_for="xpu", allow_xpu=True) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 22fa9344b61f..aaae72c86228 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1624,6 +1624,7 @@ "torch._values_copy", "torch._weight_int4pack_mm", "torch._weight_int4pack_mm_for_cpu", + "torch._weight_int4pack_mm_with_scales_and_zeros", "torch._weight_int8pack_mm", "torch._weight_norm_interface", "torch._weight_norm", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index dab0e92558fc..93274d1060be 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3636,6 +3636,21 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros): return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) +@register_meta([aten._weight_int4pack_mm_with_scales_and_zeros]) +def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros): + torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") + torch._check(w.dim() == 2, lambda: "w must be a 2D tensor") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + lambda: f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check( + w.dtype is torch.int32, + lambda: f"expected w to be int32, got {w.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + def kai_roundup(a: int, b: int) -> int: return ((a + b - 1) // b) * b From ec5f2e30282d1e29116eadb9e5c532759a3bbd2f Mon Sep 17 00:00:00 2001 From: Nikita Shulga <2453524+malfet@users.noreply.github.com> Date: Tue, 8 Apr 2025 16:03:40 +0000 Subject: [PATCH 275/332] [Build] Fix fbgemm build with gcc-12+ (#150847) By suppressing more warnings TODO: fbgemm pin really needs to get updated Pull Request resolved: https://github.com/pytorch/pytorch/pull/150847 Approved by: https://github.com/atalman, https://github.com/Skylion007 --- cmake/Dependencies.cmake | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 7627c3d9c7bb..b6c51e639eee 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -737,6 +737,12 @@ if(USE_FBGEMM) set_property(TARGET fbgemm_avx2 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET fbgemm_avx512 PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET fbgemm PROPERTY POSITION_INDEPENDENT_CODE ON) + # TODO: Remove next two lines after fbgemm pin is updated + + # For more details see https://github.com/pytorch/pytorch/issues/150846 + target_compile_options_if_supported(fbgemm_avx512 -Wno-maybe-uninitialized) + target_compile_options_if_supported(fbgemm_avx512 -Wno-uninitialized) + if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 13.0.0) # See https://github.com/pytorch/pytorch/issues/74352 target_compile_options_if_supported(asmjit -Wno-deprecated-copy) From 1239260a0eed8f64cc87cb46b60355658df0ebea Mon Sep 17 00:00:00 2001 From: Yuanhao Ji Date: Tue, 8 Apr 2025 16:05:03 +0000 Subject: [PATCH 276/332] [Accelerator][Chore] Use existing `acc` when raising an error (#150829) As the title said, `acc` already exists so we just use it instead of calling `current_accelerator()` again. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150829 Approved by: https://github.com/guangyey, https://github.com/Skylion007 --- torch/accelerator/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/accelerator/_utils.py b/torch/accelerator/_utils.py index 3a29acd240cd..730f2a82543d 100644 --- a/torch/accelerator/_utils.py +++ b/torch/accelerator/_utils.py @@ -16,7 +16,7 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int: raise RuntimeError("Accelerator expected") if acc.type != device.type: raise ValueError( - f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}." + f"{device.type} doesn't match the current accelerator {acc}." ) device_index = device.index if device_index is None: From 97f34f012534ac381e13f430a91112cc027eee7f Mon Sep 17 00:00:00 2001 From: ikalinic Date: Tue, 8 Apr 2025 16:18:11 +0000 Subject: [PATCH 277/332] [ROCm][Windows] Include AOTriton dependent sources in Windows build (#150521) Includes ATen native transformers hipified sources in ROCm+Windows build. This was removed due to Trinton not being available on Windows, but this causes further linker errors. Setting `USE_FLASH_ATTENTION=0` and `USE_MEM_EFF_ATTENTION=0` during the build will mitigate the missing headers, but also not cause any linker errors, so we will use this approach for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150521 Approved by: https://github.com/jeffdaily --- aten/src/ATen/CMakeLists.txt | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index e10fdb7e88ee..d939b7b7b084 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -384,12 +384,11 @@ endif() ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) - if(WIN32) # Windows doesn't support Composable Kernels and Triton + if(WIN32) # Windows doesn't support Composable Kernels file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" - ${native_hip_bgemm} ${native_hip_ck} - ${native_transformers_hip_hip} ${native_transformers_hip_cpp}) + ${native_hip_bgemm} ${native_hip_ck}) endif() # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) list(APPEND all_hip_cpp @@ -408,9 +407,6 @@ endif() ${miopen_cpp} ${all_hip_cpp} ) - if(WIN32) # Windows doesn't support Triton - exclude(all_hip_cpp "${all_hip_cpp}" ${native_transformers_hip_cpp}) - endif() endif() if(USE_XPU) From 4447352e6499c28c17f4d48d40c2b1cc3d2863a5 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 8 Apr 2025 16:29:05 +0000 Subject: [PATCH 278/332] Revert "[CUDA] Only use vec128 if CUDA version is newer than 12.8 (#150705)" This reverts commit 5228986c395dc79f90d2a2b991deea1eef188260. Reverted https://github.com/pytorch/pytorch/pull/150705 on behalf of https://github.com/atalman due to break periodic tests ([comment](https://github.com/pytorch/pytorch/pull/150705#issuecomment-2787017751)) --- aten/src/ATen/native/cuda/CUDALoops.cuh | 6 ++---- aten/src/ATen/native/cuda/MemoryAccess.cuh | 4 +--- aten/src/ATen/native/cuda/thread_constants.h | 5 +---- aten/src/ATen/test/cuda_vectorized_test.cu | 19 +++++++------------ 4 files changed, 11 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index fb71dc5488f5..82d0defd972b 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -78,7 +78,7 @@ constexpr auto sum_of_sizes(args_t args, std::index_sequence) { } } -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080) +#ifdef USE_ROCM template constexpr auto elems_per_thread(){ if constexpr (io_sizes == 1) { @@ -219,7 +219,7 @@ static inline void launch_vectorized_kernel( constexpr auto io_size = calc_io_size(); int64_t grid = (N + io_block_work_size() - 1) / io_block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION < 12080) +#ifdef USE_ROCM int vec_size = memory::can_vectorize_up_to(data); #else using cpp_type = typename function_traits::result_type; @@ -241,13 +241,11 @@ static inline void launch_vectorized_kernel( C10_CUDA_KERNEL_LAUNCH_CHECK(); break; #endif -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) case 8: vectorized_elementwise_kernel<8, func_t, array_t> <<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; -#endif case 4: vectorized_elementwise_kernel<4, func_t, array_t> <<>>(N, f, data); diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 3e46f873c61d..fd88df3f8b17 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -486,9 +486,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { uint64_t address = reinterpret_cast(pointer); constexpr int vec2_alignment = std::alignment_of_v>; constexpr int vec4_alignment = std::alignment_of_v>; -#if defined(USE_ROCM) || (defined(CUDA_VERSION) && CUDA_VERSION >= 12080) constexpr int vec8_alignment = std::alignment_of_v>; -#endif #ifdef USE_ROCM constexpr int vec16_alignment = std::alignment_of_v>; constexpr int type_size = sizeof(scalar_t); @@ -497,7 +495,7 @@ inline C10_HOST_DEVICE int can_vectorize_up_to(const char *pointer) { } else if (type_size <= 2 && (address % vec8_alignment == 0)) { return 8; } else -#elif defined(CUDA_VERSION) && CUDA_VERSION >= 12080 +#else if (address % vec8_alignment == 0) { return 8; } else diff --git a/aten/src/ATen/native/cuda/thread_constants.h b/aten/src/ATen/native/cuda/thread_constants.h index 9299b79916cf..bcc797a26e1c 100644 --- a/aten/src/ATen/native/cuda/thread_constants.h +++ b/aten/src/ATen/native/cuda/thread_constants.h @@ -18,11 +18,8 @@ constexpr int thread_work_size() { return 4; } constexpr uint32_t num_threads() { return C10_WARP_SIZE * 4; } -#if defined(CUDA_VERSION) && CUDA_VERSION < 12080 -constexpr int thread_work_size() { return 4; } -#else + constexpr int thread_work_size() { return 8; } #endif -#endif constexpr int block_work_size() { return thread_work_size() * num_threads(); } diff --git a/aten/src/ATen/test/cuda_vectorized_test.cu b/aten/src/ATen/test/cuda_vectorized_test.cu index 4e0c14b17337..6b120f7eb304 100644 --- a/aten/src/ATen/test/cuda_vectorized_test.cu +++ b/aten/src/ATen/test/cuda_vectorized_test.cu @@ -46,17 +46,12 @@ TEST(TestLoops, HasSameArgTypes) { TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { char *ptr = reinterpret_cast(buffer1); -#if defined(CUDA_VERSION) && CUDA_VERSION < 12080 - constexpr auto vectorize_limit = 4; -#else - constexpr auto vectorize_limit= 8; -#endif - ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); - ASSERT_EQ(memory::can_vectorize_up_to(ptr), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 1), 1); @@ -70,8 +65,8 @@ TEST(TestVectorizedMemoryAccess, CanVectorizeUpTo) { ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 4), 1); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), vectorize_limit); - ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), vectorize_limit); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); + ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 8); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 4); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 2); ASSERT_EQ(memory::can_vectorize_up_to(ptr + 8), 1); From 173f126068991a2888a54b64e02a93e93a4d636b Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Mon, 7 Apr 2025 16:08:23 -0700 Subject: [PATCH 279/332] [invoke_subgraph] Preserve node meta (#150782) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150782 Approved by: https://github.com/bdhirsh ghstack dependencies: #150666 --- test/dynamo/test_graph_deduplication.py | 11 +++++++++++ test/higher_order_ops/test_invoke_subgraph.py | 4 ++++ torch/_higher_order_ops/invoke_subgraph.py | 13 ++++++++++++- torch/_higher_order_ops/utils.py | 6 ++++++ 4 files changed, 33 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 99ed2f5a8dd2..2ff363a5f5c7 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -112,7 +112,9 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -205,7 +207,9 @@ def forward(self, primals_1: "f32[10, 10]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None + add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1,) """, @@ -349,7 +353,9 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None @@ -358,7 +364,9 @@ def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): class ___forward_subgraph_1_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2) + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3) + cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None @@ -416,6 +424,7 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): class ___forward_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None return (add_1,) @@ -564,7 +573,9 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None + mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None + sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index db585afaafd7..69394e0e6428 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -921,6 +921,7 @@ class ___backward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add,) """, @@ -1030,7 +1031,9 @@ def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"): class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, primals_0: "f32[8, 8]", primals_1: "f32[8, 8]"): mm: "f32[8, 8]" = torch.ops.aten.mm.default(primals_0, primals_1) + sin: "f32[8, 8]" = torch.ops.aten.sin.default(mm) + t: "f32[8, 8]" = torch.ops.aten.t.default(primals_0); primals_0 = None t_1: "f32[8, 8]" = torch.ops.aten.t.default(primals_1); primals_1 = None return (sin, mm, t, t_1) @@ -1055,6 +1058,7 @@ class ___backward_invoke_subgraph_0_post_graph(torch.nn.Module): def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: "f32[8, 8]"): cos: "f32[8, 8]" = torch.ops.aten.cos.default(mm); mm = None mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_0, cos); tangents_0 = cos = None + mm_1: "f32[8, 8]" = torch.ops.aten.mm.default(t, mul); t = None mm_2: "f32[8, 8]" = torch.ops.aten.mm.default(mul, t_1); mul = t_1 = None return (mm_2, mm_1) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 833b04e78e43..c899370b8d5a 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -298,7 +298,7 @@ def get_output_metadata(subgraph, operands): def trace_joint_graph_as_bwd( - fn, num_primals, joint_operands, include_key_set, exclude_key_set + subgraph, num_primals, joint_operands, include_key_set, exclude_key_set ): """ Naively trace out a joint graph. This simplifies the reconstruction of joint @@ -308,6 +308,17 @@ def trace_joint_graph_as_bwd( dummy_aot_config = get_dummy_aot_autograd_config() + if isinstance(subgraph, torch.fx.GraphModule): + + def graph_with_interpreter(*args): + # Running graph with interpreter is needed for propagating the stack_trace + with torch.fx.traceback.preserve_node_meta(): + return torch.fx.Interpreter(subgraph).run(*args) + + fn = graph_with_interpreter + else: + fn = subgraph + # This joint_fn is inserted as the backward graph as is. This simplifies the # min-cut partitioner work later on. # Input signature - (*primals, *tangents) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index ca9884687f3c..27f4e739eb41 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -821,4 +821,10 @@ def __repr__(self): return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})" def __call__(self, *args, **kwargs): + if isinstance(self.subgraph, torch.fx.GraphModule): + # Running graph with interpreter is needed for propagating the stack_trace + with fx_traceback.preserve_node_meta(): + return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)( + *args, **kwargs + ) return self.ctx.functionalize(self.subgraph)(*args, **kwargs) From 3e0038ae85246c1d3ffd95a618bb8fdebd9dd513 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 11:53:48 +0000 Subject: [PATCH 280/332] Fix torch.matmul related out dtype check (#148174) ---- - torch.matmul -> CompositeImplicitAutograd -> dot_out (when left_dim == 1 & right_dim == 1) -> mv_out (when left_dim == 2 & right_dim == 1) -> mm_out (when left_dim == 1 & right_dim == 2) -> ... - torch.dot - torch.vdot - torch.mm - torch.mv ISSUE related: https://github.com/pytorch/pytorch/issues/138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148174 Approved by: https://github.com/jansel --- test/test_ops.py | 7 ------- torch/_decomp/decompositions.py | 2 +- torch/_meta_registrations.py | 2 +- torch/_refs/__init__.py | 4 ++-- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index d15fa7c6659d..09992fff10a7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -135,7 +135,6 @@ def reduction_dtype_filter(op): xfail("cummin"), xfail("diag"), xfail("diagonal_copy"), - xfail("dot"), xfail("expand_copy"), xfail("fft.ihfft2"), xfail("fft.ihfftn"), @@ -159,7 +158,6 @@ def reduction_dtype_filter(op): xfail("linalg.lu_factor"), xfail("linalg.lu_factor_ex"), xfail("linalg.lu_solve"), - xfail("linalg.matrix_power"), xfail("linalg.qr"), xfail("linalg.slogdet"), xfail("linalg.solve"), @@ -168,12 +166,9 @@ def reduction_dtype_filter(op): xfail("logcumsumexp"), xfail("lu_solve"), xfail("lu_unpack"), - xfail("matmul"), - xfail("mm"), xfail("mode"), xfail("msort"), xfail("multinomial"), - xfail("mv"), xfail("nan_to_num"), xfail("nanmean"), xfail("narrow_copy"), @@ -182,7 +177,6 @@ def reduction_dtype_filter(op): xfail("nn.functional.avg_pool3d"), xfail("nn.functional.gelu"), xfail("nn.functional.hardshrink"), - xfail("nn.functional.linear"), xfail("nn.functional.logsigmoid"), xfail("nn.functional.softplus"), xfail("nn.functional.softshrink"), @@ -210,7 +204,6 @@ def reduction_dtype_filter(op): xfail("triu"), xfail("unfold_copy"), xfail("unsqueeze_copy"), - xfail("vdot"), xfail("view_copy"), xfail("where"), # Output has dynamic shape. diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 6be4f1d276ef..c2dc7e510833 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -4338,7 +4338,7 @@ def grid_sampler_2d( @register_decomposition(aten.mv) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def mv(self, vec): torch._check( diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 93274d1060be..9466f7430348 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2227,7 +2227,7 @@ def meta__fused_moving_avg_obs_fq_helper( @register_meta(aten.mm) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_mm(a, b): torch._check(a.dim() == 2, lambda: "a must be 2D") torch._check(b.dim() == 2, lambda: "b must be 2D") diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 3cf2a0b52146..c9080a01ede3 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -6301,7 +6301,7 @@ def wrapper(self, other): @register_decomposition(aten.dot) -@out_wrapper() +@out_wrapper(exact_dtype=True) @_dot_check_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("self", "other"), @@ -6321,7 +6321,7 @@ def dot(self, other): @register_decomposition(aten.vdot) -@out_wrapper() +@out_wrapper(exact_dtype=True) @_dot_check_wrapper @elementwise_type_promotion_wrapper( type_promoting_args=("self", "other"), From 4926bd60040cb453aad726dc9b155743e149f11c Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 8 Apr 2025 17:10:36 +0000 Subject: [PATCH 281/332] Revert "Fix the Problems About Defining Static Variable in Inline Function (#147095)" This reverts commit 3da14d38bd396f5bbe8494872d1509efa1a6f048. Reverted https://github.com/pytorch/pytorch/pull/147095 on behalf of https://github.com/atalman due to breaks internally ([comment](https://github.com/pytorch/pytorch/pull/147095#issuecomment-2787129770)) --- ...cpp_extensions_open_device_registration.py | 52 +++++++++---------- torch/csrc/api/src/serialize.cpp | 1 + .../csrc/distributed/rpc/python_remote_call.h | 1 + torch/csrc/distributed/rpc/rref_proto.h | 1 + torch/csrc/distributed/rpc/script_call.h | 1 + .../csrc/distributed/rpc/script_remote_call.h | 1 + torch/csrc/distributed/rpc/script_resp.h | 1 + torch/csrc/jit/serialization/export.cpp | 1 - torch/csrc/jit/serialization/export.h | 1 + torch/csrc/jit/serialization/pickler.cpp | 20 ------- torch/csrc/jit/serialization/pickler.h | 19 +++++-- torch/csrc/jit/serialization/unpickler.cpp | 1 + 12 files changed, 50 insertions(+), 50 deletions(-) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 21394218c65b..5d1f0c34ee2e 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -4,6 +4,7 @@ import io import os import sys +import tempfile import unittest from typing import Union from unittest.mock import patch @@ -345,22 +346,23 @@ def test_open_device_storage_pin_memory(self): cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("openreg") self.assertTrue(cpu_untyped_storage_pinned.is_pinned("openreg")) + @unittest.skip( + "Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function" + ) def test_open_device_serialization(self): self.module.set_custom_device_index(-1) storage = torch.UntypedStorage(4, device=torch.device("openreg")) - self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") + self.assertEqual(torch.serialization.location_tag(storage), "openreg") self.module.set_custom_device_index(0) storage = torch.UntypedStorage(4, device=torch.device("openreg")) self.assertEqual(torch.serialization.location_tag(storage), "openreg:0") - # TODO(FFFrog): Comment this because openreg.device is missing - # Uncomment this after improving openreg - # cpu_storage = torch.empty(4, 4).storage() - # openreg_storage = torch.serialization.default_restore_location( - # cpu_storage, "openreg:0" - # ) - # self.assertTrue(openreg_storage.is_openreg) + cpu_storage = torch.empty(4, 4).storage() + openreg_storage = torch.serialization.default_restore_location( + cpu_storage, "openreg:0" + ) + self.assertTrue(openreg_storage.is_openreg) # test tensor MetaData serialization x = torch.empty(4, 4).long() @@ -369,24 +371,22 @@ def test_open_device_serialization(self): self.module.custom_set_backend_meta(y) self.assertTrue(self.module.check_backend_meta(y)) - # TODO(FFFrog): Comment this because openreg.device is missing - # Uncomment this after improving openreg - # self.module.custom_serialization_registry() - # with tempfile.TemporaryDirectory() as tmpdir: - # path = os.path.join(tmpdir, "data.pt") - # torch.save(y, path) - # z1 = torch.load(path) - # loads correctly onto the openreg backend device - # self.assertTrue(z1.is_openreg) - # loads BackendMeta data correctly - # self.assertTrue(self.module.check_backend_meta(z1)) - - # cross-backend - # z2 = torch.load(path, map_location="cpu") - # loads correctly onto the cpu backend device - # self.assertFalse(z2.is_openreg) - # loads BackendMeta data correctly - # self.assertFalse(self.module.check_backend_meta(z2)) + self.module.custom_serialization_registry() + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "data.pt") + torch.save(y, path) + z1 = torch.load(path) + # loads correctly onto the openreg backend device + self.assertTrue(z1.is_openreg) + # loads BackendMeta data correctly + self.assertTrue(self.module.check_backend_meta(z1)) + + # cross-backend + z2 = torch.load(path, map_location="cpu") + # loads correctly onto the cpu backend device + self.assertFalse(z2.is_openreg) + # loads BackendMeta data correctly + self.assertFalse(self.module.check_backend_meta(z2)) def test_open_device_storage_resize(self): cpu_tensor = torch.randn([8]) diff --git a/torch/csrc/api/src/serialize.cpp b/torch/csrc/api/src/serialize.cpp index fae54d124847..e8497a7f22b5 100644 --- a/torch/csrc/api/src/serialize.cpp +++ b/torch/csrc/api/src/serialize.cpp @@ -1,4 +1,5 @@ #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/python_remote_call.h b/torch/csrc/distributed/rpc/python_remote_call.h index 09d4ba36dc62..0a3054b594d2 100644 --- a/torch/csrc/distributed/rpc/python_remote_call.h +++ b/torch/csrc/distributed/rpc/python_remote_call.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch::distributed::rpc { class TORCH_API PythonRemoteCall : public RpcCommandBase { diff --git a/torch/csrc/distributed/rpc/rref_proto.h b/torch/csrc/distributed/rpc/rref_proto.h index a1482b46939b..e6bffd1870b3 100644 --- a/torch/csrc/distributed/rpc/rref_proto.h +++ b/torch/csrc/distributed/rpc/rref_proto.h @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 476ee118fe7f..19e1871ead87 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/distributed/rpc/script_remote_call.h b/torch/csrc/distributed/rpc/script_remote_call.h index e18edab64821..534ac0044599 100644 --- a/torch/csrc/distributed/rpc/script_remote_call.h +++ b/torch/csrc/distributed/rpc/script_remote_call.h @@ -3,6 +3,7 @@ #include #include #include +#include #include namespace torch::distributed::rpc { diff --git a/torch/csrc/distributed/rpc/script_resp.h b/torch/csrc/distributed/rpc/script_resp.h index 53841e3d705c..fd8cd4b845d1 100644 --- a/torch/csrc/distributed/rpc/script_resp.h +++ b/torch/csrc/distributed/rpc/script_resp.h @@ -2,6 +2,7 @@ #include #include +#include namespace torch::distributed::rpc { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 9c10e94141a2..ac20016c7bbb 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 6f8e69bf0ca6..8b2d6d84716a 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 8038aa8ca658..6ce524293a70 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -807,24 +807,4 @@ bool checkHasValidSetGetState(const std::shared_ptr& cls) { return true; } -std::unordered_set& GetBackendMetaAllowlist() { - static std::unordered_set DeviceTypeAllowlist{ - c10::DeviceType::PrivateUse1}; - return DeviceTypeAllowlist; -} - -std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization() { - // The array to save function pointer for BackendMeta serialization. - // key is the DeviceType, value is std::pair obj. - // value.first represent get function and value.seconde represent set function - static std::array< - std::optional>, - at::COMPILE_TIME_MAX_DEVICE_TYPES> - BackendMetaSerialization; - return BackendMetaSerialization; -} - } // namespace torch::jit diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 828f2b3b0521..8accfa229b84 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -299,14 +299,27 @@ using BackendMetaPtr = std::function< void(const at::Tensor&, std::unordered_map&)>; // A allowlist of device type, currently available is PrivateUse1 -TORCH_API std::unordered_set& GetBackendMetaAllowlist(); +inline std::unordered_set& GetBackendMetaAllowlist() { + static std::unordered_set DeviceTypeAllowlist{ + c10::DeviceType::PrivateUse1}; + return DeviceTypeAllowlist; +} // Dynamically obtain serialization function pairs // that require the corresponding backend. -TORCH_API std::array< +inline std::array< std::optional>, at::COMPILE_TIME_MAX_DEVICE_TYPES>& -GetBackendMetaSerialization(); +GetBackendMetaSerialization() { + // The array to save function pointer for BackendMeta serialization. + // key is the DeviceType, value is std::pair obj. + // value.first represent get function and value.seconde represent set function + static std::array< + std::optional>, + at::COMPILE_TIME_MAX_DEVICE_TYPES> + BackendMetaSerialization; + return BackendMetaSerialization; +} // Register function pointer of Tensor BackendMetadata for serialization. TORCH_API inline void TensorBackendMetaRegistry( diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index cdd58b8cef3d..0cbb710f5513 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -5,6 +5,7 @@ #endif #include #include +#include #include #include #include From 97759614c2733c4fb85f19ca9521cc4163af5935 Mon Sep 17 00:00:00 2001 From: William Wen Date: Mon, 7 Apr 2025 17:35:57 -0700 Subject: [PATCH 282/332] [dynamo] reconstruct functions decorated in the compiled region properly (#150645) We were previously unable to reconstruct functions that were decorated in the compiled region. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150645 Approved by: https://github.com/jansel --- test/dynamo/test_error_messages.py | 8 +- test/dynamo/test_reconstruct.py | 97 +++++++++++++++++++ .../TestAutograd.test_backward_with_inputs | 0 .../TestAutograd.test_set_grad_coroutines | 0 ...TestAutograd.test_set_grad_coroutines_exit | 0 ...Autograd.test_set_grad_generator_functions | 0 ...est_set_grad_generator_functions_recursive | 0 ...nce_mode_inf_tensor_in_inf_mode_inplace_op | 0 torch/_dynamo/codegen.py | 4 +- torch/_dynamo/testing.py | 6 ++ torch/_dynamo/variables/ctx_manager.py | 19 +++- torch/_dynamo/variables/functions.py | 76 ++++++++++++++- 12 files changed, 201 insertions(+), 9 deletions(-) delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_backward_with_inputs delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines_exit delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions delete mode 100644 test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions_recursive delete mode 100644 test/dynamo_expected_failures/TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 71c4c921ae4b..8310d3df974d 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -690,9 +690,9 @@ def post_munge(s): """\ Reconstruction failure Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). - Hint: If Dynamo attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. + Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't havereconstruction rules may be fundamentally unreconstructable. + Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) @@ -744,9 +744,9 @@ def post_munge(s): """\ Reconstruction failure Explanation: Dynamo has no bytecode reconstruction implemented for sourceless variable UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)). - Hint: If Dynamo attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. + Hint: If Dynamo is attempting to trace a return statement and your code is attempting to return a variable that Dynamo cannot reconstruct, then remove it from the return statement. Hint: This graph break may have been caused by an earlier graph break. Resolving the earlier graph break may resolve this one. - Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't havereconstruction rules may be fundamentally unreconstructable. + Hint: Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have reconstruction rules may be fundamentally unreconstructable. Developer debug context: UserMethodVariable(.Foo.meth at 0xmem_addr>, UserDefinedObjectVariable(Foo)) diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 4eecfdf13989..662f5420bfcb 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -300,6 +300,103 @@ def fn(model, states, x): got = opt_fn(model, states, x) self.assertEqual(expected, got) + def test_graph_break_in_wrapped_user_function(self): + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + 2 + + @torch.compile(backend="eager") + def gn(x): + x = torch.no_grad()(fn)(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_user_method(self): + class Foo: + def __init__(self): + self.a = 1 + self.b = 2 + + def fn(self, x): + x = x + self.a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + self.b + + obj = Foo() + + @torch.compile(backend="eager") + def gn(x): + obj.fn = torch.no_grad()(obj.fn) + x = obj.fn(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_nested_function(self): + @torch.compile(backend="eager") + def gn(x): + a = 1 + b = 2 + + @torch.no_grad() + def fn(x): + x = x + a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + b + + x = fn(x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + + def test_graph_break_in_wrapped_skipped_function(self): + from torch._dynamo import trace_rules + from torch._dynamo.testing import _skipped_function_for_test_reconstruct + from torch._dynamo.variables import SkipFunctionVariable + + self.assertIs( + trace_rules.lookup(_skipped_function_for_test_reconstruct), + SkipFunctionVariable, + ) + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() + assert not torch.is_grad_enabled() + return x + 2 + + @torch.compile(backend="eager") + def gn(x): + x = torch.no_grad()(_skipped_function_for_test_reconstruct)(fn, x) + # reconstruction failure would cause a skipped frame + assert torch.compiler.is_compiling() + assert torch.is_grad_enabled() + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo_expected_failures/TestAutograd.test_backward_with_inputs b/test/dynamo_expected_failures/TestAutograd.test_backward_with_inputs deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines b/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines_exit b/test/dynamo_expected_failures/TestAutograd.test_set_grad_coroutines_exit deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions b/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions_recursive b/test/dynamo_expected_failures/TestAutograd.test_set_grad_generator_functions_recursive deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/test/dynamo_expected_failures/TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op b/test/dynamo_expected_failures/TestAutogradInferenceMode.test_inference_mode_inf_tensor_in_inf_mode_inplace_op deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index b065c188bcbc..1a1f44609112 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -349,10 +349,10 @@ def gen_fn(): context=str(value), explanation=f"Dynamo has no bytecode reconstruction implemented for sourceless variable {value}.", hints=[ - "If Dynamo attempting to trace a return statement and your code is attempting to return a variable " + "If Dynamo is attempting to trace a return statement and your code is attempting to return a variable " "that Dynamo cannot reconstruct, then remove it from the return statement.", *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK, - "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have" + "Report an issue to PyTorch if you need reconstrtuction support. Note that objects that don't have " "reconstruction rules may be fundamentally unreconstructable.", ], ) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index d44ad4b2408d..ce25a2969050 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -524,3 +524,9 @@ def reset_rng_state(use_xla: bool = False) -> None: import torch_xla.core.xla_model as xm xm.set_rng_state(1337, str(xm.xla_device())) + + +def _skipped_function_for_test_reconstruct( + f: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs +) -> _T: + return f(*args, **kwargs) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 7cbed617d823..26d87113089a 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -41,8 +41,11 @@ from .base import VariableTracker from .functions import ( NestedUserFunctionVariable, + SkipFunctionVariable, UserFunctionVariable, UserMethodVariable, + WrappedNestedUserFunctionVariable, + WrappedSkipFunctionVariable, WrappedUserFunctionVariable, WrappedUserMethodVariable, ) @@ -112,9 +115,21 @@ def call_function( kwargs: "dict[str, VariableTracker]", ) -> "VariableTracker": assert len(args) == 1 + assert isinstance( + args[0], + ( + NestedUserFunctionVariable, + SkipFunctionVariable, + UserMethodVariable, + UserFunctionVariable, + ), + ) + if isinstance(args[0], NestedUserFunctionVariable): - args[0] = UserFunctionVariable(args[0].get_function()) - assert isinstance(args[0], (UserMethodVariable, UserFunctionVariable)) + return WrappedNestedUserFunctionVariable(args[0], self) + + if isinstance(args[0], SkipFunctionVariable): + return WrappedSkipFunctionVariable(args[0], self) if isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 701b067710de..fcbfb22c6d33 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -958,11 +958,15 @@ def call_function( self.context.exit(tx) return result + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + class WrappedUserFunctionVariable(UserFunctionVariable): def __init__(self, wrapped, context, **kwargs) -> None: kwargs.pop("fn", None) - kwargs.pop("obj", None) super().__init__(wrapped.fn, **kwargs) self.wrapped = wrapped self.context = context @@ -978,6 +982,11 @@ def call_function( self.context.exit(tx) return result + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + def invoke_and_store_as_constant(tx: "InstructionTranslator", fn, name, args, kwargs): def convert(x): @@ -1174,6 +1183,46 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.store_attr(name) +class WrappedNestedUserFunctionVariable(NestedUserFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("fn_name", None) + kwargs.pop("code", None) + kwargs.pop("f_globals", None) + kwargs.pop("defaults", None) + kwargs.pop("kwdefaults", None) + kwargs.pop("annotations", None) + kwargs.pop("closure", None) + kwargs.pop("wrapped_fn", None) + super().__init__( + wrapped.fn_name, + wrapped.code, + wrapped.f_globals, + wrapped.defaults, + wrapped.kwdefaults, + wrapped.annotations, + wrapped.closure, + wrapped.wrapped_fn, + ) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + class SkipFunctionVariable(VariableTracker): _nonvar_fields = { "value", @@ -1323,6 +1372,31 @@ def var_getattr(self, tx: "InstructionTranslator", name: str): return fn_var_getattr(tx, self.value, self.source, name) +class WrappedSkipFunctionVariable(SkipFunctionVariable): + def __init__(self, wrapped, context, **kwargs) -> None: + kwargs.pop("value", None) + kwargs.pop("reason", None) + super().__init__(wrapped.value, reason=wrapped.reason, **kwargs) + self.wrapped = wrapped + self.context = context + + def call_function( + self, + tx: "InstructionTranslator", + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + self.context.enter(tx) + result = super().call_function(tx, args, kwargs) + self.context.exit(tx) + return result + + def reconstruct(self, codegen): + codegen.add_push_null(lambda: codegen(self.context)) + codegen(self.wrapped) + codegen.extend_output(create_call_function(1, False)) + + class WrapperUserFunctionVariable(VariableTracker): """ Used to represent a wrapper object that contains the actual callable as an From e6bd133866892958eebbe2d2ca799628d98008f6 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Tue, 8 Apr 2025 14:49:17 +0000 Subject: [PATCH 283/332] add batching rule for `torch.Tensor.scatter_add_` (#150543) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150543 Approved by: https://github.com/zou3519 --- aten/src/ATen/functorch/BatchRulesScatterOps.cpp | 10 ++++++++++ test/functorch/test_vmap.py | 1 - torch/testing/_internal/common_methods_invocations.py | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index e512efad59bb..14f03bd17f4d 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -773,6 +773,15 @@ std::tuple> scatter_add_batch_rule( self, self_bdim, dim, index, index_bdim, src, src_bdim); } +std::tuple> scatter_add__batch_rule( + const Tensor& self, std::optional self_bdim, + int64_t dim, + const Tensor& index, std::optional index_bdim, + const Tensor& src, std::optional src_bdim) { + return scatter_batch_rule(ATEN_FN(scatter_add_), + self, self_bdim, dim, index, index_bdim, src, src_bdim); +} + std::tuple> scatter_reduce_batch_rule( const Tensor& self, std::optional self_bdim, int64_t dim, @@ -1278,6 +1287,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT2(scatter, value, scatter_value_batch_rule); VMAP_SUPPORT2(scatter, src, scatter_src_batch_rule); VMAP_SUPPORT(scatter_add, scatter_add_batch_rule); + VMAP_SUPPORT(scatter_add_, scatter_add__batch_rule); VMAP_SUPPORT2(scatter, reduce, scatter_reduce_batch_rule); VMAP_SUPPORT2(scatter, value_reduce, scatter_value_reduce_batch_rule); VMAP_SUPPORT2(scatter_reduce, two, scatter_reduce_two_batch_rule); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index d552179fc9dc..894aa6f544d7 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4598,7 +4598,6 @@ def test_op_has_batch_rule(self, device, dtype, op): "polygamma", "pow", "remainder", - "scatter_add", "scatter", "square", "sub", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 17391695cdc3..a715b2bbd28d 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -18386,6 +18386,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): ), OpInfo('index_fill', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), + inplace_variant=torch.Tensor.index_fill_, supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -18421,6 +18422,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), OpInfo('index_add', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + inplace_variant=torch.Tensor.index_add_, supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -19342,6 +19344,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))), OpInfo('scatter_add', dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + inplace_variant=torch.Tensor.scatter_add_, sample_inputs_func=sample_inputs_scatter_add, error_inputs_func=error_inputs_scatter_and_scatter_add, supports_forward_ad=True, @@ -21506,6 +21509,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): OpInfo( 'scatter_reduce', variant_test_name='sum', + inplace_variant=torch.Tensor.scatter_reduce_, # complex not added to dtypes as complex gradients are not properly handled # and scatter_reduce hasn't been added to the whitelist in gen_variable_type yet dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), From aafc4b6188b70cf808f756f23b1a05355bcb7696 Mon Sep 17 00:00:00 2001 From: Basil Wong Date: Tue, 8 Apr 2025 18:12:53 +0000 Subject: [PATCH 284/332] Do not depend on numpy during the import (#150816) Summary: Related issue: https://github.com/pytorch/pytorch/issues/149681 We can follow up with a different implementation that does not use numpy(potentially with Torch primitives). Test Plan: pending: contbuild & OSS CI Differential Revision: D72609835 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150816 Approved by: https://github.com/jerryzh168, https://github.com/cyyever, https://github.com/albanD --- torch/nn/utils/_expanded_weights/conv_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/nn/utils/_expanded_weights/conv_utils.py b/torch/nn/utils/_expanded_weights/conv_utils.py index eb14df567095..7b7f58b5ff5f 100644 --- a/torch/nn/utils/_expanded_weights/conv_utils.py +++ b/torch/nn/utils/_expanded_weights/conv_utils.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs from typing import Optional -import numpy as np - import torch import torch.nn.functional as F @@ -213,6 +211,8 @@ def conv_unfold_weight_grad_sample( groups, func, ): + import numpy as np + n = input.shape[0] in_channels = input.shape[1] @@ -318,6 +318,9 @@ def unfold3d( >>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape torch.Size([3, 32, 120]) """ + + import numpy as np + if len(tensor.shape) != 5: raise ValueError( f"Input tensor must be of the shape [B, C, D, H, W]. Got{tensor.shape}" From c36d9b0d8dbc2d22b0344a18ab60e71ef3c43640 Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Tue, 8 Apr 2025 18:49:22 +0000 Subject: [PATCH 285/332] [Codemod][AddExplicitStrictExportForTrainingInferenceArg] caffe2/torch/ao (#150826) Differential Revision: D72615631 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150826 Approved by: https://github.com/ydwu4 --- torch/ao/quantization/pt2e/lowering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/ao/quantization/pt2e/lowering.py b/torch/ao/quantization/pt2e/lowering.py index 76b6f365a84c..587cee22560d 100644 --- a/torch/ao/quantization/pt2e/lowering.py +++ b/torch/ao/quantization/pt2e/lowering.py @@ -50,7 +50,7 @@ def _node_replace(m): # type: ignore[no-untyped-def] m.recompile() lowered_model = ( - torch.export.export_for_training(model, example_inputs) + torch.export.export_for_training(model, example_inputs, strict=True) .run_decompositions(_post_autograd_decomp_table()) .module() ) From 901b02cf16b61824ef6662af4549b81898a127fd Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Mon, 7 Apr 2025 13:50:00 -0700 Subject: [PATCH 286/332] [Inductor] fix alignement assumption for fallback (#150777) Inductor right now only works properly for fallback kernels producing aligned output. When Inductor create layout for fallback kernel output, Inductor does not add the tensor offset to the layout [link](https://github.com/pytorch/pytorch/blob/2a1e2b88ed7bf7d7436b741ee0c3a2297d7d7bc2/torch/_inductor/ir.py#L6935-L6941). Thus unaligned output will be treated as aligned. Adding the offset to the layout directly does not work since that change the index expression in the generated kernel and we may 'double' applying the offset. Triton already considers the offset when passing in the data_ptr. To solve this issue, we track the unaligned buffer names instead. This potentially can fix the internal issues we are debugging here: https://fb.workplace.com/groups/1075192433118967/permalink/1618308128807392/ Differential Revision: [D72600784](https://our.internmc.facebook.com/intern/diff/D72600784) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150777 Approved by: https://github.com/eellison, https://github.com/jansel --- test/inductor/test_torchinductor.py | 62 +++++++++++++++++++++++++ torch/_inductor/codegen/common.py | 2 + torch/_inductor/codegen/triton_utils.py | 7 ++- torch/_inductor/codegen/wrapper.py | 4 +- torch/_inductor/config.py | 5 +- torch/_inductor/graph.py | 9 ++-- torch/_inductor/ir.py | 17 ++++++- 7 files changed, 99 insertions(+), 7 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index bbec7ab9bbe4..8fa9be1966b1 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12895,6 +12895,68 @@ def test_special_polygamma(self): self.common(fn, (1, x)) self.common(fn, (2, x)) + def test_unaligned_input(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024 + 16, device=self.device)[1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_unaligned_input_2d(self): + def fn(x): + return torch.nn.functional.relu(x) + + x = torch.randn(1024, 1024 + 16, device=self.device)[:, 1:-15] + self.common(fn, (x,), check_lowp=False) + + def test_alignment_without_custom_op(self): + def fn(x): + a = torch.nn.functional.relu(x) + b = (3 * a)[1:-15] + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op(self): + def slice1d(x): + return (3 * x)[1:-15] + + def slice1d_meta(x): + return torch.empty_like(x)[1:-15] + + define_custom_op_for_test("slice1d", slice1d, slice1d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice1d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + + @config.patch(implicit_fallbacks=True) + def test_no_align_for_custom_op_2d(self): + def slice2d(x): + return (3 * x)[..., 1:-15] + + def slice2d_meta(x): + return torch.empty_like(x)[..., 1:-15] + + define_custom_op_for_test("slice2d", slice2d, slice2d_meta) + + def fn(x): + a = torch.nn.functional.relu(x) + b = torch.ops.test.slice2d(a) + c = torch.cos(b) + return c + + x = torch.randn(1024, 1024 + 16, device=self.device) + self.common(fn, (x,), check_lowp=False) + @dataclasses.dataclass class TestFailure: diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 7fce40e869ef..417e215d4f57 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1397,6 +1397,8 @@ def output(self, name: str) -> str: return self._lookup("out_ptr", self.output_buffers, name) def make_inplace(self, input_name: str, output_name: str) -> None: + if input_name in V.graph.unaligned_buffers: + V.graph.unaligned_buffers.add(output_name) assert output_name not in self.inplace_buffers if input_name in self.inplace_buffers: buf = self.inplace_buffers[input_name] diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2d5f6a55b4cc..ddd4ec515516 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -122,9 +122,14 @@ def signature_to_meta( def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer + if buf_name in V.graph.unaligned_buffers: + return True + if buf_name in V.graph.graph_inputs: # See Note: [Input Alignment handling in Inductor] - return buf_name not in V.graph.aligned_inputs + # For graph inputs that is not recorded in V.graph.unaligned_buffers, + # we know for sure the tensor is aligned. + return False if buf_name in V.graph.constants: # all constants are assumed to be aligned diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index d7de4b4f24a6..c10831bd8278 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -78,12 +78,13 @@ pexpr = PythonPrinter().doprint -ReuseKey = tuple[torch.device, torch.dtype, str] +ReuseKey = tuple[torch.device, torch.dtype, str, bool] BufferLike = Union[ir.Buffer, WorkspaceArg] def buffer_reuse_key(node: BufferLike) -> ReuseKey: storage_size = V.graph.get_allocation_storage_size(node) + alignment = node.get_name() not in V.graph.unaligned_buffers return ( node.get_device_or_error(), node.get_dtype(), @@ -91,6 +92,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: # for s0 for s1, just because they happen to share the same # size hint sympy_str(V.graph.sizevars.simplify(storage_size)), + alignment, ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index c210af25c16d..27b77d199f09 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -500,6 +500,9 @@ def use_autoheuristic(name: str) -> bool: # automatically create fallbacks when encountering an unhandled op implicit_fallbacks = True +assume_unaligned_fallback_output = ( + os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" +) # fuse even in cases without common reads aggressive_fusion = False @@ -1129,7 +1132,7 @@ class triton: ) # type: ignore[assignment] # hint to Triton when arguments are divisible by 16 - divisible_by_16 = True + divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" # Minimum R0_BLOCK to be used for a TritonSplitScanKernel # NOTE: This also indirectly controls the size of workspace buffer required diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 38f359c5f255..bfe4d1e960b3 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -430,7 +430,10 @@ def __init__( self.get_backend_features = functools.lru_cache(None)(get_backend_features) self.effectful_ops: dict[_EffectType, ir.Buffer] = {} - self.aligned_inputs: OrderedSet[str] = OrderedSet() + # Track the buffers that we know is unaligned + # This can either be a graph input or the output of fallback + # kernels. + self.unaligned_buffers: OrderedSet[str] = OrderedSet() self.no_fuse_buffer_names = OrderedSet[str]() self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet() @@ -1116,8 +1119,8 @@ def placeholder( # expensive and cause recompiles; Instead, we're generating code # based on the alignment of the example input without guarding. with maybe_get_suppress_shape_guards_ctx(): - if should_assume_input_aligned(example): - self.aligned_inputs.add(target) + if not should_assume_input_aligned(example): + self.unaligned_buffers.add(target) return tensor def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 84069bbdf829..a312ea3ca11c 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -95,6 +95,7 @@ sympy_index_symbol_with_prefix, sympy_product, sympy_subs, + tensor_is_aligned, ) from .virtualized import ops, OpsValue, V @@ -6996,11 +6997,16 @@ def generate_output(output, indices): # type: ignore[no-untyped-def] for key, val in output.items() } elif isinstance(output, torch.Tensor): - return MultiOutput( + buf = MultiOutput( cls.tensor_to_layout(output), packed, indices, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + output + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] + return buf elif isinstance(output, int): return output elif isinstance(output, torch.SymInt): @@ -8051,6 +8057,11 @@ def create_out_of_place( # type: ignore[no-untyped-def] ) for i, tensor in enumerate(example_output) ] + for buf, tensor in zip(packed.outputs, example_output): + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + tensor + ): + V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type] return packed.outputs else: packed = cls( @@ -8060,6 +8071,10 @@ def create_out_of_place( # type: ignore[no-untyped-def] non_tensor_args, unflatten_args, ) + if config.assume_unaligned_fallback_output or not tensor_is_aligned( + example_output + ): + V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type] packed.outputs = [packed] return packed From 17f9276e29477f7a24bf431fa905f8bdd1464cb3 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 19:49:37 +0800 Subject: [PATCH 287/332] Code Clean: Remove python3.8 specific code because PyTorch now need Python3.9 and later (#150834) As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150834 Approved by: https://github.com/Skylion007, https://github.com/albanD --- test/torch_np/numpy_tests/core/test_dtype.py | 6 ------ test/torch_np/numpy_tests/core/test_scalar_methods.py | 10 ---------- 2 files changed, 16 deletions(-) diff --git a/test/torch_np/numpy_tests/core/test_dtype.py b/test/torch_np/numpy_tests/core/test_dtype.py index aeb9710832f9..d548f49b4cc4 100644 --- a/test/torch_np/numpy_tests/core/test_dtype.py +++ b/test/torch_np/numpy_tests/core/test_dtype.py @@ -3,7 +3,6 @@ import functools import operator import pickle -import sys import types from itertools import permutations from typing import Any @@ -325,11 +324,6 @@ def test_keyword_argument(self): # test for https://github.com/numpy/numpy/pull/16574#issuecomment-642660971 assert np.dtype(dtype=np.float64) == np.dtype(np.float64) - @skipif(sys.version_info >= (3, 9), reason="Requires python 3.9") - def test_class_getitem_38(self) -> None: - with pytest.raises(TypeError): - np.dtype[Any] - class TestFromDTypeAttribute(TestCase): def test_simple(self): diff --git a/test/torch_np/numpy_tests/core/test_scalar_methods.py b/test/torch_np/numpy_tests/core/test_scalar_methods.py index e1e92de7d6c6..36ac89a02c29 100644 --- a/test/torch_np/numpy_tests/core/test_scalar_methods.py +++ b/test/torch_np/numpy_tests/core/test_scalar_methods.py @@ -5,7 +5,6 @@ """ import fractions import functools -import sys import types from typing import Any from unittest import skipIf as skipif, SkipTest @@ -222,15 +221,6 @@ def test_subscript_scalar(self) -> None: assert np.number[Any] -@instantiate_parametrized_tests -class TestClassGetitemMisc(TestCase): - @skipif(sys.version_info >= (3, 9), reason="Requires python 3.8") - @parametrize("cls", [np.number, np.complexfloating, np.int64]) - def test_class_getitem_38(self, cls: type[np.number]) -> None: - with pytest.raises(TypeError): - cls[Any] - - @skip(reason="scalartype(...).bit_count() not implemented") @instantiate_parametrized_tests class TestBitCount(TestCase): From 89505f4498f799f120aea725a793151fc9178b0f Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Tue, 8 Apr 2025 22:35:28 +0000 Subject: [PATCH 288/332] [AOTI] Always use oss schema for ExternKernelNodes serialization (#150197) Summary: Added a field `protocol` to `ExternKernelNodes` and all the lowering pass will always use the oss schema to serialize external kernel nodes from now on. Test Plan: CI Differential Revision: D72020444 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150197 Approved by: https://github.com/zhxchen17 --- test/inductor/test_aot_inductor_utils.py | 9 ------ test/inductor/test_torchbind.py | 3 +- torch/_export/serde/aoti_schema.py | 14 --------- torch/_export/serde/export_schema.thrift | 3 +- torch/_export/serde/schema.py | 3 +- torch/_export/serde/schema.yaml | 7 +++-- torch/_inductor/extern_node_serializer.py | 7 +++-- torch/_inductor/graph.py | 4 +-- .../aoti_runner/model_container_runner.cpp | 2 +- .../aoti_torch/oss_proxy_executor.cpp | 31 ++++++++++++------- .../inductor/aoti_torch/oss_proxy_executor.h | 17 +--------- .../csrc/inductor/aoti_torch/proxy_executor.h | 15 +++++++++ .../utils/generated_serialization_types.h | 20 ++++++++++-- 13 files changed, 70 insertions(+), 65 deletions(-) delete mode 100644 torch/_export/serde/aoti_schema.py diff --git a/test/inductor/test_aot_inductor_utils.py b/test/inductor/test_aot_inductor_utils.py index 6868928957a2..04a268abc3cb 100644 --- a/test/inductor/test_aot_inductor_utils.py +++ b/test/inductor/test_aot_inductor_utils.py @@ -58,15 +58,6 @@ def legacy_compile( restore_fqn=False, ) - if IS_FBCODE: - from deeplearning.aot_inductor.extern_node_thrift_serializer import ( - thrift_serializer, - ) - - if options is None: - options = {} - options["extern_node_serializer"] = thrift_serializer - with torch.no_grad(): so_path = torch._inductor.aot_compile(gm, example_inputs, options=options) # type: ignore[arg-type] diff --git a/test/inductor/test_torchbind.py b/test/inductor/test_torchbind.py index b94ba8ef1556..6f4e9fb876d4 100644 --- a/test/inductor/test_torchbind.py +++ b/test/inductor/test_torchbind.py @@ -275,7 +275,8 @@ def test_torchbind_aot_compile(self): "is_hop_single_tensor_return": None, }, }, - ] + ], + "protocol": "json", }, ) diff --git a/torch/_export/serde/aoti_schema.py b/torch/_export/serde/aoti_schema.py deleted file mode 100644 index d19add43705c..000000000000 --- a/torch/_export/serde/aoti_schema.py +++ /dev/null @@ -1,14 +0,0 @@ -from dataclasses import dataclass - -from torch._export.serde.schema import Node - - -@dataclass -class ExternKernelNode: - name: str - node: Node - - -@dataclass -class ExternKernelNodes: - nodes: list[ExternKernelNode] diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index fbf0be7d78f6..4274fc431dda 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<<3a8a6be8158821263b71ad9018c921664cd32c2f9b4deeac119e2292d186a02b>> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -358,4 +358,5 @@ struct ExternKernelNode { struct ExternKernelNodes { 10: list nodes; + 20: optional string protocol; } diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 0fbaf8644d74..d1d74c624c43 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -8,7 +8,7 @@ from torch._export.serde.union import _Union # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 7) +SCHEMA_VERSION = (8, 8) TREESPEC_VERSION = 1 @@ -484,3 +484,4 @@ class ExternKernelNode: @dataclass class ExternKernelNodes: nodes: Annotated[list[ExternKernelNode], 10] + protocol: Annotated[Optional[str], 20] = None diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 3898303bda4b..e5f9ad4f8e28 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>> +# checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>> AOTInductorModelPickleData: kind: struct fields: @@ -141,6 +141,9 @@ ExternKernelNodes: fields: nodes: type: List[ExternKernelNode] + protocol: + type: Optional[str] + default: None GradientToParameterSpec: kind: struct fields: @@ -530,5 +533,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 7 +- 8 TREESPEC_VERSION: 1 diff --git a/torch/_inductor/extern_node_serializer.py b/torch/_inductor/extern_node_serializer.py index ffd390152034..19bf39fdd2e7 100644 --- a/torch/_inductor/extern_node_serializer.py +++ b/torch/_inductor/extern_node_serializer.py @@ -1,6 +1,6 @@ import json -from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node +from torch._export.serde.schema import ExternKernelNode, ExternKernelNodes, Node from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder from torch._inductor.ir import ExternKernelNode as inductor_ExternKernelNode @@ -19,6 +19,7 @@ def extern_node_json_serializer( extern_kernel_nodes: list[inductor_ExternKernelNode], ) -> str: serialized_nodes = ExternKernelNodes( - nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] + nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes], + protocol="json", ) - return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder) + return json.dumps(_dataclass_to_dict(serialized_nodes), cls=EnumEncoder, indent=2) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index bfe4d1e960b3..bc669580397e 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -366,9 +366,7 @@ def __init__( from torch._inductor.extern_node_serializer import extern_node_json_serializer self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = ( - extern_node_serializer - if config.is_fbcode() and extern_node_serializer - else extern_node_json_serializer + extern_node_json_serializer ) self.current_node: torch.fx.Node = None # type: ignore[assignment] diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 10ea643ae18b..9123c942754f 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -109,7 +109,7 @@ consider rebuild your model with the latest AOTInductor."); if (file_exists(json_filename)) { proxy_executor_ = std::make_unique( - json_filename, device_str == "cpu"); + json_filename, device_str); proxy_executor_handle_ = reinterpret_cast(proxy_executor_.get()); } else { diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp index 99d9045a63b0..fc25970c00b3 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -18,6 +18,19 @@ bool has_key( return map.find(key) != map.end(); } +c10::Device normalize_device(const c10::Device& device) { + // cpu device doesn't have an index + // cuda device must have an index + if (device.is_cpu()) { + return c10::Device(c10::DeviceType::CPU); + } else if (device.is_cuda()) { + return c10::Device( + c10::DeviceType::CUDA, device.has_index() ? device.index() : 0); + } else { + TORCH_CHECK(false, "Unsupported device type", device); + } +} + #ifdef _WIN32 const std::string k_separator = "\\"; #else @@ -211,12 +224,11 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( serialized_arg_val["index"].is_number()) { auto index = serialized_arg_val["index"].get(); device_string += ":" + std::to_string(index); - device_->set_index(static_cast(index)); } c10::Device device(device_string); - if (device.type() != device_->type()) { + if (device != *device_) { VLOG(1) << "ProxyExecutor is using " << *device_ << " for " << op_kernel->target_ << " argument #" << index << ", which is different from the one serialized in thrift: " @@ -579,15 +591,12 @@ std::unique_ptr OSSProxyExecutor:: OSSProxyExecutor::OSSProxyExecutor( const std::string& json_path, - bool is_cpu, + const std::string& device_str, std::optional> custom_objs) { - if (is_cpu) { - device_ = std::make_unique(c10::DeviceType::CPU); - } else { - int device_idx = -1; - device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); - } - + // CUDA device must have an index as a kernel may require + // an explicit device index. e.g., merge_pooled_embeddings + c10::Device normalized_device = normalize_device(c10::Device(device_str)); + device_ = std::make_unique(normalized_device); // If custom_objs is provided, use it instead of loading from // custom_objs_config.json If custom_objs is not provided, try to load from // custom_objs_config.json @@ -617,7 +626,7 @@ OSSProxyExecutor::OSSProxyExecutor( for (auto& [customObjName, file_name] : custom_objs_json.items()) { std::string customObjPath = folder_path + k_separator + file_name.get(); - LOG(INFO) << "Loading custom object to FbProxyExecutor from: " + LOG(INFO) << "Loading custom object to OSSProxyExecutor from: " << customObjPath; std::ifstream custom_obj_file(customObjPath, std::ios::binary); diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h index d20ef2e52186..551c89a3b793 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -12,26 +12,11 @@ namespace torch::aot_inductor { -enum class DynamicArgType : int { - TensorType = 0, - ListTensorType = 1, - ListOptionalTensorType = 2, - IntType = 3, - ListIntType = 4, - NoneType = 5, -}; - inline std::ostream& operator<<(std::ostream& os, DynamicArgType arg_type) { os << static_cast(arg_type); return os; } -inline bool isTensorType(DynamicArgType arg_type) { - return arg_type == DynamicArgType::TensorType || - arg_type == DynamicArgType::ListTensorType || - arg_type == DynamicArgType::ListOptionalTensorType; -} - struct OSSDynamicArg { OSSDynamicArg( int arg_index, @@ -118,7 +103,7 @@ class OSSProxyExecutor : public ProxyExecutor { public: explicit OSSProxyExecutor( const std::string& json_path, - bool is_cpu, + const std::string& device_str, std::optional> custom_objs = std::nullopt); diff --git a/torch/csrc/inductor/aoti_torch/proxy_executor.h b/torch/csrc/inductor/aoti_torch/proxy_executor.h index 6943bca5df49..5ce5d0d4f69c 100644 --- a/torch/csrc/inductor/aoti_torch/proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/proxy_executor.h @@ -6,6 +6,21 @@ namespace torch::aot_inductor { +enum DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, + NoneType = 5, +}; + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + class ProxyExecutor { public: ProxyExecutor() = default; diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f348069b4fbb..8ba2f37d99b5 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<31c433c768b3f1bb61a5e8f4ceffc40c857bd80cf4fa0fc33fd03fa5ebb6c4d8>> +// checksum<<9ce65dfb56cd253e43e4f529501c8158869aaf36048f8849fde36713c2039a57>> // clang-format off #pragma once @@ -54,9 +54,9 @@ class ForwardRef { public: ForwardRef(): ptr_(std::make_unique()) {} - ForwardRef(ForwardRef&&) = default; + ForwardRef(ForwardRef&&); ForwardRef(const ForwardRef& other): ptr_(std::make_unique(*other.ptr_)) {} - ForwardRef& operator=(ForwardRef&&) = default; + ForwardRef& operator=(ForwardRef&&); ForwardRef& operator=(const ForwardRef& other) { ptr_ = std::make_unique(*other.ptr_); return *this; @@ -3216,6 +3216,7 @@ class ExternKernelNode { class ExternKernelNodes { private: std::vector nodes; + std::optional protocol = std::nullopt; public: @@ -3227,6 +3228,14 @@ class ExternKernelNodes { nodes = std::move(def); } + const std::optional& get_protocol() const { + return protocol; + } + + void set_protocol(std::optional def) { + protocol = std::move(def); + } + friend void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t); friend void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t); }; @@ -3315,11 +3324,13 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNode& n inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNodes& nlohmann_json_t) { nlohmann_json_j["nodes"] = nlohmann_json_t.nodes; + nlohmann_json_j["protocol"] = nlohmann_json_t.protocol; } inline void from_json(const nlohmann::json& nlohmann_json_j, ExternKernelNodes& nlohmann_json_t) { ExternKernelNodes nlohmann_json_default_obj; nlohmann_json_t.nodes = nlohmann_json_j.value("nodes", nlohmann_json_default_obj.nodes); + nlohmann_json_t.protocol = nlohmann_json_j.value("protocol", nlohmann_json_default_obj.protocol); } inline void to_json(nlohmann::json& nlohmann_json_j, const GradientToParameterSpec& nlohmann_json_t) { @@ -3688,6 +3699,9 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, UserOutputSpec& nlo nlohmann_json_t.arg = nlohmann_json_j.value("arg", nlohmann_json_default_obj.arg); } + +template ForwardRef::ForwardRef(ForwardRef&&) = default; +template ForwardRef& ForwardRef::operator=(ForwardRef&&) = default; } // namespace _export } // namespace torch From 27ded359a5dcbe8f92e01a24bec258bbfe1a73d6 Mon Sep 17 00:00:00 2001 From: eellison Date: Tue, 8 Apr 2025 10:46:02 -0700 Subject: [PATCH 289/332] Fix inplacing with multiple, fused uses (#150845) We had `can_inplace` defined on a single use. When that buffer has multiple uses inside a fused node, we need to check if the other accesses have the same index. Otherwise we may read memory that has already been written to from inplacing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150845 Approved by: https://github.com/zou3519, https://github.com/exclamaforte, https://github.com/atalman, https://github.com/jansel --- test/inductor/test_cuda_repro.py | 179 +++++++++++++++++++++++++++++++ torch/_inductor/scheduler.py | 33 ++++++ 2 files changed, 212 insertions(+) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 2f28257731af..e7722a1eee8f 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1327,6 +1327,185 @@ def fn(x, y, z): self.assertEqual(ref, res) + @torch._inductor.config.patch(emulate_precision_casts=True) + def test_dont_inplace_disjoint_accesses(self): + # TODO - would not need mms if we could annotate donated buffer.. + def forward( # noqa: F821, F722 + arg0_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg1_1: "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0", # noqa: F821, F722 + arg2_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg3_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 + arg4_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 + arg5_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 + arg6_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 + arg7_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 + ): + permute = torch.ops.aten.permute.default(arg0_1, [1, 0]) + arg0_1 = None + view = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + mm = torch.ops.aten.mm.default(view, permute) + view = permute = None + view_1 = torch.ops.aten.view.default(mm, [8, 4096, 2048]) + mm = None + permute_1 = torch.ops.aten.permute.default(arg2_1, [1, 0]) + arg2_1 = None + view_2 = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + mm_1 = torch.ops.aten.mm.default(view_2, permute_1) + view_2 = permute_1 = None + view_3 = torch.ops.aten.view.default(mm_1, [8, 4096, 2048]) + mm_1 = None + permute_2 = torch.ops.aten.permute.default(arg3_1, [1, 0]) + arg3_1 = None + view_4 = torch.ops.aten.view.default(arg1_1, [32768, 2048]) + arg1_1 = None + mm_2 = torch.ops.aten.mm.default(view_4, permute_2) + view_4 = permute_2 = None + view_5 = torch.ops.aten.view.default(mm_2, [8, 4096, 2048]) + mm_2 = None + convert_element_type_6 = torch.ops.prims.convert_element_type.default( + view_1, torch.float32 + ) + view_1 = None + pow_1 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_6, 2) + mean = torch.ops.aten.mean.dim(pow_1, [-1], True) + pow_1 = None + add = torch.ops.aten.add.Tensor(mean, 1e-06) + mean = None + rsqrt = torch.ops.aten.rsqrt.default(add) + add = None + mul = torch.ops.aten.mul.Tensor(convert_element_type_6, rsqrt) + convert_element_type_6 = rsqrt = None + convert_element_type_7 = torch.ops.prims.convert_element_type.default( + arg4_1, torch.float32 + ) + arg4_1 = None + mul_1 = torch.ops.aten.mul.Tensor(convert_element_type_7, mul) + convert_element_type_7 = mul = None + convert_element_type_8 = torch.ops.prims.convert_element_type.default( + mul_1, torch.bfloat16 + ) + mul_1 = None + convert_element_type_9 = torch.ops.prims.convert_element_type.default( + view_3, torch.float32 + ) + view_3 = None + pow_2 = torch.ops.aten.pow.Tensor_Scalar(convert_element_type_9, 2) + mean_1 = torch.ops.aten.mean.dim(pow_2, [-1], True) + pow_2 = None + add_1 = torch.ops.aten.add.Tensor(mean_1, 1e-06) + mean_1 = None + rsqrt_1 = torch.ops.aten.rsqrt.default(add_1) + add_1 = None + mul_2 = torch.ops.aten.mul.Tensor(convert_element_type_9, rsqrt_1) + convert_element_type_9 = rsqrt_1 = None + convert_element_type_10 = torch.ops.prims.convert_element_type.default( + arg5_1, torch.float32 + ) + arg5_1 = None + mul_3 = torch.ops.aten.mul.Tensor(convert_element_type_10, mul_2) + convert_element_type_10 = mul_2 = None + convert_element_type_11 = torch.ops.prims.convert_element_type.default( + mul_3, torch.bfloat16 + ) + mul_3 = None + view_6 = torch.ops.aten.view.default( + convert_element_type_8, [8, 4096, -1, 128] + ) + convert_element_type_8 = None + view_7 = torch.ops.aten.view.default( + convert_element_type_11, [8, 4096, -1, 128] + ) + convert_element_type_11 = None + view_8 = torch.ops.aten.view.default(view_5, [8, 4096, -1, 128]) + view_5 = None + convert_element_type_12 = torch.ops.prims.convert_element_type.default( + view_6, torch.float32 + ) + view_6 = None + convert_element_type_13 = torch.ops.prims.convert_element_type.default( + view_7, torch.float32 + ) + view_7 = None + unsqueeze = torch.ops.aten.unsqueeze.default(arg6_1, 0) + unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2) + unsqueeze = None + unsqueeze_2 = torch.ops.aten.unsqueeze.default(arg7_1, 0) + unsqueeze_3 = torch.ops.aten.unsqueeze.default(unsqueeze_2, 2) + unsqueeze_2 = None + mul_4 = torch.ops.aten.mul.Tensor(convert_element_type_12, unsqueeze_3) + unsqueeze_3 = None + view_9 = torch.ops.aten.view.default( + convert_element_type_12, [8, 4096, 16, 2, 64] + ) + convert_element_type_12 = None + unbind = torch.ops.aten.unbind.int(view_9, -2) + view_9 = None + getitem = unbind[0] + getitem_1 = unbind[1] + unbind = None + neg = torch.ops.aten.neg.default(getitem_1) + getitem_1 = None + cat = torch.ops.aten.cat.default([neg, getitem], -1) + neg = getitem = None + mul_5 = torch.ops.aten.mul.Tensor(cat, unsqueeze_1) + cat = unsqueeze_1 = None + add_2 = torch.ops.aten.add.Tensor(mul_4, mul_5) + mul_4 = mul_5 = None + unsqueeze_4 = torch.ops.aten.unsqueeze.default(arg6_1, 0) + arg6_1 = None + unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 2) + unsqueeze_4 = None + unsqueeze_6 = torch.ops.aten.unsqueeze.default(arg7_1, 0) + arg7_1 = None + unsqueeze_7 = torch.ops.aten.unsqueeze.default(unsqueeze_6, 2) + unsqueeze_6 = None + mul_6 = torch.ops.aten.mul.Tensor(convert_element_type_13, unsqueeze_7) + unsqueeze_7 = None + view_10 = torch.ops.aten.view.default( + convert_element_type_13, [8, 4096, 16, 2, 64] + ) + convert_element_type_13 = None + unbind_1 = torch.ops.aten.unbind.int(view_10, -2) + view_10 = None + getitem_2 = unbind_1[0] + getitem_3 = unbind_1[1] + unbind_1 = None + neg_1 = torch.ops.aten.neg.default(getitem_3) + getitem_3 = None + cat_1 = torch.ops.aten.cat.default([neg_1, getitem_2], -1) + neg_1 = getitem_2 = None + mul_7 = torch.ops.aten.mul.Tensor(cat_1, unsqueeze_5) + cat_1 = unsqueeze_5 = None + add_3 = torch.ops.aten.add.Tensor(mul_6, mul_7) + mul_6 = mul_7 = None + convert_element_type_14 = torch.ops.prims.convert_element_type.default( + add_2, torch.bfloat16 + ) + add_2 = None + convert_element_type_15 = torch.ops.prims.convert_element_type.default( + add_3, torch.bfloat16 + ) + add_3 = None + permute_3 = torch.ops.aten.permute.default( + convert_element_type_14, [0, 2, 1, 3] + ) + convert_element_type_14 = None + permute_4 = torch.ops.aten.permute.default( + convert_element_type_15, [0, 2, 1, 3] + ) + convert_element_type_15 = None + permute_5 = torch.ops.aten.permute.default(view_8, [0, 2, 1, 3]) + view_8 = None + return (permute_3, permute_4, permute_5) + + from torch._dynamo.debug_utils import aot_graph_input_parser + + kwargs = aot_graph_input_parser(forward) + out, code = run_and_get_code(torch.compile(forward), **kwargs) + # ignore tiny values.. prior to this fix absolute error was ~28 + self.assertEqual(forward(**kwargs), out, atol=0.01, rtol=2) + FileCheck().check_not("in_out").run(code[0]) + # https://github.com/pytorch/pytorch/issues/104937 def test_linear_with_zero_infeature_size(self): m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda") diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index b756cdc4aa98..34df7f21b595 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -464,6 +464,38 @@ def decide_inplace_update(self) -> None: | self.scheduler.completed_operations ) + def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: + # Inside of NodeUser, we track that the read and write are equivalent + # before deciding if the use can be inplace. + # But if that use is fused into a larger kernel, we need to check equivalence + # of other accesses in fused scheduler node as well. + fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self) + buf_name = buf_to_be_inplaced.get_name() + # Dedup read/writes with equivalent indices + # TODO - would be nice if we could just cache accesses on ReadWrites, + # and inforce variant that this class & members are functional.. + deps: OrderedSet[Dep] = OrderedSet() + for user in buf_to_be_inplaced.users: + user_node = user.node + if not isinstance(user_node, BaseSchedulerNode): + continue + + if ( + buf_to_be_inplaced.scheduler.get_fused_node(user_node) + is not fused_node + ): + continue + + deps |= ( + o + for o in user_node.read_writes.reads_and_writes() + if o.name == buf_name + ) + if len(deps) > 1: + return False + + return True + for buf in self.get_outputs(): buf_node = buf.node assert buf_node is not None @@ -515,6 +547,7 @@ def decide_inplace_update(self) -> None: and len(input_buf.node.get_inputs_that_alias_output()) > 0 ) and can_match_buffer_size(input_buf.node, buf.node) + and single_index_in_fused_node(input_buf) ): # if there isn't a triton kernel, then we don't need to call triton-specific things. # but TODO this might be a convenient place to signal to the Collective kernels to inplace From d9f47c75ded194b9992a5e0e42bfdccc1e2f2e85 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 9 Apr 2025 00:06:30 +0000 Subject: [PATCH 290/332] Revert "Fixing NCCL abort hang issue when a ProcessGroupNCCL manages multiple ncclComms (#150690)" This reverts commit 91173ff89aab5f632d483c736d11d5dcf60decac. Reverted https://github.com/pytorch/pytorch/pull/150690 on behalf of https://github.com/atalman due to failing internal test ([comment](https://github.com/pytorch/pytorch/pull/150690#issuecomment-2787905966)) --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d6f3e0d42e1e..2da127e5b267 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1371,9 +1371,6 @@ void ProcessGroupNCCL::abortCommsFromMap( const std::optional& abortReason) { // The process may control multiple devices, loop through the communicators on // each device - // NCCL expects Group abort when there are multiple communicators created in a - // device. - groupStart(); for (auto& it : ncclCommsMap) { auto& devName = it.first; auto& ncclComm = it.second; @@ -1394,7 +1391,6 @@ void ProcessGroupNCCL::abortCommsFromMap( VLOG(2) << logPrefix() << "ProcessGroupNCCL destroyed " << " communicator on CUDA device: " << devName; } - groupEnd(); } // Abort all communicators on this rank From 5f18b7d8770b38e73aa59ea1ff371ae7772d2e58 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Wed, 9 Apr 2025 02:07:48 +0000 Subject: [PATCH 291/332] [docs] remove --recursive flag from readme (#150785) Fixes #150745 See https://github.com/pytorch/pytorch/issues/150745#issuecomment-2784216663 Cloning with `--recursive` as shown in the docs prevents users from checking out commits from before NCCL was removed as a submodule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150785 Approved by: https://github.com/atalman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 00c515408528..5085abc87b7d 100644 --- a/README.md +++ b/README.md @@ -221,7 +221,7 @@ Other potentially useful environment variables may be found in `setup.py`. #### Get the PyTorch Source ```bash -git clone --recursive https://github.com/pytorch/pytorch +git clone https://github.com/pytorch/pytorch cd pytorch # if you are updating an existing checkout git submodule sync From 44deb67830f4f5dc26d6f707bf39a064670d1706 Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 8 Apr 2025 10:24:00 -0700 Subject: [PATCH 292/332] Fix _del_library (#150495) On library deletion, we need to clear fx's schema cache. Test Plan: - top PR in the stack, I don't have a good test case for this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150495 Approved by: https://github.com/eellison --- torch/library.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/library.py b/torch/library.py index fbe6f3ea1cd3..4caa6a698f66 100644 --- a/torch/library.py +++ b/torch/library.py @@ -443,6 +443,21 @@ def _del_library( op_defs, registration_handles, ): + import torch.fx + + for op_def in op_defs: + name = op_def + overload_name = "" + if "." in op_def: + name, overload_name = op_def.split(".") + if ( + name, + overload_name, + ) in torch.fx.operator_schemas._SCHEMA_TO_SIGNATURE_CACHE: + del torch.fx.operator_schemas._SCHEMA_TO_SIGNATURE_CACHE[ + (name, overload_name) + ] + captured_impls -= op_impls captured_defs -= op_defs for handle in registration_handles: From 2e7c9d33e7f933ac3b723cb3bb05b9c88432c25c Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 8 Apr 2025 10:24:01 -0700 Subject: [PATCH 293/332] Refactor layout constraint selection logic (#148104) This PR: - cleans up some existing comments that don't make sense anymore - hooks up the "custom_op_default_layout_constraint" back (that seems to have broken) - cleans up the "lazy registration path" which seems to never get hit anymore - adds dislike_padding to nodes that require exact strides Test Plan: - tests + CI disable padding Pull Request resolved: https://github.com/pytorch/pytorch/pull/148104 Approved by: https://github.com/shunting314, https://github.com/eellison ghstack dependencies: #150495 --- torch/_inductor/config.py | 2 +- torch/_inductor/graph.py | 56 ++++++++++++++------------- torch/_inductor/lowering.py | 45 +++++++++++---------- torch/fx/experimental/proxy_tensor.py | 4 +- 4 files changed, 57 insertions(+), 50 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 27b77d199f09..040b91917398 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -126,7 +126,7 @@ def prologue_fusion_enabled() -> bool: # If the custom op does not have a layout constraint tag already # then we assume the following applies. custom_op_default_layout_constraint: Literal[ - "needs_fixed_stride_order", "flexible_layout" + "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" ] = "needs_fixed_stride_order" # The default layout constraint for user-defined triton kernels. diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index bc669580397e..041a2d8c149d 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -80,11 +80,13 @@ FALLBACK_ALLOW_LIST, fallback_handler, fallback_node_due_to_unsupported_type, + get_layout_constraint_tag, lowerings, make_fallback, maybe_layout_constraints, needs_realized_inputs, require_contiguous, + tag_to_layout_constraint, unsupported_output_tensor, ) from .runtime import autotune_cache @@ -244,6 +246,14 @@ def _get_overload_packet( cur.meta["dislike_padding"] = True continue + if ( + isinstance(cur.target, torch._ops.OpOverload) + and get_layout_constraint_tag(cur.target) + == torch._C.Tag.needs_exact_strides + ): + cur.meta["dislike_padding"] = True + continue + op = _get_overload_packet(cur) if not op: continue @@ -1150,34 +1160,26 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> error.operator_str(target, args, kwargs), ) - # use contiguous unless the (custom) op asks something else - # explicitly - if torch._C.Tag.needs_exact_strides in target.tags: - decided_constraint = constrain_to_fake_tensors # type: ignore[assignment] - elif torch._C.Tag.needs_fixed_stride_order in target.tags: - decided_constraint = constrain_to_fx_strides # type: ignore[assignment] - elif torch._C.Tag.flexible_layout in target.tags: - decided_constraint = None # type: ignore[assignment] - else: - # If there are no tags, we do different things depending on - # if it's a builtin ATen/prim ops or custom ops. - # For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452 - # For custom ops, we constrain_to_fx_strides to maintain the - # behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356 + tag = get_layout_constraint_tag(target, with_default=False) + if ( + tag is None + and torch._library.utils.is_builtin(target) + and self.is_backward + ): + # for implicit fallback ATen ops during backward, if there + # is no layout constraint tag, we conservatively require contiguous + # input since some eager kernels do not + # support non-contiguous inputs. Otherwise they may silently cause + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 + # We only do this For ATen ops and for backward. # - # For ATen ops, only apply the constraint for backward - # ops since fwd ops should work for any strides. - if torch._library.utils.is_builtin(target) and self.is_backward: - decided_constraint = require_contiguous # type: ignore[assignment] - else: - # maybe_layout_constraints will decide the layout constraint for the custom op - # lazily - decided_constraint = None # type: ignore[assignment] - - # for implicitly fallback ops, we conservatively requires - # contiguous input since some eager kernels does not - # support non-contiguous inputs. They may silently cause - # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 + # TODO: should really switch to "needs_fixed_stride" constraint on these + # and identify them one by one. + decided_constraint = require_contiguous # type: ignore[assignment] + else: + tag = get_layout_constraint_tag(target, with_default=True) + decided_constraint = tag_to_layout_constraint(tag) + make_fallback(target, layout_constraint=decided_constraint) elif get_decompositions([target]): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 24520887f6aa..d9e0fb03d004 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -157,37 +157,40 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A return None if fn in _maybe_layout_constraints: return _maybe_layout_constraints[fn] - # OpOverload with custom lowerings override tag-based layout constraints - if fn in lowerings: - _maybe_layout_constraints[fn] = None - return None - # We lazily register tag-based layout constraints. - - def handle_layout_constraint_tag(tag): - if tag is torch._C.Tag.needs_fixed_stride_order: - _maybe_layout_constraints[fn] = constrain_to_fx_strides - return _maybe_layout_constraints[fn] - elif tag is torch._C.Tag.flexible_layout: - _maybe_layout_constraints[fn] = None - return None - else: - raise AssertionError(f"Unknown layout constraint tag: {tag}") + return None + - tag = get_layout_constraint_tag(fn) - return handle_layout_constraint_tag(tag) +tags_by_priority = [ + torch._C.Tag.needs_exact_strides, + torch._C.Tag.needs_fixed_stride_order, + torch._C.Tag.flexible_layout, +] -def get_layout_constraint_tag(fn): +def get_layout_constraint_tag(fn, *, with_default=True): tags_by_priority = [ + torch._C.Tag.needs_exact_strides, torch._C.Tag.needs_fixed_stride_order, torch._C.Tag.flexible_layout, ] for tag in tags_by_priority: if tag in fn.tags: return tag - if torch._library.utils.is_builtin(fn): - return torch._C.Tag.flexible_layout - return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + if with_default: + if torch._library.utils.is_builtin(fn): + return torch._C.Tag.flexible_layout + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) + return None + + +def tag_to_layout_constraint(tag): + if tag == torch._C.Tag.needs_exact_strides: + return constrain_to_fake_tensors + if tag == torch._C.Tag.needs_fixed_stride_order: + return constrain_to_fx_strides + if tag == torch._C.Tag.flexible_layout: + return None + raise AssertionError(f"Unknown layout constraint tag: {tag}") def assert_nyi(cond, msg): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 4193606d849d..9bbc16a895b6 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1169,7 +1169,9 @@ def _should_save_eager_input_vals( f"propagate the FakeTensor vals. Please file an issue." ) if isinstance(target, torch._ops.OpOverload): - return torch._C.Tag.needs_exact_strides in target.tags + from torch._inductor.lowering import get_layout_constraint_tag + + return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides return False From bc47d539fc380f521dfcc25e895e46e6d5a1fd52 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Tue, 8 Apr 2025 19:16:55 -0700 Subject: [PATCH 294/332] [MPS] Support ArgumentBuffer bindings from C++/Python (#150780) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit To workaround limitation of 32-arguments per kernel and being able to eventually compile something like ```python import torch def foo(*args): rc = torch.empty_like(args[0]) for arg in args: rc += arg return rc tensors = torch.rand(100, 32, device='mps').unbind(0) print(torch.compile(foo)(*tensors)) ``` For now, introduce `at::native::metal::get_tensor_gpu_address` and use it from both C++ test and compile_shader to convert list of tensors to list of pointers valid on GPU. Initially this binding were done via `id< MTLArgumentEncoder>`, but according to [Improving CPU Performance by Using Argument Buffers](https://developer.apple.com/documentation/metal/improving-cpu-performance-by-using-argument-buffers?language=objc#Encode-Resources-into-Argument-Buffers) article, this is not necessary when targeting Tier2-only devices (which is true of all devices on MacOS-13 or newer): > To directly encode the argument buffer resources on these Tier 2 devices, write the [MTLBuffer](https://developer.apple.com/documentation/metal/mtlbuffer?language=objc).[gpuAddress](https://developer.apple.com/documentation/metal/mtlbuffer/gpuaddress?language=objc) property — and for other resource types (samplers, textures, and acceleration structures), the [gpuResourceID](https://developer.apple.com/documentation/metal/mtlcomputepipelinestate/gpuresourceid?language=objc) property — into the corresponding structure member. To encode offsets, treat these property values as uint64 types and add the offset to them. Add both C++ and PyThon unittests that validate that this works. Please note, that using either ArgumentEncoder or directly encoding the data does not guarantee buffer will not be freed until shader execution is complete. On the other hand, this should already be guaranteed by MPSCachingAllocator that would only free the memory after all streams completed its execution. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150780 Approved by: https://github.com/dcci --- aten/src/ATen/native/mps/MetalShaderLibrary.h | 8 +++- aten/src/ATen/native/mps/OperationUtils.mm | 11 +++++- aten/src/ATen/test/mps_test_metal_library.cpp | 39 +++++++++++++++++++ test/test_mps.py | 23 +++++++++++ torch/csrc/mps/Module.cpp | 9 +++++ 5 files changed, 86 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index dff66520ccfc..acd2bf66101f 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -46,9 +46,12 @@ constexpr bool has_size_type_v = has_size_type::value; } // namespace detail +// Returns `gpuAddress` of respective `id` plus storage offset +void* get_tensor_gpu_address(const at::TensorBase&); + class MetalKernelFunction { public: - MetalKernelFunction(MTLComputePipelineState_t cps_); + MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_); ~MetalKernelFunction(); MetalKernelFunction(MetalKernelFunction&) = delete; // Shader properties @@ -56,7 +59,7 @@ class MetalKernelFunction { uint64_t getThreadExecutionWidth() const; uint64_t getStaticThreadGroupMemoryLength() const; void runCommandBlock(std::function f); - // Methods below should be called from runCommandBlock functionT + // Methods below should be called from runCommandBlock function void startEncoding(); void setArg(unsigned idx, const at::TensorBase& t); void setArg(unsigned idx, const void* ptr, uint64_t size); @@ -88,6 +91,7 @@ class MetalKernelFunction { private: MTLComputePipelineState_t cps; + MTLFunction_t func; MTLComputeCommandEncoder_t encoder = nullptr; }; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 9655988e082a..57fa278b01d8 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -922,7 +922,8 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} } std::shared_ptr MetalShaderLibrary::getKernelFunction(const std::string& name) { - return std::make_shared(getPipelineStateForFunc(name)); + auto [cpl, func] = getLibraryPipelineState(getLibrary(), name); + return std::make_shared(cpl, func); } class BundledShaderLibary : public MetalShaderLibrary { @@ -1088,10 +1089,12 @@ static dispatch_data_t getSectionData(const std::string& name) { } // MetalKernelFunction implementation -MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_) : cps([cps_ retain]) {} +MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_) + : cps([cps_ retain]), func([f_ retain]) {} MetalKernelFunction::~MetalKernelFunction() { [cps release]; + [func release]; } void MetalKernelFunction::runCommandBlock(std::function run) { @@ -1152,6 +1155,10 @@ static dispatch_data_t getSectionData(const std::string& name) { return [cps staticThreadgroupMemoryLength]; } +void* get_tensor_gpu_address(const at::TensorBase& t) { + return reinterpret_cast(getMTLBufferStorage(t).gpuAddress + t.storage_offset() * t.element_size()); +} + } // namespace at::native::mps // Check that c10::metal::ScalarType is strict subset (with matching values) of c10::ScalarType diff --git a/aten/src/ATen/test/mps_test_metal_library.cpp b/aten/src/ATen/test/mps_test_metal_library.cpp index baee8964364d..3f91516e5a66 100644 --- a/aten/src/ATen/test/mps_test_metal_library.cpp +++ b/aten/src/ATen/test/mps_test_metal_library.cpp @@ -54,6 +54,7 @@ TEST(MPSTestMetalLibrary, ArangeWithArgsShader) { }); ASSERT_TRUE((x==y).all().item().toBool()); } + TEST(MPSTestMetalLibrary, Arange2DShader) { const auto size = 16; auto x = torch::empty({size, size}, at::device(at::kMPS)); @@ -71,3 +72,41 @@ TEST(MPSTestMetalLibrary, Arange2DShader) { }); ASSERT_EQ(x.sum().item().to(), 65280); } + +TEST(MPSTestMetalLibrary, ArgumentBuffers) { + constexpr auto nbuffers = 64; + const auto size = 32; + std::vector ibuffers; + std::vector ibuffers_gpu_ptrs; + for([[maybe_unused]] auto idx: c10::irange(nbuffers)) { + ibuffers.push_back(torch::rand({size}, at::device(at::kMPS))); + ibuffers_gpu_ptrs.push_back(get_tensor_gpu_address(ibuffers.back())); + } + auto output = torch::empty({size}, at::device(at::kMPS)); + DynamicMetalShaderLibrary lib(R"MTL( + constant constexpr auto nbuffers = 64; + struct Inputs { + metal::array args; + }; + + kernel void sum_all(device float* output, constant Inputs& inputs, uint idx [[thread_position_in_grid]]) { + output[idx] = 0; + for(auto i = 0; i < nbuffers; ++i) { + output[idx] += inputs.args[i][idx]; + } + } + )MTL"); + auto func = lib.getKernelFunction("sum_all"); + func->runCommandBlock([&] { + func->startEncoding(); + func->setArg(0, output); + func->setArg(1, ibuffers_gpu_ptrs); + func->dispatch(size); + }); + // Compute sum of all 64 input tensors + auto result = torch::zeros({size}, at::device(at::kMPS)); + for(auto buf: ibuffers) { + result += buf; + } + ASSERT_EQ(result.sum().item().to(), output.sum().item().to()); +} diff --git a/test/test_mps.py b/test/test_mps.py index 576659ae29d6..0f1790a1808d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -13078,6 +13078,29 @@ def test_reduction_utils(self, dtype): self.assertLess(max_err, 1e-2 if dtype == torch.float16 else 1e-5, f"results are {y}, but all elements should have been {x_sum.item()}") + def test_argument_buffers(self): + lib = torch.mps.compile_shader(""" + constant constexpr auto nbuffers = 64; + struct Inputs { + metal::array args; + }; + + kernel void sum_all(device float* output, constant Inputs& inputs, uint idx [[thread_position_in_grid]]) { + auto rc = inputs.args[0][idx]; + for(auto i = 1; i < nbuffers; ++i) { + rc += inputs.args[i][idx]; + } + output[idx] = rc; + } + """) + inputs = torch.rand(64, 32, device="mps").unbind(0) + output = torch.empty_like(inputs[0]) + lib.sum_all(output, inputs) + correct = torch.zeros_like(inputs[0]) + for inp in inputs: + correct += inp + self.assertEqual(correct, output) + @unittest.skipIf(not torch.mps.profiler.is_metal_capture_enabled(), "Set MTL_CAPTURE_ENABLED and try again") def test_metal_capture(self): lib = torch.mps.compile_shader("kernel void full(device float* x, uint idx [[thread_position_in_grid]]) { x[idx] = 1.0; }") diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 3694cd194179..3cd75cedada7 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -394,6 +394,15 @@ struct OptionalArgCaster { } else if (py::isinstance(element)) { auto values = arg.cast>(); setValue(f, idx, values); + } else if (THPVariable_Check(element.ptr())) { + /* List of tensors, most often to overcome the limits of 32-args per + * kernel */ + auto tensorlist = py::cast>(arg); + std::vector tl_ptrs; + for (auto& t : tensorlist) { + tl_ptrs.push_back(at::native::mps::get_tensor_gpu_address(t)); + } + f.setArg(idx, tl_ptrs); } else { TORCH_CHECK(false, "Unexpected argument types"); } From 4d6ff6ca5c8752d45cb77fab993f9c13202a32a3 Mon Sep 17 00:00:00 2001 From: James Wu Date: Tue, 8 Apr 2025 11:12:59 -0700 Subject: [PATCH 295/332] Fill config2launcher with correct launchers during cache hit coordinate descent (#150860) This bug was crazy hard to reproduce, so I can't seem to get a unit test written that isn't the internal one I used for debugging. Here's a short TLDR of the bug: - Due to D71983456(OSS: https://github.com/pytorch/pytorch/pull/149910), we cache CachingAutotuners in memory. - Importantly: **Saving stuff in PyCodeCache in memory is not semantically equivalent to writing to disk**. By saving it in memory, CachingAutotuners do not reset global state. - It's possible through recompiles for different dynamo frames to compile down to exactly the same inductor output code. This involves models that run multiple times, but differ very subtley, or in ways that cause a dynamo guard failure but not a different inductor output code. - Because of this, we reuse CachingAutotuners for a second compile (with different example inputs, just the same triton kernel code) - CachingAutotuners have a Coordinate Descent class on them, which has a cache: https://fburl.com/code/4igrsams (OSS: https://github.com/pytorch/pytorch/blob/aafc4b6188b70cf808f756f23b1a05355bcb7696/torch/_inductor/runtime/coordinate_descent_tuner.py#L69) - Because we are caching these in memory and not on disk, this cache is **not cleared** between runs. - However, this variable is *not* saved on the class, and is reinitialized every time we do autotuning: https://fburl.com/code/n2o8tmje (OSS: https://github.com/pytorch/pytorch/blob/aafc4b6188b70cf808f756f23b1a05355bcb7696/torch/_inductor/runtime/triton_heuristics.py#L933) - `config2launcher` is added when we call `benchmark_one_config`, but on a CoorDesc *cache hit*, we never call `benchmark_one_config`! So we end up returning None, and erroring with: ``` AttributeError: 'NoneType' object has no attribute 'store_cubin' ``` This fixes the problem for now by just recompiling the launcher. Technically, we might be able to save config2launcher on the class to avoid this, but I don't want to risk another weird cache safety bug here, so taking the simpler approach for now. Note that this error only reproduces if: - None of AOTAutogradCache, FXgraphCache hit on the second entry: otherwise, the CachingAutotuner will go through a pickling and then not be saved in memory - We haven't spawned parallel compile workers. If there are parallel compile workers, we pickle the autotuner on the way from the worker to the parent process, once again resetting the Autotuner. - The autotune cache doesn't already have the best config stored in it So it was extraordinarily hard to debug/reproduce. Because of this, I have a complicated internal unit test but no OSS test that can trigger the exact problem. I'll work on a separate test later, but this needs to go in to fix a sev, so we're landing it based on an internal test only. Differential Revision: [D72655382](https://our.internmc.facebook.com/intern/diff/D72655382/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D72655382/)! Pull Request resolved: https://github.com/pytorch/pytorch/pull/150860 Approved by: https://github.com/oulgen --- torch/_inductor/runtime/triton_heuristics.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index be02d43c28f8..daf1afa8ec28 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -978,7 +978,15 @@ def benchmark_one_config(config): self.autotune_time_taken_ns + coordesc_time_taken_ns, found_by_coordesc=True, ) - return config2launcher.get(best_config) + + if best_config not in config2launcher: + # On a Coordesc cache hit, we might not have loaded the launcher + # This can happen because PyCodeCache saves CachingAutotuners in memory, + # even for separate compile IDs (which can have different inputs without changing output code) + config2launcher[best_config] = self._precompile_config( + best_config + ).make_launcher() + return config2launcher[best_config] def run( self, From b01877aa1389a10bb3c17c9ecb8a79ff1c9a8f79 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Mon, 17 Mar 2025 11:01:31 +0800 Subject: [PATCH 296/332] Fix addbmm & addmv & baddbmm out dtype check (#148176) ---- - torch.addbmm - torch.addmv - torch.baddbmm ISSUE related: https://github.com/pytorch/pytorch/issues/138399 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148176 Approved by: https://github.com/jansel ghstack dependencies: #148174 --- test/test_ops.py | 3 --- torch/_decomp/decompositions.py | 4 ++-- torch/_meta_registrations.py | 4 ++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 09992fff10a7..c8079ea71255 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -118,8 +118,6 @@ def reduction_dtype_filter(op): aten = torch.ops.aten meta_consistency_out_dtype_mismatch_xfails = { - xfail("addbmm"), - xfail("addmv"), xfail("alias_copy"), xfail("all"), xfail("amax"), @@ -127,7 +125,6 @@ def reduction_dtype_filter(op): xfail("aminmax"), xfail("any"), xfail("as_strided_copy"), - xfail("baddbmm"), xfail("bucketize"), xfail("conj_physical"), xfail("cross"), diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index c2dc7e510833..94cf3aeeb1d2 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1482,7 +1482,7 @@ def _addmm_activation( @register_decomposition(aten.addmv) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1): if not self.is_floating_point() and not self.is_complex(): @@ -5031,7 +5031,7 @@ def inplace_op(*args, **kwargs): @register_decomposition([aten.baddbmm]) -@out_wrapper() +@out_wrapper(exact_dtype=True) @pw_cast_for_opmath def baddbmm(self, batch1, batch2, beta=1, alpha=1): if not self.is_floating_point() and not self.is_complex(): diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 9466f7430348..86f5b522bebf 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2152,7 +2152,7 @@ def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> @register_meta([aten.baddbmm.default, aten.baddbmm.out]) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(0) dim2 = batch1.size(1) @@ -3460,7 +3460,7 @@ def meta_convolution_backward( @register_meta([aten.addbmm.default, aten.addbmm.out]) -@out_wrapper() +@out_wrapper(exact_dtype=True) def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): dim1 = batch1.size(1) dim2 = batch2.size(2) From 604467de208646f0c3b2663e45f2ff6a655a6716 Mon Sep 17 00:00:00 2001 From: FFFrog Date: Tue, 8 Apr 2025 19:49:38 +0800 Subject: [PATCH 297/332] Code Clean: Remove specific bytecode support in dynamo for python3.8 (#150838) Related Bytecode: - CALL_FINALLy - END_FINALLy - POP_FINALLy The bytecodes above were removed before python3.9, refer to [this](https://github.com/python/cpython/blob/53908bd7905b849e110d2c6f4bce739bff037146/Misc/NEWS.d/3.9.0a2.rst) for more infos. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150838 Approved by: https://github.com/Skylion007, https://github.com/jansel ghstack dependencies: #150834 --- test/dynamo/test_repros.py | 35 ------------------------------- torch/_dynamo/symbolic_convert.py | 28 ------------------------- 2 files changed, 63 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index e03e14b78799..3bd52f981c6b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4075,41 +4075,6 @@ def forward(self, **inp): res = torch.compile(mod, backend="eager", fullgraph=True)(**inputs) self.assertEqual(ref, res) - def test_call_finally_python_3_8(self): - # Issue - https://github.com/pytorch/pytorch/issues/97811 - def make_fn(g): - def fn(): - while True: - try: - print(g) - break - except Exception as _: - break - - return torch.compile(fn, backend="eager") - - make_fn(None)() - - def test_call_finally_python_3_8_2(self): - def f(x): - while x: - try: - pass - except Exception as _: - continue - - torch.compile(f, backend="eager")(0) - - def test_call_finally_opcode_python_3_8(self): - def fn(): - try: - return torch.zeros(4) - finally: - return torch.ones(4) # noqa: SIM107, B012 - - result = torch.compile(fn, backend="aot_eager")() - self.assertEqual(result, torch.ones(4)) - def test_string_format(self): s = "temp{i}" diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 6b9067a91830..fe634614db4e 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -1709,34 +1709,6 @@ def WITH_CLEANUP_FINISH(self, inst): self.popn(2) self.push(None) - def CALL_FINALLY(self, inst): - """ - pushes the address of the next instruction onto the stack and increments - bytecode counter by delta - """ - # Python 3.8 only - addr = self.indexof[self.next_instruction] - self.push(ConstantVariable.create(addr)) - self.jump(inst) - - def END_FINALLY(self, inst): - # Python 3.8 only - # https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY - tos = self.pop() - if isinstance(tos, ConstantVariable): - self.instruction_pointer = tos.as_python_constant() - else: - pass - - def POP_FINALLY(self, inst): - # Python 3.8 only - preserve_tos = inst.argval - if preserve_tos: - tos = self.pop() - _ = self.pop() - if preserve_tos: - self.push(tos) # type: ignore[possibly-undefined] - def FOR_ITER(self, inst): it = self.pop().realize() try: From 81f60f38800319d8fdf437929f31c2099877e788 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 9 Apr 2025 11:01:45 +0000 Subject: [PATCH 298/332] Expand allowed_getattr_types_for_subgm to torch.Tensor (#150867) Summary: att regular weight has the type of torch.nn.parameter.Parameter buffer and tensor constant has the type of torch.Tensor both types are valid. Test Plan: CI Differential Revision: D72657275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150867 Approved by: https://github.com/zhxchen17 --- test/export/test_export.py | 1 - torch/_export/verifier.py | 9 +++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/test/export/test_export.py b/test/export/test_export.py index 343118c715c6..f4898783fb3e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -7301,7 +7301,6 @@ def forward(self, init, accum): self.assertEqual(ep.module()(init, xs), M()(init, xs)) # map_fn references module outside the module hierarchy - @unittest.expectedFailure def test_map_buffers(self): class M1(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 8ba1132ca668..8f80f2a6bcc4 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -149,7 +149,12 @@ def allowed_getattr_types(self) -> tuple[type[Any], ...]: def allowed_getattr_types_for_subgm(self) -> tuple[type[Any], ...]: # subgm in HOP's argument could has have getattr(weight) nodes, thus stateful - return (torch.fx.GraphModule, torch.nn.parameter.Parameter, torch.utils._pytree.TreeSpec) + return ( + torch.fx.GraphModule, + torch.nn.parameter.Parameter, + torch.Tensor, # for buffer and constant tensor + torch.utils._pytree.TreeSpec + ) def check_valid_op(self, op): pass @@ -276,7 +281,7 @@ def _is_type(name, ty): if not isinstance(attr, _allowed_getattr_types(is_toplevel_gm)): raise SpecViolationError( - f"Invalid get_attr type {type(attr)}. \n" + f"Invalid get_attr type {type(attr)} on target {node.target}. \n" f"Valid get_attr types: {_allowed_getattr_types(is_toplevel_gm)}" ) From 142f0f86ce054f401d9d5145e4291629cafba45f Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 9 Apr 2025 11:57:24 +0000 Subject: [PATCH 299/332] Enable modernize-use-default-member-init (#149046) ``modernize-use-default-member-init`` prefers initialisation in class members, that make more ``= default`` constructors possible. Some violations or modernize rules have been fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149046 Approved by: https://github.com/zou3519 --- .clang-tidy | 1 - aten/src/ATen/core/Dict.h | 5 +---- .../src/ATen/core/dispatch/DispatchKeyExtractor.h | 5 ++--- aten/src/ATen/cuda/tunable/Tunable.h | 2 +- aten/src/ATen/native/RangeUtils.h | 6 +++--- aten/src/ATen/native/SpectralOps.cpp | 2 +- aten/src/ATen/native/UnaryOps.cpp | 2 +- aten/src/ATen/native/cuda/CuFFTPlanCache.h | 6 +++--- aten/src/ATen/native/cuda/MiscUtils.h | 7 +++---- aten/src/ATen/native/cuda/Resize.h | 4 ++-- .../native/cuda/linalg/BatchLinearAlgebraLib.h | 6 +++--- aten/src/ATen/native/cudnn/RNNUtils.h | 9 +++------ aten/src/ATen/native/mkldnn/MKLDNNCommon.h | 4 ++-- aten/src/ATen/native/mkldnn/xpu/detail/Attr.h | 2 +- aten/src/ATen/quantized/QTensorImpl.h | 12 ++++++------ torch/_inductor/codegen/cpp_prefix.h | 2 +- torch/csrc/autograd/profiler_python.cpp | 15 ++++++--------- torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp | 2 +- torch/csrc/jit/runtime/interpreter/code_impl.h | 9 +++------ torch/lib/libshm/socket.h | 4 ++-- 20 files changed, 45 insertions(+), 60 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index df40a6df91c0..4b1548d646b2 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -52,7 +52,6 @@ modernize-*, -modernize-macro-to-enum, -modernize-return-braced-init-list, -modernize-use-auto, --modernize-use-default-member-init, -modernize-use-using, -modernize-use-trailing-return-type, -modernize-use-nodiscard, diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h index d187d7b7c116..96cd25fec10b 100644 --- a/aten/src/ATen/core/Dict.h +++ b/aten/src/ATen/core/Dict.h @@ -116,10 +116,7 @@ class DictIterator final { DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} - DictIterator& operator=(const DictIterator& rhs) { - entryRef_ = rhs.entryRef_; - return *this; - } + DictIterator& operator=(const DictIterator& rhs) = default; DictIterator& operator=(DictIterator&& rhs) noexcept { entryRef_ = std::move(rhs.entryRef_); return *this; diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 27438b926db5..61a3c1801294 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -225,8 +225,7 @@ struct TORCH_API DispatchKeyExtractor final { explicit DispatchKeyExtractor(c10::utils::bitset dispatch_arg_indices_reverse) : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse), - nonFallthroughKeys_(DispatchKeySet::FULL), - requiresBitsetPerBackend_(false) { + nonFallthroughKeys_(DispatchKeySet::FULL) { for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; } @@ -252,7 +251,7 @@ struct TORCH_API DispatchKeyExtractor final { // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast // path), or if we need to fall back to the slower path and check // nonFallthroughKeysPerBackend_ - bool requiresBitsetPerBackend_; + bool requiresBitsetPerBackend_{false}; }; } // namespace c10 diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index 3ea292582f60..5e885d4764d2 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -40,7 +40,7 @@ enum TORCH_CUDA_CPP_API TuningStatus { class TORCH_CUDA_CPP_API ResultEntry { public: explicit ResultEntry(std::string key, double time) : key_(std::move(key)), time_(time) {} - explicit ResultEntry(std::string key, double time, const std::string& blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(blas_sig) {} + explicit ResultEntry(std::string key, double time, std::string blas_sig ) : key_(std::move(key)), time_(time), blas_sig_(std::move(blas_sig)) {} bool operator==(const ResultEntry& other) const { return key_ == other.key_; } bool operator!=(const ResultEntry& other) const { return key_ != other.key_; } operator std::string () { return key_; } diff --git a/aten/src/ATen/native/RangeUtils.h b/aten/src/ATen/native/RangeUtils.h index d1756db75016..d3ad1c6ab7df 100644 --- a/aten/src/ATen/native/RangeUtils.h +++ b/aten/src/ATen/native/RangeUtils.h @@ -2,9 +2,9 @@ #include #include -namespace at { -namespace native { + +namespace at::native { template int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar& step) { @@ -42,4 +42,4 @@ int64_t compute_arange_size(const Scalar& start, const Scalar& end, const Scalar return static_cast(size_d); } -}} // namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 0658ed6f27bd..79aaac48034a 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -756,7 +756,7 @@ static DimVector default_alldims(const Tensor& self, at::OptionalIntArrayRef dim IntArrayRef dim_unwrapped = *dim_opt; dim.resize(dim_unwrapped.size()); for (const auto i : c10::irange(dim.size())) { - dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalars=*/false); + dim[i] = maybe_wrap_dim(dim_unwrapped[i], self.dim(), /*wrap_scalar=*/false); } } else { dim.resize(self.dim()); diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index ce0057909830..420a81767fba 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -887,7 +887,7 @@ static inline void mvlgamma_check(const Tensor& self, int64_t p) { Tensor mvlgamma(const Tensor& self, int64_t p) { mvlgamma_check(self, p); auto dtype = c10::scalarTypeToTypeMeta(self.scalar_type()); - if (at::isIntegralType(self.scalar_type(), /*include_bool=*/true)) { + if (at::isIntegralType(self.scalar_type(), /*includeBool=*/true)) { // int -> float promotion dtype = c10::get_default_dtype(); } diff --git a/aten/src/ATen/native/cuda/CuFFTPlanCache.h b/aten/src/ATen/native/cuda/CuFFTPlanCache.h index 08d07c4b45a5..06276c72c53a 100644 --- a/aten/src/ATen/native/cuda/CuFFTPlanCache.h +++ b/aten/src/ATen/native/cuda/CuFFTPlanCache.h @@ -16,7 +16,7 @@ #include #include -namespace at { namespace native { namespace detail { +namespace at::native::detail { // Enum representing the FFT type enum class CuFFTTransformType : int8_t { @@ -58,7 +58,7 @@ struct CuFFTParams } }; -static_assert(std::is_trivial_v, ""); +static_assert(std::is_trivial_v ); // Returns true if the transform type has complex input inline bool cufft_complex_input(CuFFTTransformType type) { @@ -491,4 +491,4 @@ void cufft_set_plan_cache_max_size_impl(DeviceIndex device_index, int64_t max_si int64_t cufft_get_plan_cache_size_impl(DeviceIndex device_index); void cufft_clear_plan_cache_impl(DeviceIndex device_index); -}}} // namespace at::native::detail +} // namespace at::native::detail diff --git a/aten/src/ATen/native/cuda/MiscUtils.h b/aten/src/ATen/native/cuda/MiscUtils.h index e616a7d1fcfb..f733f3a38099 100644 --- a/aten/src/ATen/native/cuda/MiscUtils.h +++ b/aten/src/ATen/native/cuda/MiscUtils.h @@ -4,8 +4,8 @@ #include #include -namespace at { -namespace native { + +namespace at::native { static inline int cuda_int_cast(int64_t value, const char* varname) { auto result = static_cast(value); @@ -28,5 +28,4 @@ static inline Storage pin_memory(int64_t size) { /*resizable=*/false); } -} // namespace native -} // namespace at +} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Resize.h b/aten/src/ATen/native/cuda/Resize.h index d5de128cac1d..b2c3efe5a719 100644 --- a/aten/src/ATen/native/cuda/Resize.h +++ b/aten/src/ATen/native/cuda/Resize.h @@ -5,7 +5,7 @@ #include -namespace at { namespace native { +namespace at::native { TORCH_CUDA_CPP_API void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes); @@ -50,4 +50,4 @@ inline TensorImpl* resize_impl_cuda_( return self; } -}} +} diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h index 5e1f255ebe08..4ab411d9a025 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h @@ -36,8 +36,8 @@ // The current pytorch implementation sets gesvdj tolerance to epsilon of a C++ data type to target the best possible precision. constexpr int cusolver_gesvdj_max_sweeps = 400; -namespace at { -namespace native { + +namespace at::native { void geqrf_batched_cublas(const Tensor& input, const Tensor& tau); void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular); @@ -90,4 +90,4 @@ C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&); }} // namespace cuda::detail #endif -}} // namespace at::native +} // namespace at::native diff --git a/aten/src/ATen/native/cudnn/RNNUtils.h b/aten/src/ATen/native/cudnn/RNNUtils.h index 7e2869a80574..841164622172 100644 --- a/aten/src/ATen/native/cudnn/RNNUtils.h +++ b/aten/src/ATen/native/cudnn/RNNUtils.h @@ -6,9 +6,8 @@ #include // Declares utilities used by RNN.cpp and also needed by external consumers -namespace at { -namespace native { -namespace cudnn_rnn { + +namespace at::native::cudnn_rnn { TORCH_CUDA_CPP_API std::tuple> copy_weights_to_flat_buf_views( @@ -27,6 +26,4 @@ copy_weights_to_flat_buf_views( bool allow_type_change = false, bool include_bias = true); -} // namespace cudnn_rnn -} // namespace native -} // namespace at +} // namespace at::native::cudnn_rnn diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h index cc5739825d7e..03ef7ce450c1 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.h +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.h @@ -20,7 +20,7 @@ #endif #endif -namespace at { namespace native { +namespace at::native { // Mapping ScalarType to ideep tensor data_type TORCH_API ideep::tensor::data_type get_mkldnn_dtype(ScalarType type); @@ -62,6 +62,6 @@ TORCH_API ideep::tensor itensor_from_tensor(const Tensor& tensor, bool from_cons // Set MKLDNN verbose level TORCH_API int set_verbose(int level); -}} +} #endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h index df14020466f5..eb09d37c4b75 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Attr.h @@ -131,7 +131,7 @@ struct PostOpParam { class Attr { public: - Attr() : q_scale_(1.f), q_zero_point_(0) {} + Attr() : q_scale_(1.f) {} Attr(float q_scale, int64_t zp = 0) : q_scale_(q_scale), q_zero_point_(zp) {} /***** eltwise *****/ diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index 127fa78de12d..1763d90cc94e 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -51,8 +51,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { auto impl = c10::make_intrusive( Storage(storage()), key_set(), data_type_, quantizer_); copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), + /*src_q_impl=*/this, + /*dest_q_impl=*/impl.get(), /*version_counter=*/version_counter, /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); impl->refresh_numel(); @@ -72,8 +72,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { auto impl = c10::make_intrusive( Storage(storage()), key_set(), data_type_, quantizer_); copy_tensor_metadata( - /*src_impl=*/this, - /*dest_impl=*/impl.get(), + /*src_q_impl=*/this, + /*dest_q_impl=*/impl.get(), /*version_counter=*/std::move(version_counter), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); impl->refresh_numel(); @@ -91,8 +91,8 @@ struct TORCH_API QTensorImpl : public c10::TensorImpl { AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto q_impl = static_cast(impl.get()); copy_tensor_metadata( - /*src_impl=*/q_impl, - /*dest_impl=*/this, + /*src_q_impl=*/q_impl, + /*dest_q_impl=*/this, /*version_counter=*/version_counter(), /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); refresh_numel(); diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 8254363cbdcb..9d9b19b79da9 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -86,7 +86,7 @@ struct WelfordHelper { std::vector> welford_stk; uint64_t depth; // depth of welford_stk. uint64_t num_chunks; // number of chunks stored in welford_stk. - WelfordHelper() {} + WelfordHelper() = default; WelfordHelper(uint64_t N) { uint64_t m = (N + kChunkSize - 1) / kChunkSize; //div up depth = m > 0 ? ceil(log2(m)) : 0; diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 045c47902516..acbc7bdc0d16 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -1152,16 +1152,13 @@ std::vector> PythonTracer::getEvents( // Assuming python_tracer::PythonMemoryTracerBase is defined elsewhere class PythonMemoryTracer final : public python_tracer::PythonMemoryTracerBase { public: - explicit PythonMemoryTracer(); - ~PythonMemoryTracer() override; + explicit PythonMemoryTracer() = default; + ~PythonMemoryTracer() override = default; void start() override; void stop() override; void export_memory_history(const std::string path) override; }; -PythonMemoryTracer::PythonMemoryTracer() {} -PythonMemoryTracer::~PythonMemoryTracer() {} - static void toggle_memory_tracing(bool enable) { PyGILState_STATE gil_state = PyGILState_Ensure(); THPObjectPtr torch_cuda_memory_module( @@ -1182,9 +1179,9 @@ static void toggle_memory_tracing(bool enable) { PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries PyTuple_SetItem(args, 4, Py_None); // device (None) PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False) - PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); + PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); Py_DECREF(args); - if (result == NULL) { + if (result == nullptr) { return; } PyGILState_Release(gil_state); @@ -1209,9 +1206,9 @@ void PythonMemoryTracer::export_memory_history(const std::string path) { PyObject* py_filename = PyUnicode_FromString(path.c_str()); // Call the function with arguments (e.g., a file path) PyObject* args = PyTuple_Pack(1, py_filename); - PyObject* result = PyObject_Call(snapshot_func.get(), args, NULL); + PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); Py_DECREF(args); - if (result == NULL) { + if (result == nullptr) { return; } PyGILState_Release(gil_state); diff --git a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp index d07e1fd2309e..47454c6eca25 100644 --- a/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp +++ b/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.cpp @@ -31,7 +31,7 @@ dnnl::engine& Engine::getEngine() { static dnnl::graph::allocator alloc{ pytorch_default_allocator, pytorch_default_deallocator}; static dnnl::engine cpu_engine = dnnl::graph::make_engine_with_allocator( - dnnl::engine::kind::cpu, /* device_id = */ 0, alloc); + dnnl::engine::kind::cpu, /* index = */ 0, alloc); return cpu_engine; } diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 905c69a47966..02e64d196151 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -18,9 +18,7 @@ TORCH_DECLARE_bool(torch_jit_enable_expanded_stacks); TORCH_DECLARE_bool(torch_jit_expanded_stacks_mangled); -namespace torch::jit { - -namespace interpreter { +namespace torch::jit::interpreter { template Ttarget safe_narrow_cast(Tsource v) { @@ -64,7 +62,7 @@ struct NodeSourceInfo { const char* func_name_{nullptr}; const char* file_name_{nullptr}; size_t line_{0}; - NodeSourceInfo() {} + NodeSourceInfo() = default; }; struct CodeImpl { @@ -1060,5 +1058,4 @@ struct MobileCodeImpl : CodeImpl { bool emit_promoted_ops_; }; -} // namespace interpreter -} // namespace torch::jit +} // namespace torch::jit::interpreter diff --git a/torch/lib/libshm/socket.h b/torch/lib/libshm/socket.h index e3ff98cbc9fb..6b7207eb70a8 100644 --- a/torch/lib/libshm/socket.h +++ b/torch/lib/libshm/socket.h @@ -17,12 +17,12 @@ class Socket { public: int socket_fd; + Socket(const Socket& other) = delete; protected: Socket() { SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0)); } - Socket(const Socket& other) = delete; Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) { other.socket_fd = -1; }; @@ -122,7 +122,7 @@ class ManagerServerSocket : public Socket { SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str())); } - virtual ~ManagerServerSocket() { + ~ManagerServerSocket() override { unlink(socket_path.c_str()); } From 64ac41f68d4c1c156c356de3093488e1ee920997 Mon Sep 17 00:00:00 2001 From: Richard Howell Date: Wed, 9 Apr 2025 12:59:24 +0000 Subject: [PATCH 300/332] [pytorch] add header docs for TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT (#150854) Summary: Add header docs for the experimental TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT feature, and guard behind C10_MOBILE. Reviewed By: albanD Differential Revision: D72572345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150854 Approved by: https://github.com/larryliu0820, https://github.com/zou3519 --- aten/src/ATen/core/library.cpp | 2 +- torch/csrc/jit/mobile/import.cpp | 2 +- torch/library.h | 10 ++++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/core/library.cpp b/aten/src/ATen/core/library.cpp index bdc525dca08c..5dcac2b0e2fb 100644 --- a/aten/src/ATen/core/library.cpp +++ b/aten/src/ATen/core/library.cpp @@ -58,7 +58,7 @@ void Library::reset() { #define ERROR_CONTEXT "(Error occurred while processing ", toString(kind_), " block at ", file_, ":", line_, ")" -#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) namespace detail { std::vector torch_library_initializers; } // namespace detail diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 94f49ac67dc2..089a0df564a0 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -647,7 +647,7 @@ mobile::Module _load_for_mobile( std::optional device, ExtraFilesMap& extra_files, uint64_t module_load_options) { -#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) torch::initialize_torch_libraries(); #endif auto observer = torch::observerConfig().getModuleObserver(); diff --git a/torch/library.h b/torch/library.h index 653a45361a1b..5f6b94439b84 100644 --- a/torch/library.h +++ b/torch/library.h @@ -884,13 +884,19 @@ class TORCH_API Library final { at::OperatorName _parseNameForLib(const char* name_str) const; }; -#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) void initialize_torch_libraries(); #endif namespace detail { -#ifdef TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT +#if defined(TORCH_LIBRARY_THREAD_UNSAFE_LAZY_INIT) && defined(C10_MOBILE) +// This is an experimental feature to defer TorchLibraryInit cost to run either +// at model load time, or when a client application explicitly calls +// torch::initialize_torch_libraries(). +// +// This is not thread safe, the client is required to ensure that libraries +// containing TORCH_LIBRARY initializers are loaded in a thread safe manner. extern std::vector torch_library_initializers; class TorchLibraryInit final { private: From 886d9acb0d0e92b96fafa14b0a2817531a7e9edb Mon Sep 17 00:00:00 2001 From: Antoine Broyelle Date: Wed, 9 Apr 2025 13:10:21 +0000 Subject: [PATCH 301/332] [docs] Add 32-bit complex to the list of dtypes (#144590) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144590 Approved by: https://github.com/janeyx99 --- docs/source/tensor_attributes.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/tensor_attributes.rst b/docs/source/tensor_attributes.rst index 46388033db5b..deb85a8773f8 100644 --- a/docs/source/tensor_attributes.rst +++ b/docs/source/tensor_attributes.rst @@ -15,13 +15,14 @@ torch.dtype .. class:: dtype A :class:`torch.dtype` is an object that represents the data type of a -:class:`torch.Tensor`. PyTorch has twelve different data types: +:class:`torch.Tensor`. PyTorch has several different data types: ========================== =========================================== =========================== Data type dtype Legacy Constructors ========================== =========================================== =========================== 32-bit floating point ``torch.float32`` or ``torch.float`` ``torch.*.FloatTensor`` 64-bit floating point ``torch.float64`` or ``torch.double`` ``torch.*.DoubleTensor`` +32-bit complex ``torch.complex32`` or ``torch.chalf`` 64-bit complex ``torch.complex64`` or ``torch.cfloat`` 128-bit complex ``torch.complex128`` or ``torch.cdouble`` 16-bit floating point [1]_ ``torch.float16`` or ``torch.half`` ``torch.*.HalfTensor`` From 229908722096e096b237f28e70820061f62f0c80 Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 9 Apr 2025 14:34:30 +0000 Subject: [PATCH 302/332] [ROCm] Introduce AMD specific inductor gemm tuning (#147315) Replaces https://github.com/pytorch/pytorch/pull/143286 Adds ROCm specific MM configs for max-autotune incorporating ROCm specific triton tuning kernargs such as waves_per_eu, kpack, matrix_instr_nonkdim. This PR also introduces behavior to allow tuning for GROUP_M in triton gemm case. Dynamo huggingface inference benchmarks: `TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS="TRITON" python huggingface.py --performance --inference --bfloat16 --backend=inductor` GEOMEAN speedup (before): | 1.35x GEOMEAN speedup (after): | 1.42x name | Eager - abs latency | old - abs_latency | old - speedup | new - abs_latency | new - speedup -- | -- | -- | -- | -- | -- AlbertForMaskedLM | 26.22 | 26.52 | 98.86% | 24.58 | 106.67% AlbertForQuestionAnswering | 25.96 | 26.40 | 98.33% | 24.10 | 107.73% AllenaiLongformerBase | 21.03 | 10.65 | 197.50% | 10.49 | 200.58% BartForCausalLM | 7.77 | 9.76 | 79.63% | 8.79 | 88.46% BartForConditionalGeneration | 14.44 | 12.86 | 112.26% | 11.96 | 120.70% BertForMaskedLM | 8.10 | 8.82 | 91.89% | 8.57 | 94.53% BertForQuestionAnswering | 6.82 | 7.32 | 93.20% | 7.10 | 96.18% BlenderbotForCausalLM | 10.97 | 11.39 | 96.34% | 10.10 | 108.65% BlenderbotSmallForCausalLM | 5.91 | 5.44 | 108.72% | 4.82 | 122.67% BlenderbotSmallForConditionalGeneration | 12.64 | 9.65 | 130.94% | 9.11 | 138.83% CamemBert | 8.35 | 9.15 | 91.24% | 8.86 | 94.27% DebertaForMaskedLM | 10.92 | 6.09 | 179.44% | 5.90 | 185.05% DebertaForQuestionAnswering | 14.29 | 7.70 | 185.59% | 7.26 | 196.75% DebertaV2ForMaskedLM | 15.47 | 10.22 | 151.32% | 9.34 | 165.55% DebertaV2ForQuestionAnswering | 14.98 | 6.11 | 245.28% | 6.28 | 238.40% DistilBertForMaskedLM | 8.37 | 8.70 | 96.30% | 8.22 | 101.92% DistilBertForQuestionAnswering | 10.21 | 10.54 | 96.88% | 10.39 | 98.36% DistillGPT2 | 8.77 | 6.78 | 129.40% | 6.31 | 138.88% ElectraForCausalLM | 10.32 | 4.70 | 219.45% | 4.60 | 224.29% ElectraForQuestionAnswering | 11.48 | 5.62 | 204.20% | 5.44 | 210.95% GPT2ForSequenceClassification | 6.21 | 5.72 | 108.50% | 5.58 | 111.26% GoogleFnet | 26.51 | 20.81 | 127.37% | 19.91 | 133.11% LayoutLMForMaskedLM | 12.09 | 7.99 | 151.28% | 7.66 | 157.80% LayoutLMForSequenceClassification | 10.62 | 6.49 | 163.67% | 6.25 | 169.95% M2M100ForConditionalGeneration | 14.98 | 10.20 | 146.79% | 9.89 | 151.42% MBartForCausalLM | 7.67 | 9.78 | 78.44% | 8.87 | 86.55% MBartForConditionalGeneration | 13.45 | 12.69 | 105.99% | 12.03 | 111.82% MT5ForConditionalGeneration | 19.96 | 5.32 | 375.37% | 5.08 | 393.01% MegatronBertForCausalLM | 13.22 | 7.86 | 168.07% | 7.18 | 184.01% MegatronBertForQuestionAnswering | 15.62 | 11.81 | 132.21% | 11.02 | 141.68% MobileBertForMaskedLM | 26.63 | 10.82 | 245.99% | 11.95 | 222.73% MobileBertForQuestionAnswering | 23.53 | 7.55 | 311.51% | 9.53 | 247.03% OPTForCausalLM | 7.33 | 7.64 | 95.93% | 7.56 | 96.90% PLBartForCausalLM | 8.73 | 7.63 | 114.40% | 7.37 | 118.58% PLBartForConditionalGeneration | 10.46 | 8.50 | 122.98% | 8.16 | 128.13% PegasusForCausalLM | 7.18 | 7.37 | 97.42% | 6.64 | 108.22% PegasusForConditionalGeneration | 16.47 | 16.66 | 98.87% | 14.18 | 116.13% RobertaForCausalLM | 10.30 | 9.95 | 103.52% | 9.52 | 108.25% RobertaForQuestionAnswering | 6.37 | 7.13 | 89.28% | 6.79 | 93.87% T5ForConditionalGeneration | 12.40 | 6.72 | 184.51% | 6.48 | 191.16% T5Small | 12.02 | 6.66 | 180.55% | 6.32 | 190.33% TrOCRForCausalLM | 14.12 | 13.31 | 106.11% | 12.45 | 113.41% XGLMForCausalLM | 16.48 | 6.23 | 264.52% | 6.35 | 259.51% XLNetLMHeadModel | 74.87 | 62.23 | 120.32% | 57.95 | 129.19% YituTechConvBert | 20.21 | 10.50 | 192.48% | 9.97 | 202.72% We are also seeing improvement ~9% on internal addmm benchmark This PR will also slightly reduce the compilation time on AMD max-autotune as before this change we assess every config with matrix_instr_nonkdim [0, 16] but we remove this and use 16 for all configs with this update. No CI to test the max-autotune perf currently but this will be enabled via https://github.com/pytorch/pytorch/pull/148672 after which we can investigate more tuning updates and config pruning Pull Request resolved: https://github.com/pytorch/pytorch/pull/147315 Approved by: https://github.com/jansel, https://github.com/eellison --- torch/_inductor/kernel/mm_common.py | 10 +- torch/_inductor/select_algorithm.py | 1 + torch/_inductor/template_heuristics.py | 687 +++++++++++++++---------- 3 files changed, 418 insertions(+), 280 deletions(-) diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 663e78dc199c..079d6e83d623 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -72,8 +72,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): not inductor_config.force_same_precision or ((sym_m % 16) == 0 and (sym_n % 16) == 0 and (sym_k % 8) == 0) ) - return dict( - GROUP_M=8, + options_dict = dict( EVEN_K=even_k_symbolic, ALLOW_TF32=allow_tf32, USE_FAST_ACCUM=False, # Option for _scaled_mm @@ -83,6 +82,13 @@ def mm_options(config, sym_m, sym_n, sym_k, layout): **config.kwargs, ) + # If GROUP_M not specified then default to 8 + if "GROUP_M" not in config.kwargs: + group_m = config.kwargs.get("GROUP_M", 8) + options_dict["GROUP_M"] = group_m + + return options_dict + def persistent_mm_options(mat1, mat2): return dict( diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 1fbb9aff8580..35dae177bac7 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -1282,6 +1282,7 @@ def make_kernel_render(out_node): ), "num_stages": num_stages, "num_warps": num_warps, + "GROUP_M": kwargs.get("GROUP_M", -1), "allow_tf32": str(kwargs.get("ALLOW_TF32", None)), "acc_type": str(kwargs.get("ACC_TYPE", None)), }, diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index 400d1ad2b6de..fe6476f317f3 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -1,7 +1,7 @@ from __future__ import annotations +import dataclasses import itertools -from collections import namedtuple from functools import partial from threading import Lock from typing import Any, Callable, TYPE_CHECKING @@ -14,12 +14,59 @@ if TYPE_CHECKING: - from collections.abc import Generator, Sequence + from collections.abc import Generator from triton import Config as TritonConfig -class BaseConfigSingleton(type): +@dataclasses.dataclass +class BaseConfig: + """ + Base Gemm configuration used for most backends (CPU, CUDA) + """ + + block_m: int + block_n: int + block_k: int + num_stages: int + num_warps: int + + +@dataclasses.dataclass +class GemmConfig(BaseConfig): + """ + Gemm configuration used for most backends (CPU, CUDA) + """ + + group_m: int = 8 + + +ConvConfig = BaseConfig + + +@dataclasses.dataclass +class ROCmGemmConfig(GemmConfig): + """ + ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +@dataclasses.dataclass +class ROCmConvConfig(ConvConfig): + """ + ROCm subclass for Conv, with AMD backend specific tuneable kernargs + """ + + matrix_instr_nonkdim: int = 16 + waves_per_eu: int = 0 + kpack: int = 2 + + +class BaseHeuristicSingleton(type): """ Thread-safe implementation of single to be used in the config heuristic subclasses to ensure heavy __init__ calls are not repeatedly run @@ -29,7 +76,7 @@ class BaseConfigSingleton(type): _lock: Lock = Lock() def __call__( - cls: BaseConfigSingleton, *args: Any, **kwargs: Any + cls: BaseHeuristicSingleton, *args: Any, **kwargs: Any ) -> BaseConfigHeuristic: with cls._lock: if cls not in cls._instances: @@ -38,12 +85,7 @@ def __call__( return cls._instances[cls] -Config = namedtuple( - "Config", ["block_m", "block_n", "block_k", "num_stages", "num_warps"] -) - - -class BaseConfigHeuristic(metaclass=BaseConfigSingleton): +class BaseConfigHeuristic(metaclass=BaseHeuristicSingleton): """ Base class for mm_configs, device specific triton kernels config inherit from here """ @@ -52,36 +94,37 @@ def __init__(self) -> None: # List of dictionaries to store the kernel configs. Configs that evaluate to true # will be utilised on the target platform. The configs are as follows: # (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) - self.mm_configs = [ - Config(32, 32, 16, 1, 2), - Config(32, 32, 128, 2, 4), - Config(32, 64, 32, 5, 8), - Config(64, 32, 32, 5, 8), - Config(64, 32, 128, 5, 4), - Config(64, 64, 16, 2, 4), - Config(64, 64, 32, 2, 4), - Config(64, 64, 64, 3, 8), - Config(64, 64, 128, 5, 4), - Config(64, 128, 32, 3, 4), - Config(64, 128, 32, 4, 8), - Config(64, 128, 64, 3, 4), - Config(64, 128, 128, 4, 4), - Config(128, 64, 32, 3, 4), - Config(128, 64, 32, 4, 8), - Config(128, 128, 32, 2, 8), - Config(128, 128, 32, 3, 4), - Config(128, 128, 64, 3, 4), - Config(128, 128, 64, 5, 8), + self.mm_configs: list[BaseConfig] = [ + GemmConfig(32, 32, 16, 1, 2), + GemmConfig(32, 32, 128, 2, 4), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(64, 32, 128, 5, 4), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(64, 64, 128, 5, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(64, 128, 64, 3, 4), + GemmConfig(64, 128, 128, 4, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(128, 128, 32, 3, 4), + GemmConfig(128, 128, 64, 3, 4), + GemmConfig(128, 128, 64, 5, 8), ] # Exhaustive search for mm configs - self.exhaustive_configs = [ - Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.exhaustive_configs: list[BaseConfig] = [ + GemmConfig(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps, group_m) for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( [16, 32, 64, 128, 256], repeat=3 ) for num_stages in [1, 2, 3, 4, 5] for num_warps in [2, 4, 8] + for group_m in [8] ] # these are only used in tuned_mm when AutoHeuristic is enabled @@ -89,220 +132,237 @@ def __init__(self) -> None: # when the learned heuristic is used, the learned heuristic reduces the number of configs down to 10 # which saves compilation time (since less configs are autotuned) and potentially increase performance # because the learned heuristic might predict a config that is not part mm_configs - self.extra_mm_configs = [ - Config(16, 32, 16, 3, 2), - Config(16, 32, 32, 4, 2), - Config(16, 32, 32, 5, 2), - Config(64, 64, 128, 3, 4), - Config(128, 64, 32, 2, 2), - Config(128, 64, 64, 3, 8), - Config(128, 64, 128, 4, 8), - Config(128, 128, 32, 4, 4), - Config(128, 128, 64, 3, 8), - Config(128, 128, 64, 5, 4), + self.extra_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 32, 16, 3, 2), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(64, 64, 128, 3, 4), + GemmConfig(128, 64, 32, 2, 2), + GemmConfig(128, 64, 64, 3, 8), + GemmConfig(128, 64, 128, 4, 8), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 64, 5, 4), ] - self.int8_mm_configs = [ - Config(64, 64, 32, 2, 4), - Config(64, 128, 32, 3, 4), - Config(128, 64, 32, 3, 4), - Config(64, 128, 32, 4, 8), - Config(128, 64, 32, 4, 8), - Config(64, 32, 32, 5, 8), - Config(32, 64, 32, 5, 8), - Config(128, 128, 32, 2, 8), - Config(64, 64, 64, 3, 8), - Config(128, 256, 128, 3, 8), - Config(256, 128, 128, 3, 8), + self.int8_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 128, 32, 3, 4), + GemmConfig(128, 64, 32, 3, 4), + GemmConfig(64, 128, 32, 4, 8), + GemmConfig(128, 64, 32, 4, 8), + GemmConfig(64, 32, 32, 5, 8), + GemmConfig(32, 64, 32, 5, 8), + GemmConfig(128, 128, 32, 2, 8), + GemmConfig(64, 64, 64, 3, 8), + GemmConfig(128, 256, 128, 3, 8), + GemmConfig(256, 128, 128, 3, 8), ] - self.mixed_mm_configs = [ - Config(16, 128, 256, 3, 4), - Config(16, 128, 256, 5, 8), + self.mixed_mm_configs: list[BaseConfig] = [ + GemmConfig(16, 128, 256, 3, 4), + GemmConfig(16, 128, 256, 5, 8), ] - self.persistent_mm_configs = [ - Config(128, 256, 64, 3, 8), - Config(128, 128, 64, 3, 8), - Config(128, 128, 128, 3, 8), - Config(128, 128, 128, 3, 4), - Config(128, 128, 64, 4, 8), + self.persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 64, 3, 8), + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 64, 4, 8), ] - self.scaled_mm_configs = [ - Config(128, 256, 32, 3, 8), - Config(256, 128, 32, 3, 8), - Config(256, 64, 32, 4, 4), - Config(64, 256, 32, 4, 4), - Config(128, 128, 32, 4, 4), - Config(128, 64, 32, 4, 4), - Config(64, 128, 32, 4, 4), - Config(128, 32, 32, 4, 4), - Config(64, 32, 32, 5, 2), - Config(256, 128, 128, 3, 8), - Config(256, 64, 128, 4, 4), - Config(64, 256, 128, 4, 4), - Config(128, 128, 128, 4, 4), - Config(128, 64, 64, 4, 4), - Config(64, 128, 64, 4, 4), - Config(128, 32, 64, 4, 4), - Config(64, 32, 64, 5, 2), - Config(16, 32, 32, 2, 2), - Config(16, 64, 32, 2, 2), - Config(16, 128, 32, 2, 4), - Config(16, 256, 32, 2, 4), - Config(16, 32, 64, 2, 2), - Config(16, 64, 64, 2, 2), - Config(16, 128, 64, 2, 4), - Config(16, 256, 64, 2, 4), - Config(32, 32, 32, 2, 2), - Config(32, 64, 32, 2, 2), - Config(32, 128, 32, 2, 4), - Config(32, 256, 32, 2, 4), - Config(32, 32, 64, 2, 2), - Config(32, 64, 64, 2, 2), - Config(32, 128, 64, 2, 4), - Config(32, 256, 64, 2, 4), - Config(16, 32, 32, 3, 2), - Config(16, 64, 32, 3, 2), - Config(16, 128, 32, 3, 4), - Config(16, 256, 32, 3, 4), - Config(16, 32, 64, 3, 2), - Config(16, 64, 64, 3, 2), - Config(16, 128, 64, 3, 4), - Config(16, 256, 64, 3, 4), - Config(32, 32, 32, 3, 2), - Config(32, 64, 32, 3, 2), - Config(32, 128, 32, 3, 4), - Config(32, 256, 32, 3, 4), - Config(32, 32, 64, 3, 2), - Config(32, 64, 64, 3, 2), - Config(32, 128, 64, 3, 4), - Config(32, 256, 64, 3, 4), - Config(16, 32, 32, 4, 2), - Config(16, 64, 32, 4, 2), - Config(16, 128, 32, 4, 4), - Config(16, 256, 32, 4, 4), - Config(16, 32, 64, 4, 2), - Config(16, 64, 64, 4, 2), - Config(16, 128, 64, 4, 4), - Config(16, 256, 64, 4, 4), - Config(32, 32, 32, 4, 2), - Config(32, 64, 32, 4, 2), - Config(32, 128, 32, 4, 4), - Config(32, 256, 32, 4, 4), - Config(32, 32, 64, 4, 2), - Config(32, 64, 64, 4, 2), - Config(32, 128, 64, 4, 4), - Config(32, 256, 64, 4, 4), - Config(16, 32, 32, 5, 2), - Config(16, 64, 32, 5, 2), - Config(16, 128, 32, 5, 4), - Config(16, 256, 32, 5, 4), - Config(16, 32, 64, 5, 2), - Config(16, 64, 64, 5, 2), - Config(16, 128, 64, 5, 4), - Config(16, 256, 64, 5, 4), - Config(32, 32, 32, 5, 2), - Config(32, 64, 32, 5, 2), - Config(32, 128, 32, 5, 4), - Config(32, 256, 32, 5, 4), - Config(32, 32, 64, 5, 2), - Config(32, 64, 64, 5, 2), - Config(32, 128, 64, 5, 4), - Config(32, 256, 64, 5, 4), - Config(16, 32, 32, 6, 2), - Config(16, 64, 32, 6, 2), - Config(16, 128, 32, 6, 4), - Config(16, 256, 32, 6, 4), - Config(16, 32, 64, 6, 2), - Config(16, 64, 64, 6, 2), - Config(16, 128, 64, 6, 4), - Config(16, 256, 64, 6, 4), - Config(32, 32, 32, 6, 2), - Config(32, 64, 32, 6, 2), - Config(32, 128, 32, 6, 4), - Config(32, 256, 32, 6, 4), - Config(32, 32, 64, 6, 2), - Config(32, 64, 64, 6, 2), - Config(32, 128, 64, 6, 4), - Config(32, 256, 64, 6, 4), + self.scaled_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 256, 32, 3, 8), + GemmConfig(256, 128, 32, 3, 8), + GemmConfig(256, 64, 32, 4, 4), + GemmConfig(64, 256, 32, 4, 4), + GemmConfig(128, 128, 32, 4, 4), + GemmConfig(128, 64, 32, 4, 4), + GemmConfig(64, 128, 32, 4, 4), + GemmConfig(128, 32, 32, 4, 4), + GemmConfig(64, 32, 32, 5, 2), + GemmConfig(256, 128, 128, 3, 8), + GemmConfig(256, 64, 128, 4, 4), + GemmConfig(64, 256, 128, 4, 4), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 64, 64, 4, 4), + GemmConfig(64, 128, 64, 4, 4), + GemmConfig(128, 32, 64, 4, 4), + GemmConfig(64, 32, 64, 5, 2), + GemmConfig(16, 32, 32, 2, 2), + GemmConfig(16, 64, 32, 2, 2), + GemmConfig(16, 128, 32, 2, 4), + GemmConfig(16, 256, 32, 2, 4), + GemmConfig(16, 32, 64, 2, 2), + GemmConfig(16, 64, 64, 2, 2), + GemmConfig(16, 128, 64, 2, 4), + GemmConfig(16, 256, 64, 2, 4), + GemmConfig(32, 32, 32, 2, 2), + GemmConfig(32, 64, 32, 2, 2), + GemmConfig(32, 128, 32, 2, 4), + GemmConfig(32, 256, 32, 2, 4), + GemmConfig(32, 32, 64, 2, 2), + GemmConfig(32, 64, 64, 2, 2), + GemmConfig(32, 128, 64, 2, 4), + GemmConfig(32, 256, 64, 2, 4), + GemmConfig(16, 32, 32, 3, 2), + GemmConfig(16, 64, 32, 3, 2), + GemmConfig(16, 128, 32, 3, 4), + GemmConfig(16, 256, 32, 3, 4), + GemmConfig(16, 32, 64, 3, 2), + GemmConfig(16, 64, 64, 3, 2), + GemmConfig(16, 128, 64, 3, 4), + GemmConfig(16, 256, 64, 3, 4), + GemmConfig(32, 32, 32, 3, 2), + GemmConfig(32, 64, 32, 3, 2), + GemmConfig(32, 128, 32, 3, 4), + GemmConfig(32, 256, 32, 3, 4), + GemmConfig(32, 32, 64, 3, 2), + GemmConfig(32, 64, 64, 3, 2), + GemmConfig(32, 128, 64, 3, 4), + GemmConfig(32, 256, 64, 3, 4), + GemmConfig(16, 32, 32, 4, 2), + GemmConfig(16, 64, 32, 4, 2), + GemmConfig(16, 128, 32, 4, 4), + GemmConfig(16, 256, 32, 4, 4), + GemmConfig(16, 32, 64, 4, 2), + GemmConfig(16, 64, 64, 4, 2), + GemmConfig(16, 128, 64, 4, 4), + GemmConfig(16, 256, 64, 4, 4), + GemmConfig(32, 32, 32, 4, 2), + GemmConfig(32, 64, 32, 4, 2), + GemmConfig(32, 128, 32, 4, 4), + GemmConfig(32, 256, 32, 4, 4), + GemmConfig(32, 32, 64, 4, 2), + GemmConfig(32, 64, 64, 4, 2), + GemmConfig(32, 128, 64, 4, 4), + GemmConfig(32, 256, 64, 4, 4), + GemmConfig(16, 32, 32, 5, 2), + GemmConfig(16, 64, 32, 5, 2), + GemmConfig(16, 128, 32, 5, 4), + GemmConfig(16, 256, 32, 5, 4), + GemmConfig(16, 32, 64, 5, 2), + GemmConfig(16, 64, 64, 5, 2), + GemmConfig(16, 128, 64, 5, 4), + GemmConfig(16, 256, 64, 5, 4), + GemmConfig(32, 32, 32, 5, 2), + GemmConfig(32, 64, 32, 5, 2), + GemmConfig(32, 128, 32, 5, 4), + GemmConfig(32, 256, 32, 5, 4), + GemmConfig(32, 32, 64, 5, 2), + GemmConfig(32, 64, 64, 5, 2), + GemmConfig(32, 128, 64, 5, 4), + GemmConfig(32, 256, 64, 5, 4), + GemmConfig(16, 32, 32, 6, 2), + GemmConfig(16, 64, 32, 6, 2), + GemmConfig(16, 128, 32, 6, 4), + GemmConfig(16, 256, 32, 6, 4), + GemmConfig(16, 32, 64, 6, 2), + GemmConfig(16, 64, 64, 6, 2), + GemmConfig(16, 128, 64, 6, 4), + GemmConfig(16, 256, 64, 6, 4), + GemmConfig(32, 32, 32, 6, 2), + GemmConfig(32, 64, 32, 6, 2), + GemmConfig(32, 128, 32, 6, 4), + GemmConfig(32, 256, 32, 6, 4), + GemmConfig(32, 32, 64, 6, 2), + GemmConfig(32, 64, 64, 6, 2), + GemmConfig(32, 128, 64, 6, 4), + GemmConfig(32, 256, 64, 6, 4), ] - self.scaled_persistent_mm_configs = [ - Config(128, 128, 64, 3, 8), - Config(128, 128, 128, 3, 8), - Config(128, 128, 128, 4, 8), - Config(128, 128, 128, 4, 4), - Config(128, 128, 128, 3, 4), - Config(128, 128, 128, 5, 4), - Config(128, 128, 128, 5, 8), - Config(128, 128, 128, 6, 8), - Config(128, 128, 64, 4, 8), + self.scaled_persistent_mm_configs: list[BaseConfig] = [ + GemmConfig(128, 128, 64, 3, 8), + GemmConfig(128, 128, 128, 3, 8), + GemmConfig(128, 128, 128, 4, 8), + GemmConfig(128, 128, 128, 4, 4), + GemmConfig(128, 128, 128, 3, 4), + GemmConfig(128, 128, 128, 5, 4), + GemmConfig(128, 128, 128, 5, 8), + GemmConfig(128, 128, 128, 6, 8), + GemmConfig(128, 128, 64, 4, 8), ] # TODO: Unify with other gemm patterns, mm_plus_mm currently follows # slightly different pattern than rest - self.mm_plus_mm_configs = [ - Config(64, 64, 32, 2, 4), - Config(64, 64, 32, 3, 8), - Config(64, 64, 32, 4, 16), - Config(64, 32, 32, 4, 8), - Config(32, 64, 32, 4, 8), - Config(128, 128, 32, 1, 8), - Config(64, 64, 64, 1, 8), - Config(32, 32, 128, 1, 8), - Config(64, 64, 16, 2, 4), - Config(32, 32, 16, 1, 2), + self.mm_plus_mm_configs: list[BaseConfig] = [ + GemmConfig(64, 64, 32, 2, 4), + GemmConfig(64, 64, 32, 3, 8), + GemmConfig(64, 64, 32, 4, 16), + GemmConfig(64, 32, 32, 4, 8), + GemmConfig(32, 64, 32, 4, 8), + GemmConfig(128, 128, 32, 1, 8), + GemmConfig(64, 64, 64, 1, 8), + GemmConfig(32, 32, 128, 1, 8), + GemmConfig(64, 64, 16, 2, 4), + GemmConfig(32, 32, 16, 1, 2), ] - self.conv_configs = [ - Config(64, 256, 16, 2, 4), - Config(256, 64, 16, 2, 4), - Config(1024, 16, 16, 1, 8), - Config(128, 128, 32, 2, 8), - Config(64, 64, 32, 2, 4), - Config(64, 256, 32, 2, 8), - Config(256, 64, 32, 2, 8), + self.conv_configs: list[BaseConfig] = [ + ConvConfig(64, 256, 16, 2, 4), + ConvConfig(256, 64, 16, 2, 4), + ConvConfig(1024, 16, 16, 1, 8), + ConvConfig(128, 128, 32, 2, 8), + ConvConfig(64, 64, 32, 2, 4), + ConvConfig(64, 256, 32, 2, 8), + ConvConfig(256, 64, 32, 2, 8), ] def _finalize_mm_configs( self, - configs: list[Config], + configs: list[BaseConfig], ) -> Generator[TritonConfig, None, None]: """ Finalizes configs after scaling, applying additional constraints. """ - used = OrderedSet[Config]() + used: OrderedSet[tuple[int, ...]] = OrderedSet() max_mm_configs = config.test_configs.max_mm_configs - for block_m, block_n, block_k, num_stages, num_warps in configs: + for conf in configs: # Each warp computes a 16x16 tile = 256 elements - num_warps = min(num_warps, block_m * block_n // 256) - - if ( - Config(block_m, block_n, block_k, num_stages, num_warps) - ) not in used and (max_mm_configs is None or len(used) < max_mm_configs): - used.add(Config(block_m, block_n, block_k, num_stages, num_warps)) - yield self.triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - ) + num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + num_warps, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": num_warps, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) def _scale_mm_configs( self, m: int, n: int, k: int, - configs: Sequence[Config], + configs: list[BaseConfig], scale: float, has_int8_tensor: bool, exclude: Callable[[int, int, int], bool], - ) -> list[Config]: + ) -> list[BaseConfig]: """ Scales and filters matrix multiplication configs based on input size. """ @@ -341,7 +401,8 @@ def _scale_mm_configs( scaled_configs = [] for c in configs: - scaled_config = c._replace( + scaled_config = dataclasses.replace( + c, block_m=max(min(int(c.block_m * scale), m), min_block_size), block_n=max(min(int(c.block_n * scale), n), min_block_size), block_k=max(min(int(c.block_k * scale), k), min_block_size_k), @@ -359,7 +420,7 @@ def preprocess_mm_configs( m: int, n: int, k: int, - configs: Sequence[Config], + configs: list[BaseConfig], has_int8_tensor: bool = False, scale: int = 1, exclude: Callable[[int, int, int], bool] = lambda m, n, k: False, @@ -430,90 +491,160 @@ def __init__(self) -> None: self.default_num_stages = get_backend_num_stages() + self.mm_configs: list[BaseConfig] = [ + ROCmGemmConfig( + 16, 16, 256, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(32, 16, 256, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig( + 32, 32, 16, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(32, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(32, 64, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 64, 16, 128, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(64, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 32, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 32, 128, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 16, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(64, 64, 64, self.default_num_stages, 4, group_m=4), + ROCmGemmConfig(64, 64, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(64, 64, 256, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 64, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(64, 128, 32, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(64, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(64, 128, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(128, 32, 32, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig(128, 32, 64, self.default_num_stages, 4, group_m=8), + ROCmGemmConfig( + 128, 64, 32, self.default_num_stages, 4, group_m=8, waves_per_eu=2 + ), + ROCmGemmConfig(128, 64, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 64, 128, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 128, 32, self.default_num_stages, 8, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 4, group_m=16), + ROCmGemmConfig(128, 128, 64, self.default_num_stages, 8, group_m=8), + ROCmGemmConfig(128, 128, 128, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig( + 128, 256, 32, self.default_num_stages, 4, group_m=16, waves_per_eu=2 + ), + ROCmGemmConfig(128, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 64, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig( + 256, 128, 32, self.default_num_stages, 4, group_m=4, waves_per_eu=2 + ), + ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), + ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ] + # Exhaustive search for mm configs - self.exhaustive_configs = [ - Config(BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) + self.exhaustive_configs: list[BaseConfig] = [ + ROCmGemmConfig( + BLOCK_M, + BLOCK_N, + BLOCK_K, + num_stages, + num_warps, + group_m, + matrix_instr_nonkdim, + waves_per_eu, + kpack, + ) for BLOCK_M, BLOCK_N, BLOCK_K in itertools.product( [16, 32, 64, 128, 256], repeat=3 ) for num_stages in [1, self.default_num_stages] for num_warps in [4, 8] + for group_m in [4, 8, 16] + for matrix_instr_nonkdim in [0, 16] + for waves_per_eu in [0, 2] + for kpack in [2] ] def _filter_configs( - self, configs: list[Config], new_num_stages: int - ) -> list[Config]: - filtered_configs = [ - c._replace(num_stages=self.default_num_stages) for c in configs - ] - return filtered_configs + self, configs: list[BaseConfig], new_num_stages: int + ) -> list[BaseConfig]: + # TODO: _filter_configs can be removed once backend specific configs are added + # for all methods + for c in configs: + c.num_stages = self.default_num_stages + return configs def _finalize_mm_configs( self, - configs: list[Config], + configs: list[BaseConfig], ) -> Generator[TritonConfig, None, None]: - used = OrderedSet[tuple[Config, int, int]]() + """ + Finalizes configs after scaling, applying additional constraints. + """ + used: OrderedSet[tuple[int, ...]] = OrderedSet() max_mm_configs = config.test_configs.max_mm_configs - for block_m, block_n, block_k, num_stages, num_warps in configs: - # each warp computes 16x16 tile = 256 - num_warps = min(num_warps, block_m * block_n // 256) - kpack = 2 - for matrix_instr_nonkdim in [0, 16]: - if matrix_instr_nonkdim != 0 and ( - block_m % matrix_instr_nonkdim != 0 - or block_n % matrix_instr_nonkdim != 0 - ): - # block_m and block_n must be a multiple of matrix_instr_nonkdim - continue - if ( - Config( - block_m, - block_n, - block_k, - num_stages, - num_warps, - ), - matrix_instr_nonkdim, - kpack, - ) not in used and ( - max_mm_configs is None or len(used) < max_mm_configs - ): - used.add( - ( - Config( - block_m, - block_n, - block_k, - num_stages, - num_warps, - ), - matrix_instr_nonkdim, - kpack, - ) - ) - - yield self.triton_config( - BLOCK_M=block_m, - BLOCK_N=block_n, - BLOCK_K=block_k, - num_stages=num_stages, - num_warps=num_warps, - matrix_instr_nonkdim=matrix_instr_nonkdim, - kpack=kpack, - ) - def get_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.mm_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) + for conf in configs: + # Each warp computes a 16x16 tile = 256 elements + conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) - def get_exhaustive_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: - filtered_configs = self._filter_configs( - self.exhaustive_configs, self.default_num_stages - ) - return partial(self.preprocess_mm_configs, configs=filtered_configs) + # Defaults for AMD triton backend kern args if not set + matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu = getattr(conf, "waves_per_eu", 0) + kpack = getattr(conf, "kpack", 2) + + if matrix_instr_nonkdim != 0 and ( + conf.block_m % matrix_instr_nonkdim != 0 + or conf.block_n % matrix_instr_nonkdim != 0 + ): + # block_m and block_n must be a multiple of matrix_instr_nonkdim + continue + + # Construct key for finding duplicate configs + key: tuple[int, ...] = ( + conf.block_m, + conf.block_n, + conf.block_k, + conf.num_stages, + conf.num_warps, + waves_per_eu, + matrix_instr_nonkdim, + kpack, + ) + + # Check if gemm specific arg exists - add to key if does + group_m = getattr(conf, "group_m", None) + if group_m is not None: + key += (group_m,) + + if waves_per_eu != 0: + waves_per_eu = int(8 // conf.num_warps) + + if key not in used and ( + max_mm_configs is None or len(used) < max_mm_configs + ): + used.add(key) + kwargs = { + "BLOCK_M": conf.block_m, + "BLOCK_N": conf.block_n, + "BLOCK_K": conf.block_k, + "num_stages": conf.num_stages, + "num_warps": conf.num_warps, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + } + if group_m is not None: + kwargs["GROUP_M"] = group_m + yield self.triton_config(**kwargs) def get_extra_mm_configs(self) -> partial[Generator[TritonConfig, None, None]]: filtered_configs = self._filter_configs( From 246f3b6530bb027409efef666ec255c58c49e950 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 8 Apr 2025 22:37:31 -0700 Subject: [PATCH 303/332] [Quant][PT2E][X86] enable qconv1d-relu fusion (#150751) **Summary** As the title. - The `conv1d - relu` pattern will be annotated by the `X86InductorQuantizer`. - The pattern will be fused as `qconv_pointwise` during lowering. **Test plan** ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_qconv1d_relu_cpu ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150751 Approved by: https://github.com/jerryzh168, https://github.com/leslie-fang-intel --- aten/src/ATen/native/quantized/cpu/qconv.cpp | 37 +++-- aten/src/ATen/native/quantized/library.cpp | 2 + test/inductor/test_cpu_cpp_wrapper.py | 2 +- test/inductor/test_mkldnn_pattern_matcher.py | 141 ++++++++++++------ test/quantization/core/test_quantized_op.py | 62 +++++++- .../pt2e/test_x86inductor_quantizer.py | 4 +- torch/_inductor/fx_passes/quantization.py | 74 ++++----- torch/_inductor/graph.py | 2 +- torch/_inductor/mkldnn_ir.py | 41 +++-- torch/_inductor/mkldnn_lowerings.py | 4 +- torch/_meta_registrations.py | 7 +- .../quantizer/x86_inductor_quantizer.py | 9 +- torch/csrc/inductor/aoti_torch/c/shim_cpu.h | 2 +- torch/csrc/inductor/aoti_torch/shim_cpu.cpp | 2 +- 14 files changed, 260 insertions(+), 129 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 46b58e9a38a8..06196043e08d 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1758,23 +1758,26 @@ namespace at::native { std::optional algorithm) { #if AT_MKLDNN_ENABLED() - if (act.dim() == 3 || act.dim() == 5) { - // Conv1D/3D post op check - TORCH_CHECK( - attr == "none", - "quantized pointwise conv", - act.dim()-2, - "d doesn't support unary_post_op fusion. Got unary_post_op: ", - attr, - ".") - } else { - // Conv2D post op check - TORCH_CHECK( - attr == "none" || attr == "relu" || attr == "hardtanh" || attr == "hardswish" || attr == "swish", - "none post_op or post_op relu/hardtanh/hardswish is supported for quantized pointwise conv2d. Got unary_post_op: ", - attr, - ".") + std::vector supported_postop = { + "none" + }; + if (act.dim() == 3) { + // Conv1D post op + supported_postop.push_back("relu"); + } else if (act.dim() == 4) { + // Conv2D post op + supported_postop.push_back("relu"); + supported_postop.push_back("hardtanh"); + supported_postop.push_back("hardswish"); + supported_postop.push_back("swish"); } + TORCH_CHECK( + std::find(supported_postop.begin(), supported_postop.end(), attr) != supported_postop.end(), + "Unsupported post op ", + attr, + " for quantized pointwise conv", + act.dim()-2, + "d.") return _quantized_convolution_onednn( act, act_scale, act_zero_point, weight, weight_scales, weight_zero_points, @@ -2079,6 +2082,8 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise"), at::native::QConvoneDNN::run_pointwise); m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor); m.impl(TORCH_SELECTIVE_NAME("onednn::qconv3d_pointwise"), at::native::QConvoneDNN::run_pointwise); + m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise"), at::native::QConvoneDNN::run_pointwise); + m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor); // Conv2D with binary postop m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"), at::native::QConvoneDNN::run_pointwise_binary); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 27c484c62bb9..8a70fbffc00d 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -258,6 +258,8 @@ TORCH_LIBRARY(onednn, m) { m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); // Conv2D with binary postop m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 8bd687c42a5f..7716898c5424 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -297,7 +297,7 @@ class BaseTest(NamedTuple): condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, func_inputs=[ [ - "aoti_torch_cpu__qconv2d_pointwise_tensor", + "aoti_torch_cpu__qconv_pointwise_tensor", "torch.ops.quantized.max_pool2d", "aoti_torch_cpu__qlinear_pointwise_tensor", ] diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 4b184aee4aba..e3727df7dc87 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1043,14 +1043,14 @@ def matcher_check_fn(): # int8_mixed_bf16: [dequant_node, optional(convert_element_type_4), # dequantize_per_channel, optional(convert_element_type_3), clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 18 if int8_mixed_bf16 else 12, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 3 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 3 ) self._test_common( @@ -1104,7 +1104,7 @@ def _qconv2d_unary_test_helper( device="cpu", int8_mixed_bf16=False, unary_op=torch.nn.ReLU(), - qconv2d_unary_matcher_nodes=None, + qconv_unary_matcher_nodes=None, ): class M(torch.nn.Module): def __init__( @@ -1133,20 +1133,20 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 2 self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) - if qconv2d_unary_matcher_nodes: + if qconv_unary_matcher_nodes: self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_nodes"], - 0 if TEST_ACL else qconv2d_unary_matcher_nodes, + counters["inductor"]["qconv_unary_matcher_nodes"], + 0 if TEST_ACL else qconv_unary_matcher_nodes, ) self._test_common( @@ -1230,7 +1230,7 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardtanh(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1248,7 +1248,7 @@ def test_qconv2d_hardtanh_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.Hardtanh(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1282,7 +1282,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.Hardswish(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=17, + qconv_unary_matcher_nodes=17, ) @skipIfNoDynamoSupport @@ -1301,7 +1301,7 @@ def test_qconv2d_hardswish_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.Hardswish(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=17, + qconv_unary_matcher_nodes=17, ) @skipIfNoDynamoSupport @@ -1335,7 +1335,7 @@ def test_qconv2d_silu_int8_mixed_bf16_cpu(self): self._qconv2d_unary_test_helper( unary_op=torch.nn.SiLU(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) @skipIfNoDynamoSupport @@ -1354,7 +1354,7 @@ def test_qconv2d_silu_int8_mixed_bf16_xpu(self): device="xpu", unary_op=torch.nn.SiLU(), int8_mixed_bf16=True, - qconv2d_unary_matcher_nodes=11, + qconv_unary_matcher_nodes=11, ) def _qconv2d_add_test_helper( @@ -1415,7 +1415,7 @@ def forward(self, x): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 4 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 self.assertEqual( @@ -1512,7 +1512,7 @@ def forward(self, x, x2, x3): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 self.assertEqual( @@ -1611,7 +1611,7 @@ def forward(self, x1, x2): def matcher_check_fn(): # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 1 self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 0 self.assertEqual( @@ -1667,14 +1667,14 @@ def forward(self, x: torch.Tensor): def matcher_check_fn(): self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 4 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 3, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 4 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 4 ) self._test_common( @@ -1840,23 +1840,23 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 4 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 4 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, quantize_per_tensor] self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_nodes"], + counters["inductor"]["qconv_unary_matcher_nodes"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 1 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1 ) self._test_common( @@ -1895,16 +1895,16 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 1 # [convert_element_type_1, sub, mul_1, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) # 2. QConv2D Unary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default, relu, div_1, round_2, add_1, clamp_min_1, clamp_max_1, convert_element_type_2] self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) self._test_common( @@ -1994,10 +1994,10 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, quantize_per_tensor] @@ -2063,10 +2063,10 @@ def matcher_check_fn(): # 1. Dequant-conv pattern matched in quantization weight prepack * 2 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 8 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 8 ) # 2. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, dequantize_per_tensor, add_3, relu, quantize_per_tensor] @@ -2135,10 +2135,10 @@ def matcher_check_fn(): # 2. Dequant-conv pattern matched in quantization weight prepack * 3 # [dequantize_per_tensor, dequantize_per_channel, clone, convolution] self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 3 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 3 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"], 12 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"], 12 ) # 3. Qconv2d Binary fusion in post-grad fusion pass * 1 # [qconv2d_pointwise_default_1, add_3] @@ -2175,6 +2175,59 @@ def test_qconv2d_dequant_promotion_cpu(self): def test_qconv2d_dequant_promotion_xpu(self): self._test_qconv2d_dequant_promotion_helper(device="xpu") + @skipIfNoDynamoSupport + @skipIfNoONEDNN + def test_qconv1d_relu_cpu(self): + r""" + This testcase will quantize Conv1d->ReLU pattern. + """ + device = "cpu" + unary_op = torch.nn.ReLU() + + class M(torch.nn.Module): + def __init__( + self, + ): + super().__init__() + self.conv = torch.nn.Conv1d(3, 128, kernel_size=3, stride=1) + self.unary_fn = copy.deepcopy(unary_op) + self.conv2 = torch.nn.Conv1d( + 128, 128, kernel_size=3, stride=1, bias=False + ) + self.unary_fn2 = copy.deepcopy(unary_op) + + def forward(self, x): + tmp = self.unary_fn(self.conv(x)) + return self.unary_fn2(self.conv2(tmp)) + + mod = M().eval().to(device=device) + v = ( + torch.randn((1, 3, 8), dtype=torch.float32, requires_grad=False) + .add(1) + .to(device=device) + ) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 + ) + # 2. QConv2D Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv_unary_matcher_count"], + 0 if TEST_ACL else 2, + ) + self.assertEqual( + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 + ) + + self._test_common( + mod, + (v,), + check_quantization=True, + matcher_check_fn=matcher_check_fn, + ) + def _qlinear_test_helper( self, inputs, @@ -3211,14 +3264,14 @@ def matcher_check_fn(): 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 1, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 1, ) @@ -3310,14 +3363,14 @@ def matcher_check_fn(): counters["inductor"]["qcat_matcher_count"], 0 if TEST_ACL else 1 ) self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 2 ) self.assertEqual( - counters["inductor"]["qconv2d_unary_matcher_count"], + counters["inductor"]["qconv_unary_matcher_count"], 0 if TEST_ACL else 2, ) self.assertEqual( - counters["inductor"]["qconv2d_unary_lower_count"], 0 if TEST_ACL else 2 + counters["inductor"]["qconv_unary_lower_count"], 0 if TEST_ACL else 2 ) self._test_common( @@ -4296,7 +4349,7 @@ def forward(self, x): v = torch.randn((2, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(1) if include_ops is None: include_ops = [ - "torch.ops.onednn.qconv2d_pointwise", + "torch.ops.onednn.qconv_pointwise", "torch.ops.quantized.max_pool2d", "torch.ops.onednn.qlinear_pointwise", ] @@ -4335,7 +4388,7 @@ def forward(self, x): def matcher_check_fn(): self.assertEqual( - counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1 + counters["inductor"]["qconv_weight_prepack_matcher_count"], 1 ) self._test_common( diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 33c0c932ea05..070f341faf13 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -7010,7 +7010,7 @@ def test_qconv2d_pt2e(self): if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): # Remove some test combination to reduce UT test time continue - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7123,7 +7123,7 @@ def test_qconv2d_relu_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7174,7 +7174,7 @@ def test_qconv2d_hardtanh_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7225,7 +7225,7 @@ def test_qconv2d_silu_pt2e(self): output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7277,7 +7277,7 @@ def test_qconv2d_hardswish_pt2e(self): options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) for groups, use_bias, use_channelwise, output_dtype in options: - qconv = torch.ops.onednn.qconv2d_pointwise + qconv = torch.ops.onednn.qconv_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( input_channels_per_group * groups, @@ -7480,6 +7480,58 @@ def test_qconv2d_sum_relu_float_output_pt2e(self): qconv_x2_dtype=qconv_x2_dtype, ) + # Test qconv1d with post op relu + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv1d_relu_pt2e(self): + input_channels_per_group = 2 + output_channels_per_group = 2 + groups_list = [1, 10] + input_feature_map_shape = (10,) + kernels = (3,) + strides = (2,) + pads = (1,) + dilations = (1,) + W_scale = [1.5] + W_zero_point = [0] + use_bias_list = [False, True] + use_channelwise_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + for groups, use_bias, use_channelwise, output_dtype in options: + qconv = torch.ops.onednn.qconv_pointwise + qconv_prepack = torch.ops.onednn.qconv_prepack + conv_op = torch.nn.Conv1d( + input_channels_per_group * groups, + output_channels_per_group * groups, + kernels, + strides, + pads, + dilations, + groups, + ) + pointwise_post_op = PointwisePostOp(unary_attr="relu") + self._test_qconv_impl_cpu_tensor( + qconv, + qconv_prepack, + conv_op, + input_channels_per_group=input_channels_per_group, + input_feature_map_shape=input_feature_map_shape, + output_channels_per_group=output_channels_per_group, + groups=groups, + kernels=kernels, + strides=strides, + pads=pads, + dilations=dilations, + W_scale=W_scale, + W_zero_point=W_zero_point, + use_bias=use_bias, + post_op=pointwise_post_op, + use_channelwise=use_channelwise, + qconv_output_dtype=output_dtype, + ) + + class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), channels=st.integers(1, 64), diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 1c14ded72fe9..e0fcbbc9b515 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2840,13 +2840,13 @@ def test_lowering_to_x86(self): ) node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, - torch.ops.onednn.qconv2d_pointwise.default: 6, + torch.ops.onednn.qconv_pointwise.default: 6, torch.ops.onednn.qconv2d_pointwise.binary: 3, torch.ops.onednn.qlinear_pointwise.default: 1, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, - torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, torch.ops.onednn.qconv2d_pointwise.binary, torch.ops.onednn.qlinear_pointwise.default, ] diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 65eacb32dff4..8df1c1e1f2a6 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -163,9 +163,9 @@ def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False): ) -def get_qconv2d_pt2e_pattern(users=1): +def get_qconv_pt2e_pattern(users=1): return CallFunction( - torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, KeywordArg("x"), KeywordArg("x_scale"), KeywordArg("x_zp"), @@ -345,13 +345,13 @@ def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_val return actual_value == expected_value -def _is_valid_quantized_conv2d_optimization_pattern(): +def _is_valid_quantized_conv_optimization_pattern(): def fn(match): output_dtype = _get_pattern_output_dtype(match) if output_dtype in [torch.float32, torch.bfloat16]: # Only keep matched pattern with same output_dtype qconv_node_after_weight_prepack = filter_nodes( - match.nodes, torch.ops.onednn.qconv2d_pointwise + match.nodes, torch.ops.onednn.qconv_pointwise )[0] return _check_node_kwarg_arg_value( qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype @@ -365,7 +365,7 @@ def _is_valid_qconv_post_op_fusion_pattern(has_binary_post_op=False): return ( _is_valid_qconv_binary_optimization_pattern() if has_binary_post_op - else _is_valid_quantized_conv2d_optimization_pattern() + else _is_valid_quantized_conv_optimization_pattern() ) @@ -374,8 +374,8 @@ def fn(match): if len(match.nodes) != 1: return False return match.nodes[0].target in ( - torch.ops.onednn.qconv2d_pointwise.default, - torch.ops.onednn.qconv2d_pointwise.tensor, + torch.ops.onednn.qconv_pointwise.default, + torch.ops.onednn.qconv_pointwise.tensor, torch.ops.onednn.qconv2d_pointwise.binary, torch.ops.onednn.qconv2d_pointwise.binary_tensor, ) @@ -444,8 +444,8 @@ def qconv(match: Match, *args, **kwargs): postop_args, postop_algorithm, ) - counters["inductor"]["qconv2d_unary_lower_count"] += 1 - counters["inductor"]["qconv2d_unary_lower_nodes"] += len(match.nodes) + counters["inductor"]["qconv_unary_lower_count"] += 1 + counters["inductor"]["qconv_unary_lower_nodes"] += len(match.nodes) return L[computation_op](*computation_args) return qconv @@ -630,7 +630,7 @@ def qlinear_binary(match: Match, *args, **kwargs): def _is_valid_qconv_binary_optimization_pattern(): return _is_valid_quantized_op_binary_optimization_pattern( - torch.ops.onednn.qconv2d_pointwise + torch.ops.onednn.qconv_pointwise ) @@ -801,11 +801,11 @@ def qconv_binary(match: Match, *args, **kwargs): def _register_quantization_unary_lowering(): # QConv2d for users in [1, 2]: - qconv_pattern = get_qconv2d_pt2e_pattern(users) + qconv_pattern = get_qconv_pt2e_pattern(users) _register_quantized_conv_lowering( qconv_pattern, 2, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op ) # QLinear @@ -1375,7 +1375,7 @@ def _find_first_node_in_dequant_pattern(_node): counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes) -def _is_valid_dequant_conv2d_pattern(dtype): +def _is_valid_dequant_conv_pattern(dtype): def _inner(match): # Here we do some further check to ensure: # 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now. @@ -1390,9 +1390,9 @@ def _inner(match): if ( meta_value is None or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu") - or meta_value.dim() != 4 + or meta_value.dim() not in [3, 4] ): - # Only support conv2d now + # Only support conv1d/2d now return False assert dtype in [torch.float32, torch.bfloat16] @@ -1415,7 +1415,7 @@ def _inner(match): def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32): @register_freezing_graph_pattern( pattern, - extra_check=_is_valid_dequant_conv2d_pattern(dtype), + extra_check=_is_valid_dequant_conv_pattern(dtype), pass_number=pass_number, ) def qconv_weight_prepack(match: Match, *args, **kwargs): @@ -1430,7 +1430,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): Insert weight prepack node and change the pattern to: int8 activation | - onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight + onednn.qconv_pointwise <- onednn.qconv_prepack <- int8_weight """ assert dtype in [torch.float32, torch.bfloat16] conv_node = match.output_node() @@ -1532,7 +1532,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): "", # algorithm ) new_conv_node = graph.call_function( - torch.ops.onednn.qconv2d_pointwise.default, args=new_args + torch.ops.onednn.qconv_pointwise.default, args=new_args ) conv_node.replace_all_uses_with(new_conv_node) new_conv_node.meta.update(conv_node.meta) @@ -1549,8 +1549,8 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): if dtype == torch.bfloat16: graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type] graph.erase_node(dequant_per_channel) # type: ignore[arg-type] - counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1 - counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len( + counters["inductor"]["qconv_weight_prepack_matcher_count"] += 1 + counters["inductor"]["qconv_weight_prepack_matcher_nodes"] += len( match.nodes ) @@ -2803,12 +2803,12 @@ def qconv(match: Match, *args, **kwargs): count_key = ( "qconv2d_binary_matcher_count" if has_binary_post_op - else "qconv2d_unary_matcher_count" + else "qconv_unary_matcher_count" ) nodes_key = ( "qconv2d_binary_matcher_nodes" if has_binary_post_op - else "qconv2d_unary_matcher_nodes" + else "qconv_unary_matcher_nodes" ) counters["inductor"][count_key] += 1 counters["inductor"][nodes_key] += len(match.nodes) @@ -2828,13 +2828,13 @@ def _register_qconv_unary_fusion(): PostOpAttr( "none", None, "none", [], "" ): generate_pattern_with_output_quant( - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), ), PostOpAttr( "none", None, "relu", [], "" ): generate_pattern_with_output_quant( generate_pattern_with_unary( - get_qconv2d_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), ), PostOpAttr( @@ -2842,7 +2842,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), @@ -2853,7 +2853,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _hardswish_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2864,7 +2864,7 @@ def _register_qconv_unary_fusion(): ): generate_pattern_with_output_quant( _unary_fusion_pattern( _silu_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2877,21 +2877,21 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 3, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) # Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output conv_unary_replace_float_out_patterns = { PostOpAttr("none", None, "relu", [], ""): generate_pattern_with_unary( - get_qconv2d_pt2e_pattern(1), aten.relu.default + get_qconv_pt2e_pattern(1), aten.relu.default ), PostOpAttr( "none", None, "hardtanh", [], "" ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardtanh_fusion, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), 1, is_bf16, ), @@ -2903,7 +2903,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _hardswish_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2915,7 +2915,7 @@ def _register_qconv_unary_fusion(): ): _may_generate_pattern_with_dtype_convert( _unary_fusion_pattern( _silu_fusion, - get_qconv2d_pt2e_pattern(1 if is_bf16 else 2), + get_qconv_pt2e_pattern(1 if is_bf16 else 2), 2, is_bf16, ), @@ -2929,7 +2929,7 @@ def _register_qconv_unary_fusion(): _register_qconv_post_op_fusion_pass( patterns, 4, # pass_number - torch.ops.onednn.qconv2d_pointwise.default, # computation_op + torch.ops.onednn.qconv_pointwise.default, # computation_op unary_attr, # unary_attr ) @@ -2947,7 +2947,7 @@ def _register_qconv_binary_fusion(): ): generate_pattern_with_output_quant( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2959,7 +2959,7 @@ def _register_qconv_binary_fusion(): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), dequantize_accum_pattern, int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -2986,7 +2986,7 @@ def _register_qconv_binary_fusion(): PostOpAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, @@ -3024,7 +3024,7 @@ def _register_qconv_binary_fusion(): "sum", 1.0, "none", [], "" ): generate_pattern_with_binary( aten.add.Tensor, - get_qconv2d_pt2e_pattern(1), + get_qconv_pt2e_pattern(1), KeywordArg("accum_after_dequant"), int8_mixed_bf16_with_inplace_add, swap_inputs=swap_inputs, diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 041a2d8c149d..9063df455b0a 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1694,7 +1694,7 @@ def debug(msg: str) -> None: torch.ops.mkldnn._convolution_pointwise.binary, torch.ops.mkldnn._convolution_pointwise_.binary, torch.ops.mkldnn._convolution_transpose_pointwise.default, - torch.ops.onednn.qconv2d_pointwise.default, + torch.ops.onednn.qconv_pointwise.default, torch.ops.onednn.qconv2d_pointwise.binary, ] if torch._C.has_mkl: diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 422b256ca96f..74999462abc8 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -73,6 +73,22 @@ def _conv_input_size( input_size.append(input_size_d) return list(map(int, input_size)) + # Port from aten/src/ATen/native/ConvUtils.h: _conv_output_size + def _conv_output_size(input_size, weight_size, padding, stride, dilation=None): + has_dilation = dilation is not None + dim = len(input_size) + output_size = [] + output_size.append(input_size[0]) + output_size.append(weight_size[0]) + for d in range(2, dim): + dilation_ = dilation[d - 2] if has_dilation else 1 + kernel = dilation_ * (weight_size[d] - 1) + 1 + output_size_d = (input_size[d] + (2 * padding[d - 2]) - kernel) // stride[ + d - 2 + ] + 1 + output_size.append(output_size_d) + return output_size + # The size of prepacked_weight is the prepacked weight size of deconv: # Groups > 1: [g*o, i/g, ...] # Groups == 1: [o, i, ...] @@ -130,21 +146,18 @@ def _original_deconv_weight_size( groups, ) else: - bias_fake = ( - ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias - ) - output = torch.ops.aten.convolution( - x_fake, - weight_fake, - bias_fake, - stride, + x_shape = list(x_fake.shape) + weight_shape = list(weight_fake.shape) + if len(x_shape) != len(weight_shape): + assert len(x_shape) == 3 and len(weight_shape) == 4 + weight_shape.pop(2) + output_size = _conv_output_size( + x_shape, + weight_shape, padding, + stride, dilation, - transposed, - output_padding, - groups, ) - output_size = output.size() req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) req_stride_order = [len(req_stride_order)] + req_stride_order @@ -562,8 +575,8 @@ def __init__( inputs, constant_args, None, - op_overload=torch.ops.onednn.qconv2d_pointwise.default, - cpp_kernel_name="aoti_torch_cpu__qconv2d_pointwise_tensor", + op_overload=torch.ops.onednn.qconv_pointwise.default, + cpp_kernel_name="aoti_torch_cpu__qconv_pointwise_tensor", ) def codegen(self, wrapper): diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index d665aa3b892d..7ac5ee02ac43 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -130,7 +130,7 @@ def register_onednn_fusion_ops(): torch.ops.mkldnn._convolution_transpose_pointwise, torch.ops.mkldnn._linear_pointwise, aten.mkldnn_rnn_layer.default, - torch.ops.onednn.qconv2d_pointwise, + torch.ops.onednn.qconv_pointwise, ] @register_lowering(torch.ops.mkldnn._convolution_pointwise) @@ -428,7 +428,7 @@ def mkldnn_rnn_layer( ), ) - @register_lowering(torch.ops.onednn.qconv2d_pointwise, type_promotion_kind=None) + @register_lowering(torch.ops.onednn.qconv_pointwise, type_promotion_kind=None) def qconvolution_unary( x: TensorBox, x_scale, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 86f5b522bebf..bd0a2f7b9728 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2508,7 +2508,8 @@ def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size): ) @register_meta(torch.ops.onednn.qconv2d_pointwise.default) - def meta_qconv2d_pointwise( + @register_meta(torch.ops.onednn.qconv_pointwise.default) + def meta_qconv_pointwise( x, x_scale, x_zp, @@ -2539,7 +2540,9 @@ def meta_qconv2d_pointwise( ) assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8] out = x.new_empty(shape_out, dtype=output_dtype) - out = out.to(memory_format=torch.channels_last) + assert len(shape_out) in [3, 4], "only conv1d/2d are supported" + format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format + out = out.to(memory_format=format) return out @register_meta(torch.ops.onednn.qconv2d_pointwise.binary) diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index 25a5dfc4a193..3f91c2ddd13b 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -84,6 +84,7 @@ class _X86InductorQuantizationAnnotation(QuantizationAnnotation): # Operators support the int8 data type # and recipe is configured by default in X86InductorQuantizer. default_quantizable_ops = propagation_quantizable_ops | { + torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.linear.default, } @@ -185,6 +186,7 @@ def _global_config_filter(nodes: list[Node]) -> bool: def _map_module_function_to_aten_operator_type(): module_function_to_aten_operator: dict[Callable, torch._ops.OpOverloadPacket] = {} map_list = ( + ([torch.nn.Conv2d, F.conv1d], torch.ops.aten.conv1d.default), ([torch.nn.Conv2d, F.conv2d], torch.ops.aten.conv2d.default), ([torch.nn.Linear, F.linear], torch.ops.aten.linear.default), ([torch.nn.MaxPool2d, F.max_pool2d], torch.ops.aten.max_pool2d.default), @@ -1156,6 +1158,7 @@ def _annotate_conv2d_unary( [torch.nn.Conv2d, torch.nn.Hardswish], [torch.nn.Conv2d, torch.nn.ReLU6], [torch.nn.Conv2d, torch.nn.SiLU], + [torch.nn.Conv1d, torch.nn.ReLU], ] for unary_pattern in unary_patterns: partitions = find_sequential_partitions(gm, unary_pattern) @@ -1168,9 +1171,9 @@ def _annotate_conv2d_unary( conv_node, unary_node = self._get_output_nodes_of_partitions( [conv_partition, unary_partition] ) - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default + if conv_node.op != "call_function" or conv_node.target not in ( + torch.ops.aten.conv2d.default, + torch.ops.aten.conv1d.default, ): continue if _skip_annotate([unary_node, conv_node], filter_fn): diff --git a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h index 86f09416f9fe..c7b713bf7f87 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_cpu.h @@ -170,7 +170,7 @@ aoti_torch_cpu__qlinear_pointwise_binary_tensor( const char* unary_post_op_algorithm, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qconv2d_pointwise_tensor( +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qconv_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, AtenTensorHandle act_zero_point, diff --git a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp index 153ee9e0ddbe..9d1bb914db5c 100644 --- a/torch/csrc/inductor/aoti_torch/shim_cpu.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_cpu.cpp @@ -372,7 +372,7 @@ AOTITorchError aoti_torch_cpu__qlinear_pointwise_binary_tensor( }); } -AOTITorchError aoti_torch_cpu__qconv2d_pointwise_tensor( +AOTITorchError aoti_torch_cpu__qconv_pointwise_tensor( AtenTensorHandle X, AtenTensorHandle act_scale, AtenTensorHandle act_zero_point, From 5a422150c32f2c861061dc8be5f8d19fb6d80155 Mon Sep 17 00:00:00 2001 From: zeshengzong Date: Wed, 9 Apr 2025 15:03:24 +0000 Subject: [PATCH 304/332] Add `torch.triu_indices`, `torch.tril_indices` dtype description (#150749) Fixes #150675 ## Test Result ![image](https://github.com/user-attachments/assets/f30a0de0-6475-4d07-b441-15fffd453ba1) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150749 Approved by: https://github.com/bdhirsh --- torch/_torch_docs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 59fcc6213f30..1e5d7d340f5a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -11488,8 +11488,8 @@ def merge_dicts(*dicts): Default: if not provided, 0. Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, ``torch.long``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. {device} layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. @@ -11613,8 +11613,8 @@ def merge_dicts(*dicts): Default: if not provided, 0. Keyword args: - dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. - Default: if ``None``, ``torch.long``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor, + only support ``torch.int``, ``torch.long``. Default: if ``None``, ``torch.long``. {device} layout (:class:`torch.layout`, optional): currently only support ``torch.strided``. From d0e34822663b759f17ef5e6ec574cbf820c23b85 Mon Sep 17 00:00:00 2001 From: atalman Date: Wed, 9 Apr 2025 15:26:07 +0000 Subject: [PATCH 305/332] Update triton wheel build, setuptools pin (#150931) Observing failure in release workflow: https://github.com/pytorch/pytorch/actions/runs/14346340202/job/40216804374 ``` Traceback (most recent call last): File "/opt/python/cp311-cp311/lib/python3.11/site-packages/wheel/bdist_wheel.py", line 11, in from setuptools.command.bdist_wheel import bdist_wheel as bdist_wheel ModuleNotFoundError: No module named 'setuptools.command.bdist_wheel' The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/tmp/tmppwpqef_x/triton/python/setup.py", line 27, in from wheel.bdist_wheel import bdist_wheel File "/opt/python/cp311-cp311/lib/python3.11/site-packages/wheel/bdist_wheel.py", line 13, in raise ImportError(ERROR) from exc ImportError: The 'wheel.bdist_wheel' module has been removed. Please update your setuptools to v70.1 or later. If you're explicitly importing 'wheel.bdist_wheel', please update your import to point to 'setuptools.command.bdist_wheel' instead. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150931 Approved by: https://github.com/Skylion007 --- .github/workflows/build-triton-wheel.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index b4e9ec34f3da..99d71c7082b7 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -138,7 +138,7 @@ jobs: fi docker exec -t "${container_name}" yum install -y zlib-devel zip - docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==67.4.0 pybind11==2.13.1 auditwheel wheel + docker exec -t "${container_name}" "${PYTHON_EXECUTABLE}" -m pip install -U setuptools==78.1.0 pybind11==2.13.1 auditwheel wheel if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "rocm" || "${{ matrix.device }}" == "aarch64" ) ]]; then # With this install, it gets clang 16.0.6. From 97a5e5c6b3522411f9bbdec118e70a4be07ed3c6 Mon Sep 17 00:00:00 2001 From: pralay Date: Wed, 9 Apr 2025 15:48:09 +0000 Subject: [PATCH 306/332] Added _fused_sdp_choice_stub dispatcher support for HPU device (#149512) Currently for HPU device we don't have any support for _fused_sdp_choice_stub dispatcher function, so for `scaled_dot_product_attention` function by default selecting the `MATH Backend` using `_fused_sdp_choice_stub` for HPU device. With this PR we have enabled support for `_fused_sdp_choice_stub` dispatcher function, so that we can invoke any backend (for example math, flash_attention, efficient_attention, cudnn_attention, overrideable) according to user choice for HPU device. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149512 Approved by: https://github.com/drisspg --- aten/src/ATen/native/DispatchStub.cpp | 4 ++++ aten/src/ATen/native/DispatchStub.h | 17 +++++++++++++++++ aten/src/ATen/native/transformers/attention.cpp | 16 ++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 1be4ec37dfef..e1d329fbf30f 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -147,6 +147,7 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( c10::DeviceType::MPS, c10::DeviceType::MTIA, c10::DeviceType::XPU, + c10::DeviceType::HPU, c10::DeviceType::PrivateUse1 ); // Check if the device type is supported. @@ -203,6 +204,9 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; #endif + case DeviceType::HPU: + return hpu_dispatch_ptr != nullptr ? DispatchResult(hpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; + case DeviceType::PrivateUse1: return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel; diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 725d0d08bae1..cbe4b23c6711 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -44,6 +44,7 @@ // - MPS: Apple Silicon GPUs (Metal Performance Shaders) // - MTIA: Meta Training and Inference Devices // - XPU: Intel GPUs +// - HPU: Reserved for HPU (Intel Gaudi) device types // - PrivateUse1: Reserved for private/custom device types // // If you want to update the list of supported devices, add a new dispatch_ptr @@ -196,6 +197,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr; #endif + void* hpu_dispatch_ptr; void* privateuse1_dispatch_ptr; #else std::atomic cpu_dispatch_ptr{nullptr}; @@ -206,6 +208,7 @@ struct TORCH_API DispatchStubImpl { #if defined(USE_XPU) void* xpu_dispatch_ptr = nullptr; #endif + void* hpu_dispatch_ptr = nullptr; void* privateuse1_dispatch_ptr = nullptr; #endif }; @@ -259,6 +262,10 @@ struct DispatchStub { } #endif + void set_hpu_dispatch_ptr(FnPtr fn_ptr) { + impl.hpu_dispatch_ptr = reinterpret_cast(fn_ptr); + } + void set_hip_dispatch_ptr(FnPtr fn_ptr) { impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); } @@ -337,6 +344,13 @@ struct RegisterXPUDispatch { } }; +template +struct RegisterHPUDispatch { + RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){ + stub.set_hpu_dispatch_ptr(value); + } +}; + template struct RegisterMPSDispatch { RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { @@ -437,6 +451,9 @@ struct RegisterPRIVATEUSE1Dispatch { #define REGISTER_XPU_DISPATCH(name, fn) \ static RegisterXPUDispatch name ## __register(name, fn); +#define REGISTER_HPU_DISPATCH(name, fn) \ + static RegisterHPUDispatch name ## __register(name, fn); + #define REGISTER_HIP_DISPATCH(name, fn) \ static RegisterHIPDispatch name ## __register(name, fn); diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 27397bf78898..66bdaa0baa89 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -28,6 +28,7 @@ #include #else #include +#include #include #include #include @@ -448,6 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp) +REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta); int64_t _fused_sdp_choice_meta( const Tensor& query_, @@ -459,6 +461,20 @@ int64_t _fused_sdp_choice_meta( std::optional scale, bool enable_gqa) { auto query_key_set = query_.key_set(); + bool has_hpu = query_key_set.has(c10::DispatchKey::HPU); + if (has_hpu) { + auto choice_int = at::_ops::_fused_sdp_choice::redispatch( + c10::DispatchKeySet(DispatchKey::HPU), + query_, + key, + value, + attn_mask_, + dropout_p, + is_causal, + scale, + enable_gqa); + return choice_int; + } #if defined(USE_ROCM) bool has_rocm = query_key_set.has(c10::DispatchKey::HIP); if (has_rocm) { From 1a56609e75a31f92c55f5f74c7af92b6f1531aa5 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 9 Apr 2025 16:03:46 +0000 Subject: [PATCH 307/332] [ONNX] Supporting different opset versions for torchlib registry (#149901) - Allows opset_version to determine which onnx decomposition to choose - Adds a cleanup function to modify the registry after it is built Pull Request resolved: https://github.com/pytorch/pytorch/pull/149901 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms --- test/onnx/exporter/test_api.py | 25 +++++++++++++++++ test/onnx/torchlib/ops_test_common.py | 8 +++--- test/onnx/torchlib/ops_test_data.py | 11 +++++++- test/onnx/torchlib/test_ops.py | 4 ++- torch/onnx/_internal/exporter/_compat.py | 5 ++-- torch/onnx/_internal/exporter/_ir_passes.py | 6 ++++- .../onnx/_internal/exporter/_registration.py | 27 ++++++++++++++++++- .../exporter/_torchlib/_torchlib_registry.py | 2 ++ .../exporter/_torchlib/ops/__init__.py | 4 +-- .../_internal/exporter/_torchlib/ops/nn.py | 26 ++++++++++++++++++ 10 files changed, 106 insertions(+), 12 deletions(-) create mode 100644 torch/onnx/_internal/exporter/_torchlib/ops/nn.py diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index a4dc1c97772d..3ebf00eccec0 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -246,6 +246,31 @@ def test_dynamic_shapes_supports_nested_input_model_with_input_names_assigned(se ) ) + def test_upgraded_torchlib_impl(self): + class GeluModel(torch.nn.Module): + def forward(self, input): + # Use GELU activation function + return torch.nn.functional.gelu(input, approximate="tanh") + + input = torch.randn(1, 3, 4, 4) + onnx_program_op18 = torch.onnx.export( + GeluModel(), + input, + dynamo=True, + ) + all_nodes_op18 = [n.op_type for n in onnx_program_op18.model.graph] + self.assertIn("Tanh", all_nodes_op18) + self.assertNotIn("Gelu", all_nodes_op18) + + onnx_program_op20 = torch.onnx.export( + GeluModel(), + input, + opset_version=20, + dynamo=True, + ) + all_nodes_op20 = [n.op_type for n in onnx_program_op20.model.graph] + self.assertIn("Gelu", all_nodes_op20) + def test_refine_dynamic_shapes_with_onnx_export(self): # NOTE: From test/export/test_export.py diff --git a/test/onnx/torchlib/ops_test_common.py b/test/onnx/torchlib/ops_test_common.py index 73c00de388fa..884b66d4e02f 100644 --- a/test/onnx/torchlib/ops_test_common.py +++ b/test/onnx/torchlib/ops_test_common.py @@ -52,6 +52,7 @@ torch.float64, ) + TEST_OPSET_VERSION = 18 IS_MACOS = sys.platform.startswith("darwin") IS_WINDOWS = os.name == "nt" @@ -487,6 +488,7 @@ def dtype_op_schema_compatible(dtype: torch.dtype, schema: onnx.defs.OpSchema) - def graph_executor( test_name: str, outputs: Sequence[Any], + opset_version: int = TEST_OPSET_VERSION, ) -> Callable[[Callable[..., Any], tuple[Any], dict[str, Any]], None]: """Eagerly executes a function.""" @@ -500,10 +502,10 @@ def _capture_graph_and_evaluate_torch_script_evaluator( (), (), nodes=(), - opset_imports={"": 18, "pkg.torch.onnx": 1}, + opset_imports={"": opset_version, "pkg.torch.onnx": 1}, name="main_graph", ) - opset = onnxscript.opset18 + opset = onnxscript.values.Opset("", opset_version) tracer = _building.OpRecorder(opset, {}) ort_inputs = {} onnxscript_args: list[Any] = [] @@ -590,7 +592,7 @@ def _capture_graph_and_evaluate_torch_script_evaluator( proto = onnxscript_function.to_function_proto() ir_function = ir.serde.deserialize_function(proto) onnx_model.functions[identifier] = ir_function - _ir_passes.add_torchlib_common_imports(onnx_model) + _ir_passes.add_torchlib_common_imports(onnx_model, opset_version=opset_version) _ir_passes.add_opset_imports(onnx_model) # Make sure the model is valid model_proto = ir.to_proto(onnx_model) diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py index b255f07640b8..a69d7a4bec1e 100644 --- a/test/onnx/torchlib/ops_test_data.py +++ b/test/onnx/torchlib/ops_test_data.py @@ -46,7 +46,7 @@ import ops_test_common import torch -from torch.onnx._internal.exporter._torchlib.ops import core as core_ops +from torch.onnx._internal.exporter._torchlib.ops import core as core_ops, nn as nn_ops from torch.testing._internal import common_methods_invocations from torch.testing._internal.opinfo import definitions as opinfo_definitions @@ -78,6 +78,12 @@ class TorchLibOpInfo: compare_shape_only_for_output: tuple[int, ...] = () # Whether the function is designed for complex inputs complex: bool = False + # The ONNX opset version in which the function was introduced. + # Its specifies the minimum ONNX opset version required to use the function. + # It ensures that the function is only used when the target ONNX opset version + # is compatible. For example, if `opset_introduced=20`, the function will only + # be used when exporting to ONNX models targeting opset version 20 or higher. + opset_introduced: int = 18 # The acceptable tolerance of the inference result difference between PyTorch and ORT. # Format: {dtype: (rtol, atol)}. # For example: {torch.float16: (1e-3, 1e-3)} @@ -447,8 +453,10 @@ def _where_input_wrangler( TorchLibOpInfo("abs", core_ops.aten_abs_complex, complex=True), TorchLibOpInfo("add", core_ops.aten_add, tolerance={torch.float16: (1e-3, 1e-3)}), TorchLibOpInfo("add", core_ops.aten_add_complex, complex=True), + TorchLibOpInfo("gelu_op20", nn_ops.aten_gelu_opset20, opset_introduced=20), ) + ops_test_common.duplicate_opinfo(OPS_DB, "all", ("all_dim", "all_dims")) ops_test_common.duplicate_opinfo(OPS_DB, "any", ("any_dim", "any_dims")) ops_test_common.duplicate_opinfo( @@ -500,6 +508,7 @@ def _where_input_wrangler( "nn.functional.replication_pad3d", ), ) +ops_test_common.duplicate_opinfo(OPS_DB, "nn.functional.gelu", ("gelu_op20",)) ops_test_common.duplicate_opinfo( OPS_DB, "nn.functional.scaled_dot_product_attention", diff --git a/test/onnx/torchlib/test_ops.py b/test/onnx/torchlib/test_ops.py index 74cbeeca3138..a7a52698cd23 100644 --- a/test/onnx/torchlib/test_ops.py +++ b/test/onnx/torchlib/test_ops.py @@ -220,7 +220,9 @@ def run_test_output_match( test_name = test_suite.id() function_output, model_proto = function_executor( - test_name, reference_torch_outputs + test_name, + reference_torch_outputs, + opset_version=torchlib_op_info.opset_introduced, )(onnx_function, input_onnx, kwargs_onnx) # Finally we re-flatten everything # TODO: add pytree structure comparison. diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index a38203d2314d..b570b20bd02c 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -50,7 +50,7 @@ def export_compat( verbose: bool | None = None, input_names: Sequence[str] | None = None, output_names: Sequence[str] | None = None, - opset_version: int | None = None, + opset_version: int | None = _constants.TORCHLIB_OPSET, custom_translation_table: dict[Callable, Callable | Sequence[Callable]] | None = None, dynamic_axes: Mapping[str, Mapping[int, str]] @@ -105,8 +105,7 @@ def export_compat( dynamic_shapes_with_export_dim, need_axis_mapping = ( _dynamic_shapes.convert_str_to_export_dim(dynamic_shapes) ) - - registry = _registration.ONNXRegistry.from_torchlib() + registry = _registration.ONNXRegistry().from_torchlib(opset_version=opset_version) if custom_translation_table is not None: for torch_op, onnx_ops in custom_translation_table.items(): # TODO(justinchuby): Support complex inputs with annotations diff --git a/torch/onnx/_internal/exporter/_ir_passes.py b/torch/onnx/_internal/exporter/_ir_passes.py index 804e93acbd6f..8a715e245597 100644 --- a/torch/onnx/_internal/exporter/_ir_passes.py +++ b/torch/onnx/_internal/exporter/_ir_passes.py @@ -90,7 +90,9 @@ def rename_axis(model: ir.Model, rename_mapping: dict[str, str]) -> None: value.shape = ir.Shape(new_shape) -def add_torchlib_common_imports(model: ir.Model) -> None: +def add_torchlib_common_imports( + model: ir.Model, opset_version: int = _constants.TORCHLIB_OPSET +) -> None: """Hack to add torchlib common imports to the model.""" try: @@ -99,9 +101,11 @@ def add_torchlib_common_imports(model: ir.Model) -> None: model.opset_imports["pkg.onnxscript.torch_lib.common"] = 1 rank_func = ir.serde.deserialize_function(common_ops.Rank.to_function_proto()) + rank_func.opset_imports[""] = opset_version is_scalar_func = ir.serde.deserialize_function( common_ops.IsScalar.to_function_proto() ) + is_scalar_func.opset_imports[""] = opset_version model.functions[rank_func.identifier()] = rank_func model.functions[is_scalar_func.identifier()] = is_scalar_func except Exception: diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index ac81d2301cc2..fefc8022d7e8 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -42,6 +42,9 @@ class OnnxDecompMeta: signature: The ONNX signature of the function. When None, the signature is inferred. is_custom: Whether the function is a custom function. is_complex: Whether the function is a function that handles complex valued inputs. + opset_introduced: + The ONNX opset version in which the function was introduced. + Its specifies the minimum ONNX opset version required to use the function. device: The device the function is registered to. If None, it is registered to all devices. skip_signature_inference: Whether to skip signature inference for the function. """ @@ -51,6 +54,7 @@ class OnnxDecompMeta: signature: _schemas.OpSignature | None is_custom: bool = False is_complex: bool = False + opset_introduced: int = 18 device: Literal["cuda", "cpu"] | str | None = None # noqa: PYI051 skip_signature_inference: bool = False @@ -150,13 +154,14 @@ def opset_version(self) -> int: return self._opset_version @classmethod - def from_torchlib(cls) -> ONNXRegistry: + def from_torchlib(cls, opset_version=_constants.TORCHLIB_OPSET) -> ONNXRegistry: """Populates the registry with ATen functions from torchlib. Args: torchlib_registry: The torchlib registry to use for populating the registry. """ registry = cls() + registry._opset_version = opset_version for meta in _torchlib_registry.get_torchlib_ops(): registry._register(meta.fx_target, meta) @@ -185,6 +190,7 @@ def from_torchlib(cls) -> ONNXRegistry: logger.exception("Failed to register '%s'. Skipped", qualified_name) continue + registry._cleanup_registry_based_on_opset_version() return registry def _register( @@ -274,5 +280,24 @@ def is_registered(self, target: TorchOp) -> bool: """ return bool(self.get_decomps(target)) + def _cleanup_registry_based_on_opset_version(self) -> None: + """Pick the implementation with the highest opset version valid until the current opset version.""" + cleaned_functions = {} + for target_or_name, decomps in self.functions.items(): + # Filter decompositions to only include those with opset_introduced <= opset_version + decomps = [d for d in decomps if d.opset_introduced <= self.opset_version] + + # Keep only the decomposition with the highest opset_introduced + if decomps: + # Find the maximum opset_introduced + max_opset = max(d.opset_introduced for d in decomps) + + # Keep all decompositions with the maximum opset_introduced + cleaned_functions[target_or_name] = [ + d for d in decomps if d.opset_introduced == max_opset + ] + + self.functions = cleaned_functions + def __repr__(self) -> str: return f"{self.__class__.__name__}(functions={self.functions})" diff --git a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py index e71bdeb0c68e..039eeb3e2fc2 100644 --- a/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py +++ b/torch/onnx/_internal/exporter/_torchlib/_torchlib_registry.py @@ -30,6 +30,7 @@ def onnx_impl( *, trace_only: bool = False, complex: bool = False, + opset_introduced: int = 18, no_compile: bool = False, private: bool = False, ) -> Callable[[_T], _T]: @@ -74,6 +75,7 @@ def wrapper( fx_target=t, signature=None, is_complex=complex, + opset_introduced=opset_introduced, skip_signature_inference=no_compile, ) ) diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py index d07768f252ba..bff8860fcb1f 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["core", "hop", "symbolic"] +__all__ = ["core", "hop", "nn", "symbolic"] -from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic +from torch.onnx._internal.exporter._torchlib.ops import core, hop, nn, symbolic diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py new file mode 100644 index 000000000000..4ca21662d696 --- /dev/null +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -0,0 +1,26 @@ +"""torch.ops.aten operators under the `core` module.""" +# mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" +# ruff: noqa: TCH001,TCH002 +# flake8: noqa + +from __future__ import annotations + +import math + +from onnxscript.onnx_opset import opset20 as op20 + +import torch +from torch.onnx._internal.exporter._torchlib._tensor_typing import TReal +from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl + + +aten = torch.ops.aten + + +@onnx_impl(aten.gelu.default, trace_only=True, opset_introduced=20) +def aten_gelu_opset20( + self: TReal, + approximate: str = "none", +) -> TReal: + """gelu(Tensor self, *, bool approximate=False) -> Tensor""" + return op20.Gelu(self, approximate=approximate) From c8d37b9c85b76d6b612907a0016d8cb525a131f3 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 8 Apr 2025 11:27:41 -0700 Subject: [PATCH 308/332] [ez][c10d] Disable start event recording for coalesced col and improve profile title (#150863) While looking at enabling FR analysis for coalesced collectives, I found that for the slow-path coalescing (cols which are not all-gather, all-reduce or reduce-scatter), we still record start event for them. This is wrong and we should do the same thing as endEvent recodring. And I made the profiler title more visible when we pass in the opType for coalesced all-gather and reduce-scatter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150863 Approved by: https://github.com/eqy, https://github.com/d4l3k, https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 2da127e5b267..ecfb2b5d10d4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -3281,6 +3281,9 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { // `getKeyFromDevice` is how we get keys for both collectives and batch P2P const auto key = getKeyFromDevice(device); auto ncclStream = ncclStreams_.at(key); + auto opProfilerTitle = optype != OpType::COALESCED + ? "nccl:" + opTypeToString(optype) + "_coalesced" + : "nccl:coalesced"; // Create Work object c10::cuda::CaptureStatus capture_status = @@ -3292,7 +3295,7 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { rank_, optype, coalescing_state_ & CoalP2P, - "nccl:coalesced", + opProfilerTitle.c_str(), {}, {}, enqueue); @@ -3448,7 +3451,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( } // Start event should only be recorded before the ncclGroupStart() - if (work->timingEnabled_) { + if (work->timingEnabled_ && !coalescing_state_) { work->ncclStartEvent_->record(ncclStream); } From 8aaf296efceb56384f6c9e9532c3b8477f17e8c5 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 8 Apr 2025 20:34:08 -0700 Subject: [PATCH 309/332] [c10d][fr] Refactor analysis script for modularization and reusing for coalesce collectives (#150881) Trying to make the code of FR analysis more reusable and modularized. So we split core error analysis logic into separate functions. This PR mostly is shuffle around the code a bit. Differential Revision: [D72690120](https://our.internmc.facebook.com/intern/diff/D72690120) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150881 Approved by: https://github.com/wz337 --- tools/flight_recorder/components/builder.py | 194 +++++--------------- tools/flight_recorder/components/types.py | 23 +++ tools/flight_recorder/components/utils.py | 153 +++++++++++++++ 3 files changed, 227 insertions(+), 143 deletions(-) diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index bb61ac3e8216..d239ab1d43a0 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -16,8 +16,7 @@ Database, EntryState, Group, - MatchInfo, - MatchState, + MatchStateRecord, Membership, NCCLCall, Op, @@ -25,15 +24,14 @@ ) from tools.flight_recorder.components.utils import ( align_trace_from_beginning, + check_current_entry_match, check_no_missing_dump_files, - check_size_alltoall, check_version, + error_analysis, find_coalesced_group, - format_frames, get_version_detail, just_print_entries, match_coalesced_groups, - match_one_event, ) @@ -161,7 +159,6 @@ def build_collectives( ] } """ - major_v, minor_v = get_version_detail(version) tracebacks: list[Traceback] = [] collectives: list[Collective] = [] @@ -194,17 +191,23 @@ def build_collectives( # lets match the first collective! we need to know which ranks are involved, and ensure that this same # collective is also the first one on those ranks within that group entries = all_entries[first_rank] - desc = entries[0]["process_group"][1] + current_entry = entries[0] + desc = current_entry["process_group"][1] # For db build and logs printing, we want to use the original pg_name, not the hash one. - original_pg_name = entries[0]["process_group"][0] + original_pg_name = current_entry["process_group"][0] pg_name = _pg_guids[(original_pg_name, first_rank)] expected_ranks = set(_memberships[pg_name]) - entry_state = EntryState(entries[0], expected_ranks) - candidate_ranks = {first_rank} - candidate_idx = {} - found_ranks = set() - found_idx = {} - errors = set() + entry_state = EntryState(current_entry, expected_ranks) + match_record = MatchStateRecord( + expected_ranks=expected_ranks, + other_ranks=other_ranks, + entry_state=entry_state, + candidate_ranks={first_rank}, + candidate_idx={}, + found_ranks=set(), + found_idx={}, + errors=set(), + ) if find_coalesced_group(pg_name, entries, _pg_guids, first_rank): expected_ranks.add(first_rank) @@ -256,137 +259,42 @@ def build_collectives( ) ) else: - has_undecided_case = False - for o in expected_ranks.intersection(set(other_ranks)): - for i, e in enumerate(all_entries[o]): # type: ignore[index] - # step over ops from other PGs - # only check match state when seq_id matches - if ( - _pg_guids[(e["process_group"][0], o)] == pg_name - and e["process_group"][1] == desc - and e["collective_seq_id"] == entry_state.collective_seq_id - ): - match_info = match_one_event( - entries[0], e, _memberships, pg_name - ) - if ( - match_info.state - in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] - and mismatch[pg_name] == 0 - ): - found_ranks.add(o) - found_idx[o] = i - has_undecided_case = ( - match_info.state == MatchState.UNDECIDED - ) - else: - candidate_ranks.add(o) - candidate_idx[o] = i - if match_info.state not in [ - MatchState.FULLY_MATCHED, - MatchState.UNDECIDED, - ]: - # Here we assume the current rank is not the source of the error. - # But it's possible that the current rank is the culprit, then users will - # see lots of normal ranks reported as culprit. - # TODO: we need to figure out a better way to handle the case mentioned above. - errors.add((o, match_info)) - break - - # case one: not every rank join the collective or in the flight recorder. - if (candidate_ranks | found_ranks) != expected_ranks and expected_ranks - ( - candidate_ranks | found_ranks - ) <= dumps_ranks: - mismatch[pg_name] += 1 - logger_msg = "Not all ranks joining collective, sequence number: %s" - missing_ranks = expected_ranks - (candidate_ranks | found_ranks) - entry_state.log( - logger, logger_msg, format_frames, missing_ranks=missing_ranks - ) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - elif len(candidate_ranks) == 1 and dumps_ranks == expected_ranks: - # case two: alltoall or alltoall_base case. - if has_undecided_case: - alltoall_cases = [entries[0]] + [ - all_entries[o][found_idx[o]] for o in found_ranks - ] - fail_check, total_input_numel, total_output_numel = ( - check_size_alltoall(alltoall_cases) - ) - if major_v <= 2 and minor_v <= 3: - # We don't log the input/output sizes for alltoall before v2.4, - # so we don't consider the size mismatch as an error for now. - fail_check = False - if fail_check: - # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. - mismatch[pg_name] += 1 - logger_msg = "Input/output mismatch in the collective sequence number: %s" - entry_state.log( - logger, - logger_msg, - format_frames, - total_numel=(total_input_numel, total_output_numel), - ) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - errors.add( - (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)) - ) - else: - found_ranks.update(candidate_ranks) - found_idx.update(candidate_idx) - candidate_idx.clear() - candidate_ranks.clear() - # case three: all joined and everything matches on all ranks. - else: - found_ranks.update(candidate_ranks) - found_idx.update(candidate_idx) - candidate_idx.clear() - candidate_ranks.clear() - # case four: mismatch cases due to not same type, size mismatch or state mismatch. - elif len(errors) > 0: - mismatch[pg_name] += 1 - logger_msg = "Collective sequence number: %s has errors" - entry_state.log(logger, logger_msg, format_frames, errors=errors) - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - # partial analysis case when we cannot decide what's wrong with this collective entry. - else: - candidate_ranks.update(found_ranks) - candidate_idx.update(found_idx) - found_idx.clear() - found_ranks.clear() - if expected_ranks - dumps_ranks: - mismatch[pg_name] += 1 - logger.info( - "We cannot decide what's wrong with this collective entry " - "because we missed FR dumps from ranks (%s) so we don't have enough " - "information. If you want to debug further use -j to dump all raw trace", - str(expected_ranks - dumps_ranks), - ) - else: - logger.info( - "No errors found for this collective entry, There could be some " - "other reasons why we see collective timeout." - ) + # Iterate through all the ranks and check if there is a mis-match for the current entry. + check_current_entry_match( + all_entries, + _pg_guids, + (pg_name, desc), + current_entry, + _memberships, + mismatch, + match_record, + ) + + # Use heuristics to decide what type of errors and error messages we should print. + error_analysis( + all_entries, + match_record, + dumps_ranks, + first_rank, + current_entry, + mismatch, + get_version_detail(version), + pg_name, + ) # at this point there are 3 possibilities # 1. we found a match on all the ranks that are members of the group # -> we create a Collective and remove the individual entries from their original lists - if found_ranks == expected_ranks and mismatch[pg_name] == 0: - collectives.append(entry_state.to_collective(len(collectives))) + if match_record.found_ranks == expected_ranks and mismatch[pg_name] == 0: + collectives.append( + match_record.entry_state.to_collective(len(collectives)) + ) idx_map = { - r: found_idx[r] if r != first_rank else 0 for r in found_ranks + r: match_record.found_idx[r] if r != first_rank else 0 + for r in match_record.found_ranks } nccl_calls.extend( - entry_state.to_nccl_call( + match_record.entry_state.to_nccl_call( all_entries, idx_map, len(nccl_calls), collectives[-1].id ) ) @@ -398,19 +306,19 @@ def build_collectives( else: logger.debug("appending a non-matching collective") idx_map = { - r: candidate_idx[r] if r != first_rank else 0 - for r in candidate_ranks + r: match_record.candidate_idx[r] if r != first_rank else 0 + for r in match_record.candidate_ranks } collectives.append( - entry_state.to_collective( + match_record.entry_state.to_collective( len(collectives), - errors=errors, + errors=match_record.errors, idx_map=idx_map, all_entries=all_entries, ) ) nccl_calls.extend( - entry_state.to_nccl_call( + match_record.entry_state.to_nccl_call( all_entries, idx_map, len(nccl_calls), None ) ) diff --git a/tools/flight_recorder/components/types.py b/tools/flight_recorder/components/types.py index dbad6a93790c..5587e7179c77 100644 --- a/tools/flight_recorder/components/types.py +++ b/tools/flight_recorder/components/types.py @@ -560,3 +560,26 @@ def match(self, other: "Op") -> MatchInfo: else MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH) ) return MatchInfo(MatchState.FULLY_MATCHED) + + +class MatchStateRecord: + def __init__( + self, + expected_ranks: set[int], + other_ranks: list[int], + entry_state: EntryState, + candidate_ranks: set[int], + candidate_idx: dict[int, int], + found_ranks: set[int], + found_idx: dict[int, int], + errors: set[tuple[int, MatchInfo]], + ) -> None: + self.expected_ranks = expected_ranks + self.other_ranks = other_ranks + self.entry_state = entry_state + self.candidate_ranks = candidate_ranks + self.candidate_idx = candidate_idx + self.found_ranks = found_ranks + self.found_idx = found_idx + self.errors = errors + self.has_undecided_case = False diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 02787d3e43c6..0973ec1c17bb 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -13,6 +13,7 @@ Group, MatchInfo, MatchState, + MatchStateRecord, Membership, Op, P2P, @@ -184,6 +185,158 @@ def check_size_alltoall(alltoall_cases: list[dict[str, Any]]) -> tuple[bool, int return input_numel != output_numel, input_numel, output_numel +def check_current_entry_match( + all_entries: dict[int, list[dict[str, Any]]], + _pg_guids: dict[tuple[str, int], str], + pg_info: tuple[str, str], + current_entry: dict[str, Any], + _memberships: dict[str, set[Any]], + mismatch: dict[str, int], + match_record: MatchStateRecord, +) -> None: + pg_name, desc = pg_info[0], pg_info[1] + for o in match_record.expected_ranks.intersection(set(match_record.other_ranks)): + for i, e in enumerate(all_entries[o]): # type: ignore[index] + # step over ops from other PGs + # only check match state when seq_id matches + if ( + _pg_guids[(e["process_group"][0], o)] == pg_name + and e["process_group"][1] == desc + and e["collective_seq_id"] == match_record.entry_state.collective_seq_id + ): + match_info = match_one_event(current_entry, e, _memberships, pg_name) + if ( + match_info.state in [MatchState.FULLY_MATCHED, MatchState.UNDECIDED] + and mismatch[pg_name] == 0 + ): + match_record.found_ranks.add(o) + match_record.found_idx[o] = i + match_record.has_undecided_case = ( + match_info.state == MatchState.UNDECIDED + ) + else: + match_record.candidate_ranks.add(o) + match_record.candidate_idx[o] = i + if match_info.state not in [ + MatchState.FULLY_MATCHED, + MatchState.UNDECIDED, + ]: + # Here we assume the current rank is not the source of the error. + # But it's possible that the current rank is the culprit, then users will + # see lots of normal ranks reported as culprit. + # TODO: we need to figure out a better way to handle the case mentioned above. + match_record.errors.add((o, match_info)) + break + + +def error_analysis( + all_entries: dict[int, list[dict[str, Any]]], + match_record: MatchStateRecord, + dumps_ranks: set[int], + first_rank: int, + current_entry: dict[str, Any], + mismatch: dict[str, int], + version: tuple[int, int], + pg_name: str, +) -> None: + major_v, minor_v = version[0], version[1] + # case one: not every rank join the collective or in the flight recorder. + if ( + match_record.candidate_ranks | match_record.found_ranks + ) != match_record.expected_ranks and match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) <= dumps_ranks: + mismatch[pg_name] += 1 + logger_msg = "Not all ranks joining collective, sequence number: %s" + missing_ranks = match_record.expected_ranks - ( + match_record.candidate_ranks | match_record.found_ranks + ) + match_record.entry_state.log( + logger, logger_msg, format_frames, missing_ranks=missing_ranks + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + elif ( + len(match_record.candidate_ranks) == 1 + and dumps_ranks == match_record.expected_ranks + ): + # case two: alltoall or alltoall_base case. + if match_record.has_undecided_case: + alltoall_cases = [current_entry] + [ + all_entries[o][match_record.found_idx[o]] + for o in match_record.found_ranks + ] + fail_check, total_input_numel, total_output_numel = check_size_alltoall( + alltoall_cases + ) + if major_v <= 2 and minor_v <= 3: + # We don't log the input/output sizes for alltoall before v2.4, + # so we don't consider the size mismatch as an error for now. + fail_check = False + if fail_check: + # When we see errors in all_to_all, it's hard to tell which rank is the source of the error. + mismatch[pg_name] += 1 + logger_msg = ( + "Input/output mismatch in the collective sequence number: %s" + ) + match_record.entry_state.log( + logger, + logger_msg, + format_frames, + total_numel=(total_input_numel, total_output_numel), + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + match_record.errors.add( + (first_rank, MatchInfo(MatchState.SIZE_OR_SYNTAX_MISMATCH)) + ) + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case three: all joined and everything matches on all ranks. + else: + match_record.found_ranks.update(match_record.candidate_ranks) + match_record.found_idx.update(match_record.candidate_idx) + match_record.candidate_idx.clear() + match_record.candidate_ranks.clear() + # case four: mismatch cases due to not same type, size mismatch or state mismatch. + elif len(match_record.errors) > 0: + mismatch[pg_name] += 1 + logger_msg = "Collective sequence number: %s has errors" + match_record.entry_state.log( + logger, logger_msg, format_frames, errors=match_record.errors + ) + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + # partial analysis case when we cannot decide what's wrong with this collective entry. + else: + match_record.candidate_ranks.update(match_record.found_ranks) + match_record.candidate_idx.update(match_record.found_idx) + match_record.found_idx.clear() + match_record.found_ranks.clear() + if match_record.expected_ranks - dumps_ranks: + mismatch[pg_name] += 1 + logger.info( + "We cannot decide what's wrong with this collective entry " + "because we missed FR dumps from ranks (%s) so we don't have enough " + "information. If you want to debug further use -j to dump all raw trace", + str(match_record.expected_ranks - dumps_ranks), + ) + else: + logger.info( + "No errors found for this collective entry, There could be some " + "other reasons why we see collective timeout." + ) + + def find_coalesced_group( pg_name: str, entries: list[dict[str, Any]], From 72755a4b7a3499f31b2959eeaa7beebd214e7a80 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 9 Apr 2025 16:32:11 +0000 Subject: [PATCH 310/332] Avoid circular imports in tracing_state_functions (#150325) tracing_state_functions references some torch functions from submodules like `torch.onnx.is_in_onnx_export` that could trigger module initialization & circular imports. I turned the mapping into a function so that the dictionary is not initialized at torch import. (discovered in https://github.com/pytorch/pytorch/pull/149646) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150325 Approved by: https://github.com/zou3519 --- torch/_dynamo/variables/torch.py | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 429b3b572774..c85d5b5c577c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -34,7 +34,7 @@ import math import re from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import torch._C import torch._refs @@ -169,19 +169,23 @@ constant_fold_functions = dict.fromkeys(constant_fold_functions) -tracing_state_functions = { - torch.jit.is_scripting: False, - torch.jit.is_tracing: False, - torch._C._get_tracing_state: None, - torch.fx._symbolic_trace.is_fx_tracing: False, - torch.onnx.is_in_onnx_export: False, - torch._dynamo.external_utils.is_compiling: True, - torch._utils.is_compiling: True, - torch.compiler.is_compiling: True, - torch.compiler.is_dynamo_compiling: True, - torch.compiler.is_exporting: True, - torch.nn.modules.activation._is_make_fx_tracing: False, -} +@functools.lru_cache(None) +def tracing_state_functions() -> dict[Callable[[], Any], Optional[bool]]: + # Defined as a function to avoid circular import like torch.onnx + return { + torch.jit.is_scripting: False, + torch.jit.is_tracing: False, + torch._C._get_tracing_state: None, + torch.fx._symbolic_trace.is_fx_tracing: False, + torch.onnx.is_in_onnx_export: False, + torch._dynamo.external_utils.is_compiling: True, + torch._utils.is_compiling: True, + torch.compiler.is_compiling: True, + torch.compiler.is_dynamo_compiling: True, + torch.compiler.is_exporting: True, + torch.nn.modules.activation._is_make_fx_tracing: False, + } + bin_ops = dict.fromkeys(["add", "sub", "mul", "div", "sqrt"]) @@ -456,7 +460,7 @@ def _register(handler): ) from .builder import wrap_fx_proxy, wrap_fx_proxy_cls - @register(*tracing_state_functions) + @register(*tracing_state_functions()) def handle_tracing_state_functions( self, tx: "InstructionTranslator", *args, **kwargs ): @@ -470,7 +474,7 @@ def handle_tracing_state_functions( torch.compiler.is_exporting, ): tx.mark_inconsistent_side_effects() - return ConstantVariable.create(tracing_state_functions[self.value]) + return ConstantVariable.create(tracing_state_functions()[self.value]) @register(*dispatch_key_set_functions) def handle_dispatch_key_set_functions( From c714d2fc0ee853c87a6537820830c22f6f6c862e Mon Sep 17 00:00:00 2001 From: Yidi Wu Date: Mon, 7 Apr 2025 14:28:08 -0700 Subject: [PATCH 311/332] [hop] support base_hop._gen_schema (#149688) This PR creates two utils for generating a schema for hops from example inputs and use base hop as an exmaple. 1. HopArgumentInfoGen creates an argument or an output schema with mutation information. 2. CFuncitonSchemaGen piece together the argument info of inputs and outputs and produces torch._C.FunctionSchema. is_write attribute of argument info can be computed. Note that the is_write annotation only works when the inputs are flattened (e.g. cannot support mutation inside tuple). We need special handling the case where we have tuple inputs like cond. Pull Request resolved: https://github.com/pytorch/pytorch/pull/149688 Approved by: https://github.com/zou3519 --- aten/src/ATen/core/alias_info.h | 11 ++ test/dynamo/test_base_hop.py | 189 ++++++++++++++++++++++++++++ torch/_C/__init__.pyi.in | 24 ++++ torch/_higher_order_ops/base_hop.py | 63 +++++++++- torch/_higher_order_ops/schema.py | 154 +++++++++++++++++++++++ torch/_higher_order_ops/utils.py | 27 +++- torch/_ops.py | 20 +++ torch/csrc/jit/python/init.cpp | 15 +++ 8 files changed, 501 insertions(+), 2 deletions(-) create mode 100644 torch/_higher_order_ops/schema.py diff --git a/aten/src/ATen/core/alias_info.h b/aten/src/ATen/core/alias_info.h index a8a55bb782c4..bf0ff6ee72d3 100644 --- a/aten/src/ATen/core/alias_info.h +++ b/aten/src/ATen/core/alias_info.h @@ -1,4 +1,6 @@ #pragma once +#include +#include #include #include #include @@ -18,6 +20,15 @@ namespace c10 { */ class AliasInfo { public: + AliasInfo() = default; + AliasInfo(bool is_write, const std::set& before_qual_strings, const std::set& after_qual_strings) : isWrite_(is_write) { + for (const auto& s: before_qual_strings) { + beforeSets_.insert(Symbol::fromQualString(s)); + } + for (const auto& s : after_qual_strings) { + afterSets_.insert(Symbol::fromQualString(s)); + } + } // Symbol for the set that can alias anything static Symbol wildcardSet() { static const Symbol wc = Symbol::fromQualString("alias::*"); diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index 3f9c23efc1d1..b42c56b21ced 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import unittest +from typing import Any import torch import torch._dynamo.test_case @@ -73,6 +74,194 @@ def forward(self, l_x_: "f32[3, 3]", l_y_: "f32[3, 3]"): """, # NOQA: B950 ) + def _find_hop_schema( + self, gm: torch.fx.GraphModule, target: Any + ) -> list[torch._C.FunctionSchema]: + import torch.utils._pytree as pytree + + schemas = [] + for node in gm.graph.find_nodes(op="call_function", target=target): + + def _get_example_value(node: torch.fx.Node) -> Any: + if node.op == "get_attr": + return getattr(gm, node.target) + else: + return node.meta["example_value"] + + fake_args, fake_kwargs = pytree.tree_map_only( + torch.fx.Node, + _get_example_value, + (node.args, node.kwargs), + ) + schema = node.target.gen_schema(*fake_args, **fake_kwargs) + schemas.append(schema) + return schemas + + def test_schema_gen_single_return(self): + def inner(x, y): + return (x @ y).sin().cos() + + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=False) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + out = f(x.clone(), y) + self.assertEqual(out, inner(x.clone(), y)) + schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test) + self.assertEqual(len(schemas), 1) + self.assertExpectedInline( + str(schemas[0]), + """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> ((Tensor))""", # noqa: B950 + ) + + def test_schema_gen_pytree_in_out(self): + def inner(x_y): + x, y = x_y + return [ + (x @ y).sin().cos(), + (x + y, x - y), + {"out": (x @ y,)}, + ] + + # make x not require grad because we want to inplace mutate it + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=True) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend) + def f(x, y): + return invoke_quant_test(inner, [x, y], scheme="nf4") + + out = f(x.clone(), y) + self.assertEqual(out, inner([x.clone(), y])) + schemas = self._find_hop_schema(backend.graphs[0], invoke_quant_test) + self.assertEqual(len(schemas), 1) + self.assertExpectedInline( + str(schemas[0]), + """invoke_quant_test(Any subgraph, Tensor arg0, Tensor arg1, str scheme="nf4") -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950 + ) + + def test_schema_gen_single_return_with_mutation(self): + def inner(x, y): + x.add_(1) + y.mul_(-1) + return (x @ y).sin().cos() + + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=False) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + with self.assertRaisesRegex( + RuntimeError, + "Encountered input mutation during higher order op tracing for HOP", + ): + f(x.clone(), y) + + def test_schema_gen_pytree_in_out_with_mutation(self): + def inner(x_y): + x, y = x_y + x.add_(1) + return [ + (x @ y).sin().cos(), + (x + y, x - y), + {"out": (x @ y,)}, + ] + + # make x not require grad because we want to inplace mutate it + x = torch.randn(3, 3, requires_grad=False) + y = torch.randn(3, 3, requires_grad=True) + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, [x, y], scheme="nf4") + + with self.assertRaisesRegex( + RuntimeError, + "Encountered input mutation during higher order op tracing for HOP", + ): + f(x.clone(), y) + + def test_none_input(self): + def inner(x, y): + if x is not None: + return y.sin() + return y.cos() + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + x = None + y = torch.randn(3, 4) + out = f(x, y) + self.assertEqual(out, inner(x, y)) + self.assertExpectedInline( + normalize_graph(backend.graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_y_: "f32[3, 4]"): + l_y_ = L_y_ + + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4'); subgraph_0 = l_y_ = None + getitem: "f32[3, 4]" = invoke_quant_test[0]; invoke_quant_test = None + return (getitem,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_y_: "f32[3, 4]"): + cos: "f32[3, 4]" = l_y_.cos(); l_y_ = None + return (cos,) +""", + ) + + def test_int_input(self): + def inner(x, y): + return x + y + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def f(x, y): + return invoke_quant_test(inner, x, y, scheme="nf4") + + x = 1 + y = torch.randn(3, 4) + out = f(x, y) + self.assertEqual(out, inner(x, y)) + self.assertExpectedInline( + normalize_graph(backend.graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_y_: "f32[3, 4]"): + l_y_ = L_y_ + + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_y_, scheme = 'nf4'); subgraph_0 = l_y_ = None + getitem: "f32[3, 4]" = invoke_quant_test[0]; invoke_quant_test = None + return (getitem,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_y_: "f32[3, 4]"): + add: "f32[3, 4]" = 1 + l_y_; l_y_ = None + return (add,) +""", + ) + @torch._dynamo.config.patch(assume_static_by_default=True) def test_aot_eager(self): def inner(x, y): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index c6003fe63fcc..3c487e321c8c 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -913,6 +913,12 @@ class AliasInfo: is_write: _bool before_set: Set[str] after_set: Set[str] + def __init__( + self, + is_write: _bool, + before_set: Set[str], + after_set: Set[str] + ) -> None: ... # Defined in torch/aten/src/ATen/core/function_schema.h class Argument: @@ -925,6 +931,15 @@ class Argument: alias_info: Optional[AliasInfo] is_write: _bool real_type: JitType + def __init__( + self, + name: str, + type: JitType, + N: Optional[_int], + defualt_value: Optional[Any], + kwarg_only: _bool, + alias_info: Optional[AliasInfo] + ) -> None: ... class FunctionSchema: arguments: List[Argument] @@ -932,6 +947,15 @@ class FunctionSchema: name: str overload_name: str is_mutable: _bool + def __init__( + self, + name: str, + overload_name: str, + arguments: List[Argument], + returns: List[Argument], + is_vararg: _bool, + is_varret: _bool + ) -> None: ... class _UpgraderEntry: bumped_at_version: _int diff --git a/torch/_higher_order_ops/base_hop.py b/torch/_higher_order_ops/base_hop.py index a8fc106214b7..af47b3e5fdc5 100644 --- a/torch/_higher_order_ops/base_hop.py +++ b/torch/_higher_order_ops/base_hop.py @@ -6,7 +6,10 @@ import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization -from torch._higher_order_ops.utils import reenter_make_fx +from torch._higher_order_ops.utils import ( + check_input_alias_and_mutation_return_ouputs, + reenter_make_fx, +) from torch._ops import HigherOrderOperator from torch._subclasses import FakeTensorMode from torch._subclasses.functional_tensor import disable_functional_mode @@ -126,6 +129,64 @@ def _call_Functionalize(self, ctx, subgraph, *operands, **kwargs): out = self(functionalized_subgraph, *unwrapped_operands, **kwargs) return ctx.wrap_tensors(out) + def gen_schema(self, *args, **kwargs): + from .schema import CFunctionSchemaGen, HopArgumentInfoGen + + subgraph, *operands = args + + assert isinstance( + subgraph, torch.fx.GraphModule + ), f"NYI non GraphModule subgraph got {subgraph}" + + fake_args = [ + ph.meta["example_value"] + for ph in subgraph.graph.find_nodes(op="placeholder") + ] + ( + mutated_inp_idx, + inp_inp_alias, + inp_out_alias, + out_out_alias, + output, + ) = check_input_alias_and_mutation_return_ouputs(subgraph, fake_args) + + assert ( + len(inp_inp_alias) == 0 + and len(inp_out_alias) == 0 + and len(out_out_alias) == 0 + ), "Aliasing is not suppported for HOP subgraph." + args = [ + HopArgumentInfoGen.from_example( + subgraph, name="subgraph", default_value=None, is_mutated=False + ) + ] + for idx, arg in enumerate((*operands, *kwargs.items())): + if isinstance(arg, tuple): + # kwargs value are treated as default argument + arg_name, example_value = arg + default = example_value + else: + arg_name = f"arg{idx}" + example_value = arg + default = None + args.append( + HopArgumentInfoGen.from_example( + example_value=example_value, + name=arg_name, + default_value=default, + is_mutated=idx in mutated_inp_idx, + ) + ) + + # The output is represented as a single argument + out = HopArgumentInfoGen.from_example( + example_value=output, + name="out", + default_value=None, + is_mutated=False, + ) + return CFunctionSchemaGen.from_hop_argument_info(str(self), args, out) + class BaseHOPFunction(torch.autograd.Function): @staticmethod diff --git a/torch/_higher_order_ops/schema.py b/torch/_higher_order_ops/schema.py new file mode 100644 index 000000000000..1cf4e9a5032c --- /dev/null +++ b/torch/_higher_order_ops/schema.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Any, Optional + +import torch + + +# Below is an implementation of generating FunctionSchema from example values. +# This is helpful for generating FunctionSchema for HigherOrderOperator, where +# we don't have a function to inspect and each call of the higher order operator +# would have different schema. +@dataclass(frozen=True) +class HopArgumentInfo: + # Could give a name to the operand by default it's empty string. + name: str + example_value: Any + # Provide an default_value + default_value: Any + # Whether this arugment gets mutated in the hop subgraph. + # For output, this should always be False + is_mutated: bool + + +class HopArgumentInfoGen: + @staticmethod + def from_example( + example_value: Any, + *, + name: str = "", + default_value: Optional[Any], + is_mutated: bool = False, + ) -> HopArgumentInfo: + if default_value is not None: + assert type(example_value) == type(default_value) + return HopArgumentInfo( + name=name, + example_value=example_value, + default_value=default_value, + is_mutated=is_mutated, + ) + + +class CTypeGen: + convert_to_base_ty = { + int: torch._C.IntType.get(), + float: torch._C.FloatType.get(), + str: torch._C.StringType.get(), + bool: torch._C.BoolType.get(), + } + + # should return torch._C.JitType but that annotation is busted + @staticmethod + def from_example(obj: Any) -> Any: + import torch + + if isinstance(obj, torch.fx.GraphModule): + return torch._C.AnyType.get() + return torch._C._jit_try_infer_type(obj).type() + + +class CArgumentGen: + @staticmethod + def from_hop_argument_info( + arg_idx: int, arg_info: HopArgumentInfo, is_output: bool = False + ) -> Any: + typ = CTypeGen.from_example(arg_info.example_value) + if is_output: + return torch._C.Argument("", typ, None, None, False, None) + + alias_set = set({f"alias::a{arg_idx}"}) if arg_info.is_mutated else set() + alias_info = torch._C._AliasInfo(arg_info.is_mutated, alias_set, alias_set) # type: ignore[attr-defined] + return torch._C.Argument( + arg_info.name, typ, None, arg_info.default_value, False, alias_info + ) + + +class CFunctionSchemaGen: + """ + Note: [HigherOrderOperator schema generation] + Each invocation of a HigherOrderOperator will have a different schema. + For example, the schema of torch.cond varies depending on the true_fn and + false_fn. So we need a way to generate the schema for each invocation of a HOP. + + We want to enforce the following invariants for HOP's schema: + 1. Flattened inputs. There should be no pytree structure in it. + 2. Flattened outputs. Note even if the hop returns a single value, it should be wrapped as a tuple. + 3. No aliasing. This includes inp-inp aliasing, inp-out aliasing and out-out aliasing. + + By enforcing these invariants, we could make HOP's schema meets the requirement of schema parser + and makes hop easier to handle downstream. For example, suppose we have an invoke_quant_test HOP: + + class GraphModule(torch.nn.Module): + def forward(self, l_x_, l_y_): + subgraph_0 = self.subgraph_0 + invoke_quant_test = torch.ops.higher_order.invoke_quant_test(subgraph_0, l_x_, l_y_, scheme = 'nf4'); + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_, l_y_): + add_ = l_x_.add_(1) + matmul = l_x_ @ l_y_ + sin = matmul.sin() + child = sin.cos() + child_1 = l_x_ + l_y_ + child_2 = l_x_ - l_y_ + child_3 = l_x_ @ l_y_ + return (child, child_1, child_2, child_3) + + By encoding the inputs of hop into a list of HopArgumentInfo and output as a single HopArgumentInfo, + we would get the following schema: + invoke_quant_test(Any arg0, Tensor(!) arg1, Tensor arg2, str scheme="\\"nf4\\"") -> (Tensor, Tensor, Tensor, Tensor) + """ + + @staticmethod + def from_hop_argument_info( + op_name: str, + inp_argument_info: list[HopArgumentInfo], + out_argument_info: HopArgumentInfo, + ) -> Any: + args = [] + for i, arg_info in enumerate(inp_argument_info): + args.append(CArgumentGen.from_hop_argument_info(i, arg_info)) + + # NOTE: we want the output to always be a single argument with torch._C.TupleType. + assert isinstance( + out_argument_info.example_value, tuple + ), f"expect out_argument_info's example_value to be a tuple but got {out_argument_info.example_value}" + assert ( + not out_argument_info.is_mutated + ), "out_argument_info.is_mutated should always be set to False." + rets = None + if len(out_argument_info.example_value) == 1: + rets = [CArgumentGen.from_hop_argument_info(0, out_argument_info, True)] + else: + rets = [ + CArgumentGen.from_hop_argument_info( + i, + HopArgumentInfoGen.from_example( + name=f"out{i}", + example_value=val, + default_value=None, + is_mutated=False, + ), + is_output=True, + ) + for i, val in enumerate(out_argument_info.example_value) + ] + + return torch._C.FunctionSchema( + op_name, + "", + args, + rets, + False, + False, + ) diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 27f4e739eb41..fd3f327a68ae 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -703,6 +703,25 @@ def check_input_alias_and_mutation( gm: torch.fx.GraphModule, fake_args: list[FakeTensor], ) -> tuple[list[int], dict[int, int], dict[int, int], dict[int, int]]: + ( + mutated_inputs, + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + ) = check_input_alias_and_mutation_return_ouputs(gm, fake_args)[:-1] + return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map + + +def check_input_alias_and_mutation_return_ouputs( + gm: torch.fx.GraphModule, + fake_args: list[FakeTensor], +) -> tuple[ + list[int], + dict[int, int], + dict[int, int], + dict[int, int], + Union[tuple[Any, ...], list[Any]], +]: with disable_proxy_modes_tracing(): """This function returns mutated inputs, inp-inp alias, inp-out alias, out-out alias in the graph module gm. It checks whether input tensor versions have @@ -765,7 +784,13 @@ def _tensor_storage(t) -> StorageWeakRef: for i, inp in enumerate(cloned) if isinstance(inp, torch.Tensor) and _tensor_storage(inp) in out_storage_map } - return mutated_inputs, inp_inp_alias_map, inp_out_alias_map, out_out_alias_map + return ( + mutated_inputs, + inp_inp_alias_map, + inp_out_alias_map, + out_out_alias_map, + outputs, + ) registered_hop_fake_fns: dict[torch._ops.OpOverload, Callable] = {} diff --git a/torch/_ops.py b/torch/_ops.py index 0842f57fbff7..4e4c346e25eb 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -470,6 +470,26 @@ def wrapper(): return wrapper() + # NOTE [HigherOrderOprator Schema] + # Each invocation of a HigherOrderOperator (hop) should have its own schema because + # the subgraphs and the arguments can be different even for the same hop. + # + # Each hop should implement its own gen_schema method, which should + # take the same input as the __call__ method and returns a FunctionSchema. + # The schema provides a unified way to check if the hop mutates its inputs, + # which can be useful in implementing optimizations. + # + # If the hop doesn't implement the gen_schema method, + # we expect it to be functional. It should not mutate its inputs and there + # are no input, output aliasing via views or direct referencing. + def gen_schema(self, *args, **kwargs): + raise NotImplementedError( + f"HigherOrderOperator {self._name} does not implement a gen_schema. " + f"This is OK as long as the hop is functional. " + f"e.g. it should not mutate its inputs and there are no input, output aliasing " + f"via views or direct referencing." + ) + def __str__(self): return f"{self.name()}" diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 5911064b22f2..5c46e936a4ec 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1936,6 +1936,13 @@ void initJITBindings(PyObject* module) { self.addArgumentValues(value_map); }); py::class_(m, "FunctionSchema") + .def(py::init< + std::string, + std::string, + std::vector, + std::vector, + bool, + bool>()) .def_property_readonly( "name", [](FunctionSchema& self) { return self.name(); }) .def_property_readonly( @@ -1993,6 +2000,13 @@ void initJITBindings(PyObject* module) { .def_property_readonly( "is_mutable", [](FunctionSchema& self) { return self.is_mutable(); }); py::class_(m, "Argument") + .def(py::init< + std::string, + const TypePtr&, + std::optional, + std::optional, + bool, + std::optional>()) .def_property_readonly("name", [](Argument& self) { return self.name(); }) .def_property_readonly("type", [](Argument& self) { return self.type(); }) .def_property_readonly( @@ -2032,6 +2046,7 @@ void initJITBindings(PyObject* module) { return self.kwarg_only(); }); py::class_(m, "_AliasInfo") + .def(py::init, std::set>()) .def_property_readonly( "is_write", [](AliasInfo& self) { return self.isWrite(); }) .def_property_readonly( From a4bb2f106f8cc642539d4698b6d869a87adca92f Mon Sep 17 00:00:00 2001 From: rzou Date: Tue, 8 Apr 2025 10:24:01 -0700 Subject: [PATCH 312/332] Inductor respects exact strides on custom ops by default (#150511) If a tag is not specified on a custom operator, then inductor will assume that it needs exact strides. Test Plan: - tests + CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/150511 Approved by: https://github.com/eellison, https://github.com/shunting314 ghstack dependencies: #150495, #148104 --- test/inductor/test_triton_kernels.py | 2 -- test/test_custom_ops.py | 3 ++- torch/_inductor/config.py | 2 +- torch/_library/custom_ops.py | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 4966821120c5..951440ff52a2 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3451,7 +3451,6 @@ def impl2(x): lib.define( "add_op(Tensor x, Tensor y) -> Tensor", - tags=[torch._C.Tag.needs_exact_strides], ) def impl(x, y): @@ -3465,7 +3464,6 @@ def meta(x, y): lib.define( "add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()", - tags=[torch._C.Tag.needs_exact_strides], ) def impl_out(x, y, out): diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index a2691d5e1cbf..2c74c5531258 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3699,6 +3699,7 @@ def vmap(info, in_dims, w, x=2, *, y=3, z): self.assertEqual(result, w * 2 * 3 * 42) def test_layout_constraint_tags(self): + needs_exact_strides = torch._C.Tag.needs_exact_strides needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order flexible_layout = torch._C.Tag.flexible_layout # (tags, the result of the tag inference) @@ -3706,7 +3707,7 @@ def test_layout_constraint_tags(self): ({needs_fixed_stride_order}, needs_fixed_stride_order), ({flexible_layout}, flexible_layout), # If no tags are provided, then the following is the default - (set(), needs_fixed_stride_order), + (set(), needs_exact_strides), # If multiple tags are provided, then we use the most constrained tag. ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), ] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 040b91917398..24445b374ac8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -127,7 +127,7 @@ def prologue_fusion_enabled() -> bool: # then we assume the following applies. custom_op_default_layout_constraint: Literal[ "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" -] = "needs_fixed_stride_order" +] = "needs_exact_strides" # The default layout constraint for user-defined triton kernels. # See "The default layout constraint for custom operators" for options. diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 544bbbf61582..8c46dfabc058 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -615,7 +615,7 @@ def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: lib.define( schema_str, - tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order, *tags], + tags=[_C.Tag.pt2_compliant_tag, *tags], ) self._opoverload = utils.lookup_op(self._qualname) From a0e796df03bdf34b3b552589d9bce2b36b5d4295 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 9 Apr 2025 16:49:48 +0000 Subject: [PATCH 313/332] Revert "Inductor respects exact strides on custom ops by default (#150511)" This reverts commit a4bb2f106f8cc642539d4698b6d869a87adca92f. Reverted https://github.com/pytorch/pytorch/pull/150511 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/2e7c9d33e7f933ac3b723cb3bb05b9c88432c25c) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493)) --- test/inductor/test_triton_kernels.py | 2 ++ test/test_custom_ops.py | 3 +-- torch/_inductor/config.py | 2 +- torch/_library/custom_ops.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 951440ff52a2..4966821120c5 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -3451,6 +3451,7 @@ def impl2(x): lib.define( "add_op(Tensor x, Tensor y) -> Tensor", + tags=[torch._C.Tag.needs_exact_strides], ) def impl(x, y): @@ -3464,6 +3465,7 @@ def meta(x, y): lib.define( "add_out_op(Tensor x, Tensor y, Tensor(a!) out) -> ()", + tags=[torch._C.Tag.needs_exact_strides], ) def impl_out(x, y, out): diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 2c74c5531258..a2691d5e1cbf 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -3699,7 +3699,6 @@ def vmap(info, in_dims, w, x=2, *, y=3, z): self.assertEqual(result, w * 2 * 3 * 42) def test_layout_constraint_tags(self): - needs_exact_strides = torch._C.Tag.needs_exact_strides needs_fixed_stride_order = torch._C.Tag.needs_fixed_stride_order flexible_layout = torch._C.Tag.flexible_layout # (tags, the result of the tag inference) @@ -3707,7 +3706,7 @@ def test_layout_constraint_tags(self): ({needs_fixed_stride_order}, needs_fixed_stride_order), ({flexible_layout}, flexible_layout), # If no tags are provided, then the following is the default - (set(), needs_exact_strides), + (set(), needs_fixed_stride_order), # If multiple tags are provided, then we use the most constrained tag. ({flexible_layout, needs_fixed_stride_order}, needs_fixed_stride_order), ] diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 24445b374ac8..040b91917398 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -127,7 +127,7 @@ def prologue_fusion_enabled() -> bool: # then we assume the following applies. custom_op_default_layout_constraint: Literal[ "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" -] = "needs_exact_strides" +] = "needs_fixed_stride_order" # The default layout constraint for user-defined triton kernels. # See "The default layout constraint for custom operators" for options. diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index 8c46dfabc058..544bbbf61582 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -615,7 +615,7 @@ def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: lib.define( schema_str, - tags=[_C.Tag.pt2_compliant_tag, *tags], + tags=[_C.Tag.pt2_compliant_tag, _C.Tag.needs_fixed_stride_order, *tags], ) self._opoverload = utils.lookup_op(self._qualname) From 01568cb17a5de2d7943d102b98cdc35b58eef411 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 9 Apr 2025 16:49:48 +0000 Subject: [PATCH 314/332] Revert "Refactor layout constraint selection logic (#148104)" This reverts commit 2e7c9d33e7f933ac3b723cb3bb05b9c88432c25c. Reverted https://github.com/pytorch/pytorch/pull/148104 on behalf of https://github.com/atalman due to [GH job link](https://github.com/pytorch/pytorch/actions/runs/14357056427/job/40251630946) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/2e7c9d33e7f933ac3b723cb3bb05b9c88432c25c) ([comment](https://github.com/pytorch/pytorch/pull/148104#issuecomment-2790369493)) --- torch/_inductor/config.py | 2 +- torch/_inductor/graph.py | 56 +++++++++++++-------------- torch/_inductor/lowering.py | 45 ++++++++++----------- torch/fx/experimental/proxy_tensor.py | 4 +- 4 files changed, 50 insertions(+), 57 deletions(-) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 040b91917398..27b77d199f09 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -126,7 +126,7 @@ def prologue_fusion_enabled() -> bool: # If the custom op does not have a layout constraint tag already # then we assume the following applies. custom_op_default_layout_constraint: Literal[ - "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout" + "needs_fixed_stride_order", "flexible_layout" ] = "needs_fixed_stride_order" # The default layout constraint for user-defined triton kernels. diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 9063df455b0a..ca989c431aa7 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -80,13 +80,11 @@ FALLBACK_ALLOW_LIST, fallback_handler, fallback_node_due_to_unsupported_type, - get_layout_constraint_tag, lowerings, make_fallback, maybe_layout_constraints, needs_realized_inputs, require_contiguous, - tag_to_layout_constraint, unsupported_output_tensor, ) from .runtime import autotune_cache @@ -246,14 +244,6 @@ def _get_overload_packet( cur.meta["dislike_padding"] = True continue - if ( - isinstance(cur.target, torch._ops.OpOverload) - and get_layout_constraint_tag(cur.target) - == torch._C.Tag.needs_exact_strides - ): - cur.meta["dislike_padding"] = True - continue - op = _get_overload_packet(cur) if not op: continue @@ -1160,26 +1150,34 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> error.operator_str(target, args, kwargs), ) - tag = get_layout_constraint_tag(target, with_default=False) - if ( - tag is None - and torch._library.utils.is_builtin(target) - and self.is_backward - ): - # for implicit fallback ATen ops during backward, if there - # is no layout constraint tag, we conservatively require contiguous - # input since some eager kernels do not - # support non-contiguous inputs. Otherwise they may silently cause - # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 - # We only do this For ATen ops and for backward. - # - # TODO: should really switch to "needs_fixed_stride" constraint on these - # and identify them one by one. - decided_constraint = require_contiguous # type: ignore[assignment] + # use contiguous unless the (custom) op asks something else + # explicitly + if torch._C.Tag.needs_exact_strides in target.tags: + decided_constraint = constrain_to_fake_tensors # type: ignore[assignment] + elif torch._C.Tag.needs_fixed_stride_order in target.tags: + decided_constraint = constrain_to_fx_strides # type: ignore[assignment] + elif torch._C.Tag.flexible_layout in target.tags: + decided_constraint = None # type: ignore[assignment] else: - tag = get_layout_constraint_tag(target, with_default=True) - decided_constraint = tag_to_layout_constraint(tag) - + # If there are no tags, we do different things depending on + # if it's a builtin ATen/prim ops or custom ops. + # For ATen ops, we require_contiguous to fix https://github.com/pytorch/pytorch/issues/140452 + # For custom ops, we constrain_to_fx_strides to maintain the + # behavior of PyTorch 2.5: https://github.com/pytorch/pytorch/issues/148356 + # + # For ATen ops, only apply the constraint for backward + # ops since fwd ops should work for any strides. + if torch._library.utils.is_builtin(target) and self.is_backward: + decided_constraint = require_contiguous # type: ignore[assignment] + else: + # maybe_layout_constraints will decide the layout constraint for the custom op + # lazily + decided_constraint = None # type: ignore[assignment] + + # for implicitly fallback ops, we conservatively requires + # contiguous input since some eager kernels does not + # support non-contiguous inputs. They may silently cause + # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452 make_fallback(target, layout_constraint=decided_constraint) elif get_decompositions([target]): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index d9e0fb03d004..24520887f6aa 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -157,40 +157,37 @@ def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., A return None if fn in _maybe_layout_constraints: return _maybe_layout_constraints[fn] - return None - + # OpOverload with custom lowerings override tag-based layout constraints + if fn in lowerings: + _maybe_layout_constraints[fn] = None + return None + # We lazily register tag-based layout constraints. + + def handle_layout_constraint_tag(tag): + if tag is torch._C.Tag.needs_fixed_stride_order: + _maybe_layout_constraints[fn] = constrain_to_fx_strides + return _maybe_layout_constraints[fn] + elif tag is torch._C.Tag.flexible_layout: + _maybe_layout_constraints[fn] = None + return None + else: + raise AssertionError(f"Unknown layout constraint tag: {tag}") -tags_by_priority = [ - torch._C.Tag.needs_exact_strides, - torch._C.Tag.needs_fixed_stride_order, - torch._C.Tag.flexible_layout, -] + tag = get_layout_constraint_tag(fn) + return handle_layout_constraint_tag(tag) -def get_layout_constraint_tag(fn, *, with_default=True): +def get_layout_constraint_tag(fn): tags_by_priority = [ - torch._C.Tag.needs_exact_strides, torch._C.Tag.needs_fixed_stride_order, torch._C.Tag.flexible_layout, ] for tag in tags_by_priority: if tag in fn.tags: return tag - if with_default: - if torch._library.utils.is_builtin(fn): - return torch._C.Tag.flexible_layout - return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) - return None - - -def tag_to_layout_constraint(tag): - if tag == torch._C.Tag.needs_exact_strides: - return constrain_to_fake_tensors - if tag == torch._C.Tag.needs_fixed_stride_order: - return constrain_to_fx_strides - if tag == torch._C.Tag.flexible_layout: - return None - raise AssertionError(f"Unknown layout constraint tag: {tag}") + if torch._library.utils.is_builtin(fn): + return torch._C.Tag.flexible_layout + return getattr(torch._C.Tag, config.custom_op_default_layout_constraint) def assert_nyi(cond, msg): diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 9bbc16a895b6..4193606d849d 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1169,9 +1169,7 @@ def _should_save_eager_input_vals( f"propagate the FakeTensor vals. Please file an issue." ) if isinstance(target, torch._ops.OpOverload): - from torch._inductor.lowering import get_layout_constraint_tag - - return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides + return torch._C.Tag.needs_exact_strides in target.tags return False From c59aaa03ffbbc3b502040e1d60469bb7aebb83b4 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 4 Apr 2025 16:47:00 -0700 Subject: [PATCH 315/332] [DTensor] add _explicit_order_placements util (#150493) The util converts a list of placements in the traditional DTensor format (e.g. [_StridedShard(0), Shard(0)], where list position is mesh_dim and sharding is always applied left-to-right (from dim 0 to higher dims)) to a more explicitly ordered format, also replacing '_StridedShard' with simple 'Shard' placements in the process. (e.g. the above becomes [(1, Shard(0)), (0, Shard(0)] where the first item in the tuple is the mesh_dim and the ordering of the tuples is the sharding order. This is useful so far as a helper for fixing local shape computation for strided sharding in the uneven shape case, in the following PR- but may also be useful more broadly if we can use explicit orderings to simplify other parts of DTensor logic. This skips implementing some combinations of _StridedSharding that are not currently used in the wild today, but could be supported easily. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150493 Approved by: https://github.com/wanchaol, https://github.com/XilunWu --- test/distributed/tensor/test_utils.py | 90 ++++++++++++++++++++- torch/distributed/tensor/_utils.py | 54 +++++++++++++ torch/distributed/tensor/placement_types.py | 3 + 3 files changed, 144 insertions(+), 3 deletions(-) diff --git a/test/distributed/tensor/test_utils.py b/test/distributed/tensor/test_utils.py index a9798f9d434a..179f4a7913ce 100644 --- a/test/distributed/tensor/test_utils.py +++ b/test/distributed/tensor/test_utils.py @@ -3,13 +3,16 @@ import itertools import torch -from torch.distributed._tensor import distribute_tensor, DTensor -from torch.distributed._tensor._utils import compute_local_shape_and_global_offset from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import distribute_tensor, DTensor from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._utils import ( + _explicit_order_placements, + compute_local_shape_and_global_offset, +) from torch.distributed.tensor.debug import CommDebugMode from torch.distributed.tensor.placement_types import _StridedShard, Replicate, Shard -from torch.testing._internal.common_utils import run_tests +from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorTestBase, with_comms, @@ -19,6 +22,87 @@ c10d_functional = torch.ops.c10d_functional +class LocalTest(TestCase): + def test_explicit_order_placements(self): + # mesh_shape: ShapeType, placements: Sequence[Placement] + test_cases = [ + { + "mesh_shape": [2, 4], + "placements": [Replicate(), Replicate()], + "ordered": [(0, Replicate()), (1, Replicate())], + }, + { + "mesh_shape": [3, 2], + "placements": [Shard(0), Replicate()], + "ordered": [(0, Shard(0)), (1, Replicate())], + }, + { + "mesh_shape": [2, 4], + "placements": [_StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(1, Shard(0)), (0, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=4), Shard(0)], + "ordered": [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))], + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=12), + _StridedShard(0, split_factor=4), + Shard(0), + ], + "ordered": [(2, Shard(0)), (1, Shard(0)), (0, Shard(0))], + }, + ] + for test_case in test_cases: + actual = _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + expected = test_case["ordered"] + + self.assertEqual( + actual, + expected, + f"mesh_shape={test_case['mesh_shape']} placements={test_case['placements']}, output: {actual=}, {expected=}", + ) + + error_cases = [ + { + "mesh_shape": [2, 3, 4], + "placements": [Shard(0), _StridedShard(0, split_factor=3), Shard(0)], + "exception_type": RuntimeError, + "exception_text": "Can only convert _StridedShard to ordered Shard if split_factor", + }, + { + "mesh_shape": [2, 3, 4], + "placements": [ + _StridedShard(0, split_factor=3), + Shard(0), + Shard(0), + ], + "exception_type": NotImplementedError, + "exception_text": r"Strided sharding does not allow Shard\(\) to appear after the strided part has ended", + }, + { + "mesh_shape": [2, 3], + "placements": [ + Shard(0), + ], + "exception_type": RuntimeError, + "exception_text": "Expected one placement per mesh dim", + }, + ] + for test_case in error_cases: + with self.assertRaisesRegex( + test_case["exception_type"], test_case["exception_text"] + ): + _explicit_order_placements( + test_case["mesh_shape"], test_case["placements"] + ) + + class UtilTest(DTensorTestBase): @property def world_size(self): diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index 61705610f08f..34b000a34910 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Sequence from typing import cast @@ -15,6 +16,59 @@ ) +def _explicit_order_placements( + mesh_shape: ShapeType, placements: Sequence[Placement] +) -> Sequence[tuple[int, Placement]]: + """ + Replace Strided Shards with regular shards in an adjusted order. + + Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. + + ex. + [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> + [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] + + """ + if not len(placements) == len(mesh_shape): + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." + ) + ordered = [] + deferred_strided_placements = defaultdict(list) + strided_part_ended_for_dim = set() + for mesh_dim, p in enumerate(placements): + if isinstance(p, _StridedShard): + # validate the stride is the correct multiple of the meshdim and the earlier shard + deferred_strided_placements[p.dim].append((mesh_dim, p)) + + else: + ordered.append((mesh_dim, p)) + if isinstance(p, Shard): + if p.dim in strided_part_ended_for_dim: + raise NotImplementedError( + f"Strided sharding does not allow Shard() to appear after " + f"the strided part has ended. {p} at mesh dim {mesh_dim} in " + f"{placements} violates this assumption." + ) + + if p.dim in deferred_strided_placements: + strided_part_ended_for_dim.add(p.dim) + strided_placements = deferred_strided_placements.pop(p.dim) + aggregate_size = mesh_shape[mesh_dim] + while len(strided_placements) > 0: + strided_mesh_dim, strided = strided_placements.pop() + if not strided.split_factor == aggregate_size: + raise RuntimeError( + f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" + f" == aggregate mesh size ({aggregate_size})" + ) + aggregate_size *= mesh_shape[strided_mesh_dim] + ordered.append((strided_mesh_dim, Shard(p.dim))) + + return ordered + + def compute_local_shape_and_global_offset( global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] ) -> tuple[tuple[int, ...], tuple[int, ...]]: diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index ceb9f170fd3e..7b3302359e03 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -472,6 +472,9 @@ def _split_tensor( f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" ) + # num_chunks represents the size of this StridedShard mesh dim, while self.split_factor + # represents the aggregate num chunks for other shardings applied logically earlier than this strided shard. + # (e.g. in FSDP+TP case, num_chunks is size(dp dim), split_factor is size(tp dim)) total_split = num_chunks * self.split_factor assert tensor.size(self.dim) % total_split == 0, ( "_StridedShard currently only allows even sharding but got tensor size" From 6fb089f2a2eea75a45ac2340f0e68736524e20bf Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 8 Apr 2025 17:06:57 -0700 Subject: [PATCH 316/332] [AO] fix per token block size calculation (#150890) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150890 Approved by: https://github.com/jerryzh168 --- torch/ao/quantization/observer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 6a39bdc0fc39..673d52e8924e 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -493,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -702,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -997,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase): 3. Compute the scale and zero point the same way as in the :class:`~torch.ao.quantization.MinMaxObserver` """ + histogram: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor @@ -1524,6 +1527,7 @@ class RecordingObserver(ObserverBase): qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit """ + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} def __init__(self, dtype=torch.quint8): @@ -1790,7 +1794,7 @@ def get_block_size( ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" return (1, granularity.group_size) elif isinstance(granularity, PerToken): - block_size = list(input_shape) + block_size = [1] * len(input_shape) block_size[-1] = input_shape[-1] return tuple(block_size) raise ValueError(f"Unsupported Granularity: {granularity}") From cc185c32e0c9719aa67d7490c8ed89419e9d7f1d Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 9 Apr 2025 17:28:29 +0000 Subject: [PATCH 317/332] [aoti] Use generate_fake_kernels_from_real_mismatches config for draft exported programs (#150651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Sometimes we get `MetadataMismatchError` in aoti compilation because draft export uses the flag below to infer the fake kernel when there’s a mismatch, but aoti doesn’t have this flag turned on. https://fburl.com/code/9qzytl6q torch._functorch.config.generate_fake_kernels_from_real_mismatches If we set this flag to True, then aoti compilation would work. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:test_aot_inductor -- -r aoti_runtime_asserts ``` Differential Revision: D72345085 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150651 Approved by: https://github.com/angelayi --- test/inductor/test_aot_inductor.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index d008dba80421..d58809ea769e 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -3372,8 +3372,7 @@ def forward(self, q, k, v, attn_bias): self.check_model(Model(), example_inputs) def test_aoti_runtime_asserts(self): - from torch._dispatch.python import enable_python_dispatcher - from torch.export._draft_export import draft_export + from torch.export._draft_export import draft_export, FailureType with torch.library._scoped_library("mylib", "FRAGMENT") as lib: torch.library.define( @@ -3403,23 +3402,23 @@ def forward(self, a, b): example_inputs = (torch.randn(100), torch.tensor(10)) ep = draft_export(M(), example_inputs) + report = ep._report + need_config_patch = any( + not f.xfail and f.failure_type == FailureType.MISMATCHED_FAKE_KERNEL + for f in report.failures + ) m = ep.module() - from torch.fx.passes.fake_tensor_prop import FakeTensorProp - example_inputs = [ - node.meta["val"] for node in m.graph.nodes if node.op == "placeholder" - ] - fake_mode = example_inputs[0].fake_mode - with enable_python_dispatcher(), fake_mode: - FakeTensorProp(m, mode=fake_mode).propagate_dont_convert_inputs( - *example_inputs - ) + # This should no longer be needed after #150093 + from torch._functorch import config as functorch_config - # TODO: change to the tests below after MetadataMismatchError is fixed - # pt2_file = torch._inductor.aoti_compile_and_package(ep) - # optimized = torch._inductor.aoti_load_package(pt2_file) + with functorch_config.patch( + {"generate_fake_kernels_from_real_mismatches": need_config_patch} + ): + pt2_file = torch._inductor.aoti_compile_and_package(ep) + optimized = torch._inductor.aoti_load_package(pt2_file) - # self.assertTrue(same(optimized(example_inputs), m(example_inputs))) + self.assertTrue(same(optimized(*example_inputs), m(*example_inputs))) def test_index_put_with_none_index(self): # index_put falls back in the deterministic mode From d04a6ec0211858e22db7e0e3ec965b462ed19513 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 9 Apr 2025 17:59:17 +0000 Subject: [PATCH 318/332] add reduce_scatter to symm mem ops (#150813) + a few small fixes (don't error out on 0-element tensors, a few more checks for contiguous outputs, more threads for better perf). Pull Request resolved: https://github.com/pytorch/pytorch/pull/150813 Approved by: https://github.com/xw285cornell --- test/distributed/test_symmetric_memory.py | 76 ++++++- .../c10d/CUDASymmetricMemory-inl.h | 2 +- .../c10d/CUDASymmetricMemoryOps.cu | 211 ++++++++++++++++-- .../csrc/distributed/c10d/SymmetricMemory.cpp | 4 + 4 files changed, 276 insertions(+), 17 deletions(-) diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index b5e961276f87..d25da76a8931 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -771,7 +771,7 @@ def test_subgroup(self) -> None: self.assertTrue(buf.eq(peer_rank + world.size() // 2).all()) -@skipIfRocm +# @skipIfRocm @instantiate_parametrized_tests @requires_cuda_p2p_access() class SymmMemCollectiveTest(MultiProcessTestCase): @@ -912,7 +912,7 @@ def test_two_shot_all_reduce(self) -> None: shift = align_bytes // t.element_size() numel = size_bytes // t.element_size() res = t[shift : shift + numel] - res.normal_().fill_(1) + res.normal_() inp = res.clone() if not inplace: out = torch.empty_like(inp) @@ -940,6 +940,78 @@ def _verify_all_reduce_result(self, inp, res): gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01 ) + @skipIfRocm + @skip_if_lt_x_gpu(4) + def test_reduce_scatter(self) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + + for dtype, size_bytes, align_bytes, split_last_dim in itertools.product( + [torch.float, torch.bfloat16], + [128, 8192, 36 * 1024 * 16], + [4, 8, 16], + [True, False], + ): + t = symm_mem.empty(36 * 1024 * 16, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + + self.assertTrue(t.data_ptr() % 16 == 0) + self.assertTrue(align_bytes % t.element_size() == 0) + self.assertTrue(size_bytes % t.element_size() == 0) + + shift = align_bytes // t.element_size() + numel = size_bytes // t.element_size() + res = t[shift : shift + numel].normal_() + if split_last_dim: + res = res.view(-1, 128 // t.element_size()) + inp = res.clone() + out_size = list(inp.shape) + out_size[-1] = inp.shape[-1] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + torch.ops.symm_mem.reduce_scatter_out(res, group_name, split_last_dim, out) + + # Head and tail should not be written + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_reduce_scatter_result(inp, out) + + dist.destroy_process_group() + + @skipIfRocm + @skip_if_lt_x_gpu(4) + def test_reduce_scatter_corner_cases(self) -> None: + dtype = torch.bfloat16 + self._init_process() + group_name = dist.group.WORLD.group_name + t = symm_mem.empty(16384, dtype=dtype, device=self.device).fill_(0) + symm_mem.rendezvous(t, group=group_name) + res = t[:0] + out_size = res.shape[0] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out) + res = t[:48] + out_size = res.shape[0] // self.world_size + out = torch.empty(out_size, dtype=dtype, device=self.device) + with self.assertRaisesRegex(RuntimeError, "divisible"): + torch.ops.symm_mem.reduce_scatter_out(res, group_name, False, out) + res = t[: 2 * 48].view(2, 48) + out = torch.empty(2, 48 // self.world_size, dtype=dtype, device=self.device) + with self.assertRaisesRegex(RuntimeError, "divisible"): + torch.ops.symm_mem.reduce_scatter_out(res, group_name, True, out) + + def _verify_reduce_scatter_result(self, inp, res): + gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, *res.shape) + gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, *inp.shape) + sum_inps = gathered_inps.sum(0) + slice_width = sum_inps.shape[-1] // self.world_size + for i in range(self.world_size): + torch.testing.assert_close( + gathered_res[i], + sum_inps[..., i * slice_width : (i + 1) * slice_width], + rtol=1e-01, + atol=1e-01, + ) + @skip_if_lt_x_gpu(4) @parametrize("align_bytes", [4, 8, 16]) def test_multimem_all_gather(self, align_bytes: int) -> None: diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h index c228da413a3e..46bf5ff31987 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h @@ -314,7 +314,7 @@ __device__ __inline__ Vec ld_vec(const T* addr) { template __device__ __inline__ void st_vec(T* addr, const Vec& vec) { -#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) CUDA_KERNEL_ASSERT(false); #else if constexpr (Alignment == 16) { diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index 02baeb51e51c..e67dfa2f60f1 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -463,6 +463,10 @@ at::Tensor one_shot_all_reduce_out_impl( local_input->numel() <= input.numel(), "one_shot_all_reduce: local input size must be smaller than symm buffer size."); } + if (input.numel() == 0) { + TORCH_CHECK(input.scalar_type() == out.scalar_type()); + return out; + } auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); TORCH_CHECK( symm_mem != nullptr, @@ -555,9 +559,14 @@ at::Tensor one_shot_all_reduce_copy( } constexpr size_t two_shot_all_reduce_max_num_blocks = 24; -constexpr size_t two_shot_all_reduce_max_num_threads = 512; - -template +constexpr size_t two_shot_all_reduce_max_num_threads = 1024; + +template < + typename T, + int alignment, + int k_world_size, + bool reduce_scatter = false, + bool split_last_dim = false> static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ void two_shot_all_reduce_kernel( T** input_ptrs, @@ -566,31 +575,48 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ size_t numel, uint32_t** signal_pads, size_t rank, - size_t world_size) { + size_t world_size, + size_t last_dim_size = 0) { static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - + int32_t N_last_dim = + last_dim_size / world_size; // used only for split_last_dim reduce_scatter sync_remote_blocks(signal_pads, rank, world_size); __syncthreads(); const size_t numel_per_rank = - at::round_up(numel, alignment * world_size) / world_size; - const size_t start = numel_per_rank * rank; + at::round_up(numel, numel_per_thread * world_size) / world_size; + const size_t start = split_last_dim ? last_dim_size / world_size * rank + : numel_per_rank * rank; auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; auto stride = blockDim.x * gridDim.x * numel_per_thread; for (size_t i = offset; i < numel_per_rank; i += stride) { - if (start + i >= numel) { - continue; + if constexpr (!reduce_scatter) { + // we call reduce-scatter only with evenly divisible number of elements + if (start + i >= numel) { + continue; + } + } + size_t idx = i; + if constexpr (split_last_dim) { + idx = i / N_last_dim * last_dim_size + i % N_last_dim; } auto vec = load_and_reduce( - input_ptrs, rank, world_size, input_offset + start + i); - // store to local buffer - st_vec(input_ptrs[rank] + input_offset + start + i, vec); + input_ptrs, rank, world_size, input_offset + start + idx); + // store to local buffer or to output + if constexpr (reduce_scatter) { + st_vec(output_ptr + i, vec); + } else { + st_vec(input_ptrs[rank] + input_offset + start + i, vec); + } } __syncthreads(); sync_remote_blocks(signal_pads, rank, world_size); + if constexpr (reduce_scatter) { + return; + } __syncthreads(); for (size_t i = offset; i < numel_per_rank; i += stride) { Vec tmp[k_world_size]; @@ -611,8 +637,7 @@ static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ if (remote_start + i >= numel) { continue; } - st_vec( - output_ptr + remote_start + i, tmp[step]); + st_vec(output_ptr + remote_start + i, tmp[step]); } } // need to make sure all blocks exit simultaneously so that the data @@ -679,11 +704,28 @@ at::Tensor two_shot_all_reduce_impl( get_and_verify_alignment(input, "two_shot_all_reduce"); if (output.has_value()) { + TORCH_CHECK( + output->is_contiguous(), + "two_shot_all_reduce: output must be contiguous."); const size_t output_alignment = get_and_verify_alignment(*output, "two_shot_all_reduce"); TORCH_CHECK( alignment <= output_alignment, "two_shot_all_reduce: output alignment must be equal to or larger than input."); + TORCH_CHECK( + output->sizes() == input.sizes(), + "two_shot_all_reduce: input/output size mismatch, input.sizes(): ", + input.sizes(), + ", output.sizes(): ", + output->sizes()); + if (input.numel() == 0) { + TORCH_CHECK(output->scalar_type() == input.scalar_type()); + return *output; + } + } else { + if (input.numel() == 0) { + return input; + } } int num_blocks = 0, num_threads = 0; @@ -764,6 +806,146 @@ at::Tensor two_shot_all_reduce_out( at::Tensor output) { return two_shot_all_reduce_impl(input, output, reduce_op, group_name); } + +at::Tensor reduce_scatter_out( + at::Tensor input, + std::string group_name, + bool split_last_dim, + at::Tensor output) { + TORCH_CHECK( + input.is_contiguous(), "reduce_scatter: input must be contiguous."); + TORCH_CHECK( + output.is_contiguous(), "reduce_scatter: output must be contiguous."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); + TORCH_CHECK( + symm_mem != nullptr, + "reduce_scatter: input must be allocated with empty_strided_p2p()."); + + const size_t alignment = get_and_verify_alignment(input, "reduce_scatter"); + + const size_t output_alignment = + get_and_verify_alignment(input, "reduce_scatter"); + + TORCH_CHECK( + input.numel() % + (symm_mem->get_world_size() * + (alignment / input.element_size())) == + 0, + "expected number of elements to be divisible by world_size * alignment, number of elements ", + input.numel(), + " world size ", + symm_mem->get_world_size(), + "alignment ", + alignment); + + if (split_last_dim) { + TORCH_CHECK(input.dim() == output.dim()); + bool are_equal_except_last = std::equal( + input.sizes().begin(), input.sizes().end() - 1, output.sizes().begin()); + TORCH_CHECK( + are_equal_except_last, + "reduce_scatter expected input and output to have same sizes except in the last dimension"); + TORCH_CHECK( + output.size(-1) == input.size(-1) / symm_mem->get_world_size(), + "reduce_scatter expected output last dim size to be input last dim size / world_size"); + + TORCH_CHECK( + input.size(-1) % + (symm_mem->get_world_size() * + (alignment / input.element_size())) == + 0, + "expected last dimension to be divisible by world_size * alignment, last dimension ", + input.size(-1), + " world size ", + symm_mem->get_world_size(), + "alignment ", + alignment); + } else { + TORCH_CHECK(input.dim() == 1, "reduce_scatter expected 1D input"); + TORCH_CHECK(output.dim() == 1, "reduce_scatter expected 1D output"); + TORCH_CHECK(output.numel() == input.numel() / symm_mem->get_world_size()); + } + if (input.numel() == 0) { + TORCH_CHECK(input.scalar_type() == output.scalar_type()); + return output; + } + + TORCH_CHECK( + output_alignment >= alignment, + "reduce_scatter: output alignment should be not smaller than input alignment"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + symm_mem->get_world_size(), + two_shot_all_reduce_max_num_blocks, + two_shot_all_reduce_max_num_threads, + num_blocks, + num_threads); + if (split_last_dim) { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel< + scalar_t, + k_alignment, + k_world_size, + true, + true> + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + output.data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size(), + input.size(-1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + } else { + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES_NO_DEFAULT(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel< + scalar_t, + k_alignment, + k_world_size, + true, + false> + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + output.data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size(), + input.size(-1)); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + } + return output; +} } // namespace #endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030 @@ -899,6 +1081,7 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { m.impl("one_shot_all_reduce_copy_out", ::one_shot_all_reduce_copy_out); m.impl("two_shot_all_reduce_", ::two_shot_all_reduce_); m.impl("two_shot_all_reduce_out", ::two_shot_all_reduce_out); + m.impl("reduce_scatter_out", ::reduce_scatter_out); m.impl("_async_input_mm", c10d::cuda::detail::async_input_mm); #endif diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index 76eb7205a398..f68681de1698 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -250,6 +250,10 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( "two_shot_all_reduce_out(Tensor(a!) input, str reduce_op, str group_name, Tensor(b!) output) -> Tensor(b!)"); + // note this implementation also modified the input tensor + m.def( + "reduce_scatter_out(Tensor(a!) input, str group_name, bool split_last_dim, Tensor(b!) output) -> Tensor(b!)"); + // An mm that supports consuming asynchronous input. It guarantees the // following rasterization order, and that the corresponding signal arrives // before an input chunk is consumed. From d3a2872c676b1c67ee47170422f247d429e22241 Mon Sep 17 00:00:00 2001 From: Zhuoran Zhao Date: Wed, 9 Apr 2025 18:35:36 +0000 Subject: [PATCH 319/332] Hipify global scrach defintion in AOTI codegen (#150893) Summary: as title, a refactor is very needed I think .... or at least unify internal/external AOTI wrapper hipification method Test Plan: P1780296121 Differential Revision: D72683568 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150893 Approved by: https://github.com/davidberard98 --- torch/_inductor/codegen/cpp_wrapper_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 77364ae48734..e0f0726e7a89 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -440,7 +440,7 @@ def process_args(arg, arg_type, arg_signature=None): is not None ): global_scratch_def, global_scratch_var = global_scratch - code.writeline(global_scratch_def) + code.writeline(maybe_hipify_code_wrapper(global_scratch_def)) new_args.append(f"&{global_scratch_var}") return ", ".join(new_args) From cfab04d01b734ae48e0dd1ce957a4fad793ea3a2 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 9 Apr 2025 18:52:01 +0000 Subject: [PATCH 320/332] Fix aten.div type promotion for FakeTensor (#150874) Summary: When we divide a FakeTensor by an integer using the fast op implementation, the type promotion should be `ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT` so we get a float when dividing an int FakeTensor by an integer. ``` FAST = get_fast_op_impls() fast_div = FAST[torch.ops.aten.div.Tensor] fast_div(fake_tensor, some_int) ``` Test Plan: ``` python test/test_fake_tensor.py -k test_fast_div ``` Differential Revision: D72667430 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150874 Approved by: https://github.com/angelayi --- test/test_fake_tensor.py | 8 ++++++++ torch/_subclasses/fake_impls.py | 11 ++++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 1b99bd94061e..7dad38355e20 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -972,6 +972,14 @@ def add(x, y): self.assertIsInstance(r[0], FakeTensor) self.assertIsInstance(r[1], FakeTensor) + def test_fast_div(self): + mode = FakeTensorMode() + with mode: + x = torch.empty(2, 2, device="cpu", dtype=torch.int32) + from torch._subclasses.fake_impls import get_fast_op_impls + fast_div = get_fast_op_impls()[torch.ops.aten.div.Tensor] + y = fast_div(mode, x, 2) + self.assertEqual(y.dtype, torch.float32) instantiate_parametrized_tests(FakeTensorTest) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index bc7bc1ba7f82..9d85bf4c77b3 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -890,7 +890,9 @@ def infer_size(a, b): return tuple(expandedSizes) -def make_fast_binary_impl(slow_ref): +def make_fast_binary_impl( + slow_ref, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT +): def fast_binary_impl(mode, *args, **kwargs): def slow(msg): count_label(f"slow {msg}") @@ -957,7 +959,7 @@ def slow(msg): # compute promotion # TODO: we don't need the compute type _, common_dtype = elementwise_dtypes( - *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + *operands, type_promotion_kind=type_promotion_kind ) # check all tensors on same device @@ -1042,7 +1044,10 @@ def get_fast_op_impls(): ) register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul)) # type: ignore[has-type] register_fast_op_impl(torch.ops.aten.div.Tensor)( - make_fast_binary_impl(torch._refs.div) + make_fast_binary_impl( + torch._refs.div, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ) ) register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach) return FAST_OP_IMPLEMENTATIONS From a4545f09daece41a89109bb2650232072ecbaa7a Mon Sep 17 00:00:00 2001 From: "Yanan Cao (PyTorch)" Date: Wed, 9 Apr 2025 19:18:29 +0000 Subject: [PATCH 321/332] [Codemod][AddExplicitStrictExportForTrainingInferenceArg] caffe2/test/export (#150884) Differential Revision: D72667175 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150884 Approved by: https://github.com/ydwu4 --- test/export/test_db.py | 1 + test/export/test_experimental.py | 12 +++-- test/export/test_serialize.py | 66 +++++++++++++---------- test/export/test_unflatten_training_ir.py | 2 +- test/export/test_verifier.py | 22 ++++---- 5 files changed, 61 insertions(+), 42 deletions(-) diff --git a/test/export/test_db.py b/test/export/test_db.py index 7c8c1860bc5b..a035bdd23916 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -99,6 +99,7 @@ def test_exportdb_not_supported_rewrite( rewrite_case.example_args, rewrite_case.example_kwargs, dynamic_shapes=rewrite_case.dynamic_shapes, + strict=True, ) diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index f95484f0a128..bd68bb7cd772 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -60,7 +60,9 @@ def _check_equality_and_annotations(m_func, inps): ) # ExportedProgram from original module. - original_exported_module = torch.export.export_for_training(m_func(), inps) + original_exported_module = torch.export.export_for_training( + m_func(), inps, strict=True + ) # Check whether input annotations are the same as tracing the original module. orig_ph_name_list = [ @@ -116,7 +118,7 @@ def forward(self, x): m = Module() example_inputs = (torch.randn(3),) m(*example_inputs) - ep = torch.export.export_for_training(m, example_inputs) + ep = torch.export.export_for_training(m, example_inputs, strict=True) joint_ep = _export_forward_backward(ep) self.assertExpectedInline( str(joint_ep.graph_module.code).strip(), @@ -226,7 +228,7 @@ def forward(self, x): example_inputs = (torch.randn(3),) m(*example_inputs) ep = torch.export.export_for_training( - m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}} + m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True ) _export_forward_backward(ep) @@ -261,7 +263,7 @@ def forward(self, x, labels): labels = torch.ones(4, dtype=torch.int64) inputs = (x, labels) - ep = export_for_training(net, inputs) + ep = export_for_training(net, inputs, strict=True) ep = _export_forward_backward(ep) def test_joint_loss_index(self): @@ -281,7 +283,7 @@ def forward(self, x): inputs = (torch.randn(4, 4),) for i in [0, 1]: - ep = export_for_training(Foo(i), inputs) + ep = export_for_training(Foo(i), inputs, strict=True) ep_joint = _export_forward_backward(ep, joint_loss_index=i) for j, spec in enumerate(ep_joint.graph_signature.output_specs): if i == j: diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index ae9f45cbeb21..9e7a0793879d 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -100,7 +100,7 @@ def op_schema(cls, op): return torch.ops.aten.add.Tensor._schema inp = (torch.ones(10),) - ep = export_for_training(TestModule(), inp) + ep = export_for_training(TestModule(), inp, strict=True) # Register the custom op handler. foo_custom_op = FooExtensionOp() @@ -165,7 +165,9 @@ def forward(self, x, y, use_p=False): model = MyModule().eval() random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) - exp_program = export_for_training(model, random_inputs, {"use_p": True}) + exp_program = export_for_training( + model, random_inputs, {"use_p": True}, strict=True + ) output_buffer = io.BytesIO() # Tests that example inputs are preserved when saving and loading module. @@ -184,7 +186,7 @@ class M(torch.nn.Module): def forward(self, x): return x.sin() - exp_program = export_for_training(M(), (torch.randn(4, 4),)) + exp_program = export_for_training(M(), (torch.randn(4, 4),), strict=True) output_buffer = io.BytesIO() # Tests that example forward arg names are preserved when saving and loading module. @@ -224,7 +226,7 @@ def forward(self, x): inp = (torch.ones(10),) # Module will only be able to roundtrip if metadata # can be correctly parsed. - ep = export_for_training(MyModule(), inp) + ep = export_for_training(MyModule(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -288,7 +290,7 @@ def forward(self, x): # Check that module can be roundtripped, thereby confirming proper deserialization. inp = (torch.ones(10),) - ep = export_for_training(MyModule(), inp) + ep = export_for_training(MyModule(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -318,6 +320,7 @@ def forward(self, x, w, b): torch.ones([512]), torch.ones([512]), ), + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) @@ -355,7 +358,10 @@ def forward(self, a, b, c) -> torch.Tensor: "c": {0: dim0_ac, 1: dim1_bc}, } exported_module = export_for_training( - DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes + DynamicShapeSimpleModel(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) sym_size_nodes = [ @@ -416,7 +422,10 @@ def forward(self, a, b, c) -> torch.Tensor: "c": {0: dim0_ac, 1: dim1_bc}, } exported_module = export_for_training( - DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes + DynamicShapeSimpleModel(), + inputs, + dynamic_shapes=dynamic_shapes, + strict=True, ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) for v in serialized.exported_program.range_constraints.values(): @@ -442,7 +451,9 @@ def forward(self, x): return torch.split(x, 2) input = torch.arange(10.0).reshape(5, 2) - exported_module = export_for_training(MyModule(), (input,)).run_decompositions() + exported_module = export_for_training( + MyModule(), (input,), strict=True + ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -504,8 +515,7 @@ def forward(self, x): return torch.ops.aten.var_mean.correction(x, [1])[0] exported_module = export_for_training( - MyModule(), - (torch.ones([512, 512], requires_grad=True),), + MyModule(), (torch.ones([512, 512], requires_grad=True),), strict=True ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) @@ -526,7 +536,7 @@ def forward(self, x): return x + x ep = export_for_training( - M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) + M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},), strict=True ) range_constraints = list(ep.range_constraints.keys()) @@ -561,7 +571,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f = Foo() x, _ = torch.sort(torch.randn(3, 4)) - exported_module = export_for_training(f, (x,)).run_decompositions() + exported_module = export_for_training(f, (x,), strict=True).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -579,7 +589,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: b = x + y return b + a - ep = export_for_training(Module(), (torch.randn(3, 2), torch.randn(3, 2))) + ep = export_for_training( + Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True + ) s = ExportedProgramSerializer().serialize(ep) c = canonicalize(s.exported_program) g = c.graph_module.graph @@ -593,7 +605,7 @@ class M(torch.nn.Module): def forward(self, x): return torch.ops.aten.sum.dim_IntList(x, []) - ep = torch.export.export_for_training(M(), (torch.randn(3, 2),)) + ep = torch.export.export_for_training(M(), (torch.randn(3, 2),), strict=True) serialized = ExportedProgramSerializer().serialize(ep) for node in serialized.exported_program.graph_module.graph.nodes: if "aten.sum.dim_IntList" in node.target: @@ -1260,7 +1272,7 @@ def forward(self, x): a = a * 2 return a, b - ep = torch.export.export_for_training(M(), (torch.ones(3),)) + ep = torch.export.export_for_training(M(), (torch.ones(3),), strict=True) # insert another getitem node for node in ep.graph.nodes: @@ -1406,7 +1418,7 @@ def __init__(self) -> None: def forward(self): return self.p * self.p - ep = torch.export.export_for_training(M(), ()) + ep = torch.export.export_for_training(M(), (), strict=True) ep._example_inputs = None roundtrip_ep = deserialize(serialize(ep)) self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) @@ -1434,7 +1446,7 @@ def forward(self, x): return x + x f = Module() - ep = export_for_training(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),), strict=True) serialized_program = ExportedProgramSerializer().serialize(ep) serialized_program.exported_program.schema_version.major = -1 @@ -1470,7 +1482,7 @@ def forward(self, x): y = self.linear(y) return y - ep = export_for_training(Module(), inp) + ep = export_for_training(Module(), inp, strict=True) buffer = io.BytesIO() save(ep, buffer) @@ -1487,7 +1499,7 @@ def forward(self, x): f = Foo() inp = (torch.randn(2, 2),) - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) with tempfile.NamedTemporaryFile() as f: save(ep, f) @@ -1504,7 +1516,7 @@ def forward(self, x, y): f = Foo() inp = (torch.tensor([6]), torch.tensor([7])) - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) with TemporaryFileName() as fname: path = Path(fname) @@ -1522,7 +1534,7 @@ def forward(self, x): f = Foo() - ep = export_for_training(f, inp) + ep = export_for_training(f, inp, strict=True) buffer = io.BytesIO() save(ep, buffer, extra_files={"extra.txt": "moo"}) @@ -1540,7 +1552,7 @@ def forward(self, x): f = Foo() - ep = export_for_training(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),), strict=True) with self.assertRaisesRegex( RuntimeError, r"Serialized version .* does not match our current" @@ -1566,7 +1578,7 @@ def forward(self, x): list_tensor = [torch.tensor(3), torch.tensor(4)] return x + self.a + list_tensor[0] + list_tensor[1] - ep = export_for_training(Foo(), (torch.tensor(1),)) + ep = export_for_training(Foo(), (torch.tensor(1),), strict=True) buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) @@ -1592,7 +1604,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) # Replace one of the values with an instance of our custom class for node in ep.graph.nodes: @@ -1700,7 +1712,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1735,7 +1747,7 @@ def forward(self, x): f = Foo() inputs = (torch.ones(2, 2),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1771,7 +1783,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export_for_training(f, inputs) + ep = export_for_training(f, inputs, strict=True) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} diff --git a/test/export/test_unflatten_training_ir.py b/test/export/test_unflatten_training_ir.py index 684d9a149ecf..6816787eff22 100644 --- a/test/export/test_unflatten_training_ir.py +++ b/test/export/test_unflatten_training_ir.py @@ -14,7 +14,7 @@ def mocked_training_ir_export(*args, **kwargs): - return export_for_training(*args, **kwargs) + return export_for_training(*args, **kwargs, strict=True) def make_dynamic_cls(cls): diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index dd3d18db1cda..5d3cfd564637 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -20,7 +20,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export_for_training(f, (torch.randn(100), torch.randn(100))) + ep = export_for_training(f, (torch.randn(100), torch.randn(100)), strict=True) verifier = Verifier() verifier.check(ep) @@ -48,7 +48,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() ep = export_for_training( - f, (torch.randn(100), torch.randn(100)) + f, (torch.randn(100), torch.randn(100)), strict=True ).run_decompositions({}) for node in ep.graph.nodes: if node.target == torch.ops.aten.add.Tensor: @@ -72,7 +72,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3))) + ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True) verifier = Verifier() verifier.check(ep) @@ -92,7 +92,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() ep = export_for_training( - f, (torch.randn(3, 3), torch.randn(3, 3)) + f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True ).run_decompositions({}) for node in ep.graph_module.true_graph_0.graph.nodes: if node.target == torch.ops.aten.add.Tensor: @@ -111,7 +111,7 @@ def __init__(self) -> None: def forward(self, x: Tensor) -> Tensor: return self.linear(x) - ep = export_for_training(M(), (torch.randn(10, 10),)) + ep = export_for_training(M(), (torch.randn(10, 10),), strict=True) ep.validate() def test_ep_verifier_invalid_param(self) -> None: @@ -125,7 +125,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True) # Parameter doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -150,7 +150,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True) # Buffer doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -182,7 +182,9 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training( + M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True + ) ep.validate() def test_ep_verifier_invalid_output(self) -> None: @@ -205,7 +207,9 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training( + M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True + ) output_node = list(ep.graph.nodes)[-1] output_node.args = ( From f237ee54bfb35d16cd10e358d4b78578c88a5781 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 9 Apr 2025 19:29:50 +0000 Subject: [PATCH 322/332] ProcessGroupGloo: support lazy_init (#150801) This adds lazy initialization support to ProcessGroupGloo via `TORCH_GLOO_LAZY_INIT` or via `create_device(..., lazy_init=True)` This is still a draft PR as there's one race condition when doing coalesced operations that needs to be fixed upstream in Gloo first. Depends on https://github.com/facebookincubator/gloo/pull/427 landing first This also updates the gloo submodule to include the required changes. Test plan: added lazy init test variants ``` pytest -v test/distributed/test_c10d_gloo.py -k Lazy ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150801 Approved by: https://github.com/fduwjj --- docs/source/distributed.rst | 7 +++ test/distributed/test_c10d_gloo.py | 22 ++++++++-- third_party/gloo | 2 +- torch/_C/_distributed_c10d.pyi | 4 +- .../distributed/c10d/GlooDeviceFactory.cpp | 44 +++++++++++++------ .../distributed/c10d/GlooDeviceFactory.hpp | 9 ++-- .../distributed/c10d/ProcessGroupGloo.cpp | 27 +++++++----- .../distributed/c10d/ProcessGroupGloo.hpp | 23 +++++----- torch/csrc/distributed/c10d/init.cpp | 28 +++++++++--- torch/testing/_internal/common_distributed.py | 6 +-- 10 files changed, 119 insertions(+), 53 deletions(-) diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index 7092a836417f..8e8d14e17e54 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -284,6 +284,13 @@ The machine with rank 0 will be used to set up all connections. This is the default method, meaning that ``init_method`` does not have to be specified (or can be ``env://``). +Improving initialization time +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* ``TORCH_GLOO_LAZY_INIT`` - establishes connections on demand rather than + using a full mesh which can greatly improve initialization time for non all2all + operations. + Post-Initialization ------------------- diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index 9228efdedf34..57ad689179da 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -46,6 +46,7 @@ requires_gloo, simple_sparse_reduce_tests, skip_if_lt_x_gpu, + skip_if_win32, verify_ddp_error_logged, ) from torch.testing._internal.common_utils import ( @@ -219,6 +220,8 @@ def test_default_store_timeout_gloo(self): class ProcessGroupGlooTest(MultiProcessTestCase): + lazy_init = False + def _create_process_group_gloo(self, store, rank, world_size, opts): pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, opts) dist.barrier(group=pg) @@ -231,7 +234,7 @@ def setUp(self): def opts(self, threads=2): opts = c10d.ProcessGroupGloo._Options() opts._timeout = 50.0 - opts._devices = [create_device(interface=LOOPBACK)] + opts._devices = [create_device(interface=LOOPBACK, lazy_init=self.lazy_init)] opts._threads = threads return opts @@ -241,8 +244,8 @@ def test_multi_device_constructor(self): opts = c10d.ProcessGroupGloo._Options() opts._timeout = 5.0 opts._devices = [ - create_device(interface=LOOPBACK), - create_device(interface=LOOPBACK), + create_device(interface=LOOPBACK, lazy_init=self.lazy_init), + create_device(interface=LOOPBACK, lazy_init=self.lazy_init), ] pg = self._create_process_group_gloo(store, self.rank, self.world_size, opts) @@ -2334,6 +2337,19 @@ def test_forward_backward_optimizer(self): optimizer.step() +@skip_if_win32() +class ProcessGroupGlooLazyInitTest(ProcessGroupGlooTest): + lazy_init = True + + def setUp(self): + os.environ["TORCH_GLOO_LAZY_INIT"] = "1" + super().setUp() + + def tearDown(self) -> None: + del os.environ["TORCH_GLOO_LAZY_INIT"] + return super().tearDown() + + class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): @property def device(self): diff --git a/third_party/gloo b/third_party/gloo index e348db90d867..c61070427610 160000 --- a/third_party/gloo +++ b/third_party/gloo @@ -1 +1 @@ -Subproject commit e348db90d8677277e926c14c94ee2acfa77173d4 +Subproject commit c61070427610ccd923efe3e7f8b3eca12bbcc31a diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 6aaaf4b9c5f1..0487eb7c924a 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -570,9 +570,9 @@ class ProcessGroupGloo(Backend): timeout: timedelta, ) -> None: ... @staticmethod - def create_device(hostname="", interface="") -> Device: ... + def create_device(hostname="", interface="", lazy_init=None) -> Device: ... @staticmethod - def create_default_device() -> Device: ... + def create_default_device(lazy_init=None) -> Device: ... def _set_default_timeout(self, timeout) -> None: ... class _ProcessGroupWrapper(Backend): diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index af09ba39470c..32c4c4f88ac0 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -39,12 +39,14 @@ C10_DEFINE_SHARED_REGISTRY_WITHOUT_WARNING( GlooDeviceRegistry, ::gloo::transport::Device, const std::string& /* interface */, - const std::string& /* hostname */) + const std::string& /* hostname */, + bool /* lazyInit */) #if GLOO_HAVE_TRANSPORT_TCP static std::shared_ptr<::gloo::transport::Device> makeTCPDevice( const std::string& interfaceName, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeTCPDevice(): interface or hostname " @@ -56,7 +58,11 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPDevice( } else { attr.hostname = hostname; } - return ::gloo::transport::tcp::CreateDevice(attr); + if (lazyInit) { + return ::gloo::transport::tcp::CreateLazyDevice(attr); + } else { + return ::gloo::transport::tcp::CreateDevice(attr); + } } // Registry priority is per key identifier. We register TCP to `LINUX` for @@ -69,12 +75,15 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice) #if GLOO_HAVE_TRANSPORT_TCP_TLS static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( const std::string& interface, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interface.empty() || !hostname.empty(), "GlooDeviceFactory::makeTCPTLSDevice(): interface or hostname " "can't be empty"); + TORCH_CHECK(!lazyInit, "TCP_TLS transport does not support lazy init"); + ::gloo::transport::tcp::attr attr; if (!interface.empty()) { attr.iface = interface; @@ -105,12 +114,15 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP_TLS, makeTCPTLSDevice) #if GLOO_HAVE_TRANSPORT_UV static std::shared_ptr<::gloo::transport::Device> makeUVDevice( const std::string& interfaceName, - const std::string& hostname) { + const std::string& hostname, + bool lazyInit) { TORCH_CHECK( !interfaceName.empty() || !hostname.empty(), "GlooDeviceFactory::makeUVDevice(): interface or hostname " "can't be empty"); + TORCH_CHECK(!lazyInit, "UV transport does not support lazy init"); + ::gloo::transport::uv::attr attr; if (!interfaceName.empty()) { attr.iface = interfaceName; @@ -131,23 +143,27 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, UV, makeUVDevice) namespace { std::shared_ptr<::gloo::transport::Device> makeGlooDevice( const std::string& interfaceName, - const std::string& hostName) { + const std::string& hostName, + bool lazyInit) { static auto transportName = c10::utils::get_env("GLOO_DEVICE_TRANSPORT"); if (transportName.has_value()) { return GlooDeviceRegistry()->Create( - transportName.value().c_str(), interfaceName, hostName); + transportName.value().c_str(), interfaceName, hostName, lazyInit); } #ifdef __linux__ - return GlooDeviceRegistry()->Create("LINUX", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "LINUX", interfaceName, hostName, lazyInit); #endif #ifdef __APPLE__ - return GlooDeviceRegistry()->Create("APPLE", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "APPLE", interfaceName, hostName, lazyInit); #endif #ifdef _WIN32 - return GlooDeviceRegistry()->Create("WIN32", interfaceName, hostName); + return GlooDeviceRegistry()->Create( + "WIN32", interfaceName, hostName, lazyInit); #endif return nullptr; @@ -155,8 +171,8 @@ std::shared_ptr<::gloo::transport::Device> makeGlooDevice( } // anonymous namespace std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: - makeDeviceForInterface(const std::string& interfaceName) { - auto device = makeGlooDevice(interfaceName, ""); + makeDeviceForInterface(const std::string& interfaceName, bool lazyInit) { + auto device = makeGlooDevice(interfaceName, "", lazyInit); if (!device) { TORCH_CHECK(false, "makeDeviceForInterface(): unsupported gloo device"); } @@ -164,8 +180,8 @@ std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: } std::shared_ptr<::gloo::transport::Device> GlooDeviceFactory:: - makeDeviceForHostname(const std::string& hostname) { - auto device = makeGlooDevice("", hostname); + makeDeviceForHostname(const std::string& hostname, bool lazyInit) { + auto device = makeGlooDevice("", hostname, lazyInit); if (!device) { TORCH_CHECK(false, "makeDeviceForHostname(): unsupported gloo device"); } diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp index 1221e9d033f2..a7220f0d81c7 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.hpp @@ -14,18 +14,21 @@ class TORCH_API GlooDeviceFactory { public: // Create new device instance for specific interface. static std::shared_ptr<::gloo::transport::Device> makeDeviceForInterface( - const std::string& interface); + const std::string& interface, + bool lazyInit); // Create new device instance for specific hostname or address. static std::shared_ptr<::gloo::transport::Device> makeDeviceForHostname( - const std::string& hostname); + const std::string& hostname, + bool lazyInit); }; TORCH_DECLARE_SHARED_REGISTRY( GlooDeviceRegistry, ::gloo::transport::Device, const std::string&, /* interface */ - const std::string& /* hostname */); + const std::string&, /* hostname */ + bool /* lazyInit */); } // namespace c10d diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 3c5644eeab68..077bf311284f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -415,6 +415,10 @@ const auto kLoopbackAddress = "127.0.0.1"; } // namespace +bool getDefaultGlooLazyInit() { + return ::c10d::getCvarBool(TORCH_GLOO_LAZY_INIT, false); +} + // static void ProcessGroupGloo::AsyncWork::execute( const c10::intrusive_ptr& work) { @@ -687,23 +691,24 @@ bool doesHostnameResolveToUsableAddress(const std::string& hostname) { } // namespace std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDeviceForInterface(const std::string& interface_name) { - return ::c10d::GlooDeviceFactory::makeDeviceForInterface(interface_name); + createDeviceForInterface(const std::string& interface_name, bool lazyInit) { + return ::c10d::GlooDeviceFactory::makeDeviceForInterface( + interface_name, lazyInit); } std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDeviceForHostname(const std::string& hostname) { + createDeviceForHostname(const std::string& hostname, bool lazyInit) { TORCH_CHECK( doesHostnameResolveToUsableAddress(hostname), "Cannot resolve ", hostname, " to a (local) address"); - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname, lazyInit); } #if defined(__linux__) || defined(_WIN32) std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDefaultDevice() { + createDefaultDevice(bool lazyInit) { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. @@ -716,7 +721,8 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.data())) { - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.data()); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname( + hostname.data(), lazyInit); } // Otherwise, use the loopback address. @@ -724,13 +730,13 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); - return createDeviceForHostname(kLoopbackAddress); + return createDeviceForHostname(kLoopbackAddress, lazyInit); } #endif #ifdef __APPLE__ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: - createDefaultDevice() { + createDefaultDevice(bool lazyInit) { // Use the hostname to resolve the network address to // use. Note: if the hostname does not resolve to an address (e.g. // because of misconfigured /etc/hosts file), this will not work. @@ -743,7 +749,8 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: // Use this machine's hostname if it resolves to an address. if (doesHostnameResolveToUsableAddress(hostname.get())) { - return ::c10d::GlooDeviceFactory::makeDeviceForHostname(hostname.get()); + return ::c10d::GlooDeviceFactory::makeDeviceForHostname( + hostname.get(), lazyInit); } // Otherwise, use the loopback address. @@ -751,7 +758,7 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: "Unable to resolve hostname to a (local) address. ", "Using the loopback address as fallback. ", "Manually set the network interface to bind to with GLOO_SOCKET_IFNAME."); - return createDeviceForHostname(kLoopbackAddress); + return createDeviceForHostname(kLoopbackAddress, lazyInit); } #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 059ba8a4ee3f..917544d9e113 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -28,6 +28,13 @@ namespace c10d { constexpr const char* GLOO_BACKEND_NAME = "gloo"; +// Control whether or not connections are established in a full mesh or lazily +// as needed. +static std::vector TORCH_GLOO_LAZY_INIT = {"TORCH_GLOO_LAZY_INIT"}; + +// Returns default value for lazyInit. +bool TORCH_API getDefaultGlooLazyInit(); + // ProcessGroupGloo implements Gloo bindings for c10d. // // All functions on this class are expected to be called in the same @@ -244,24 +251,20 @@ class TORCH_API ProcessGroupGloo : public Backend { // Create new device instance for specific interface. static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( - const std::string& interface); + const std::string& interface, + bool lazyInit = false); // Create new device instance for specific hostname or address. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( - const std::string& hostname); + const std::string& hostname, + bool lazyInit = false); // Create new device instance. // It tries to resolve this machine's hostname and bind to that address. // If that fails (i.e. the hostname doesn't resolve to an address), it // falls back to binding to the loopback address. - static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); - - // Create ProcessGroupGloo instance. - static c10::intrusive_ptr createProcessGroupGloo( - const c10::intrusive_ptr& store, - int rank, - int size, - std::chrono::milliseconds timeout); + static std::shared_ptr<::gloo::transport::Device> createDefaultDevice( + bool lazyInit = false); explicit ProcessGroupGloo( const c10::intrusive_ptr& store, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 0217d2471dc8..f1bd5fb14cf1 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2849,24 +2849,36 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). processGroupGloo .def_static( "create_device", - [](const std::string& hostname, const std::string& interface) + [](const std::string& hostname, + const std::string& interface, + std::optional lazyInit_) -> std::shared_ptr<::gloo::transport::Device> { + bool lazyInit = + lazyInit_.value_or(::c10d::getDefaultGlooLazyInit()); + if (!hostname.empty()) { return ::c10d::ProcessGroupGloo::createDeviceForHostname( - hostname); + hostname, lazyInit); } if (!interface.empty()) { return ::c10d::ProcessGroupGloo::createDeviceForInterface( - interface); + interface, lazyInit); } throw std::invalid_argument( "Specify either `hostname` or `interface` argument."); }, py::arg("hostname") = "", - py::arg("interface") = "") + py::arg("interface") = "", + py::arg("lazy_init") = std::nullopt) .def_static( "create_default_device", - &::c10d::ProcessGroupGloo::createDefaultDevice); + [](std::optional lazyInit_) { + bool lazyInit = + lazyInit_.value_or(::c10d::getDefaultGlooLazyInit()); + + return ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit); + }, + py::arg("lazy_init") = std::nullopt); processGroupGloo .def( @@ -2898,20 +2910,22 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). py::gil_scoped_release nogil{}; auto options = ::c10d::ProcessGroupGloo::Options::create(); + bool lazyInit = ::c10d::getDefaultGlooLazyInit(); // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. char* ifnameEnv = getenv(GLOO_SOCKET_IFNAME_ENV.c_str()); if (ifnameEnv && strlen(ifnameEnv) > 1) { for (const auto& iface : ::c10d::split(',', ifnameEnv)) { options->devices.push_back( - ::c10d::ProcessGroupGloo::createDeviceForInterface(iface)); + ::c10d::ProcessGroupGloo::createDeviceForInterface( + iface, lazyInit)); } } else { // If no hostname is specified, this function looks up // the machine's hostname and returns a device instance // associated with the address that the hostname resolves to. options->devices.push_back( - ::c10d::ProcessGroupGloo::createDefaultDevice()); + ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit)); } options->timeout = timeout; diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 2a8fc04265c4..6a3e654d9e71 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -442,11 +442,11 @@ def create_tcp_store( TIMEOUT_OVERRIDE["test_join_kwargs"] = 200 -def create_device(interface=None): +def create_device(interface=None, lazy_init: bool = False): if sys.platform == "win32" or interface is None: - return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1") + return c10d.ProcessGroupGloo.create_device(hostname="127.0.0.1", lazy_init=lazy_init) else: - return c10d.ProcessGroupGloo.create_device(interface=interface) + return c10d.ProcessGroupGloo.create_device(interface=interface, lazy_init=lazy_init) def get_timeout(test_id) -> int: From ea0cbba1fceb969c311759fcab429548d5ba933e Mon Sep 17 00:00:00 2001 From: angelayi Date: Wed, 9 Apr 2025 19:44:26 +0000 Subject: [PATCH 323/332] [export] Refine draft-export CVE with Dim.AUTO (#150876) Instead of using refine_dynamic_shapes_from_suggested_fixes to fix ConstraintViolationErrors in draft-export, we can just convert the dims to Dim.AUTO, which is less error prone Pull Request resolved: https://github.com/pytorch/pytorch/pull/150876 Approved by: https://github.com/pianpwk --- test/export/test_draft_export.py | 4 +--- torch/export/_draft_export.py | 33 +++++++++++++++++++++----------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 6fda3fcdb0ad..1f23cb5cee4b 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -303,9 +303,7 @@ def forward(self, a): report = ep._report self.assertEqual(len(report.failures), 1) - self.assertEqual( - report.failures[0].failure_type, FailureType.CONSTRAINT_VIOLATION_ERROR - ) + self.assertEqual(report.failures[0].failure_type, FailureType.GUARD_ADDED) inp = (torch.randn(3, 3),) self.assertEqual(ep.module()(*inp), M()(*inp)) diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 604f865a2b08..103a4abb0540 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -11,10 +11,11 @@ import torch import torch._logging._internal import torch._logging.structured +import torch.utils._pytree as pytree from torch._export.passes.insert_custom_op_guards import insert_custom_op_guards from torch.export import ExportedProgram from torch.export._trace import _export -from torch.export.dynamic_shapes import refine_dynamic_shapes_from_suggested_fixes +from torch.export.dynamic_shapes import _DimHint, _DimHintType, Dim log = logging.getLogger(__name__) @@ -23,7 +24,7 @@ class FailureType(IntEnum): MISSING_FAKE_KERNEL = 1 DATA_DEPENDENT_ERROR = 2 - CONSTRAINT_VIOLATION_ERROR = 3 + GUARD_ADDED = 3 MISMATCHED_FAKE_KERNEL = 4 def __str__(self) -> str: @@ -94,17 +95,19 @@ def print(self, str_to_filename: dict[int, str]) -> str: Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation. """ # noqa: B950 - elif self.failure_type == FailureType.CONSTRAINT_VIOLATION_ERROR: + elif self.failure_type == FailureType.GUARD_ADDED: locals_info = ( prettify_frame_locals(**self.data["frame_locals"]) if self.data["frame_locals"] else "" ) - return f"""Constraint violation error. - The specified input dynamic_shapes spec was found to be incorrect during tracing. + return f"""Guard Added. + A guard was added during tracing, which might've resulted in some incorrect + tracing or constraint violation error. Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}. - This occurred at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}: + This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}: {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n Because of this, we have modified the dynamic shapes structure to be the following. You can also use torch.export.Dim.AUTO instead to specify your dynamic shapes, and we will automatically infer the dynamism for you. @@ -216,6 +219,8 @@ def _hash(self, element: tuple[str, dict[str, Any]]) -> int: return hash((key, data["op"], data["reason"])) elif key == "propagate_real_tensors_provenance": return hash((key, json.dumps(data["user_stack"]))) + elif key == "guard_added": + return hash((key, json.dumps(data["user_stack"]))) elif key == "create_unbacked_symbol": return hash((key, json.dumps(data["user_stack"]))) @@ -377,10 +382,16 @@ def draft_export( pre_dispatch=pre_dispatch, preserve_module_call_signature=preserve_module_call_signature, ) - except torch._dynamo.exc.UserError as exc: - new_shapes = refine_dynamic_shapes_from_suggested_fixes( - exc.msg, dynamic_shapes - ) + except torch._dynamo.exc.UserError: + + def convert_dim_to_auto(dim: Any) -> Any: + if isinstance(dim, Dim): + return Dim.AUTO(min=dim.min, max=dim.max) + elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC: + return Dim.AUTO(min=dim.min, max=dim.max) + return dim + + new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes) ep = _export( mod, args, @@ -420,7 +431,7 @@ def draft_export( if new_shapes is None: continue - failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR + failure_type = FailureType.GUARD_ADDED log_contents["new_dynamic_shapes"] = new_shapes elif log_name == "missing_fake_kernel": failure_type = FailureType.MISSING_FAKE_KERNEL From 2b9d8a56333ea1c0f05247bcef6105eefba2ea62 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 9 Apr 2025 20:15:34 +0000 Subject: [PATCH 324/332] Fix `-Wmissing-braces` in a few files (#150802) Test Plan: Sandcastle Reviewed By: wenxin0319 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150802 Approved by: https://github.com/Skylion007 --- aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index e74d71fe1aff..75d4e8c75c9b 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -245,7 +245,7 @@ Tensor two_four_sgemm( ElementC(0), {cute::_1{}, cute::_0{}, problem_size.m()}}; } else { - return {ElementC(0)}; + return {{ElementC(0)}}; } }() }; From 860765d621e14730f8b6e7344da0053c4f00d540 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 9 Apr 2025 09:09:43 -0700 Subject: [PATCH 325/332] update benchamark result due to <1% regression (#150937) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Screenshot 2025-04-09 at 9 07 13 AM PR https://hud.pytorch.org/pr/148104 which is acceptable but we have to update this to avoid flakiness in the future . Pull Request resolved: https://github.com/pytorch/pytorch/pull/150937 Approved by: https://github.com/zou3519 --- .../pr_time_benchmarks/expected_results.csv | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index af033eacff97..46c979979fdf 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -6,7 +6,7 @@ add_loop_eager_dynamic,compile_time_instruction_count,5633000000,0.025 -add_loop_inductor,compile_time_instruction_count,28810000000,0.015 +add_loop_inductor,compile_time_instruction_count,28950000000,0.015 @@ -14,7 +14,7 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42490000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,25120000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,25350000000,0.015 @@ -22,7 +22,7 @@ basic_modules_ListOfLinears_eager,compile_time_instruction_count,963100000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,17990000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18110000000,0.015 @@ -46,32 +46,32 @@ sum_floordiv_regression,compile_time_instruction_count,985300000,0.015 -symint_sum,compile_time_instruction_count,3189000000,0.015 +symint_sum,compile_time_instruction_count,3214000000,0.015 -symint_sum_loop,compile_time_instruction_count,4180000000,0.015 +symint_sum_loop,compile_time_instruction_count,4204000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2042000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2057000000,0.015 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5884000000,0.015 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5917000000,0.015 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,8501000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8561000000,0.015 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1856000000,0.015 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1876000000,0.015 -aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3751000000,0.015 +aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3779000000,0.015 -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10200000000,0.015 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10260000000,0.015 From d751698a362de177fcc569348a7bf81456ab727d Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Tue, 8 Apr 2025 17:30:09 +0000 Subject: [PATCH 326/332] Support negative values for fill with uint tensors (#144458) Fixes https://github.com/pytorch/pytorch/issues/144188 Pull Request resolved: https://github.com/pytorch/pytorch/pull/144458 Approved by: https://github.com/amjames, https://github.com/eellison --- test/inductor/test_torchinductor_opinfo.py | 2 ++ torch/_inductor/codegen/triton.py | 6 +++++- torch/testing/_internal/common_methods_invocations.py | 10 +++++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 765927047700..ac552b312dea 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -977,6 +977,8 @@ def test_comprehensive(self, device, dtype, op): "nn.functional.interpolate.bicubic", "nn.functional.upsample_bilinear", "nn.functional.upsample_nearest", + "fill", + "full_like", ): if dtype not in allowed_dtypes: raise unittest.SkipTest("Skipped!") diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 5aaab1ed47ed..b125efd6bdbf 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -933,7 +933,11 @@ def _shaped_constant(value, dtype, shape): # NOTE: We use a tensor here in order to get the expected type. # Otherwise, e.g. float64 constants would be trunctated to float32. - return f"tl.full({shape}, {triton_val}, {triton_type})" + if value < 0 and not dtype.is_signed: + triton_signed_type = f"tl.{triton_type[4:]}" + return f"tl.full({shape}, {triton_val}, {triton_signed_type}).to({triton_type})" + else: + return f"tl.full({shape}, {triton_val}, {triton_type})" @classmethod def constant(cls, value, dtype): diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index a715b2bbd28d..8e32eaa861aa 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1933,6 +1933,10 @@ def get_val(dtype): if torch.cuda.is_available(): inputs.append(((S,), get_val(dtype), {'device': 'cuda'})) + if not dtype.is_signed: + # For unsigned dtypes, negative values are converted. + inputs.append(((S,), -get_val(dtype), {})) + for shape, fill_value, kwargs in inputs: t = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, @@ -18911,12 +18915,12 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), )), OpInfo('full_like', - dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), + dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, + torch.uint16, torch.uint32), supports_out=False, sample_inputs_func=sample_inputs_full_like, supports_autograd=False, - skips=( - )), + ), OpInfo('new_zeros', op=lambda x, *args, **kwargs: x.new_zeros(*args, **kwargs), dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf), From 357814c85c00a2b5b3fb9add97735e4789caa7e0 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 9 Apr 2025 07:41:33 -0700 Subject: [PATCH 327/332] [AOTI] Remove typedef for half and bfloat16 (#150657) Summary: typedef is prone to name collision. Explicitly spell out the actual aten types, needed for the libtorch-free codegen. Pull Request resolved: https://github.com/pytorch/pytorch/pull/150657 Approved by: https://github.com/malfet --- torch/_inductor/codegen/cpp.py | 4 +++- torch/_inductor/codegen/cpp_micro_gemm.py | 4 ++-- torch/_inductor/codegen/cpp_prefix.h | 8 +------- torch/_inductor/codegen/cpp_utils.py | 18 +++++++++--------- torch/_inductor/runtime/hints.py | 6 +++--- torch/csrc/inductor/aoti_include/common.h | 2 -- torch/csrc/inductor/aoti_torch/c/shim.h | 5 ----- 7 files changed, 18 insertions(+), 29 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 23134b7916a7..b05e20ada3ee 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -2990,7 +2990,9 @@ def store_reduction(self, name, index, value): else: # Vertical reduction if out_dtype != dtype: - converted_value = f"{DTYPE_TO_CPP[out_dtype]}_{value}" + converted_value = ( + f"{DTYPE_TO_CPP[out_dtype].replace('::', '_')}_{value}" + ) if out_dtype == torch.bool: convert = f"{value}.template cast()" else: diff --git a/torch/_inductor/codegen/cpp_micro_gemm.py b/torch/_inductor/codegen/cpp_micro_gemm.py index 67a1b08cb5c4..77cf270ad894 100644 --- a/torch/_inductor/codegen/cpp_micro_gemm.py +++ b/torch/_inductor/codegen/cpp_micro_gemm.py @@ -1413,7 +1413,7 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): int64_t ldb, int64_t ldc, int64_t q_group_size, - const bfloat16* {{restrict_keyword}} ScaleAndZeros, + const at:BFloat16* {{restrict_keyword}} ScaleAndZeros, int64_t lds, // leading dimension of ScaleAndZeros int64_t k_start) { constexpr int BLOCK_K = {{block_k}}; @@ -1551,7 +1551,7 @@ class CppMicroGemmWoQInt4Avx512(CppMicroGemmFP32Vec): def get_kernel_extra_args_declare(self) -> str: return ( "const int64_t q_group_size,\n" - " const bfloat16* __restrict__ ScaleAndZeros,\n" + " const at:BFloat16* __restrict__ ScaleAndZeros,\n" " const int64_t lds,\n" " int64_t k_start," ) diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index 9d9b19b79da9..415c979c0eda 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -43,12 +43,6 @@ #include #endif -typedef at::Half half; -typedef at::BFloat16 bfloat16; - -typedef at::Float8_e4m3fn float8_e4m3fn; -typedef at::Float8_e5m2 float8_e5m2; - template struct Welford { T mean = T(0); @@ -635,7 +629,7 @@ inline int64_t randint64_cpu(uint32_t seed, uint32_t offset, int64_t low, int64_ template struct AsIntegerType { typedef T type; }; template <> struct AsIntegerType { typedef uint32_t type; }; template <> struct AsIntegerType { typedef uint64_t type; }; -template <> struct AsIntegerType { typedef uint16_t type; }; +template <> struct AsIntegerType { typedef uint16_t type; }; template typename std::enable_if_t, T> diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 8362d052f773..8707ee4d9bb2 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -30,7 +30,7 @@ DTYPE_TO_CPP = { torch.float32: "float", torch.float64: "double", - torch.float16: "half", + torch.float16: "at::Half", torch.int64: "int64_t", torch.int32: "int32_t", torch.int16: "int16_t", @@ -40,14 +40,14 @@ torch.uint16: "uint16_t", torch.uint8: "uint8_t", torch.bool: "bool", - torch.bfloat16: "bfloat16", - torch.complex32: "c10::complex", - torch.complex64: "c10::complex", - torch.complex128: "c10::complex", - torch.float8_e4m3fn: "float8_e4m3fn", - torch.float8_e5m2: "float8_e5m2", - torch.float8_e4m3fnuz: "float8_e4m3fnuz", - torch.float8_e5m2fnuz: "float8_e5m2fnuz", + torch.bfloat16: "at::BFloat16", + torch.complex32: "at::complex", + torch.complex64: "at::complex", + torch.complex128: "at::complex", + torch.float8_e4m3fn: "at::Float8_e4m3fn", + torch.float8_e5m2: "at::Float8_e5m2", + torch.float8_e4m3fnuz: "at::Float8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::Float8_e5m2fnuz", } DTYPE_TO_ATEN = { diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 3bc8df35a838..f224217db22b 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -181,14 +181,14 @@ class HalideInputSpec(typing.NamedTuple): alias_of: Optional[str] = None def bindings_type(self) -> str: - if self.ctype in ("half*", "bfloat16*"): + if self.ctype in ("at::Half*", "at::BFloat16*"): return "uint16_t*" # half not defined return self.ctype def halide_type(self) -> str: - if self.ctype == "half*": + if self.ctype == "at::Half*": return "halide_type_t(halide_type_float, 16)" # half not defined - if self.ctype == "bfloat16*": + if self.ctype == "at::BFloat16*": return "halide_type_t(halide_type_bfloat, 16)" # half not defined return f"halide_type_of<{self.ctype.replace('*', '')}>()" diff --git a/torch/csrc/inductor/aoti_include/common.h b/torch/csrc/inductor/aoti_include/common.h index e942e48823fa..e0e61ac0615d 100644 --- a/torch/csrc/inductor/aoti_include/common.h +++ b/torch/csrc/inductor/aoti_include/common.h @@ -9,8 +9,6 @@ #include #include -using half = at::Half; -using bfloat16 = at::BFloat16; // Round up to the nearest multiple of 64 [[maybe_unused]] inline int64_t align(int64_t nbytes) { diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index f56f6eca7449..be187c0118a2 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -783,11 +783,6 @@ int32_t aoti_torch_dtype() = delete; return aoti_torch_dtype_##typename(); \ } -namespace c10 { -struct BFloat16; -struct Half; -} // namespace c10 - DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) From 31fe258efc0573aa01f8215db8b6970e71eba6b8 Mon Sep 17 00:00:00 2001 From: Tom Ritchford Date: Wed, 9 Apr 2025 17:29:55 +0000 Subject: [PATCH 328/332] [inductor] Add features to docstring_linter (see #142496) (#145834) ## Improvements to `docstring_linter` * Add a "grandfather list" of existing undocumented classes and functions (`--grandfather`, `--grandfather-tolerance`, `--no-grandfather`, `--write-grandfather`) * In classes, now just one of the class itself or its `__init__()` method needs to be documented (`--lint-init` turns the old behavior back on) * Now classes and functions defined local to other functions do not need to be documented (`--lint-local` turns the old behavior back on) * New `--report` flag produces a compact report of long, undocumented classes or function definitions: see attached example run over all pytorch: [pytorch-docs.json](https://github.com/user-attachments/files/18455981/pytorch-docs.json) ## Help text ``` $ python tools/linter/adapters/docstring_linter.py --help usage: docstring_linter.py [-h] [-l] [-v] [--grandfather GRANDFATHER] [--grandfather-tolerance GRANDFATHER_TOLERANCE] [--lint-init] [--lint-local] [--lint-protected] [--max-class MAX_CLASS] [--max-def MAX_DEF] [--min-docstring MIN_DOCSTRING] [--no-grandfather] [--report] [--write-grandfather] [files ...] `docstring_linter` reports on long functions, methods or classes without docstrings positional arguments: files A list of files or directories to lint optional arguments: -h, --help show this help message and exit -l, --lintrunner Run for lintrunner and print LintMessages which aren't edits -v, --verbose Print more debug info --grandfather GRANDFATHER, -g GRANDFATHER Set the grandfather list --grandfather-tolerance GRANDFATHER_TOLERANCE, -t GRANDFATHER_TOLERANCE Tolerance for grandfather sizes, in percent --lint-init, -i Lint __init__ and class separately --lint-local, -o Lint definitions inside other functions --lint-protected, -p Lint functions, methods and classes that start with _ --max-class MAX_CLASS, -c MAX_CLASS Maximum number of lines for an undocumented class --max-def MAX_DEF, -d MAX_DEF Maximum number of lines for an undocumented function --min-docstring MIN_DOCSTRING, -s MIN_DOCSTRING Minimum number of characters for a docstring --no-grandfather, -n Disable the grandfather list --report, -r Print a report on all classes and defs --write-grandfather, -w Rewrite the grandfather list ``` --- Pull Request resolved: https://github.com/pytorch/pytorch/pull/145834 Approved by: https://github.com/amjames, https://github.com/eellison --- tools/linter/adapters/docstring_linter.py | 583 ++++++++++++++---- .../block_names.py.txt | 44 ++ .../python_code.py.txt.json | 33 - .../python_code.py.txt.lintrunner | 33 - .../python_code.py.txt.report.json | 325 ++++++++++ .../python_code.py.txt.single.line.json | 25 + .../python_code.py.txt.terse.json | 140 +++++ .../python_code.py.txt.terse.line.json | 140 +++++ tools/test/test_docstring_linter.py | 70 ++- 9 files changed, 1202 insertions(+), 191 deletions(-) create mode 100644 tools/test/docstring_linter_testdata/block_names.py.txt create mode 100644 tools/test/docstring_linter_testdata/python_code.py.txt.report.json create mode 100644 tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json create mode 100644 tools/test/docstring_linter_testdata/python_code.py.txt.terse.json create mode 100644 tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json diff --git a/tools/linter/adapters/docstring_linter.py b/tools/linter/adapters/docstring_linter.py index cb9b4ebe9881..cd67243d1ac9 100644 --- a/tools/linter/adapters/docstring_linter.py +++ b/tools/linter/adapters/docstring_linter.py @@ -1,16 +1,21 @@ from __future__ import annotations +import dataclasses as dc +import itertools +import json import sys import token -from functools import cached_property +from enum import Enum +from functools import cached_property, total_ordering from pathlib import Path -from typing import TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import Self -_PARENT = Path(__file__).parent.absolute() +_FILE = Path(__file__).absolute() _PATH = [Path(p).absolute() for p in sys.path] -if TYPE_CHECKING or _PARENT not in _PATH: +if TYPE_CHECKING or _FILE.parent not in _PATH: from . import _linter else: import _linter @@ -20,149 +25,485 @@ from tokenize import TokenInfo +GRANDFATHER_LIST = Path(str(_FILE).replace(".py", "-grandfather.json")) + +# We tolerate a 10% increase in block size before demanding a docstring +TOLERANCE_PERCENT = 10 + MAX_LINES = {"class": 100, "def": 80} -MIN_DOCSTRING = 16 # docstrings shorter than this are ignored -IGNORE_PROTECTED = True # If True, ignore classes and files whose names start with _. +MIN_DOCSTRING = 50 # docstrings shorter than this are too short ERROR_FMT = "Every {type} with more than {length} lines needs a docstring" DESCRIPTION = """`docstring_linter` reports on long functions, methods or classes without docstrings""" -# How many top violations to report? -REPORT_TOP_RESULTS = 3 +@total_ordering +@dc.dataclass +class Block: + """A block of Python code starting with either `def` or `class`""" + + class Category(str, Enum): + CLASS = "class" + DEF = "def" + + category: Category + + # The sequence of tokens that contains this Block. + # Tokens are represented in `Block` as indexes into `self.tokens` + tokens: Sequence[TokenInfo] = dc.field(repr=False) + + # The name of the function or class being defined + name: str + + # The index of the very first token in the block (the "class" or "def" keyword) + begin: int + + # The index of the first INDENT token for this block + indent: int + + # The index of the DEDENT token for this end of this block + dedent: int + + # The docstring for the block + docstring: str + + # These next members only get filled in after all blocks have been constructed + # and figure out family ties + + # The full qualified name of the block within the file. + # This is the name of this block and all its parents, joined with `.`. + full_name: str = "" + + # The index of this block within the full list of blocks in the file + index: int = 0 + + # Is this block contained within a function definition? + is_local: bool = dc.field(default=False, repr=False) + + # Is this block a function definition in a class definition? + is_method: bool = dc.field(default=False, repr=False) + + # A block index to the parent of this block, or None for a top-level block. + parent: int | None = None + + # A list of block indexes for the children + children: list[int] = dc.field(default_factory=list) + + @property + def start_line(self) -> int: + return self.tokens[max(self.indent, self.index)].start[0] + + @property + def end_line(self) -> int: + return self.tokens[max(self.dedent, self.index)].start[0] + + @property + def line_count(self) -> int: + return self.end_line - self.start_line + + @property + def is_class(self) -> bool: + return self.category == Block.Category.CLASS + + @property + def display_name(self) -> str: + """A user-friendly name like 'class One' or 'def One.method()'""" + ending = "" if self.is_class else "()" + return f"{self.category.value} {self.full_name}{ending}" + + DATA_FIELDS = ( + "category", + "children", + "display_name", + "docstring", + "full_name", + "index", + "is_local", + "is_method", + "line_count", + "parent", + "start_line", + ) + + def as_data(self) -> dict[str, Any]: + d = {i: getattr(self, i) for i in self.DATA_FIELDS} + d["category"] = d["category"].value + return d + + @property + def is_init(self) -> bool: + return not self.is_class and self.name == "__init__" + + def contains(self, b: Block) -> bool: + return self.start_line < b.start_line and self.end_line >= b.end_line + + def __eq__(self, o: object) -> bool: + assert isinstance(o, Block) + return o.tokens is self.tokens and o.index == self.index + + def __hash__(self) -> int: + return super().__hash__() + + def __lt__(self, o: Self) -> bool: + assert isinstance(o, Block) and o.tokens is self.tokens + return o.index < self.index + + +class DocstringFile(_linter.PythonFile): + def __getitem__(self, i: int | slice) -> TokenInfo | Sequence[TokenInfo]: + return self.tokens[i] + + def next_token(self, start: int, token_type: int, error: str) -> int: + for i in range(start, len(self.tokens)): + if self.tokens[i].type == token_type: + return i + raise _linter.ParseError(self.tokens[-1], error) + + def docstring(self, start: int) -> str: + for i in range(start + 1, len(self.tokens)): + tk = self.tokens[i] + if tk.type == token.STRING: + return tk.string + if tk.type not in _linter.EMPTY_TOKENS: + return "" + return "" -def _is_def(t: TokenInfo) -> bool: - return t.type == token.NAME and t.string in ("class", "def") + @cached_property + def indent_to_dedent(self) -> dict[int, int]: + dedents = dict[int, int]() + stack = list[int]() + + for i, t in enumerate(self.tokens): + if t.type == token.INDENT: + stack.append(i) + elif t.type == token.DEDENT: + dedents[stack.pop()] = i + + return dedents + + @cached_property + def errors(self) -> dict[str, str]: + return {} + @cached_property + def blocks(self) -> list[Block]: + blocks: list[Block] = [] + + for i in range(len(self.tokens)): + try: + if (b := self.block(i)) is not None: + blocks.append(b) + except _linter.ParseError as e: + self.errors[e.token.line] = " ".join(e.args) + + for i, parent in enumerate(blocks): + for j in range(i + 1, len(blocks)): + if parent.contains(child := blocks[j]): + child.parent = i + parent.children.append(j) + else: + break -class DocstringLinter(_linter.FileLinter[_linter.PythonFile]): + for i, b in enumerate(blocks): + b.index = i + + parents = [b] + while (p := parents[-1].parent) is not None: + parents.append(blocks[p]) + parents = parents[1:] + + b.is_local = not all(p.is_class for p in parents) + b.is_method = not b.is_class and bool(parents) and parents[0].is_class + + def add_full_names(children: Sequence[Block], prefix: str = "") -> None: + dupes: dict[str, list[Block]] = {} + for b in children: + dupes.setdefault(b.name, []).append(b) + + for dl in dupes.values(): + for i, b in enumerate(dl): + suffix = f"[{i + 1}]" if len(dl) > 1 else "" + b.full_name = prefix + b.name + suffix + + for b in children: + if kids := [blocks[i] for i in b.children]: + add_full_names(kids, b.full_name + ".") + + add_full_names([b for b in blocks if b.parent is None]) + return blocks + + def block(self, begin: int) -> Block | None: + t = self.tokens[begin] + if not (t.type == token.NAME and t.string in ("class", "def")): + return None + + category = Block.Category[t.string.upper()] + try: + ni = self.next_token(begin + 1, token.NAME, "Definition but no name") + name = self.tokens[ni].string + indent = self.next_token(ni + 1, token.INDENT, "Definition but no indent") + dedent = self.indent_to_dedent[indent] + docstring = self.docstring(indent) + except _linter.ParseError: + name = "(ParseError)" + indent = -1 + dedent = -1 + docstring = "" + + return Block( + begin=begin, + category=category, + dedent=dedent, + docstring=docstring, + indent=indent, + name=name, + tokens=self.tokens, + ) + + +class DocstringLinter(_linter.FileLinter[DocstringFile]): linter_name = "docstring_linter" description = DESCRIPTION is_fixer = False - results: dict[str, list[tuple[int, Path, str]]] - def __init__(self, argv: list[str] | None = None) -> None: + path_to_blocks: dict[str, list[dict[str, Any]]] + path_to_errors: dict[str, list[dict[str, Any]]] + + def __init__(self, argv: Sequence[str] | None = None) -> None: super().__init__(argv) - self.results = {} + add_arguments(self.parser.add_argument) + self.path_to_blocks = {} + self.path_to_errors = {} - help = "Maximum number of lines for an undocumented class" - self.parser.add_argument( - "--max-class", "-c", default=MAX_LINES["class"], type=int, help=help - ) + def lint_all(self) -> bool: + success = super().lint_all() + self._report() + self._write_grandfather() + return success - help = "Maximum number of lines for an undocumented function" - self.parser.add_argument( - "--max-def", "-d", default=MAX_LINES["def"], type=int, help=help - ) + def _lint(self, df: DocstringFile) -> Iterator[_linter.LintResult]: + if (p := str(df.path)) in self.path_to_blocks: + print("Repeated file", p, file=sys.stderr) + return - help = "Minimum number of characters for a docstring" - self.parser.add_argument( - "--min-docstring", "-m", default=MIN_DOCSTRING, type=int, help=help - ) + blocks = df.blocks + bad = {b for b in blocks if self._is_bad_block(b, df)} + bad = self._dont_require_constructor_and_class_docs(blocks, bad) + gf = self._grandfathered(df.path, bad) - help = "Lint functions, methods and classes that start with _" - self.parser.add_argument( - "--lint-protected", "-p", action="store_true", help=help - ) + yield from (self._block_result(b, df) for b in sorted(bad - gf)) + + def as_data(b: Block) -> dict[str, Any]: + status = "grandfather" if b in gf else "bad" if b in bad else "good" + return {"status": status, **b.as_data()} + + self.path_to_blocks[p] = [as_data(b) for b in blocks] + + def _error(self, df: DocstringFile, result: _linter.LintResult) -> None: + self.path_to_errors[str(df.path)] = [{str(result.line): result.name}] @cached_property - def max_lines(self) -> dict[str, int]: - return {"class": self.args.max_class, "def": self.args.max_def} + def _grandfather(self) -> dict[str, dict[str, Any]]: + try: + with open(self.args.grandfather) as fp: + return json.load(fp) # type: ignore[no-any-return] + except FileNotFoundError: + return {} + except Exception as e: + print("ERROR:", e, "in", GRANDFATHER_LIST, file=sys.stderr) + raise - def lint_all(self) -> bool: - success = super().lint_all() - if not self.args.lintrunner and self.results: - self._report_results() - return success + @cached_property + def _max_lines(self) -> dict[str, int]: + return {"class": self.args.max_class, "def": self.args.max_def} - def _lint(self, pf: _linter.PythonFile) -> Iterator[_linter.LintResult]: - tokens = pf.tokens - indents = indent_to_dedent(tokens) - defs = [i for i, t in enumerate(tokens) if _is_def(t)] - - def next_token(start: int, token_type: int, error: str) -> int: # type: ignore[return] - for i in range(start, len(tokens)): - if tokens[i].type == token_type: - return i - raise _linter.ParseError(tokens[-1], error) - - for i in defs: - name = next_token(i + 1, token.NAME, "Definition with no name") - if not self.args.lint_protected and tokens[name].string.startswith("_"): - continue - - indent = next_token(name + 1, token.INDENT, "Definition with no indent") - dedent = indents[indent] - - lines = tokens[dedent].start[0] - tokens[indent].start[0] - max_lines = self.max_lines[tokens[i].string] - if lines <= max_lines: - continue - - # Now search for a docstring - docstring_len = -1 - for k in range(indent + 1, len(tokens)): - tk = tokens[k] - if tk.type == token.STRING: - docstring_len = len(tk.string) - break - if tk.type not in _linter.EMPTY_TOKENS: - break + def _grandfathered(self, path: Path | None, bad: set[Block]) -> set[Block]: + if path is None or self.args.no_grandfather or self.args.write_grandfather: + return set() + + grand: dict[str, int] = self._grandfather.get(str(path), {}) + tolerance_ratio = 1 + self.args.grandfather_tolerance / 100.0 + + def grandfathered(b: Block) -> bool: + lines = int(grand.get(b.display_name, 0) * tolerance_ratio) + return b.line_count <= lines + + return {b for b in bad if grandfathered(b)} + + def _block_result(self, b: Block, df: DocstringFile) -> _linter.LintResult: + def_name = "function" if b.category == "def" else "class" + msg = f"docstring found for {def_name} '{b.name}' ({b.line_count} lines)" + if len(b.docstring): + msg = msg + f" was too short ({len(b.docstring)} characters)" + else: + msg = "No " + msg + return _linter.LintResult(msg, *df.tokens[b.begin].start) + + def _display( + self, df: DocstringFile, results: list[_linter.LintResult] + ) -> Iterator[str]: + if not self.args.report: + yield from super()._display(df, results) + + def _dont_require_constructor_and_class_docs( + self, blocks: Sequence[Block], bad: set[Block] + ) -> set[Block]: + if self.args.lint_init: + return bad + + good = {b for b in blocks if len(b.docstring) >= self.args.min_docstring} + + def has_class_init_doc(b: Block) -> bool: + if b.is_class: + # Is it a class whose constructor is documented? + children = (blocks[i] for i in b.children) + return any(b.is_init and b in good for b in children) + + # Is it a constructor whose class is documented? + return b.is_init and b.parent is not None and blocks[b.parent] in good + + return {b for b in bad if not has_class_init_doc(b)} + + def _is_bad_block(self, b: Block, df: DocstringFile) -> bool: + max_lines = self._max_lines[b.category] + return ( + not df.omitted(df.tokens, b.begin, b.dedent) + and b.line_count > max_lines + and len(b.docstring) < self.args.min_docstring + and (self.args.lint_local or not b.is_local) + and (self.args.lint_protected or not b.name.startswith("_")) + ) - if docstring_len >= self.args.min_docstring: - continue - - # Now check if it's omitted - if pf.omitted(pf.tokens[i:indent]): - continue - - t = tokens[i] - def_name = "function" if t.string == "def" else t.string - tname = tokens[name].string - msg = f"docstring found for {def_name} '{tname}' ({lines} lines)" - if docstring_len < 0: - msg = "No " + msg - else: - msg = msg + f" was too short ({docstring_len} characters)" - yield _linter.LintResult(msg, *t.start) - if pf.path is not None: - self.results.setdefault(def_name, []).append((lines, pf.path, tname)) - - def _report_results(self) -> None: - print() - for i, (k, v) in enumerate(sorted(self.results.items())): - if i: - print() - top = sorted(v, reverse=True)[:REPORT_TOP_RESULTS] - if len(top) == 1: - s = "" - t = f"{len(top)} " - else: - s = "es" if k.endswith("s") else "s" - t = "" - print(f"Top {t}undocumented {k}{s}:") - for lines, path, tname in top: - print(f" {lines} lines: {path}:{tname}") - - -def indent_to_dedent(tokens: Sequence[TokenInfo]) -> dict[int, int]: - indent_to_dedent: dict[int, int] = {} - stack: list[int] = [] - - for i, t in enumerate(tokens): - if t.type == token.INDENT: - stack.append(i) - elif t.type == token.DEDENT: - assert stack - indent_to_dedent[stack.pop()] = i - - assert not stack - # Can't happen: the tokenization process would already have failed on a bad indent - - return indent_to_dedent + def _report(self) -> None: + if not self.args.lintrunner and self.path_to_blocks and self.args.report: + report = { + k: s for k, v in self.path_to_blocks.items() if (s := file_summary(v)) + } | self.path_to_errors + print(json.dumps(report, sort_keys=True, indent=2)) + + def _write_grandfather(self) -> None: + if self.args.write_grandfather: + results: dict[str, dict[str, int]] = {} + + for path, blocks in self.path_to_blocks.items(): + for block in blocks: + if block["status"] == "bad": + d = results.setdefault(path, {}) + d[block["display_name"]] = block["line_count"] + + with open(self.args.grandfather, "w") as fp: + json.dump(results, fp, sort_keys=True, indent=2) + + +def make_recursive(blocks: list[dict[str, Any]]) -> list[dict[str, Any]]: + def rec(i: int) -> dict[str, Any]: + d = dict(blocks[i]) + d["children"] = [rec(c) for c in d["children"]] + return d + + return [rec(i) for i, b in enumerate(blocks) if b["parent"] is None] + + +def make_terse( + blocks: Sequence[dict[str, Any]], + index_by_line: bool = True, +) -> dict[str, dict[str, Any]]: + result: dict[str, dict[str, Any]] = {} + + max_line = max(b["start_line"] for b in blocks) if blocks else 0 + line_field_width = len(str(max_line)) + + for b in blocks: + root = f"{b['category']} {b['full_name']}" + for i in itertools.count(): + name = root + bool(i) * f"[{i + 1}]" + if name not in result: + break + + d = { + "docstring_len": len(b["docstring"]), + "lines": b["line_count"], + "status": b.get("status", "good"), + } + + start_line = b["start_line"] + if index_by_line: + d["name"] = name + result[f"{start_line:>{line_field_width}}"] = d + else: + d["line"] = start_line + result[name] = d + + if kids := b["children"]: + if not all(isinstance(k, int) for k in kids): + assert all(isinstance(k, dict) for k in kids) + d["children"] = make_terse(kids) + + return result + + +def file_summary( + blocks: Sequence[dict[str, Any]], report_all: bool = False +) -> dict[str, str]: + def to_line(v: dict[str, Any]) -> str | None: + if (status := v["status"]) == "good": + if not report_all: + return None + fail = "" + elif status == "grandfather": + fail = ": (grandfathered)" + else: + assert status == "bad" + fail = ": FAIL" + name = v["name"] + lines = v["lines"] + docs = v["docstring_len"] + parens = "()" if name.startswith("def ") else "" + return f"{name}{parens}: {lines=}, {docs=}{fail}" + + t = make_terse(blocks) + r = {k: line for k, v in t.items() if (line := to_line(v))} + while r and all(k.startswith(" ") for k in r): + r = {k[1:]: v for k, v in r.items()} + return r + + +def add_arguments(add: Callable[..., Any]) -> None: + h = "Set the grandfather list" + add("--grandfather", "-g", default=str(GRANDFATHER_LIST), type=str, help=h) + + h = "Tolerance for grandfather sizes, in percent" + add("--grandfather-tolerance", "-t", default=TOLERANCE_PERCENT, type=float, help=h) + + h = "Lint __init__ and class separately" + add("--lint-init", "-i", action="store_true", help=h) + + h = "Lint definitions inside other functions" + add("--lint-local", "-o", action="store_true", help=h) + + h = "Lint functions, methods and classes that start with _" + add("--lint-protected", "-p", action="store_true", help=h) + + h = "Maximum number of lines for an undocumented class" + add("--max-class", "-c", default=MAX_LINES["class"], type=int, help=h) + + h = "Maximum number of lines for an undocumented function" + add("--max-def", "-d", default=MAX_LINES["def"], type=int, help=h) + + h = "Minimum number of characters for a docstring" + add("--min-docstring", "-s", default=MIN_DOCSTRING, type=int, help=h) + + h = "Disable the grandfather list" + add("--no-grandfather", "-n", action="store_true", help=h) + + h = "Print a report on all classes and defs" + add("--report", "-r", action="store_true", help=h) + + h = "Rewrite the grandfather list" + add("--write-grandfather", "-w", action="store_true", help=h) if __name__ == "__main__": diff --git a/tools/test/docstring_linter_testdata/block_names.py.txt b/tools/test/docstring_linter_testdata/block_names.py.txt new file mode 100644 index 000000000000..a3a41ec9cb46 --- /dev/null +++ b/tools/test/docstring_linter_testdata/block_names.py.txt @@ -0,0 +1,44 @@ +def top(number): + if number == 0: + + def fun(): + if number == 10: + def sab(): + return 1 + else: + def sub(): + return 2 + return sub + + elif number == 1: + + def fun(): + if number == 11: + def sub(): + return 3 + else: + def sub(): + return 4 + return sub + + elif number == 2: + + def fun(): + if number == 12: + def sub(): + return 5 + else: + def sab(): + return 6 + return sub + + elif number == 3: + + def run(): + if number == 12: + def sub(): + return 5 + else: + def sub(): + return 6 + return sub diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.json b/tools/test/docstring_linter_testdata/python_code.py.txt.json index eebee3718730..5efc13550f3d 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.json +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.json @@ -32,39 +32,6 @@ "replacement": null, "severity": "error" }, - { - "char": 8, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 72, - "name": "No docstring found for function 'not_short' (11 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, - { - "char": 12, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 73, - "name": "No docstring found for class 'Long' (6 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, - { - "char": 0, - "code": "DOCSTRING_LINTER", - "description": null, - "line": 84, - "name": "No docstring found for class 'NotDocstring' (12 lines)", - "original": null, - "path": "tools/test/docstring_linter_testdata/python_code.py.txt", - "replacement": null, - "severity": "error" - }, { "char": null, "code": "DOCSTRING_LINTER", diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner index 07adffee6d84..a787cb1ecb32 100644 --- a/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.lintrunner @@ -21,36 +21,3 @@ tools/test/docstring_linter_testdata/python_code.py.txt:71: No docstring found f ^ 72 | def not_short(): 73 | class Long: - -tools/test/docstring_linter_testdata/python_code.py.txt:72: No docstring found for function 'not_short' (11 lines) - 70 | - 71 | def needs_docs(self): - 72 | def not_short(): - ^ - 73 | class Long: - 74 | a = 1 - -tools/test/docstring_linter_testdata/python_code.py.txt:73: No docstring found for class 'Long' (6 lines) - 71 | def needs_docs(self): - 72 | def not_short(): - 73 | class Long: - ^ - 74 | a = 1 - 75 | b = 1 - -tools/test/docstring_linter_testdata/python_code.py.txt:84: No docstring found for class 'NotDocstring' (12 lines) - 82 | - 83 | - 84 | class NotDocstring: - ^ - 85 | def short1(self): - 86 | pass - -Top undocumented classes: - 12 lines: tools/test/docstring_linter_testdata/python_code.py.txt:NotDocstring - 6 lines: tools/test/docstring_linter_testdata/python_code.py.txt:LongWithShortDocstring - 6 lines: tools/test/docstring_linter_testdata/python_code.py.txt:Long - -Top undocumented functions: - 12 lines: tools/test/docstring_linter_testdata/python_code.py.txt:needs_docs - 11 lines: tools/test/docstring_linter_testdata/python_code.py.txt:not_short diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.report.json b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json new file mode 100644 index 000000000000..2ccc6f05703d --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.report.json @@ -0,0 +1,325 @@ +[ + { + "category": "class", + "children": [], + "display_name": "class ShortWithDocstring", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "ShortWithDocstring", + "index": 0, + "is_local": false, + "is_method": false, + "line_count": 4, + "parent": null, + "start_line": 2 + }, + { + "category": "class", + "children": [], + "display_name": "class Short", + "docstring": "", + "full_name": "Short", + "index": 1, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 7 + }, + { + "category": "class", + "children": [ + 3 + ], + "display_name": "class LongWithDocstring", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "LongWithDocstring", + "index": 2, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 11 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithDocstring.short1()", + "docstring": "", + "full_name": "LongWithDocstring.short1", + "index": 3, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 2, + "start_line": 14 + }, + { + "category": "class", + "children": [ + 5 + ], + "display_name": "class LongWithoutDocstring", + "docstring": "", + "full_name": "LongWithoutDocstring", + "index": 4, + "is_local": false, + "is_method": false, + "line_count": 4, + "parent": null, + "start_line": 20 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithoutDocstring.short1()", + "docstring": "", + "full_name": "LongWithoutDocstring.short1", + "index": 5, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 4, + "start_line": 21 + }, + { + "category": "class", + "children": [ + 7 + ], + "display_name": "class LongWithShortDocstring", + "docstring": "\"\"\"TODO\"\"\"", + "full_name": "LongWithShortDocstring", + "index": 6, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 25 + }, + { + "category": "def", + "children": [], + "display_name": "def LongWithShortDocstring.short1()", + "docstring": "", + "full_name": "LongWithShortDocstring.short1", + "index": 7, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 6, + "start_line": 28 + }, + { + "category": "class", + "children": [ + 9 + ], + "display_name": "class _Protected", + "docstring": "\"\"\"TODO\"\"\"", + "full_name": "_Protected", + "index": 8, + "is_local": false, + "is_method": false, + "line_count": 6, + "parent": null, + "start_line": 32 + }, + { + "category": "def", + "children": [], + "display_name": "def _Protected.short1()", + "docstring": "", + "full_name": "_Protected.short1", + "index": 9, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 8, + "start_line": 35 + }, + { + "category": "def", + "children": [], + "display_name": "def short()", + "docstring": "", + "full_name": "short", + "index": 10, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 42 + }, + { + "category": "def", + "children": [], + "display_name": "def long()", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "long", + "index": 11, + "is_local": false, + "is_method": false, + "line_count": 8, + "parent": null, + "start_line": 46 + }, + { + "category": "def", + "children": [], + "display_name": "def long_without_docstring()", + "docstring": "", + "full_name": "long_without_docstring", + "index": 12, + "is_local": false, + "is_method": false, + "line_count": 3, + "parent": null, + "start_line": 59 + }, + { + "category": "class", + "children": [ + 14, + 15, + 16, + 17 + ], + "display_name": "class ImpossibleCombo", + "docstring": "\"\"\"This docstring, while short, is enough\"\"\"", + "full_name": "ImpossibleCombo", + "index": 13, + "is_local": false, + "is_method": false, + "line_count": 15, + "parent": null, + "start_line": 69 + }, + { + "category": "def", + "children": [ + 15, + 16, + 17 + ], + "display_name": "def ImpossibleCombo.needs_docs()", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs", + "index": 14, + "is_local": false, + "is_method": true, + "line_count": 12, + "parent": 13, + "start_line": 72 + }, + { + "category": "def", + "children": [ + 16, + 17 + ], + "display_name": "def ImpossibleCombo.needs_docs.not_short()", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short", + "index": 15, + "is_local": true, + "is_method": false, + "line_count": 11, + "parent": 14, + "start_line": 73 + }, + { + "category": "class", + "children": [], + "display_name": "class ImpossibleCombo.needs_docs.not_short.Long", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short.Long", + "index": 16, + "is_local": true, + "is_method": false, + "line_count": 6, + "parent": 15, + "start_line": 74 + }, + { + "category": "class", + "children": [], + "display_name": "class ImpossibleCombo.needs_docs.not_short.Short", + "docstring": "", + "full_name": "ImpossibleCombo.needs_docs.not_short.Short", + "index": 17, + "is_local": true, + "is_method": false, + "line_count": 3, + "parent": 15, + "start_line": 81 + }, + { + "category": "class", + "children": [ + 19, + 20, + 21 + ], + "display_name": "class NotDocstring", + "docstring": "", + "full_name": "NotDocstring", + "index": 18, + "is_local": false, + "is_method": false, + "line_count": 12, + "parent": null, + "start_line": 85 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short1()", + "docstring": "", + "full_name": "NotDocstring.short1", + "index": 19, + "is_local": false, + "is_method": true, + "line_count": 2, + "parent": 18, + "start_line": 86 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short2()", + "docstring": "", + "full_name": "NotDocstring.short2", + "index": 20, + "is_local": false, + "is_method": true, + "line_count": 2, + "parent": 18, + "start_line": 91 + }, + { + "category": "def", + "children": [], + "display_name": "def NotDocstring.short3()", + "docstring": "", + "full_name": "NotDocstring.short3", + "index": 21, + "is_local": false, + "is_method": true, + "line_count": 3, + "parent": 18, + "start_line": 94 + }, + { + "category": "def", + "children": [], + "display_name": "def long_with_omit()", + "docstring": "", + "full_name": "long_with_omit", + "index": 22, + "is_local": false, + "is_method": false, + "line_count": 1, + "parent": null, + "start_line": 102 + } +] diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json new file mode 100644 index 000000000000..bbf71643c76a --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.single.line.json @@ -0,0 +1,25 @@ +{ + " 2": "class ShortWithDocstring: lines=4, docs=44", + " 7": "class Short: lines=3, docs=0", + " 11": "class LongWithDocstring: lines=6, docs=44", + " 14": "def LongWithDocstring.short1(): lines=3, docs=0", + " 20": "class LongWithoutDocstring: lines=4, docs=0", + " 21": "def LongWithoutDocstring.short1(): lines=3, docs=0", + " 25": "class LongWithShortDocstring: lines=6, docs=10", + " 28": "def LongWithShortDocstring.short1(): lines=3, docs=0", + " 32": "class _Protected: lines=6, docs=10", + " 35": "def _Protected.short1(): lines=3, docs=0", + " 42": "def short(): lines=3, docs=0", + " 46": "def long(): lines=8, docs=44", + " 59": "def long_without_docstring(): lines=3, docs=0", + " 69": "class ImpossibleCombo: lines=15, docs=44", + " 72": "def ImpossibleCombo.needs_docs(): lines=12, docs=0", + " 73": "def ImpossibleCombo.needs_docs.not_short(): lines=11, docs=0", + " 74": "class ImpossibleCombo.needs_docs.not_short.Long: lines=6, docs=0", + " 81": "class ImpossibleCombo.needs_docs.not_short.Short: lines=3, docs=0", + " 85": "class NotDocstring: lines=12, docs=0", + " 86": "def NotDocstring.short1(): lines=2, docs=0", + " 91": "def NotDocstring.short2(): lines=2, docs=0", + " 94": "def NotDocstring.short3(): lines=3, docs=0", + "102": "def long_with_omit(): lines=1, docs=0" +} diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json new file mode 100644 index 000000000000..0b86e9e6ba1e --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.json @@ -0,0 +1,140 @@ +{ + "class ImpossibleCombo": { + "docstring_len": 44, + "line": 69, + "lines": 15, + "status": "good" + }, + "class ImpossibleCombo.needs_docs.not_short.Long": { + "docstring_len": 0, + "line": 74, + "lines": 6, + "status": "good" + }, + "class ImpossibleCombo.needs_docs.not_short.Short": { + "docstring_len": 0, + "line": 81, + "lines": 3, + "status": "good" + }, + "class LongWithDocstring": { + "docstring_len": 44, + "line": 11, + "lines": 6, + "status": "good" + }, + "class LongWithShortDocstring": { + "docstring_len": 10, + "line": 25, + "lines": 6, + "status": "good" + }, + "class LongWithoutDocstring": { + "docstring_len": 0, + "line": 20, + "lines": 4, + "status": "good" + }, + "class NotDocstring": { + "docstring_len": 0, + "line": 85, + "lines": 12, + "status": "good" + }, + "class Short": { + "docstring_len": 0, + "line": 7, + "lines": 3, + "status": "good" + }, + "class ShortWithDocstring": { + "docstring_len": 44, + "line": 2, + "lines": 4, + "status": "good" + }, + "class _Protected": { + "docstring_len": 10, + "line": 32, + "lines": 6, + "status": "good" + }, + "def ImpossibleCombo.needs_docs": { + "docstring_len": 0, + "line": 72, + "lines": 12, + "status": "good" + }, + "def ImpossibleCombo.needs_docs.not_short": { + "docstring_len": 0, + "line": 73, + "lines": 11, + "status": "good" + }, + "def LongWithDocstring.short1": { + "docstring_len": 0, + "line": 14, + "lines": 3, + "status": "good" + }, + "def LongWithShortDocstring.short1": { + "docstring_len": 0, + "line": 28, + "lines": 3, + "status": "good" + }, + "def LongWithoutDocstring.short1": { + "docstring_len": 0, + "line": 21, + "lines": 3, + "status": "good" + }, + "def NotDocstring.short1": { + "docstring_len": 0, + "line": 86, + "lines": 2, + "status": "good" + }, + "def NotDocstring.short2": { + "docstring_len": 0, + "line": 91, + "lines": 2, + "status": "good" + }, + "def NotDocstring.short3": { + "docstring_len": 0, + "line": 94, + "lines": 3, + "status": "good" + }, + "def _Protected.short1": { + "docstring_len": 0, + "line": 35, + "lines": 3, + "status": "good" + }, + "def long": { + "docstring_len": 44, + "line": 46, + "lines": 8, + "status": "good" + }, + "def long_with_omit": { + "docstring_len": 0, + "line": 102, + "lines": 1, + "status": "good" + }, + "def long_without_docstring": { + "docstring_len": 0, + "line": 59, + "lines": 3, + "status": "good" + }, + "def short": { + "docstring_len": 0, + "line": 42, + "lines": 3, + "status": "good" + } +} diff --git a/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json new file mode 100644 index 000000000000..ee2facfc6b5d --- /dev/null +++ b/tools/test/docstring_linter_testdata/python_code.py.txt.terse.line.json @@ -0,0 +1,140 @@ +{ + " 2": { + "docstring_len": 44, + "lines": 4, + "name": "class ShortWithDocstring", + "status": "good" + }, + " 7": { + "docstring_len": 0, + "lines": 3, + "name": "class Short", + "status": "good" + }, + " 11": { + "docstring_len": 44, + "lines": 6, + "name": "class LongWithDocstring", + "status": "good" + }, + " 14": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithDocstring.short1", + "status": "good" + }, + " 20": { + "docstring_len": 0, + "lines": 4, + "name": "class LongWithoutDocstring", + "status": "good" + }, + " 21": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithoutDocstring.short1", + "status": "good" + }, + " 25": { + "docstring_len": 10, + "lines": 6, + "name": "class LongWithShortDocstring", + "status": "good" + }, + " 28": { + "docstring_len": 0, + "lines": 3, + "name": "def LongWithShortDocstring.short1", + "status": "good" + }, + " 32": { + "docstring_len": 10, + "lines": 6, + "name": "class _Protected", + "status": "good" + }, + " 35": { + "docstring_len": 0, + "lines": 3, + "name": "def _Protected.short1", + "status": "good" + }, + " 42": { + "docstring_len": 0, + "lines": 3, + "name": "def short", + "status": "good" + }, + " 46": { + "docstring_len": 44, + "lines": 8, + "name": "def long", + "status": "good" + }, + " 59": { + "docstring_len": 0, + "lines": 3, + "name": "def long_without_docstring", + "status": "good" + }, + " 69": { + "docstring_len": 44, + "lines": 15, + "name": "class ImpossibleCombo", + "status": "good" + }, + " 72": { + "docstring_len": 0, + "lines": 12, + "name": "def ImpossibleCombo.needs_docs", + "status": "good" + }, + " 73": { + "docstring_len": 0, + "lines": 11, + "name": "def ImpossibleCombo.needs_docs.not_short", + "status": "good" + }, + " 74": { + "docstring_len": 0, + "lines": 6, + "name": "class ImpossibleCombo.needs_docs.not_short.Long", + "status": "good" + }, + " 81": { + "docstring_len": 0, + "lines": 3, + "name": "class ImpossibleCombo.needs_docs.not_short.Short", + "status": "good" + }, + " 85": { + "docstring_len": 0, + "lines": 12, + "name": "class NotDocstring", + "status": "good" + }, + " 86": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.short1", + "status": "good" + }, + " 91": { + "docstring_len": 0, + "lines": 2, + "name": "def NotDocstring.short2", + "status": "good" + }, + " 94": { + "docstring_len": 0, + "lines": 3, + "name": "def NotDocstring.short3", + "status": "good" + }, + "102": { + "docstring_len": 0, + "lines": 1, + "name": "def long_with_omit", + "status": "good" + } +} diff --git a/tools/test/test_docstring_linter.py b/tools/test/test_docstring_linter.py index 85ea26de4e77..d09c84de131a 100644 --- a/tools/test/test_docstring_linter.py +++ b/tools/test/test_docstring_linter.py @@ -1,10 +1,14 @@ # mypy: ignore-errors -from __future__ import annotations +import json import sys from pathlib import Path -from tools.linter.adapters.docstring_linter import DocstringLinter +from tools.linter.adapters.docstring_linter import ( + DocstringLinter, + file_summary, + make_terse, +) _PARENT = Path(__file__).parent.absolute() @@ -16,11 +20,69 @@ from .linter_test_case import LinterTestCase TEST_FILE = Path("tools/test/docstring_linter_testdata/python_code.py.txt") +TEST_FILE2 = Path("tools/test/docstring_linter_testdata/more_python_code.py.txt") +TEST_BLOCK_NAMES = Path("tools/test/docstring_linter_testdata/block_names.py.txt") +ARGS = "--max-class=3", "--max-def=4", "--min-docstring=16" class TestDocstringLinter(LinterTestCase): LinterClass = DocstringLinter + maxDiff = 10_240 def test_python_code(self): - args = "--max-class=3 --max-def=4".split() - self.lint_test(TEST_FILE, args) + self.lint_test(TEST_FILE, ARGS) + + def test_report(self): + actual = _dumps(_data()) + self.assertExpected(TEST_FILE, actual, "report.json") + + def test_terse(self): + terse = make_terse(_data(), index_by_line=False) + actual = _dumps(terse) + self.assertExpected(TEST_FILE, actual, "terse.json") + + def test_terse_line(self): + terse = make_terse(_data(), index_by_line=True) + actual = _dumps(terse) + self.assertExpected(TEST_FILE, actual, "terse.line.json") + + def test_file_summary(self): + actual = _dumps(file_summary(_data(), report_all=True)) + self.assertExpected(TEST_FILE, actual, "single.line.json") + + def test_file_names(self): + f = DocstringLinter.make_file(TEST_BLOCK_NAMES) + actual = [b.full_name for b in f.blocks] + expected = [ + "top", + "top.fun[1]", + "top.fun[1].sab", + "top.fun[1].sub", + "top.fun[2]", + "top.fun[2].sub[1]", + "top.fun[2].sub[2]", + "top.fun[3]", + "top.fun[3].sub", + "top.fun[3].sab", + "top.run", + "top.run.sub[1]", + "top.run.sub[2]", + ] + self.assertEqual(actual, expected) + + +def _dumps(d: dict) -> str: + return json.dumps(d, sort_keys=True, indent=2) + "\n" + + +def _data(): + docstring_file = DocstringLinter.make_file(TEST_FILE) + return [b.as_data() for b in docstring_file.blocks] + + +def _next_stdout(mock_stdout): + length = 0 + while True: + s = mock_stdout.getvalue() + yield s[length:] + length = len(s) From 087e8587cd13953b629d42be5427f82f6c0ece30 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 9 Apr 2025 09:02:22 -0700 Subject: [PATCH 329/332] support backed_size_oblivious in guard_or_false/guard_or_true (#150231) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150231 Approved by: https://github.com/pianpwk --- test/test_dynamic_shapes.py | 32 +++++++++++++ torch/fx/experimental/symbolic_shapes.py | 58 ++++++++++++++++-------- 2 files changed, 71 insertions(+), 19 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 96115b7b37fd..224846681500 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -2880,6 +2880,22 @@ def func2(a, b): unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) ) + # Test backed_size_oblivious + with torch.fx.experimental._config.patch("backed_size_oblivious", True): + + def func3(a, b): + if guard_or_true(a.size()[0] != 9): + return b * 10 + else: + return b * 20 + + compiled = torch.compile(func3, dynamic=True, fullgraph=True) + a = torch.rand(9, 2) + b = torch.rand(3, 4) + + self.assertEqual(func3(a, b), b * 20) + self.assertEqual(compiled(a, b), b * 10) + @skipIfTorchDynamo("Not a TorchDynamo suitable test") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_guard_or_false(self): @@ -2929,6 +2945,22 @@ def func2(a, b): unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) ) + # Test backed_size_oblivious + with torch.fx.experimental._config.patch("backed_size_oblivious", True): + + def func3(a, b): + if guard_or_false(a.size()[0] == 9): + return b * 10 + else: + return b * 20 + + compiled = torch.compile(func3, dynamic=True, fullgraph=True) + a = torch.rand(9, 2) + b = torch.rand(3, 4) + + self.assertEqual(func3(a, b), b * 10) + self.assertEqual(compiled(a, b), b * 20) + def test_guards_float_div(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 8) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 7b870f298ca0..9cb699483628 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1195,20 +1195,30 @@ def guard_or_false(a: BoolLikeType) -> bool: """ Try to guard a, if data dependent error encountered just return false. """ - try: - return bool(guard_bool(a)) - except GuardOnDataDependentSymNode: - return False + if torch.fx.experimental._config.backed_size_oblivious: + return statically_known_true(a) + else: + try: + return bool(guard_bool(a)) + except GuardOnDataDependentSymNode: + return False def guard_or_true(a: BoolLikeType) -> bool: """ Try to guard a, if data dependent error encountered just return true. """ - try: - return bool(guard_bool(a)) - except GuardOnDataDependentSymNode: - return True + if torch.fx.experimental._config.backed_size_oblivious: + result = _static_eval(a) + if result is not None: + return result + else: + return True + else: + try: + return bool(guard_bool(a)) + except GuardOnDataDependentSymNode: + return True def definitely_true(a: BoolLikeType) -> bool: @@ -1253,6 +1263,23 @@ def definitely_false(a: BoolLikeType) -> bool: return not bool(a) +def _static_eval(x: Union[bool, SymBool]) -> Optional[bool]: + if isinstance(x, SymBool): + expr = x.node.expr + shape_env = x.node.shape_env + try: + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + else: + return None + except Exception: + log.debug("Could not simplify %s", expr) + return None + assert isinstance(x, bool) + return x + + def statically_known_true(x: Union[bool, SymBool]) -> bool: """ Returns True if x can be simplified to a constant and is true. @@ -1264,18 +1291,11 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool: Args: x (bool, SymBool): The expression to try statically evaluating """ - if isinstance(x, SymBool): - expr = x.node.expr - shape_env = x.node.shape_env - try: - simplified = shape_env._maybe_evaluate_static(expr) - if simplified is not None: - return bool(simplified) - except Exception: - log.debug("Could not simplify %s", expr) + result = _static_eval(x) + if result is None: return False - assert isinstance(x, bool) - return x + else: + return result def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]: From 786422a4d74b37507b12d9ee209e2517ef1dcca0 Mon Sep 17 00:00:00 2001 From: Yifei Teng Date: Wed, 9 Apr 2025 21:47:59 +0000 Subject: [PATCH 330/332] Remove a workaround added in #149381 (#150693) Remove a workaround added in https://github.com/pytorch/pytorch/pull/149381. Fixes https://github.com/pytorch/xla/issues/8934 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150693 Approved by: https://github.com/albanD --- .ci/pytorch/test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 69566a244c9b..1e6b50f04f26 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -1175,7 +1175,6 @@ build_xla() { # These functions are defined in .circleci/common.sh in pytorch/xla repo retry install_pre_deps_pytorch_xla $XLA_DIR $USE_CACHE CMAKE_PREFIX_PATH="${SITE_PACKAGES}/torch:${CMAKE_PREFIX_PATH}" XLA_SANDBOX_BUILD=1 build_torch_xla $XLA_DIR - retry install_post_deps_pytorch_xla assert_git_not_dirty } From cc2decdb25d196f20c095268f252683cd4129cc8 Mon Sep 17 00:00:00 2001 From: Wei Wang <143543872+nWEIdia@users.noreply.github.com> Date: Wed, 9 Apr 2025 21:57:05 +0000 Subject: [PATCH 331/332] [CI][CUDA][Distributed]Update test_composability.py (#148578) world_size = int(os.getenv("WORLD_SIZE", 4)) in subsequent lines indicate the tests in this file do not only require > 1 GPU, but at least 4 GPUs. skip_if_lt_x_gpu(4) does not properly skip this on a platform with 2 GPUs. skip_if_lt_x_gpu being broken, potentially related to a similar issue: https://github.com/pytorch/pytorch/issues/146094 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148578 Approved by: https://github.com/atalman --- test/distributed/test_composability.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributed/test_composability.py b/test/distributed/test_composability.py index 91b22a60e74b..812d5d8abc16 100644 --- a/test/distributed/test_composability.py +++ b/test/distributed/test_composability.py @@ -385,7 +385,7 @@ def apply_dp(partial_model): if not ( dist.is_available() and dist.is_nccl_available() - and torch.cuda.device_count() > 1 + and torch.cuda.device_count() > 3 ): print( "c10d NCCL not available or not enough GPUs, skipping tests", From b347f0cfd20b3eb7bd30113aa2feeb349db152c1 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Thu, 13 Mar 2025 16:43:24 +0800 Subject: [PATCH 332/332] Add xpu backend for depthwise_conv2d/3d Ops --- aten/src/ATen/native/Convolution.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 78cc6237451d..38bfdaa397f0 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -603,7 +603,7 @@ struct ConvParams { // nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of // a depthwise multiplier) bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const { - return input.is_cuda() && + return (input.is_cuda() || input.is_xpu()) && !transposed && (input.ndimension() == 4 || input.ndimension() == 5) && at::symint::size(input, 1) == groups && @@ -1219,6 +1219,12 @@ ConvBackend _select_conv_backend( return ConvBackend::Cudnn; } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { return ConvBackend::MiopenDepthwise; + } else if (params.use_mkldnn(input, weight)) { + if (params.transposed) { + return ConvBackend::MkldnnTranspose; + } else { + return ConvBackend::Mkldnn; + } } else { if (input.ndimension() == 4) { return ConvBackend::CudaDepthwise2d;