Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 21 additions & 24 deletions colossalai/fx/profiler/tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import uuid
from copy import deepcopy
from typing import Optional

import torch
from torch.types import _bool, _device, _dtype
Expand Down Expand Up @@ -28,8 +26,6 @@ class MetaTensor(torch.Tensor):

_tensor: torch.Tensor

__slots__ = ['_tensor']

@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
Expand All @@ -47,7 +43,7 @@ def __new__(cls, elem, fake_device=None):
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device if fake_device is not None else elem.device,
device=fake_device if fake_device is not None else torch.device('cpu'),
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
Expand All @@ -59,8 +55,8 @@ def __new__(cls, elem, fake_device=None):

def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Expand All @@ -76,13 +72,13 @@ def unwrap(x):
x = x.to(torch.device('meta'))
return x

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

if 'device' in kwargs:
fake_device = kwargs['device']
kwargs['device'] = torch.device('meta')

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)

# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)

Expand Down Expand Up @@ -118,23 +114,24 @@ def to(self, *args, **kwargs) -> torch.Tensor:
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
for arg in args:
if isinstance(arg, str) or isinstance(arg, _device):
device = arg
if 'device' in kwargs:
device = kwargs['device']
result = super().to(*args, **kwargs)
if device is not None:
result = MetaTensor(result, fake_device=device)
return result
fake_device = None

def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
return 'meta'
return x

elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)

def cpu(self, *args, **kwargs):
if self.device.type == 'cpu':
return self.to(*args, **kwargs)
return self.to(*args, device='cpu', **kwargs)

def cuda(self, *args, **kwargs):
if self.device.type == 'cuda':
return self.to(*args, **kwargs)
return self.to(*args, device='cuda', **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device='cuda:0', non_blocking=non_blocking)
Loading