Skip to content

[XCCL] torch OOM in TP scenario with torch2.9 #2084

@faaany

Description

@faaany

🐛 Describe the bug

With the latest torch-2.9 (built with 2025.2.1), a simple UT with TP=2 fails with the following error message:

memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter:  268
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter:  269
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter:  270
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
Process Process-2:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/sdp/fanli/test_xccl.py", line 60, in init_process
    fn(rank, world_size)
  File "/home/sdp/fanli/test_xccl.py", line 46, in run
    output = ret + weight + gradient
torch.OutOfMemoryError: XPU out of memory. Tried to allocate 512.00 MiB. GPU 1 has a total capacity of 63.98 GiB. Of the allocated memory 2.50 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. Please use `empty_cache` to release all unoccupied cached memory. 

The full script is attached. To put it simply, the following code does not free memory after each iteration, causing the reserved memory to keep increasing and finally crashing:

for epoch in range(1):
    for iter in range(1000):
        # generate the random input for each iteration
        output = torch.empty_like(ret)
        # mock the compute part(fwd/bwd)
        output = ret + weight + gradient
        ret = torch.nn.functional.linear(output, weight=ret)
        with torch.xpu.stream(coll_stream):
            dist.all_reduce(ret, op=dist.ReduceOp.SUM)            
        
        if rank == 0:
            print("iter: ", iter)
            report_memory(True)

After investigating the memory addresses of the variables, we found that instead of reusing the memory addresses, in each iteration new memory address will be created for output and ret.

If the line dist.all_reduce(ret, op=dist.ReduceOp.SUM) is out-commented, the code can run through.

Versions

torch 2.9.0.dev20250903+xpu
torchaudio 2.8.0.dev20250904+xpu
torchvision 0.24.0.dev20250904+xpu
intel_extension_for_pytorch 2.9.10+git2426de8
vllm 0.1.dev8696+g4935c6506.xpu /fanli/vllm-xpu

test_xccl.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions