diff --git a/src/ATen/native/xpu/sycl/DistanceKernels.cpp b/src/ATen/native/xpu/sycl/DistanceKernels.cpp index cd807b1be..2c787d747 100644 --- a/src/ATen/native/xpu/sycl/DistanceKernels.cpp +++ b/src/ATen/native/xpu/sycl/DistanceKernels.cpp @@ -729,7 +729,7 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const size_t stride = item_id.get_local_range().size(); int64_t i = static_cast( - (n2_val_ - device_sqrt(n2_squared_minus_1_val_ - 2 * k))); + (n2_val_ - device_sqrt(n2_squared_minus_1_val_ - 2 * k))); int64_t j = k - n_ * i + i * (i + 1) / 2 + i + 1; const scalar_t* const start = in_ptr + i * m_; @@ -760,8 +760,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const int64_t n, const int64_t m, accscalar_t p_val, - accscalar_t n2_val, - accscalar_t n2_squared_minus_1_val, + const double n2_val, + const double n2_squared_minus_1_val, scalar_t* out_data, const scalar_t* in_data, const int64_t wgroup_size) @@ -778,8 +778,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { const int64_t n_; const int64_t m_; accscalar_t p_val_; - accscalar_t n2_val_; - accscalar_t n2_squared_minus_1_val_; + const double n2_val_; + const double n2_squared_minus_1_val_; scalar_t* out_data_; const scalar_t* in_data_; sycl_local_acc_t shared_; @@ -805,8 +805,6 @@ static void pdist_kernel_impl( } auto p_val = static_cast(p); - auto n2_val = static_cast(n2); - auto n2_squared_minus_1_val = static_cast(n2_squared_minus_1); auto out_data = result.mutable_data_ptr(); auto in_data = self.const_data_ptr(); @@ -815,8 +813,8 @@ static void pdist_kernel_impl( n, m, p_val, - n2_val, - n2_squared_minus_1_val, + n2, + n2_squared_minus_1, out_data, in_data, wgroup_size / min_sg_size);