-
Notifications
You must be signed in to change notification settings - Fork 572
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
The issue comes from the backward computation of aten.mul of two complex numbers from DTensors: the result will be b + ai when it should be a + bi. Not sure why it happens -- when doing aten operations, the input tensors have been de-sugared and should have nothing to do with DTensor.
To replicate, put the following code in pytorch/test/distributed/tensor/parallel/test_tp_examples.py
@with_comms
def test_apply_rotary_embedding(self):
device_mesh = self.build_device_mesh()
def apply_rotary_emb(xq, freqs_cis):
xq_ = torch.view_as_complex(xq)
xq_out = torch.view_as_real(xq_ * freqs_cis)
return xq_out
with CommDebugMode():
# xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
# freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
# xq_out = apply_rotary_emb(xq, freqs_cis)
# xq_out.sum().backward()
xq = torch.randn(1, 1, 2, requires_grad=True, device=self.device_type)
freqs_cis = torch.randn(1, 1, dtype=torch.complex64, requires_grad=False, device=self.device_type)
xq_dt = distribute_tensor(xq, device_mesh, (Replicate(),))
freqs_cis_dt = distribute_tensor(freqs_cis, device_mesh, (Replicate(),))
xq_out_dt = apply_rotary_emb(xq_dt, freqs_cis_dt)
xq_out_dt.sum().backward()
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed