From c5e14221db4bd51838d9b3cff05c33648daf5ccc Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Wed, 17 Sep 2025 16:43:32 +0000 Subject: [PATCH] Clean up getDeviceIndexOfCurrentQueue --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 3 +- src/ATen/native/xpu/sycl/Norm.h | 3 +- src/ATen/native/xpu/sycl/SoftMaxKernels.cpp | 12 +- .../native/xpu/sycl/TensorShapeKernels.cpp | 3 +- src/comm/DeviceProperties.h | 130 +++++++----------- src/comm/Runtime.h | 4 - 6 files changed, 60 insertions(+), 95 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index f8e1b6906e..4124890274 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -185,8 +185,7 @@ int get_num_threads_by_dev_max_group_size( int get_prefer_simd(int numPlane, int nHw) { // decide SIMD: SIMD32 or SIMD16 - auto dev_id = at::xpu::getDeviceIndexOfCurrentQueue(); - + auto dev_id = at::xpu::current_device(); auto* dev_prop = at::xpu::getDeviceProperties(dev_id); auto sub_group_size = dev_prop->sub_group_sizes; int simd = sub_group_size[1]; diff --git a/src/ATen/native/xpu/sycl/Norm.h b/src/ATen/native/xpu/sycl/Norm.h index 6117b6d261..8b887a443a 100644 --- a/src/ATen/native/xpu/sycl/Norm.h +++ b/src/ATen/native/xpu/sycl/Norm.h @@ -269,8 +269,7 @@ class NormConfig { } void get_max_vec_size() { - auto dev_id = getDeviceIndexOfCurrentQueue(); - int total_resource = syclMaxWorkItemsPerTile(dev_id); + int64_t total_resource = syclMaxWorkItemsPerTile(); constexpr int float4_size = sizeof(float) * 4; max_vec_size = float4_size / element_size_bytes; diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index fdedf5fb09..020fe4a1bc 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -1559,8 +1559,7 @@ void spatial_softmax_forward( canUse32BitIndexMath(input) && canUse32BitIndexMath(output); // decide SIMD: SIMD32 or SIMD16 - auto dev_id = at::xpu::getDeviceIndexOfCurrentQueue(); - auto* dev_prop = at::xpu::getDeviceProperties(dev_id); + auto* dev_prop = at::xpu::getCurrentDeviceProperties(); auto sub_group_size = dev_prop->sub_group_sizes; int SIMD = sub_group_size[1]; if (SIMD == SIMD32) { @@ -1749,8 +1748,7 @@ void spatial_softmax_backward( canUse32BitIndexMath(output) && canUse32BitIndexMath(gradOutput); // decide SIMD: SIMD32 or SIMD16 - auto* dev_prop = - at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue()); + auto* dev_prop = at::xpu::getCurrentDeviceProperties(); auto sub_group_size = dev_prop->sub_group_sizes; int SIMD = sub_group_size[1]; if (SIMD == SIMD32) { @@ -1901,8 +1899,7 @@ Tensor& masked_softmax_forward( canUse32BitIndexMath(input) && canUse32BitIndexMath(output); // decide SIMD: SIMD32 or SIMD16 - auto* dev_prop = - at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue()); + auto* dev_prop = at::xpu::getCurrentDeviceProperties(); auto sub_group_size = dev_prop->sub_group_sizes; int SIMD = sub_group_size[1]; if (SIMD == SIMD32) { @@ -2026,8 +2023,7 @@ void masked_softmax_backward( canUse32BitIndexMath(output) && canUse32BitIndexMath(gradOutput); // decide SIMD: SIMD32 or SIMD16 - auto* dev_prop = - at::xpu::getDeviceProperties(at::xpu::getDeviceIndexOfCurrentQueue()); + auto* dev_prop = at::xpu::getCurrentDeviceProperties(); auto sub_group_size = dev_prop->sub_group_sizes; int SIMD = sub_group_size[1]; if (SIMD == SIMD32) { diff --git a/src/ATen/native/xpu/sycl/TensorShapeKernels.cpp b/src/ATen/native/xpu/sycl/TensorShapeKernels.cpp index 8c46658e7a..fed5fee904 100644 --- a/src/ATen/native/xpu/sycl/TensorShapeKernels.cpp +++ b/src/ATen/native/xpu/sycl/TensorShapeKernels.cpp @@ -669,8 +669,7 @@ void split_with_sizes_copy_out_xpu_contiguous_no_cast( num_groups += div_up(split_chunk_size, GROUP_SIZE * BYTES_PER_THREAD); } - auto dev_id = getDeviceIndexOfCurrentQueue(); - int64_t tile_size = syclMaxWorkItemsPerTile(dev_id); + int64_t tile_size = syclMaxWorkItemsPerTile(); const int64_t max_groups = tile_size / GROUP_SIZE * 2.0; // Make each thread process BYTES_PER_THREAD * iter_factor bytes to regulate diff --git a/src/comm/DeviceProperties.h b/src/comm/DeviceProperties.h index ee0d285eaf..5d1f0c8f77 100644 --- a/src/comm/DeviceProperties.h +++ b/src/comm/DeviceProperties.h @@ -3,17 +3,15 @@ #include #include -#include namespace xpu { namespace sycl { template static int64_t syclMaxWorkGroupSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { - auto q = c10::xpu::getCurrentXPUStream(dev_id).queue(); - auto ctx = q.get_context(); - auto dev = q.get_device(); + at::DeviceIndex dev_id = at::xpu::current_device()) { + auto& ctx = c10::xpu::get_device_context(); + auto& dev = c10::xpu::get_raw_device(dev_id); auto kid = ::sycl::get_kernel_id(); // The kernel won't be built for devices except for the first device. @@ -30,73 +28,69 @@ static int64_t syclMaxWorkGroupSize( template static int64_t syclMaxWorkGroupSize( - KernelClass /*kfn*/, - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + const KernelClass& /*kfn*/, + at::DeviceIndex dev_id = at::xpu::current_device()) { return syclMaxWorkGroupSize(dev_id); } static inline int64_t syclDeviceMaxWorkGroupSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->max_work_group_size; } static inline int64_t syclMaxSubGroupSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); - auto subgroup_sizes = dev_prop->sub_group_sizes; - uint64_t max_val = 0; - for (auto i : subgroup_sizes) { - if (i > max_val) - max_val = i; - } - return max_val; + const auto& subgroup_sizes = dev_prop->sub_group_sizes; + TORCH_CHECK( + !subgroup_sizes.empty(), + "The device subgroup sizes is empty, please check the device status."); + return *std::max_element(subgroup_sizes.begin(), subgroup_sizes.end()); } static inline int64_t syclMinSubGroupSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); - auto subgroup_sizes = dev_prop->sub_group_sizes; - uint64_t min_val = dev_prop->max_work_group_size; - for (auto i : subgroup_sizes) { - if (i < min_val) - min_val = i; - } - return min_val; + const auto& subgroup_sizes = dev_prop->sub_group_sizes; + TORCH_CHECK( + !subgroup_sizes.empty(), + "The device subgroup sizes is empty, please check the device status."); + return *std::min_element(subgroup_sizes.begin(), subgroup_sizes.end()); } static inline int64_t syclMaxComputeUnitSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->max_compute_units; } static inline int64_t syclGpuEuCount( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->gpu_eu_count; } static inline int64_t syclGpuEuSimdWidth( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->gpu_eu_simd_width; } static inline int64_t syclGpuHWThreadsPerEU( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->gpu_hw_threads_per_eu; } static inline int64_t syclGpuEUCountPerSubslice( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->gpu_eu_count_per_subslice; } static inline int64_t syclMaxWorkItemsPerTile( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); int64_t eu_cnt = dev_prop->gpu_eu_count; int64_t simd_width = syclMaxSubGroupSize(dev_id); @@ -105,7 +99,7 @@ static inline int64_t syclMaxWorkItemsPerTile( } static inline int64_t syclMaxWorkItemsPerSubSlice( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); int64_t simd_width = syclMaxSubGroupSize(dev_id); int64_t eu_count = dev_prop->gpu_eu_count_per_subslice; @@ -113,7 +107,7 @@ static inline int64_t syclMaxWorkItemsPerSubSlice( } static inline int64_t syclMaxWorkItemsPerEU( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); int64_t simd_width = syclMaxSubGroupSize(dev_id); int64_t hw_threads = dev_prop->gpu_hw_threads_per_eu; @@ -121,94 +115,76 @@ static inline int64_t syclMaxWorkItemsPerEU( } static inline int64_t syclMaxNumSubGroups( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->max_num_sub_groups; } static inline int64_t syclMaxDSSNum( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { int64_t dss_num = syclMaxComputeUnitSize(dev_id) / syclGpuEUCountPerSubslice(dev_id); return dss_num; } static inline size_t syclGlobalMemSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->global_mem_size; } static inline int64_t syclLocalMemSize( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->local_mem_size; } template uint32_t syclPrefVectorWidth( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { (void)dev_id; // Suppress unused variable warning // Hot fix. This is the preferred vector width for GPUs up to LNL/BMG. - uint32_t vec_width = 16; + constexpr uint32_t vec_width = 16; - if (std::is_same::value) { - return vec_width / sizeof(char); - } - if (std::is_same::value) { - return vec_width / sizeof(short); - } - if (std::is_same::value) { - return vec_width / sizeof(int); - } - if (std::is_same::value) { - return vec_width / sizeof(int64_t); - } - if (std::is_same::value) { - return vec_width / sizeof(float); - } - if (std::is_same::value) { - return vec_width / sizeof(double); + if constexpr ( + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return vec_width / sizeof(T); + } else { + throw std::invalid_argument( + "Invalid data type to fetch preferred vector width!"); } - if (std::is_same::value) { - return vec_width / sizeof(::sycl::half); - } - throw std::invalid_argument( - "Invalid data type to fetch preferred vector width!"); } template uint32_t syclNativeVectorWidth( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); - if (std::is_same::value) { + if constexpr (std::is_same_v) { return dev_prop->native_vector_width_char; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_short; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_int; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_long; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_float; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_double; - } - if (std::is_same::value) { + } else if constexpr (std::is_same_v) { return dev_prop->native_vector_width_half; + } else { + throw std::invalid_argument( + "Invalid data type to fetch native vector width!"); } - throw std::invalid_argument( - "Invalid data type to fetch native vector width!"); } static inline bool syclHasFloat64( - at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + at::DeviceIndex dev_id = at::xpu::current_device()) { auto* dev_prop = at::xpu::getDeviceProperties(dev_id); return dev_prop->has_fp64; } diff --git a/src/comm/Runtime.h b/src/comm/Runtime.h index fa7daaf125..4fd44d08e5 100644 --- a/src/comm/Runtime.h +++ b/src/comm/Runtime.h @@ -4,10 +4,6 @@ namespace at::xpu { -static inline at::DeviceIndex getDeviceIndexOfCurrentQueue() { - return c10::xpu::getCurrentXPUStream().device_index(); -} - static inline sycl::queue& getCurrentSYCLQueue() { return c10::xpu::getCurrentXPUStream().queue(); }