-
Notifications
You must be signed in to change notification settings - Fork 59
Description
🐛 Describe the bug
With the latest torch-2.9 (built with 2025.2.1), a simple UT with TP=2 fails with the following error message:
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter: 268
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter: 269
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
iter: 270
memory (MiB) | allocated: 2048.0 | max allocated: 3072.0 | reserved: 62976.0 | max reserved: 62976.0
Process Process-2:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/sdp/fanli/test_xccl.py", line 60, in init_process
fn(rank, world_size)
File "/home/sdp/fanli/test_xccl.py", line 46, in run
output = ret + weight + gradient
torch.OutOfMemoryError: XPU out of memory. Tried to allocate 512.00 MiB. GPU 1 has a total capacity of 63.98 GiB. Of the allocated memory 2.50 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. Please use `empty_cache` to release all unoccupied cached memory.
The full script is attached. To put it simply, the following code does not free memory after each iteration, causing the reserved memory to keep increasing and finally crashing:
for epoch in range(1):
for iter in range(1000):
# generate the random input for each iteration
output = torch.empty_like(ret)
# mock the compute part(fwd/bwd)
output = ret + weight + gradient
ret = torch.nn.functional.linear(output, weight=ret)
with torch.xpu.stream(coll_stream):
dist.all_reduce(ret, op=dist.ReduceOp.SUM)
if rank == 0:
print("iter: ", iter)
report_memory(True)
After investigating the memory addresses of the variables, we found that instead of reusing the memory addresses, in each iteration new memory address will be created for output and ret.
If the line dist.all_reduce(ret, op=dist.ReduceOp.SUM)
is out-commented, the code can run through.
Versions
torch 2.9.0.dev20250903+xpu
torchaudio 2.8.0.dev20250904+xpu
torchvision 0.24.0.dev20250904+xpu
intel_extension_for_pytorch 2.9.10+git2426de8
vllm 0.1.dev8696+g4935c6506.xpu /fanli/vllm-xpu