Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/ATen/native/xpu/sycl/DistanceKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
const size_t stride = item_id.get_local_range().size();

int64_t i = static_cast<int64_t>(
(n2_val_ - device_sqrt<accscalar_t>(n2_squared_minus_1_val_ - 2 * k)));
(n2_val_ - device_sqrt<double>(n2_squared_minus_1_val_ - 2 * k)));
int64_t j = k - n_ * i + i * (i + 1) / 2 + i + 1;

const scalar_t* const start = in_ptr + i * m_;
Expand Down Expand Up @@ -760,8 +760,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
const int64_t n,
const int64_t m,
accscalar_t p_val,
accscalar_t n2_val,
accscalar_t n2_squared_minus_1_val,
const double n2_val,
const double n2_squared_minus_1_val,
scalar_t* out_data,
const scalar_t* in_data,
const int64_t wgroup_size)
Expand All @@ -778,8 +778,8 @@ struct PdistKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
const int64_t n_;
const int64_t m_;
accscalar_t p_val_;
accscalar_t n2_val_;
accscalar_t n2_squared_minus_1_val_;
const double n2_val_;
const double n2_squared_minus_1_val_;
scalar_t* out_data_;
const scalar_t* in_data_;
sycl_local_acc_t<scalar_t, 1> shared_;
Expand All @@ -805,8 +805,6 @@ static void pdist_kernel_impl(
}

auto p_val = static_cast<accscalar_t>(p);
auto n2_val = static_cast<accscalar_t>(n2);
auto n2_squared_minus_1_val = static_cast<accscalar_t>(n2_squared_minus_1);

auto out_data = result.mutable_data_ptr<scalar_t>();
auto in_data = self.const_data_ptr<scalar_t>();
Expand All @@ -815,8 +813,8 @@ static void pdist_kernel_impl(
n,
m,
p_val,
n2_val,
n2_squared_minus_1_val,
n2,
n2_squared_minus_1,
out_data,
in_data,
wgroup_size / min_sg_size);
Expand Down