From ac624bbfaa2343984a5972d2c78f08beb6644992 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 19 Sep 2025 00:08:12 +0800 Subject: [PATCH] Fixing Reduction kernel Signed-off-by: Feng Yuan --- src/ATen/native/xpu/sycl/Reduce.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 6900e73af0..156f2815ea 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -36,7 +36,7 @@ inline at::detail::Array group_reduce( at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - vec_t* shared_ = reinterpret_cast(shared_); + vec_t* shared_ = reinterpret_cast(shared); int l_x = item.get_local_linear_id(); // int dim_x = wg_size; auto sg = item.get_sub_group(); @@ -105,7 +105,7 @@ inline at::detail::Array group_x_reduce( at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - vec_t* shared_ = reinterpret_cast(shared_); + vec_t* shared_ = reinterpret_cast(shared); int l_x = item.get_local_id(1), l_y = item.get_local_id(0); int g_x = item.get_local_range(1); int dim_x = g_x; @@ -147,7 +147,7 @@ inline at::detail::Array group_y_reduce( at::detail::Array value, CombineFunc combine) { using vec_t = at::detail::Array; - vec_t* shared_ = reinterpret_cast(shared_); + vec_t* shared_ = reinterpret_cast(shared); int l_x = item.get_local_id(1), l_y = item.get_local_id(0); int g_x = item.get_local_range(1); int dim_y = item.get_local_range(0);