Skip to content

numerical issue when running SDPA with DTensor #267

@tianyu-l

Description

@tianyu-l

The issue comes from the backward computation of aten.mul of two complex numbers from DTensors: the result will be b + ai when it should be a + bi. Not sure why it happens -- when doing aten operations, the input tensors have been de-sugared and should have nothing to do with DTensor.

To replicate, put the following code in pytorch/test/distributed/tensor/parallel/test_tp_examples.py

    @with_comms
    def test_apply_rotary_embedding(self):
        device_mesh = self.build_device_mesh()
        def apply_rotary_emb(xq, freqs_cis):
            xq_ = torch.view_as_complex(xq)
            xq_out = torch.view_as_real(xq_ * freqs_cis)
            return xq_out
            
        with CommDebugMode():
            # xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            # freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            # xq_out = apply_rotary_emb(xq, freqs_cis)
            # xq_out.sum().backward()

            xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
            freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
            xq_dt = distribute_tensor(xq, device_mesh, (Replicate(),))
            freqs_cis_dt = distribute_tensor(freqs_cis, device_mesh, (Replicate(),))
            xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
            xq_out_dt.sum().backward()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions