diff --git a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp index 8e8b09d5c..cd6c27788 100644 --- a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp @@ -130,31 +130,24 @@ void exp2_kernel(TensorIteratorBase& iter) { } template -struct Logit0Functor { - using T_ACC = acc_type_device; +struct LogitFunctor { scalar_t operator()(scalar_t x) const { - const T_ACC x_acc = static_cast(x); - // suppress compiler optimization on data type promotion. - volatile T_ACC res = std::log(x_acc / (T_ACC(1) - x_acc)); - return res; + return std::log(x / (1 - x)); } }; template -struct Logit1Functor { +struct LogitEpsFunctor { using T_ACC = acc_type_device; scalar_t operator()(scalar_t x) const { - const T_ACC x_acc = static_cast(x); - T_ACC z = x_acc < lo_ ? lo_ : (x_acc > hi_ ? hi_ : x_acc); - // suppress compiler optimization on data type promotion. - volatile T_ACC res = std::log(z / (T_ACC(1) - z)); - return res; + scalar_t x_clamped = x < low_ ? low_ : (x > high_ ? high_ : x); + return std::log(x_clamped / (1 - x_clamped)); } - Logit1Functor(const T_ACC lo, const T_ACC hi) : lo_(lo), hi_(hi) {} + LogitEpsFunctor(const T_ACC low, const T_ACC high) : low_(low), high_(high) {} private: - T_ACC lo_; - T_ACC hi_; + scalar_t low_; + scalar_t high_; }; void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) { @@ -167,11 +160,11 @@ void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) { using T_ACC = acc_type_device; const T_ACC eps = eps_scalar.to(); if (eps < T_ACC(0)) { - gpu_kernel(iter, Logit0Functor()); + gpu_kernel(iter, LogitFunctor()); } else { - const T_ACC lo = eps; - const T_ACC hi = T_ACC(1) - eps; - gpu_kernel(iter, Logit1Functor(lo, hi)); + const T_ACC low = eps; + const T_ACC high = T_ACC(1) - eps; + gpu_kernel(iter, LogitEpsFunctor(low, high)); } }); }