From 72ec47aa0932609e4904fa6c6f774ea52f0e6ade Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:52:07 +0800 Subject: [PATCH 1/7] Update LossNLLKernel.cpp --- src/ATen/native/xpu/sycl/LossNLLKernel.cpp | 1131 +++++++++----------- 1 file changed, 528 insertions(+), 603 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp index cd682e725e..f13e7f5147 100644 --- a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp +++ b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp @@ -1,715 +1,640 @@ #include +#include #include #include #include +#include #include #include +#include +#include +#include +#include #include +#include +#include + namespace at::native::xpu { +#define CHECK_INDEX_IN_CLASS(INDEX, N_CLASSES) \ + if constexpr (std::is_unsigned_v) { \ + SYCL_KERNEL_ASSERT(INDEX < N_CLASSES); \ + } else { \ + SYCL_KERNEL_ASSERT(INDEX >= 0 && INDEX < N_CLASSES); \ + } + +int nll_loss_threads(int64_t nframe) { + return std::clamp( + 1 << static_cast(std::round(std::log2(nframe / 16))), 32, 1024); +} + using namespace at::xpu; + template struct NllLossForwardNoReduceKernelFunctor { - void operator()(sycl::item<1> item_id) const { - auto input_ptr = input_data; - auto target_ptr = target_data; - auto weight_ptr = has_weight ? weight_data : NULL; - auto output_ptr = output_data; - auto local_item_id = item_id.get_id(0); - for (int i = local_item_id; i < batch_size; i += local_size) { - int cur_target = target_ptr[i * target_stride]; + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, batch_size) { + index_t cur_target = target[index]; if (cur_target == ignore_index) { - output_ptr[i * output_stride_0] = 0.0f; + output[index] = static_cast(0); continue; } - scalar_t cur_weight = - has_weight ? weight_ptr[cur_target] : static_cast(1.0f); - output_ptr[i * output_stride_0] = - -static_cast( - input_ptr[i * input_stride_0 + cur_target * input_stride_1]) * - cur_weight; + CHECK_INDEX_IN_CLASS(cur_target, n_classes); + auto cur_weight = + weights != nullptr ? weights[cur_target] : static_cast(1); + output[index] = -cur_weight * input[index][cur_target]; } } + NllLossForwardNoReduceKernelFunctor( - scalar_t* input_data_, - index_t* target_data_, - scalar_t* weight_data_, - scalar_t* output_data_, - bool has_weight_, - int64_t batch_size_, - int64_t local_size_, - int64_t target_stride_, - int n_classes_, - int64_t ignore_index_, - int64_t output_stride_0_, - int64_t input_stride_0_, - int64_t input_stride_1_) - : input_data(input_data_), - target_data(target_data_), - weight_data(weight_data_), - output_data(output_data_), - has_weight(has_weight_), - batch_size(batch_size_), - local_size(local_size_), - target_stride(target_stride_), - n_classes(n_classes_), - ignore_index(ignore_index_), - output_stride_0(output_stride_0_), - input_stride_0(input_stride_0_), - input_stride_1(input_stride_1_) {} + int64_t batch_size, + PackedTensorAccessor64 input, + const index_t* target, + scalar_t* output, + const scalar_t* weights, + int64_t n_classes, + int64_t ignore_index) + : batch_size(batch_size), + input(input), + target(target), + output(output), + weights(weights), + n_classes(n_classes), + ignore_index(ignore_index) {} private: - scalar_t* input_data; - index_t* target_data; - scalar_t* weight_data; - scalar_t* output_data; - bool has_weight; int64_t batch_size; - int64_t local_size; - int64_t target_stride; - int n_classes; + PackedTensorAccessor64 input; + const index_t* target; + scalar_t* output; + const scalar_t* weights; + int64_t n_classes; int64_t ignore_index; - int64_t output_stride_0; - int64_t input_stride_0; - int64_t input_stride_1; }; template struct NllLossForwardReduce1DKernelFunctor { - void operator()(sycl::item<1> item_id) const { - auto input_ptr = input_data; - auto target_ptr = target_data; - auto weight_ptr = has_weight ? weight_data : NULL; - auto total_weight_ptr = total_weight_data; - auto output_ptr = output_data; - int cur_target = target_ptr[0]; - total_weight_ptr[0] = - has_weight ? weight_ptr[cur_target] : static_cast(1.0f); - if (cur_target != ignore_index) { - output_ptr[0] = -static_cast(input_ptr[cur_target]) * - static_cast(total_weight_ptr[0]); + void operator()(sycl::nd_item<1> item) const { + SYCL_KERNEL_ASSERT( + item.get_local_id(2) == 0 && item.get_local_id(1) == 0 && + item.get_local_id(0) == 0); + + const index_t t = *target; + if (t != ignore_index) { + CHECK_INDEX_IN_CLASS(t, n_classes); + const auto cur_weight = weights != nullptr ? weights[t] : scalar_t{1}; + *total_weight = cur_weight; + + if (size_average) { + // If we try to normalize a zero then we return a NaN + if (cur_weight == 0) { + *output = std::numeric_limits::quiet_NaN(); + } else { + *output = -input[t]; + } + } else { + *output = -cur_weight * input[t]; + } } else { - output_ptr[0] = static_cast(0.f); - total_weight_ptr[0] = static_cast(0.f); - } - if (reduction == at::Reduction::Mean) { - output_ptr[0] /= total_weight_ptr[0]; + *output = scalar_t{0}; + *total_weight = scalar_t{0}; } } + NllLossForwardReduce1DKernelFunctor( - scalar_t* input_data_, - index_t* target_data_, - scalar_t* weight_data_, - scalar_t* output_data_, - scalar_t* total_weight_data_, - bool has_weight_, - int64_t ignore_index_, - int64_t reduction_) - : input_data(input_data_), - target_data(target_data_), - weight_data(weight_data_), - output_data(output_data_), - total_weight_data(total_weight_data_), - has_weight(has_weight_), - ignore_index(ignore_index_), - reduction(reduction_) {} + scalar_t* output, + scalar_t* total_weight, + const scalar_t* input, + const index_t* target, + const scalar_t* weights, + bool size_average, + int64_t n_classes, + int64_t ignore_index) + : output(output), + total_weight(total_weight), + input(input), + target(target), + weights(weights), + size_average(size_average), + n_classes(n_classes), + ignore_index(ignore_index) {} private: - scalar_t* input_data; - index_t* target_data; - scalar_t* weight_data; - scalar_t* output_data; - scalar_t* total_weight_data; - bool has_weight; + scalar_t* output; + scalar_t* total_weight; + const scalar_t* input; + const index_t* target; + const scalar_t* weights; + bool size_average; + int64_t n_classes; int64_t ignore_index; - int64_t reduction; }; template struct NllLossForwardReduce2DKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { - void operator()(sycl::nd_item<1> item_id) const { - auto input_ptr = input_data; - auto target_ptr = target_data; - auto weight_ptr = has_weight ? weight_data : NULL; - auto total_weight_ptr = total_weight_data; - auto output_ptr = output_data; - int64_t local_id = item_id.get_local_id(0); - local_output_acc[local_id] = accscalar_t(0); - local_total_weight_acc[local_id] = accscalar_t(0); - for (int i = local_id; i < batch_size; i += local_size) { - int cur_target = target_ptr[i]; - if (cur_target != ignore_index) { + void operator()(sycl::nd_item<1> item) const { + auto local_id = item.get_local_id(0); + auto local_range = item.get_local_range(0); + + sh_inputs[local_id] = static_cast(0); + acc_weight[local_id] = static_cast(0); + + for (int i = local_id; i < nframe; i += local_range) { + index_t t = target[i]; + if (t != ignore_index) { + CHECK_INDEX_IN_CLASS(t, n_classes); scalar_t cur_weight = - has_weight ? weight_ptr[cur_target] : static_cast(1.0f); - local_total_weight_acc[local_id] += - static_cast(cur_weight); - local_output_acc[local_id] -= - static_cast(input_ptr[i * n_target + cur_target]) * - static_cast(cur_weight); + weights != nullptr ? weights[t] : static_cast(1); + sh_inputs[local_id] -= input[i * ndim + t] * cur_weight; + acc_weight[local_id] += cur_weight; } } - // reduce - for (int64_t i = (local_size >> 1); i > 0; i >>= 1) { - item_id.barrier(sycl_global_and_local_fence); - if (local_id < i) { - local_total_weight_acc[local_id] += - local_total_weight_acc[local_id + i]; - local_output_acc[local_id] += local_output_acc[local_id + i]; + item.barrier(sycl_global_and_local_fence); + + for (int stride = local_range / 2; stride > 0; stride >>= 1) { + if (local_id < stride) { + sh_inputs[local_id] += sh_inputs[local_id + stride]; + acc_weight[local_id] += acc_weight[local_id + stride]; } + item.barrier(sycl_global_and_local_fence); } - item_id.barrier(sycl_global_and_local_fence); - if (reduction == at::Reduction::Mean) { - output_ptr[0] = static_cast( - local_output_acc[0] / local_total_weight_acc[0]); - } else { - output_ptr[0] = static_cast(local_output_acc[0]); + if (local_id == 0) { + *total_weight = static_cast(acc_weight[0]); + if (size_average) { + *output = static_cast(sh_inputs[0] / acc_weight[0]); + } else { + *output = static_cast(sh_inputs[0]); + } } - total_weight_ptr[0] = static_cast(local_total_weight_acc[0]); } + NllLossForwardReduce2DKernelFunctor( - scalar_t* input_data_, - index_t* target_data_, - scalar_t* weight_data_, - scalar_t* output_data_, - scalar_t* total_weight_data_, - bool has_weight_, - int64_t batch_size_, - int64_t local_size_, - int64_t ignore_index_, - int n_target_, - int64_t reduction_) - : input_data(input_data_), - target_data(target_data_), - weight_data(weight_data_), - output_data(output_data_), - total_weight_data(total_weight_data_), - has_weight(has_weight_), - batch_size(batch_size_), - local_size(local_size_), - ignore_index(ignore_index_), - n_target(n_target_), - reduction(reduction_) {} + scalar_t* output, + scalar_t* total_weight, + const scalar_t* input, + const index_t* target, + const scalar_t* weights, + bool size_average, + int64_t nframe, + int64_t ndim, + int64_t n_classes, + int64_t ignore_index, + int64_t smem_size) + : output(output), + total_weight(total_weight), + input(input), + target(target), + weights(weights), + size_average(size_average), + nframe(nframe), + ndim(ndim), + n_classes(n_classes), + ignore_index(ignore_index), + smem_size(smem_size) {} void sycl_ker_config_convention(sycl::handler& cgh) { - local_output_acc = sycl_local_acc_t(local_size, cgh); - local_total_weight_acc = sycl_local_acc_t(local_size, cgh); + sh_inputs = sycl_local_acc_t(smem_size, cgh); + acc_weight = sycl_local_acc_t(smem_size, cgh); } private: - scalar_t* input_data; - index_t* target_data; - scalar_t* weight_data; - scalar_t* output_data; - scalar_t* total_weight_data; - bool has_weight; - int64_t batch_size; - int64_t local_size; + scalar_t* output; + scalar_t* total_weight; + const scalar_t* input; + const index_t* target; + const scalar_t* weights; + bool size_average; + int64_t nframe; + int64_t ndim; + int64_t n_classes; int64_t ignore_index; - int n_target; - sycl_local_acc_t local_output_acc; - sycl_local_acc_t local_total_weight_acc; - int64_t reduction; + int64_t smem_size; + sycl_local_acc_t sh_inputs; + sycl_local_acc_t acc_weight; }; -template -void nll_loss_forward_template( - const Tensor& input, - const Tensor& target, - const Tensor& output, - const Tensor& weight, - const Tensor& total_weight, - int64_t reduction, - int64_t ignore_index) { - int n_dims = input.dim(); - int n_classes = input.size(-1); - ignore_index -= 0; - - int64_t batch_size = input.size(0); - - if (reduction == at::Reduction::None && n_dims == 2) { - using NllLossForwardNoReduceKernel = - NllLossForwardNoReduceKernelFunctor; - - output.resize_({batch_size}); - total_weight.zero_(); - int64_t target_stride = target.stride(0); - - auto weight_cont = weight.defined() ? weight.contiguous() : weight; - - auto& queue = getCurrentSYCLQueue(); - int64_t local_size = syclMaxWorkGroupSize(); - bool has_weight = weight.defined() - ? true - : false; // sycl kernel can not accept host pointer - - auto output_stride_0 = output.stride(0); - auto input_stride_0 = input.stride(0); - auto input_stride_1 = input.stride(1); - - auto input_data = input.data_ptr(); - auto target_data = target.data_ptr(); - auto weight_data = has_weight - ? weight_cont.data_ptr() - : input_data; // use the input as the dummy data. - auto output_data = output.data_ptr(); - NllLossForwardNoReduceKernel kfn( - input_data, - target_data, - weight_data, - output_data, - has_weight, - batch_size, - local_size, - target_stride, - n_classes, - ignore_index, - output_stride_0, - input_stride_0, - input_stride_1); - - sycl_kernel_submit(sycl::range<1>(local_size), queue, kfn); - return; - } - - output.resize_({}); - total_weight.resize_({}); - - auto input_cont = input.contiguous(); - auto weight_cont = weight.defined() ? weight.contiguous() : weight; - auto target_cont = target.contiguous(); - - scalar_t* _input_data = input_cont.data_ptr(); - scalar_t* _weight_data = - weight.defined() ? weight_cont.data_ptr() : NULL; - index_t* _target_data = target_cont.data_ptr(); - scalar_t* _output_data = output.data_ptr(); - scalar_t* _total_weight_data = total_weight.data_ptr(); - bool has_weight = _weight_data != NULL ? true : false; - auto& queue = getCurrentSYCLQueue(); - - if (input_cont.dim() == 1 || input_cont.dim() == 0) { - int64_t local_size = 1; - auto input_data = _input_data; - auto weight_data = has_weight - ? _weight_data - : input_data; // use the input as the dummy data. - auto target_data = _target_data; - auto total_weight_data = _total_weight_data; - auto output_data = _output_data; - NllLossForwardReduce1DKernelFunctor kfn( - input_data, - target_data, - weight_data, - output_data, - total_weight_data, - has_weight, - ignore_index, - reduction); - - sycl_kernel_submit(sycl::range<1>(local_size), queue, kfn); - } else if (input_cont.dim() == 2) { - using accscalar_t = at::acc_type; - using NllLossForwardReduce2DKernel = - NllLossForwardReduce2DKernelFunctor; - - int64_t batch_size = input.size(0); - int n_target = input.size(1); - int64_t local_size = syclMaxWorkGroupSize(); - auto input_data = _input_data; - auto weight_data = has_weight - ? _weight_data - : input_data; // use the input as the dummy data. - auto target_data = _target_data; - auto total_weight_data = _total_weight_data; - auto output_data = _output_data; - NllLossForwardReduce2DKernelFunctor kfn( - input_data, - target_data, - weight_data, - output_data, - total_weight_data, - has_weight, - batch_size, - local_size, - ignore_index, - n_target, - reduction); - - sycl_kernel_submit( - sycl::range<1>(local_size), sycl::range<1>(local_size), queue, kfn); - } -} - template struct NllLossBackwardNoReduceKernelFunctor { - void operator()(sycl::nd_item<1> item_id) const { - auto target_ptr = target_data; - auto gradOutput_ptr = gradOutput_data; - auto weights_ptr = has_weights ? weights_data : NULL; - auto gradInput_ptr = gradInput_data; - - auto local_id = item_id.get_local_id(0); - auto group_id = item_id.get_group(0); - - for (int i = group_id * local_size + local_id; i < batch_size; - i += item_id.get_global_range(0)) { - int cur_target = target_ptr[i * target_stride]; + void operator()(sycl::nd_item<1> item) const { + XPU_KERNEL_LOOP(item, index, batch_size) { + index_t cur_target = target[index]; if (cur_target == ignore_index) { continue; } - scalar_t cur_weight = - has_weights ? weights_ptr[cur_target] : static_cast(1.0f); - gradInput_ptr[i * gradInput_stride_0 + cur_target * gradInput_stride_1] = - -cur_weight * - static_cast(gradOutput_ptr[i * gradOutput_stride_0]); + CHECK_INDEX_IN_CLASS(cur_target, n_classes); + scalar_t weight = + weights != nullptr ? weights[cur_target] : static_cast(1); + auto grad_input_ = grad_input; + grad_input_[index][cur_target] = -weight * grad_output[index]; } } + NllLossBackwardNoReduceKernelFunctor( - index_t* target_data_, - scalar_t* gradOutput_data_, - scalar_t* weights_data_, - scalar_t* gradInput_data_, - bool has_weights_, - int64_t local_size_, - int64_t batch_size_, - int64_t target_stride_, - int64_t ignore_index_, - int64_t gradInput_stride_0_, - int64_t gradInput_stride_1_, - int64_t gradOutput_stride_0_) - : target_data(target_data_), - gradOutput_data(gradOutput_data_), - weights_data(weights_data_), - gradInput_data(gradInput_data_), - has_weights(has_weights_), - local_size(local_size_), - batch_size(batch_size_), - target_stride(target_stride_), - ignore_index(ignore_index_), - gradInput_stride_0(gradInput_stride_0_), - gradInput_stride_1(gradInput_stride_1_), - gradOutput_stride_0(gradOutput_stride_0_) {} + int batch_size, + const index_t* target, + PackedTensorAccessor64 grad_output, + PackedTensorAccessor64 grad_input, + const scalar_t* weights, + int64_t n_classes, + int64_t ignore_index) + : batch_size(batch_size), + target(target), + grad_output(grad_output), + grad_input(grad_input), + weights(weights), + n_classes(n_classes), + ignore_index(ignore_index) {} private: - index_t* target_data; - scalar_t* gradOutput_data; - scalar_t* weights_data; - scalar_t* gradInput_data; - bool has_weights; - int64_t local_size; - int64_t batch_size; - int64_t target_stride; + int batch_size; + const index_t* target; + PackedTensorAccessor64 grad_output; + PackedTensorAccessor64 grad_input; + const scalar_t* weights; + int64_t n_classes; int64_t ignore_index; - int64_t gradInput_stride_0; - int64_t gradInput_stride_1; - int64_t gradOutput_stride_0; }; template struct NllLossBackwardReduce1DKernelFunctor { - void operator()(sycl::item<1> item_id) const { - auto gradOutput_ptr = gradOutput_data; - auto weights_ptr = has_weights ? weights_data : NULL; - auto gradInput_ptr = gradInput_data; - auto target_ptr = target_data; - auto total_weight_ptr = total_weight_data; - - int t = (int)*target_ptr; - if (t != (int)ignore_index) { - scalar_t grad = - -((reduction == at::Reduction::Mean) - ? static_cast(gradOutput_ptr[0]) / - static_cast(*total_weight_ptr) - : static_cast(gradOutput_ptr[0])); - gradInput_ptr[t] = has_weights ? weights_ptr[t] * grad : grad; + void operator()(sycl::nd_item<1> item) const { + const index_t t = *target; + if (t != ignore_index) { + CHECK_INDEX_IN_CLASS(t, n_classes); + const auto grad = + -(size_average ? *grad_output / *total_weight : *grad_output); + grad_input[t] = weights != nullptr ? weights[t] * grad : grad; } } + NllLossBackwardReduce1DKernelFunctor( - index_t* target_data_, - scalar_t* gradOutput_data_, - scalar_t* weights_data_, - scalar_t* gradInput_data_, - scalar_t* total_weight_data_, - bool has_weights_, - int64_t ignore_index_, - int64_t reduction_) - : target_data(target_data_), - gradOutput_data(gradOutput_data_), - weights_data(weights_data_), - gradInput_data(gradInput_data_), - total_weight_data(total_weight_data_), - has_weights(has_weights_), - ignore_index(ignore_index_), - reduction(reduction_) {} + scalar_t* grad_input, + const scalar_t* grad_output, + const scalar_t* weights, + const index_t* target, + const scalar_t* total_weight, + bool size_average, + int64_t n_classes, + int64_t ignore_index) + : grad_input(grad_input), + grad_output(grad_output), + weights(weights), + target(target), + total_weight(total_weight), + size_average(size_average), + n_classes(n_classes), + ignore_index(ignore_index) {} private: - index_t* target_data; - scalar_t* gradOutput_data; - scalar_t* weights_data; - scalar_t* gradInput_data; - scalar_t* total_weight_data; - bool has_weights; + scalar_t* grad_input; + const scalar_t* grad_output; + const scalar_t* weights; + const index_t* target; + const scalar_t* total_weight; + bool size_average; + int64_t n_classes; int64_t ignore_index; - int64_t reduction; +}; + +template +struct bwd_index_type { + using type = T; +}; +template <> +struct bwd_index_type { + using type = int; +}; +template <> +struct bwd_index_type { + using type = uint64_t; }; template struct NllLossBackwardReduce2DKernelFunctor { - void operator()(sycl::item<1> item_id) const { - auto gradOutput_ptr = gradOutput_data; - auto weights_ptr = has_weights ? weights_data : NULL; - auto gradInput_ptr = gradInput_data; - auto target_ptr = target_data; - auto total_weight_ptr = total_weight_data; - - auto local_item_id = item_id.get_id(0); - - int i, t; - - scalar_t grad = - -((reduction == at::Reduction::Mean) - ? static_cast(gradOutput_ptr[0]) / - static_cast(*total_weight_ptr) - : static_cast(gradOutput_ptr[0])); - for (i = local_item_id; i < nframe; i += local_size) { - t = (int)target_ptr[i]; - if (t != (int)ignore_index) { - gradInput_ptr[i * ndim + t] = - has_weights ? weights_ptr[t] * grad : grad; + void operator()(sycl::nd_item<1> item) const { + auto local_id = item.get_local_id(0); + auto local_range = item.get_local_range(0); + using bwd_index_t = typename bwd_index_type::type; + const auto grad = + -(size_average ? *grad_output / *total_weight : *grad_output); + + for (int i = local_id; i < nframe; i += local_range) { + const index_t t = target[i]; + if (t != ignore_index) { + CHECK_INDEX_IN_CLASS(t, n_classes); + const bwd_index_t index = static_cast(i) * ndim + t; + if constexpr (!std::is_unsigned_v) { + SYCL_KERNEL_ASSERT(index >= 0); + } + grad_input[index] = weights != nullptr ? weights[t] * grad : grad; } } } + NllLossBackwardReduce2DKernelFunctor( - index_t* target_data_, - scalar_t* gradOutput_data_, - scalar_t* weights_data_, - scalar_t* gradInput_data_, - scalar_t* total_weight_data_, - bool has_weights_, - int64_t ignore_index_, - int64_t reduction_, - int ndim_, - int64_t local_size_, - int nframe_) - : target_data(target_data_), - gradOutput_data(gradOutput_data_), - weights_data(weights_data_), - gradInput_data(gradInput_data_), - total_weight_data(total_weight_data_), - has_weights(has_weights_), - ignore_index(ignore_index_), - reduction(reduction_), - ndim(ndim_), - local_size(local_size_), - nframe(nframe_) {} + scalar_t* grad_input, + const scalar_t* grad_output, + const index_t* target, + const scalar_t* weights, + const scalar_t* total_weight, + bool size_average, + int nframe, + int ndim, + int64_t n_classes, + int64_t ignore_index) + : grad_input(grad_input), + grad_output(grad_output), + target(target), + weights(weights), + total_weight(total_weight), + size_average(size_average), + nframe(nframe), + ndim(ndim), + n_classes(n_classes), + ignore_index(ignore_index) {} private: - index_t* target_data; - scalar_t* gradOutput_data; - scalar_t* weights_data; - scalar_t* gradInput_data; - scalar_t* total_weight_data; - bool has_weights; - int64_t ignore_index; - int64_t reduction; - int ndim; - int64_t local_size; + scalar_t* grad_input; + const scalar_t* grad_output; + const index_t* target; + const scalar_t* weights; + const scalar_t* total_weight; + bool size_average; int nframe; + int ndim; + int64_t n_classes; + int64_t ignore_index; }; -template -static inline void nll_loss_backward_template( - const Tensor& input, - const Tensor& target, - const Tensor& gradOutput, - const Tensor& gradInput, - int64_t reduction, - const Tensor& weight, +#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Byte, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, index_t, __VA_ARGS__)) + +void nll_loss_forward_kernel( + const Tensor& output, const Tensor& total_weight, + const Tensor& input_, + const Tensor& target_, + const Tensor& weight, + int64_t reduction, int64_t ignore_index) { - int n_dims = input.dim(); + auto input = *input_.expect_contiguous(); + auto target = *target_.expect_contiguous(); - gradInput.resize_as_(input); - gradInput.zero_(); + int64_t n_classes = input.size(-1); + int64_t n_dims = input.dim(); + int64_t batch_size = n_dims == 1 ? 1 : input.size(0); - int64_t batch_size = input.size(0); + auto weight_ = weight.defined() ? weight.contiguous() : weight; if (reduction == at::Reduction::None && n_dims == 2) { - using NllLossBackwardNoReduceKernel = - NllLossBackwardNoReduceKernelFunctor; - - int64_t target_stride = target.stride(0); - check_dim_size(gradOutput, 1, 0, batch_size); - auto weight_cont = weight.defined() ? weight.contiguous() : weight; - - auto& queue = getCurrentSYCLQueue(); - int64_t local_size = syclMaxWorkGroupSize(); - int64_t global_size = - ((batch_size + local_size - 1) / local_size) * local_size; - bool has_weight = weight.defined() ? true : false; - - auto gradInput_stride_0 = gradInput.stride(0); - auto gradInput_stride_1 = gradInput.stride(1); - auto gradOutput_stride_0 = gradOutput.stride(0); - - auto target_data = target.data_ptr(); - auto gradOutput_data = gradOutput.data_ptr(); - auto weight_data = has_weight - ? weight_cont.data_ptr() - : gradOutput_data; // Use gradOutput handler as dummy weight - auto gradInput_data = gradInput.data_ptr(); - NllLossBackwardNoReduceKernel kfn( - target_data, - gradOutput_data, - weight_data, - gradInput_data, - has_weight, - local_size, - batch_size, - target_stride, - ignore_index, - gradInput_stride_0, - gradInput_stride_1, - gradOutput_stride_0); - - sycl_kernel_submit( - sycl::range<1>(global_size), sycl::range<1>(local_size), queue, kfn); + at::native::resize_output(output, {batch_size}); + total_weight.zero_(); + if (batch_size == 0) { + // This guards from unnecessary operations and launching SYCL kernel with + // 0 blocks. + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_forward_no_reduce_xpu_kernel", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_forward_no_reduce_xpu_kernel_index", + [&] { + auto kfn = + NllLossForwardNoReduceKernelFunctor( + batch_size, + input.packed_accessor64(), + target.const_data_ptr(), + output.mutable_data_ptr(), + weight_.defined() ? weight_.const_data_ptr() + : nullptr, + n_classes, + ignore_index); + sycl_kernel_submit( + GET_GROUPS(batch_size) * SYCL_NUM_THREADS, + SYCL_NUM_THREADS, + getCurrentSYCLQueue(), + kfn); + }); + }); return; } - auto weight_cont = weight.defined() ? weight.contiguous() : weight; - auto target_cont = target.contiguous(); - bool has_weight = weight.defined() ? true : false; + // produce scalar outputs for the reduction case + at::native::resize_output(output, {}); + total_weight.resize_({}); - TORCH_CHECK( - gradOutput.dim() <= 1 && gradOutput.numel() == 1, - "Expected a single element grad_output tensor, but got: ", - gradOutput.sizes()); + if (target.numel() == 0) { + if (reduction == at::Reduction::Mean) { + output.fill_(std::numeric_limits::quiet_NaN()); + } else { + output.zero_(); + } + total_weight.zero_(); + return; + } - auto& queue = getCurrentSYCLQueue(); if (n_dims == 1) { - auto gradOutput_data = gradOutput.data_ptr(); - auto weight_data = has_weight - ? weight_cont.data_ptr() - : gradOutput_data; // Use gradOutput handler as dummy weight - auto gradInput_data = gradInput.data_ptr(); - auto target_data = target_cont.data_ptr(); - auto total_weight_data = total_weight.data_ptr(); - NllLossBackwardReduce1DKernelFunctor kfn( - target_data, - gradOutput_data, - weight_data, - gradInput_data, - total_weight_data, - has_weight, - ignore_index, - reduction); - - sycl_kernel_submit(sycl::range<1>(1), queue, kfn); - } else { - int nframe = input.size(0); - int ndim = input.size(1); - int64_t local_size = 32; - - auto gradOutput_data = gradOutput.data_ptr(); - auto weight_data = has_weight - ? weight_cont.data_ptr() - : gradOutput_data; // use the gradOutput handler as dummy weight - auto gradInput_data = gradInput.data_ptr(); - auto target_data = target_cont.data_ptr(); - auto total_weight_data = total_weight.data_ptr(); - NllLossBackwardReduce2DKernelFunctor kfn( - target_data, - gradOutput_data, - weight_data, - gradInput_data, - total_weight_data, - has_weight, - ignore_index, - reduction, - ndim, - local_size, - nframe); - - sycl_kernel_submit(sycl::range<1>(local_size), queue, kfn); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_forward_reduce_xpu_kernel_1d", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_forward_reduce_xpu_kernel_1d_index", + [&] { + auto kfn = + NllLossForwardReduce1DKernelFunctor( + output.mutable_data_ptr(), + total_weight.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weight_.defined() ? weight_.const_data_ptr() + : nullptr, + reduction == at::Reduction::Mean, + n_classes, + ignore_index); + sycl_kernel_submit(1, 1, getCurrentSYCLQueue(), kfn); + }); + }); + + } else if (n_dims == 2) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_forward_reduce_xpu_kernel_2d", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_forward_reduce_xpu_kernel_2d_index", + [&] { + using accscalar_t = at::acc_type; + int nthreads = nll_loss_threads(input.size(0)); + using NllLossForwardReduce2DKernel = + NllLossForwardReduce2DKernelFunctor< + scalar_t, + index_t, + accscalar_t>; + int64_t local_size = + syclMaxWorkGroupSize(); + NllLossForwardReduce2DKernel kfn( + output.mutable_data_ptr(), + total_weight.mutable_data_ptr(), + input.const_data_ptr(), + target.const_data_ptr(), + weight_.defined() ? weight_.const_data_ptr() + : nullptr, + reduction == at::Reduction::Mean, + input.size(0), + input.size(1), + n_classes, + ignore_index, + local_size); + sycl_kernel_submit( + sycl::range<1>(nthreads), + sycl::range<1>(nthreads), + getCurrentSYCLQueue(), + kfn); + }); + }); } } -#define AT_DISPATCH_NLL_LOSS_INDEX_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, \ - NAME, \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Byte, index_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE_USING_HINT( \ - at::ScalarType::Long, index_t, __VA_ARGS__)) - -void nll_loss_forward_kernel( - const Tensor& self, - const Tensor& target, - const OptionalTensorRef weight_opt, - int64_t reduction, - int64_t ignore_index, - const Tensor& output, - const Tensor& total_weight) { - const Tensor& weight = weight_opt.getTensorRef(); - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - self.scalar_type(), - "nll_loss_forward_out_kernel", - [&]() { - AT_DISPATCH_NLL_LOSS_INDEX_TYPES( - target.scalar_type(), "nll_loss_forward_out_kernel_index", [&]() { - nll_loss_forward_template( - self, - target, - output, - weight, - total_weight, - reduction, - ignore_index); - }); - }); - - // return std::tuple(output, total_weight); -} - void nll_loss_backward_kernel( - const Tensor& grad_output, - const Tensor& self, - const Tensor& target, - const OptionalTensorRef weight_opt, - int64_t reduction, - int64_t ignore_index, + const Tensor& grad_input_, + const Tensor& grad_output_, + const Tensor& input_, + const Tensor& target_, const Tensor& total_weight, - const Tensor& grad_input) { - const Tensor& weight = weight_opt.getTensorRef(); - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - self.scalar_type(), - "nll_loss_backward_out_kernel", - [&]() { - AT_DISPATCH_NLL_LOSS_INDEX_TYPES( - target.scalar_type(), "nll_loss_backward_out_kernel_index", [&]() { - nll_loss_backward_template( - self, - target, - grad_output, - grad_input, - reduction, - weight, - total_weight, - ignore_index); - }); - }); - // return grad_input; + const Tensor& weight, + int64_t reduction, + int64_t ignore_index) { + auto target = *target_.expect_contiguous(); + auto input = *input_.expect_contiguous(); + auto grad_input = *grad_input_.expect_contiguous(); + auto grad_output = *grad_output_.expect_contiguous(); + + int64_t n_dims = input.dim(); + int64_t n_classes = input.size(-1); + int64_t batch_size = n_dims == 1 ? 1 : input.size(0); + + auto weight_ = weight.defined() ? weight.contiguous() : weight; + + if (reduction == at::Reduction::None && n_dims == 2) { + if (batch_size == 0) { + return; + } + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_backward_no_reduce_xpu_kernel", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_backward_no_reduce_xpu_kernel_index", + [&] { + auto kfn = + NllLossBackwardNoReduceKernelFunctor( + batch_size, + target.const_data_ptr(), + grad_output.packed_accessor64(), + grad_input.packed_accessor64(), + weight.defined() ? weight_.const_data_ptr() + : nullptr, + n_classes, + ignore_index); + sycl_kernel_submit( + GET_GROUPS(batch_size) * SYCL_NUM_THREADS, + SYCL_NUM_THREADS, + getCurrentSYCLQueue(), + kfn); + }); + }); + return; + } + + if (n_dims == 1) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_backward_reduce_xpu_kernel_1d", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_backward_reduce_xpu_kernel_1d_index", + [&] { + auto kfn = + NllLossBackwardReduce1DKernelFunctor( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + weight.defined() ? weight_.const_data_ptr() + : nullptr, + target.const_data_ptr(), + total_weight.const_data_ptr(), + reduction == at::Reduction::Mean, + n_classes, + ignore_index); + sycl_kernel_submit( + sycl::range<1>(1), + sycl::range<1>(1), + getCurrentSYCLQueue(), + kfn); + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + input.scalar_type(), + "nll_loss_backward_reduce_xpu_kernel_2d", + [&] { + AT_DISPATCH_NLL_LOSS_INDEX_TYPES( + target.scalar_type(), + "nll_loss_backward_reduce_xpu_kernel_2d_index", + [&] { + auto kfn = + NllLossBackwardReduce2DKernelFunctor( + grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + target.const_data_ptr(), + weight.defined() ? weight_.const_data_ptr() + : nullptr, + total_weight.const_data_ptr(), + reduction == at::Reduction::Mean, + input.size(0), + input.size(1), + n_classes, + ignore_index); + sycl_kernel_submit( + nll_loss_threads(input.size(0)), + nll_loss_threads(input.size(0)), + getCurrentSYCLQueue(), + kfn); + }); + }); + } } #undef AT_DISPATCH_NLL_LOSS_INDEX_TYPES From 235618f84151f7bb6c4035bef77e04f9bcdcaf48 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:53:00 +0800 Subject: [PATCH 2/7] Update LossNLLKernel.h --- src/ATen/native/xpu/sycl/LossNLLKernel.h | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LossNLLKernel.h b/src/ATen/native/xpu/sycl/LossNLLKernel.h index f680aeb291..cb344ebab2 100644 --- a/src/ATen/native/xpu/sycl/LossNLLKernel.h +++ b/src/ATen/native/xpu/sycl/LossNLLKernel.h @@ -4,22 +4,22 @@ namespace at::native::xpu { TORCH_XPU_API void nll_loss_forward_kernel( - const Tensor& self, - const Tensor& target, - const OptionalTensorRef weight_opt, - int64_t reduction, - int64_t ignore_index, const Tensor& output, - const Tensor& total_weight); + const Tensor& total_weight, + const Tensor& input_, + const Tensor& target_, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index); TORCH_XPU_API void nll_loss_backward_kernel( - const Tensor& grad_output, - const Tensor& self, - const Tensor& target, - const OptionalTensorRef weight_opt, - int64_t reduction, - int64_t ignore_index, + const Tensor& grad_input_, + const Tensor& grad_output_, + const Tensor& input_, + const Tensor& target_, const Tensor& total_weight, - const Tensor& grad_input); + const Tensor& weight, + int64_t reduction, + int64_t ignore_index); } // namespace at::native::xpu From cdb78be861e4baa3243e6185a28799ad30bab12a Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:53:50 +0800 Subject: [PATCH 3/7] Update LossNLL.cpp --- src/ATen/native/xpu/LossNLL.cpp | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/src/ATen/native/xpu/LossNLL.cpp b/src/ATen/native/xpu/LossNLL.cpp index 28cceca996..93e0639f7c 100644 --- a/src/ATen/native/xpu/LossNLL.cpp +++ b/src/ATen/native/xpu/LossNLL.cpp @@ -18,16 +18,9 @@ TORCH_IMPL_FUNC(nll_loss_forward_out_xpu) int64_t ignore_index, const Tensor& output, const Tensor& total_weight) { + const Tensor& weight = weight_opt.getTensorRef(); xpu::nll_loss_forward_kernel( - self, - target, - ((weight_opt.has_value() && (*weight_opt).defined()) - ? at::OptionalTensorRef(*weight_opt) - : at::OptionalTensorRef()), - reduction, - ignore_index, - output, - total_weight); + output, total_weight, self, target, weight, reduction, ignore_index); } TORCH_IMPL_FUNC(nll_loss_backward_out_xpu) @@ -39,19 +32,18 @@ TORCH_IMPL_FUNC(nll_loss_backward_out_xpu) int64_t ignore_index, const Tensor& total_weight, const Tensor& grad_input) { + const Tensor& weight = weight_opt.getTensorRef(); grad_input.zero_(); xpu::nll_loss_backward_kernel( + grad_input, grad_output, self, target, - ((weight_opt.has_value() && (*weight_opt).defined()) - ? at::OptionalTensorRef(*weight_opt) - : at::OptionalTensorRef()), - reduction, - ignore_index, total_weight, - grad_input); + weight, + reduction, + ignore_index); } } // namespace native -} // namespace at \ No newline at end of file +} // namespace at From f87f1e56647b252e0e5afeeb950ccf37e5ac32b6 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 26 Sep 2025 08:54:36 +0800 Subject: [PATCH 4/7] Update KernelUtils.h --- src/ATen/native/xpu/sycl/KernelUtils.h | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/ATen/native/xpu/sycl/KernelUtils.h b/src/ATen/native/xpu/sycl/KernelUtils.h index 98fbe9a074..bc467c36dc 100644 --- a/src/ATen/native/xpu/sycl/KernelUtils.h +++ b/src/ATen/native/xpu/sycl/KernelUtils.h @@ -11,3 +11,22 @@ i = _i_n_d_e_x) #define XPU_KERNEL_LOOP(item, i, n) XPU_KERNEL_LOOP_TYPE(item, i, n, int) + +// Use 1024 threads per block, which requires cuda sm_2x or above +constexpr int SYCL_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int GET_GROUPS( + const int64_t N, + const int64_t max_threads_per_group = SYCL_NUM_THREADS) { + TORCH_INTERNAL_ASSERT( + N > 0, "XPU kernel launch blocks must be positive, but got N=", N); + constexpr int64_t max_int = std::numeric_limits::max(); + + // Round up division for positive number that cannot cause integer overflow + auto group_num = (N - 1) / max_threads_per_group + 1; + TORCH_INTERNAL_ASSERT( + group_num <= max_int, "Can't schedule too many blocks on XPU device"); + + return static_cast(group_num); +} From 5b4ce0ecd5c00a1de433cadcb87e784eebfb99f1 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Tue, 7 Oct 2025 18:42:57 +0800 Subject: [PATCH 5/7] Update LossNLLKernel.cpp --- src/ATen/native/xpu/sycl/LossNLLKernel.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp index f13e7f5147..d9da83e846 100644 --- a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp +++ b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp @@ -77,9 +77,7 @@ struct NllLossForwardNoReduceKernelFunctor { template struct NllLossForwardReduce1DKernelFunctor { void operator()(sycl::nd_item<1> item) const { - SYCL_KERNEL_ASSERT( - item.get_local_id(2) == 0 && item.get_local_id(1) == 0 && - item.get_local_id(0) == 0); + SYCL_KERNEL_ASSERT(item.get_local_id(0) == 0 && item.get_group(0) == 0); const index_t t = *target; if (t != ignore_index) { @@ -263,6 +261,7 @@ struct NllLossBackwardNoReduceKernelFunctor { template struct NllLossBackwardReduce1DKernelFunctor { void operator()(sycl::nd_item<1> item) const { + SYCL_KERNEL_ASSERT(item.get_local_id(0) == 0 && item.get_group(0) == 0); const index_t t = *target; if (t != ignore_index) { CHECK_INDEX_IN_CLASS(t, n_classes); From 9006cfe0a4bc0612ce245bcf59a2b38d0f91f4cc Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:05:01 +0800 Subject: [PATCH 6/7] Update KernelUtils.h --- src/ATen/native/xpu/sycl/KernelUtils.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/KernelUtils.h b/src/ATen/native/xpu/sycl/KernelUtils.h index bc467c36dc..fb1b6c4c22 100644 --- a/src/ATen/native/xpu/sycl/KernelUtils.h +++ b/src/ATen/native/xpu/sycl/KernelUtils.h @@ -12,10 +12,8 @@ #define XPU_KERNEL_LOOP(item, i, n) XPU_KERNEL_LOOP_TYPE(item, i, n, int) -// Use 1024 threads per block, which requires cuda sm_2x or above constexpr int SYCL_NUM_THREADS = 1024; -// CUDA: number of blocks for threads. inline int GET_GROUPS( const int64_t N, const int64_t max_threads_per_group = SYCL_NUM_THREADS) { From adc4152ee6512b906c3aea4330a37af59f53ed62 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:50:00 +0800 Subject: [PATCH 7/7] Update LossNLLKernel.cpp --- src/ATen/native/xpu/sycl/LossNLLKernel.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp index d9da83e846..604289c5b7 100644 --- a/src/ATen/native/xpu/sycl/LossNLLKernel.cpp +++ b/src/ATen/native/xpu/sycl/LossNLLKernel.cpp @@ -493,8 +493,6 @@ void nll_loss_forward_kernel( scalar_t, index_t, accscalar_t>; - int64_t local_size = - syclMaxWorkGroupSize(); NllLossForwardReduce2DKernel kfn( output.mutable_data_ptr(), total_weight.mutable_data_ptr(), @@ -507,7 +505,7 @@ void nll_loss_forward_kernel( input.size(1), n_classes, ignore_index, - local_size); + nthreads); sycl_kernel_submit( sycl::range<1>(nthreads), sycl::range<1>(nthreads),