Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 12 additions & 19 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,31 +130,24 @@ void exp2_kernel(TensorIteratorBase& iter) {
}

template <typename scalar_t>
struct Logit0Functor {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
struct LogitFunctor {
scalar_t operator()(scalar_t x) const {
const T_ACC x_acc = static_cast<T_ACC>(x);
// suppress compiler optimization on data type promotion.
volatile T_ACC res = std::log(x_acc / (T_ACC(1) - x_acc));
return res;
return std::log(x / (1 - x));
}
};

template <typename scalar_t>
struct Logit1Functor {
struct LogitEpsFunctor {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
scalar_t operator()(scalar_t x) const {
const T_ACC x_acc = static_cast<T_ACC>(x);
T_ACC z = x_acc < lo_ ? lo_ : (x_acc > hi_ ? hi_ : x_acc);
// suppress compiler optimization on data type promotion.
volatile T_ACC res = std::log(z / (T_ACC(1) - z));
return res;
scalar_t x_clamped = x < low_ ? low_ : (x > high_ ? high_ : x);
return std::log(x_clamped / (1 - x_clamped));
}
Logit1Functor(const T_ACC lo, const T_ACC hi) : lo_(lo), hi_(hi) {}
LogitEpsFunctor(const T_ACC low, const T_ACC high) : low_(low), high_(high) {}

private:
T_ACC lo_;
T_ACC hi_;
scalar_t low_;
scalar_t high_;
};

void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
Expand All @@ -167,11 +160,11 @@ void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
using T_ACC = acc_type_device<scalar_t, c10::DeviceType::XPU>;
const T_ACC eps = eps_scalar.to<T_ACC>();
if (eps < T_ACC(0)) {
gpu_kernel(iter, Logit0Functor<scalar_t>());
gpu_kernel(iter, LogitFunctor<scalar_t>());
} else {
const T_ACC lo = eps;
const T_ACC hi = T_ACC(1) - eps;
gpu_kernel(iter, Logit1Functor<scalar_t>(lo, hi));
const T_ACC low = eps;
const T_ACC high = T_ACC(1) - eps;
gpu_kernel(iter, LogitEpsFunctor<scalar_t>(low, high));
}
});
}
Expand Down