import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import os
 
def report_memory(usemb = True):
    """Simple GPU memory report."""
    if usemb:
        unit = 1024.0 * 1024.0
        string = 'memory (MiB)'
    else:
        unit = 1024.0 * 1024.0 * 1024.0
        string = 'memory (GiB)'
    string += ' | allocated: {}'.format(
        torch.xpu.memory_allocated() / unit)
    string += ' | max allocated: {}'.format(
        torch.xpu.max_memory_allocated() / unit)
    string += ' | reserved: {}'.format(
        torch.xpu.memory_reserved() / unit)
    string += ' | max reserved: {}'.format(
        torch.xpu.max_memory_reserved() / unit)
    print("{}".format(string))
 
def run(rank, world_size):
    dp_ranks = range(0, world_size)
    dp_group = dist.new_group(dp_ranks)

    torch.xpu.set_device("xpu:{}".format(rank))
    
    shape= (16384, 16384)
    # init weight and gradient
    weight = torch.ones(shape).half()
    weight = weight.xpu("xpu:{}".format(rank))
    
    gradient = torch.zeros(shape).half()
    gradient = gradient.xpu("xpu:{}".format(rank))

    ret = torch.randn(shape).half()
    ret = ret.xpu("xpu:{}".format(rank))
    coll_stream = torch.xpu.Stream()
        
    for epoch in range(1):
        for iter in range(1000):
            # generate the random input for each iteration
            output = torch.empty_like(ret)
            # mock the compute part(fwd/bwd)
            output = ret + weight + gradient
            ret = torch.nn.functional.linear(output, weight=ret)

            with torch.xpu.stream(coll_stream):
                dist.all_reduce(ret, op=dist.ReduceOp.SUM)

            if rank == 0:
                print("iter: ", iter)
                report_memory(True)
    torch.xpu.synchronize()
    
def init_process(rank, world_size, fn, backend='xccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '27500'
    dist.init_process_group(backend, rank=rank, world_size=world_size)
    fn(rank, world_size)
 
 
if __name__ == "__main__":
    mp.set_start_method("spawn")
    world_size = 2
    processes = []
 
    for rank in range(world_size):
        p = mp.Process(target=init_process, args=(rank, world_size, run))
        p.start()
        processes.append(p)
 
    for p in processes:
        p.join()