Skip to content

Conversation

Silv3S
Copy link
Contributor

@Silv3S Silv3S commented Oct 10, 2025

Summary

torch.special.logit for bfloat16 and float16 input runs in higher precision, because input is casted to AccumulateTypeDevice, which is float32 (pytorch/aten/src/ATen/AccumulateType.h). Output is casted back to lower precision, but because intermediate results are in float32, we have different results than CPU. It might affect other tests so I wanted to clarify if this is expected or we should always try to match CPU reference in our kernels.

Minimal repro

import torch
import pytest

@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["float32", "bfloat16"])
def test_special_logit(dtype):
    input_cpu = torch.tensor([0.5234], device="cpu", dtype=dtype)
    input_xpu = input_cpu.to("xpu")

    reference_cpu = torch.log(input_cpu/(1 - input_cpu))
    reference_xpu = torch.log(input_xpu/(1 - input_xpu))
    print(f"reference_cpu logit: {reference_cpu}")
    print(f"reference_xpu logit: {reference_xpu}")
    assert torch.allclose(reference_cpu, reference_xpu.cpu(), atol=1e-5, rtol=1e-5)

    logit_cpu = torch.special.logit(input_cpu)
    logit_xpu = torch.special.logit(input_xpu)
    print(f"CPU logit: {logit_cpu}")
    print(f"XPU logit: {logit_xpu}")
    assert torch.allclose(logit_cpu, logit_xpu.cpu(), atol=1e-5, rtol=1e-5)

Results

device dtype reference torch.special.logit torch.special.logit (fix)
CPU fp32 0.0937 0.0937
XPU fp32 0.0937 0.0937 0.0937
CUDA fp32 0.0937 0.0937
CPU bf16 0.0967 0.0967
XPU bf16 0.0967 0.0938 0.0967
CUDA bf16 0.0967 0.0938

@Copilot Copilot AI review requested due to automatic review settings October 10, 2025 11:56
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a precision issue with torch.special.logit for bfloat16 and float16 inputs by modifying the kernel to run computations in reduced precision instead of casting to higher precision (float32). The change ensures consistency between CPU and XPU device results for half-precision floating point types.

  • Simplified logit computation to use native input precision instead of accumulate type casting
  • Renamed functors for clarity (Logit0Functor → LogitFunctor, Logit1Functor → LogitEpsFunctor)
  • Updated parameter names and types to match the new precision-preserving approach

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Silv3S Silv3S changed the title Run torch.special.logit in reduced precision, for bf16/f16 inputs Run torch.special.logit in reduced precision for bf16/f16 inputs Oct 10, 2025
Copy link

@australopitek australopitek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but we need to come up with the way of handling such discrepancies between CPU and CUDA results in future, and stick to it. Currently CPU gives different results than CUDA for these ops.

@EikanWang
Copy link
Contributor

@Silv3S , what's the behavior on CUDA? In general, we should align with CUDA because running CUDA models on XPU is the use case in my mind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants