Skip to content

Conversation

@Silv3S
Copy link
Contributor

@Silv3S Silv3S commented Oct 10, 2025

Summary

torch.special.logit for bfloat16 and float16 input runs in higher precision, because input is casted to AccumulateTypeDevice, which is float32 (pytorch/aten/src/ATen/AccumulateType.h). Output is casted back to lower precision, but because intermediate results are in float32, we have different results than CPU. It might affect other tests so I wanted to clarify if this is expected or we should always try to match CPU reference in our kernels.

Minimal repro

import torch
import pytest

@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["float32", "bfloat16"])
def test_special_logit(dtype):
    input_cpu = torch.tensor([0.5234], device="cpu", dtype=dtype)
    input_xpu = input_cpu.to("xpu")

    reference_cpu = torch.log(input_cpu/(1 - input_cpu))
    reference_xpu = torch.log(input_xpu/(1 - input_xpu))
    print(f"reference_cpu logit: {reference_cpu}")
    print(f"reference_xpu logit: {reference_xpu}")
    assert torch.allclose(reference_cpu, reference_xpu.cpu(), atol=1e-5, rtol=1e-5)

    logit_cpu = torch.special.logit(input_cpu)
    logit_xpu = torch.special.logit(input_xpu)
    print(f"CPU logit: {logit_cpu}")
    print(f"XPU logit: {logit_xpu}")
    assert torch.allclose(logit_cpu, logit_xpu.cpu(), atol=1e-5, rtol=1e-5)

Results

device dtype reference torch.special.logit torch.special.logit (fix)
CPU fp32 0.0937 0.0937
XPU fp32 0.0937 0.0937 0.0937
CUDA fp32 0.0937 0.0937
CPU bf16 0.0967 0.0967
XPU bf16 0.0967 0.0938 0.0967
CUDA bf16 0.0967 0.0938

@Copilot Copilot AI review requested due to automatic review settings October 10, 2025 11:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a precision issue with torch.special.logit for bfloat16 and float16 inputs by modifying the kernel to run computations in reduced precision instead of casting to higher precision (float32). The change ensures consistency between CPU and XPU device results for half-precision floating point types.

  • Simplified logit computation to use native input precision instead of accumulate type casting
  • Renamed functors for clarity (Logit0Functor → LogitFunctor, Logit1Functor → LogitEpsFunctor)
  • Updated parameter names and types to match the new precision-preserving approach

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@Silv3S Silv3S changed the title Run torch.special.logit in reduced precision, for bf16/f16 inputs Run torch.special.logit in reduced precision for bf16/f16 inputs Oct 10, 2025
Copy link

@australopitek australopitek left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but we need to come up with the way of handling such discrepancies between CPU and CUDA results in future, and stick to it. Currently CPU gives different results than CUDA for these ops.

@EikanWang
Copy link
Contributor

@Silv3S , what's the behavior on CUDA? In general, we should align with CUDA because running CUDA models on XPU is the use case in my mind.

@CuiYifeng
Copy link
Contributor

@Silv3S I noticed that community also perceives the inconsistency between CPU and CUDA but there is currently no clear solution. This PR is a temporary solution, right?

@Silv3S
Copy link
Contributor Author

Silv3S commented Oct 21, 2025

@EikanWang Our current kernel is aligned with CUDA, but it feels wrong to execute bf16/fp16 operator in fp32 precision. We lose benefits of low precision by using extra casts and computing it as normal floats. In my opinion proper solution is to either 1) calculate everything in original precision or 2) keep the T_ACC for sub and div but do log in original precision.

graphviz (1)

@CuiYifeng Thanks for pointing to community PR pytorch/pytorch#140252, I wasn't aware that someone else encountered this bug. I think we should align to CPU regarding expected accuracy, so it's not a temporary solution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants