diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7efc2412eaf7..1c4796a533a4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -256,197 +256,30 @@ def call_binary_op(op, lhs, rhs): return convert - ########## Creation ########## - - def _arange(self, node: fx.Node) -> relax.Var: - import torch - - start_end_step = [None, None, None] - if "start" in node.kwargs: - start_end_step[0] = node.kwargs["start"] - if "end" in node.kwargs: - start_end_step[1] = node.kwargs["end"] - if "step" in node.kwargs: - start_end_step[2] = node.kwargs["step"] - - if len(node.args) == 1: - assert start_end_step[1] is None - start_end_step[1] = node.args[0] - elif len(node.args) == 2: - assert start_end_step[0] is None - assert start_end_step[1] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - elif len(node.args) == 3: - assert start_end_step[0] is None - assert start_end_step[1] is None - assert start_end_step[2] is None - start_end_step[0] = node.args[0] - start_end_step[1] = node.args[1] - start_end_step[2] = node.args[2] - - if start_end_step[0] is None: - start_end_step[0] = 0 - if start_end_step[2] is None: - start_end_step[2] = 1 - - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - elif any([isinstance(x, float) for x in start_end_step]): - dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) - else: - dtype = "int64" - start_end_step = [ - self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step - ] - return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) - - def _empty(self, node: fx.Node) -> relax.Var: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - return self.block_builder.emit(relax.op.zeros(node.args, dtype)) - - def _inplace_fill(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - dtype = x.struct_info.dtype - value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) - filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) - self.env[node.args[0]] = filled - return filled - - def _tensor(self, node: fx.Node) -> relax.Var: - dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None - if isinstance(node.args[0], float): - return relax.const(node.args[0], dtype if dtype is not None else "float32") - elif isinstance(node.args[0], int): - return relax.const(node.args[0], dtype if dtype is not None else "int64") - raise ValueError("torch.tensor with value not a float or int is not accepted") - - def _inplace_tril_triu(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - k = node.args[1] if len(node.args) > 1 else 0 - assert isinstance(k, int) - - mutated = self.block_builder.emit(op(x, k)) - self.env[node.args[0]] = mutated - return mutated - - return convert - - def _new_ones(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - self_var = args[0] - size = args[1:] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, self_var.struct_info.dtype), - self_var.struct_info.dtype, - ) - ) - - def _ones(self, node: fx.Node) -> relax.Var: - import torch + ########## Neural Network ########## - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] return self.block_builder.emit( - relax.op.full( - size, - relax.const(1, dtype), - dtype, - ) + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - def _full(self, node: fx.Node) -> relax.Var: - import torch + def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - size = args[0] - if not isinstance(size, (list, tuple)): - size = (size,) - size = relax.ShapeExpr(size) - dtype = ( - TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - if "dtype" in node.kwargs - else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) - ) - value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size return self.block_builder.emit( - relax.op.full( - size, - value, - dtype, - ) + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) - ########## Statistical ########## - - def _sum(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) - - def _mean(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) - - ########## DataType ########## - - def _float(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) - - def _half(self, node: fx.Node) -> relax.Var: - return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) - - def _type(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - - def _to(self, node: fx.Node) -> relax.Var: - import torch - - x = self.env[node.args[0]] - if len(node.args) == 2: - if isinstance(node.args[1], torch.dtype): - dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - elif "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) - return self.block_builder.emit(relax.op.astype(x, dtype)) - return x - - ########## Linear Algebra ########## - - def _matmul_impl(self, a: relax.Expr, b: relax.Expr): - return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _addmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] y = self.env[node.args[1]] z = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) res = None if alpha != 0: @@ -463,12 +296,50 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + + def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] a = self.env[node.args[1]] b = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) res = None if alpha != 0: @@ -485,229 +356,73 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) - return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) - - def _unbind(self, node: fx.Node) -> relax.Var: - if len(node.args) == 2: - assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" - dim = node.args[1] - elif "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 - x = self.env[node.args[0]] - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) - ########## Manipulation ########## + if bias is None: + return conv1d_transpose - def _cat(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) - return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - def _expand(self, node: fx.Node) -> relax.Var: + def _conv1d_transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - broadcast_shape, in_shape = [], self.shape_of(args[0]) - for idx, i in enumerate(args[1:]): - if isinstance(i, int) and i == -1: - broadcast_shape.append(in_shape[idx]) - else: - broadcast_shape.append(i) - return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) - - def _flatten(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - start_dim = module.start_dim - end_dim = module.end_dim - else: - start_dim = node.args[1] if len(node.args) >= 2 else 0 - end_dim = node.args[2] if len(node.args) == 3 else -1 - shape = self.shape_of(x) - start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim - end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim - flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) - new_shape = ( - [shape[i] for i in range(0, start_dim)] - + [flattened] - + [shape[i] for i in range(end_dim + 1, len(shape))] + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, ) - return self.block_builder.emit(relax.op.reshape(x, new_shape)) - - def _permute(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) - - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - - def _split(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - split_size = node.args[1] - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 - if isinstance(split_size, (list, tuple)): - n_section = [] - for s in split_size[:-1]: - cum_sum = 0 if not n_section else n_section[-1] - n_section.append(s + cum_sum) - else: - n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size - return self.block_builder.emit(relax.op.split(x, n_section, dim)) - - def _chunk(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - chunks = node.args[1] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 2: - dim = node.args[2] - else: - dim = 0 - return self.block_builder.emit(relax.op.split(x, chunks, dim)) - - def _transpose(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - full_idx = list(range(len(self.shape_of(args[0])))) - full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] - return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - - def _squeeze(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - return self.block_builder.emit(relax.op.squeeze(x, dim)) - - def _repeat(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _tile(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) - return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - - def _cumsum(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - elif len(node.args) > 1: - dim = node.args[1] - else: - dim = None - if "dtype" in node.kwargs: - dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) - else: - dtype = None - if "out" in node.kwargs: - raise ValueError("specifying out for cumsum is not supported yet") - - return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - - def _index_select(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - dim = node.args[1] - index = self.env[node.args[2]] - return self.block_builder.emit(relax.op.take(x, index, dim)) - - def _masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - return self.block_builder.emit(relax.op.where(mask, values, x)) - - def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - mask = self.env[node.args[1]] - value = node.args[2] - rx_value = relax.const(value) - values = self.block_builder.emit(relax.op.full_like(x, rx_value)) - output = self.block_builder.emit(relax.op.where(mask, values, x)) - self.env[node.args[0]] = output - return output - - ########## Search ########## - - def _argmax_argmin(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node): - x = self.env[node.args[0]] - dim = None - keepdims = False - - if len(node.args) > 1: - dim = node.args[1] - if len(node.args) > 2: - keepdims = node.args[2] - - if "dim" in node.kwargs: - dim = node.kwargs["dim"] - if "keepdim" in node.kwargs: - keepdims = node.kwargs["keepdim"] - if "keepdims" in node.kwargs: - keepdims = node.kwargs["keepdims"] - - return self.block_builder.emit(op(x, dim, keepdims)) - - return convert - - ########## Neural Network ########## - def _linear(self, node: fx.Node) -> relax.Var: + def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - bias = None if module.bias is None else self.params[module.bias] - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + bias = self.params.get(module.bias, None) - def _linear_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) - def _conv1d_impl( + def _conv2d_transpose_impl( self, x: relax.Expr, weight: relax.Expr, @@ -717,45 +432,28 @@ def _conv1d_impl( dilation: Optional[Tuple], groups: Optional[Tuple], ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( x, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, - data_layout="NCW", - kernel_layout="OIW", + data_layout="NCHW", + kernel_layout="OIHW", out_dtype="float32", ) ) if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + return conv2d_transpose - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - def _conv1d_functional(self, node: fx.Node) -> relax.Var: + def _conv2d_transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -764,7 +462,7 @@ def _conv1d_functional(self, node: fx.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( + return self._conv2d_transpose_impl( x, weight, bias=bias, @@ -774,7 +472,23 @@ def _conv1d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv1d_transpose_impl( + def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv1d_impl( self, x: relax.Expr, weight: relax.Expr, @@ -784,8 +498,8 @@ def _conv1d_transpose_impl( dilation: Optional[Tuple], groups: Optional[Tuple], ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( x, weight, strides=strides, @@ -799,31 +513,12 @@ def _conv1d_transpose_impl( ) if bias is None: - return conv1d_transpose - + return conv1d assert len(self.shape_of(bias)) == 1 bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + return self.block_builder.emit(relax.op.add(conv1d, bias)) - def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: + def _conv1d(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -832,7 +527,7 @@ def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: padding = args[4] if len(args) > 4 else 0 dilation = args[5] if len(args) > 5 else 1 groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( + return self._conv1d_impl( x, weight, bias=bias, @@ -842,6 +537,22 @@ def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) + def _conv1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + def _conv2d_impl( self, x: relax.Expr, @@ -873,24 +584,6 @@ def _conv2d_impl( return self.block_builder.emit(relax.op.add(conv2d, bias)) def _conv2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv2d_functional(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] @@ -909,7 +602,23 @@ def _conv2d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv2d_transpose_impl( + def _conv2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + + def _conv3d_impl( self, x: relax.Expr, weight: relax.Expr, @@ -918,37 +627,53 @@ def _conv2d_transpose_impl( padding: Optional[Tuple], dilation: Optional[Tuple], groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( x, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", + data_layout="NCDHW", + kernel_layout="OIDHW", out_dtype="float32", ) ) if bias is None: - return conv2d_transpose - + return conv3d assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv3d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + bias = self.params.get(module.bias, None) - return self._conv2d_transpose_impl( + return self._conv3d_impl( x, weight, bias=bias, @@ -958,182 +683,570 @@ def _conv2d_transpose(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var: + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _embedding_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + return self._embedding_impl(x, weight) + + def _group_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + eps = module.eps + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _linear(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] weight = args[1] bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv3d_impl( + def _linear_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _max_pool2d_impl( self, x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None else stride + return self.block_builder.emit( + relax.op.nn.max_pool2d( x, - weight, - strides=strides, + pool_size=kernel_size, + strides=stride, padding=padding, dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", + ceil_mode=ceil_mode, + layout="NCHW", ) ) - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) + def _max_pool2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _max_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + + ########## Creation ########## + + def _arange(self, node: fx.Node) -> relax.Var: + import torch + + start_end_step = [None, None, None] + if "start" in node.kwargs: + start_end_step[0] = node.kwargs["start"] + if "end" in node.kwargs: + start_end_step[1] = node.kwargs["end"] + if "step" in node.kwargs: + start_end_step[2] = node.kwargs["step"] + + if len(node.args) == 1: + assert start_end_step[1] is None + start_end_step[1] = node.args[0] + elif len(node.args) == 2: + assert start_end_step[0] is None + assert start_end_step[1] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + elif len(node.args) == 3: + assert start_end_step[0] is None + assert start_end_step[1] is None + assert start_end_step[2] is None + start_end_step[0] = node.args[0] + start_end_step[1] = node.args[1] + start_end_step[2] = node.args[2] + + if start_end_step[0] is None: + start_end_step[0] = 0 + if start_end_step[2] is None: + start_end_step[2] = 1 + + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + elif any([isinstance(x, float) for x in start_end_step]): + dtype = TorchFXImporter._convert_data_type(torch.get_default_dtype()) + else: + dtype = "int64" + start_end_step = [ + self.env[x] if isinstance(x, torch.fx.Node) else x for x in start_end_step + ] + return self.block_builder.emit(relax.op.arange(*start_end_step, dtype=dtype)) + + def _empty(self, node: fx.Node) -> relax.Var: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + return self.block_builder.emit(relax.op.zeros(node.args, dtype)) + + def _inplace_fill(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dtype = x.struct_info.dtype + value = args[1] if isinstance(args[1], relax.Expr) else relax.const(args[1], dtype) + filled = self.block_builder.emit(relax.op.full(x.struct_info.shape, value, dtype)) + self.env[node.args[0]] = filled + return filled + + def _tensor(self, node: fx.Node) -> relax.Var: + dtype = node.kwargs["dtype"] if "dtype" in node.kwargs else None + if isinstance(node.args[0], float): + return relax.const(node.args[0], dtype if dtype is not None else "float32") + elif isinstance(node.args[0], int): + return relax.const(node.args[0], dtype if dtype is not None else "int64") + raise ValueError("torch.tensor with value not a float or int is not accepted") + + def _inplace_tril_triu(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + k = node.args[1] if len(node.args) > 1 else 0 + assert isinstance(k, int) + + mutated = self.block_builder.emit(op(x, k)) + self.env[node.args[0]] = mutated + return mutated + + return convert + + def _new_ones(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + self_var = args[0] + size = args[1:] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, self_var.struct_info.dtype), + self_var.struct_info.dtype, + ) + ) + + def _ones(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + return self.block_builder.emit( + relax.op.full( + size, + relax.const(1, dtype), + dtype, + ) + ) + + def _full(self, node: fx.Node) -> relax.Var: + import torch + + args = self.retrieve_args(node) + size = args[0] + if not isinstance(size, (list, tuple)): + size = (size,) + size = relax.ShapeExpr(size) + dtype = ( + TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + if "dtype" in node.kwargs + else TorchFXImporter._convert_data_type(torch.get_default_dtype(), self.env) + ) + value = args[1] if isinstance(args[1], relax.expr.Constant) else relax.const(args[1], dtype) + return self.block_builder.emit( + relax.op.full( + size, + value, + dtype, + ) + ) + + ########## Statistical ########## + + def _sum(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.sum(args[0], args[1])) + + def _mean(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False + if len(args) == 1: + return self.block_builder.emit(relax.op.mean(args[0], keepdims=keepdim)) + return self.block_builder.emit(relax.op.mean(args[0], args[1], keepdims=keepdim)) + + ########## DataType ########## + + def _float(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float32")) + + def _half(self, node: fx.Node) -> relax.Var: + return self.block_builder.emit(relax.op.astype(self.env[node.args[0]], "float16")) + + def _type(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + + def _to(self, node: fx.Node) -> relax.Var: + import torch + + x = self.env[node.args[0]] + if len(node.args) == 2: + if isinstance(node.args[1], torch.dtype): + dtype = TorchFXImporter._convert_data_type(node.args[1], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + elif "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(node.kwargs["dtype"], self.env) + return self.block_builder.emit(relax.op.astype(x, dtype)) + return x + + ########## Manipulation ########## + + def _cat(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + return self.block_builder.emit(relax.op.concat(args[0], axis=axis)) + + def _expand(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + broadcast_shape, in_shape = [], self.shape_of(args[0]) + for idx, i in enumerate(args[1:]): + if isinstance(i, int) and i == -1: + broadcast_shape.append(in_shape[idx]) + else: + broadcast_shape.append(i) + return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + + def _flatten(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + if node.target in self.named_modules: + module = self.named_modules[node.target] + start_dim = module.start_dim + end_dim = module.end_dim + else: + start_dim = node.args[1] if len(node.args) >= 2 else 0 + end_dim = node.args[2] if len(node.args) == 3 else -1 + shape = self.shape_of(x) + start_dim = start_dim if start_dim >= 0 else len(shape) + start_dim + end_dim = end_dim if end_dim >= 0 else len(shape) + end_dim + flattened = reduce(lambda x, y: x * y, [shape[i] for i in range(start_dim, end_dim + 1)]) + new_shape = ( + [shape[i] for i in range(0, start_dim)] + + [flattened] + + [shape[i] for i in range(end_dim + 1, len(shape))] + ) + return self.block_builder.emit(relax.op.reshape(x, new_shape)) + + def _permute(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.permute_dims(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.permute_dims(args[0], args[1:])) + + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.reshape(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.reshape(args[0], args[1:])) - def _conv3d(self, node: fx.Node) -> relax.Var: + def _split(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] + split_size = node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + else: + dim = 0 + if isinstance(split_size, (list, tuple)): + n_section = [] + for s in split_size[:-1]: + cum_sum = 0 if not n_section else n_section[-1] + n_section.append(s + cum_sum) + else: + n_section = (self.shape_of(x)[dim].value + split_size - 1) // split_size + return self.block_builder.emit(relax.op.split(x, n_section, dim)) - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) + def _chunk(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + chunks = node.args[1] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 2: + dim = node.args[2] + else: + dim = 0 + return self.block_builder.emit(relax.op.split(x, chunks, dim)) - def _conv3d_functional(self, node: fx.Node) -> relax.Var: + def _transpose(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) + full_idx = list(range(len(self.shape_of(args[0])))) + full_idx[args[1]], full_idx[args[2]] = full_idx[args[2]], full_idx[args[1]] + return self.block_builder.emit(relax.op.permute_dims(args[0], full_idx)) - def _max_pool2d(self, node: fx.Node) -> relax.Var: + def _squeeze(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - dilation = module.dilation - ceil_mode = module.ceil_mode + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - stride = node.args[2] if nargs > 2 else node.kwargs["stride"] - padding = node.args[3] if nargs > 3 else node.kwargs["padding"] - dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] - ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] + dim = None + return self.block_builder.emit(relax.op.squeeze(x, dim)) - stride = kernel if stride is None else stride + def _repeat(self, node: fx.Node) -> relax.Var: + import torch # type: ignore - return self.block_builder.emit( - relax.op.nn.max_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - dilation=dilation, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) - def _avg_pool2d(self, node: fx.Node) -> relax.Var: + def _tile(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + + def _cumsum(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - ceil_mode = module.ceil_mode + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1: + dim = node.args[1] else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - if nargs > 2: - stride = node.args[2] - elif "stride" in node.kwargs.keys(): - stride = node.kwargs["stride"] - else: - stride = None - if nargs > 3: - padding = node.args[3] - elif "padding" in node.kwargs.keys(): - padding = node.kwargs["padding"] - else: - padding = 0 - if nargs > 4: - ceil_mode = node.args[4] - elif "ceil_mode" in node.kwargs.keys(): - ceil_mode = node.kwargs["ceil_mode"] - else: - ceil_mode = False + dim = None + if "dtype" in node.kwargs: + dtype = TorchFXImporter._convert_data_type(str(node.kwargs["dtype"]), self.env) + else: + dtype = None + if "out" in node.kwargs: + raise ValueError("specifying out for cumsum is not supported yet") - stride = kernel if stride is None else stride + return self.block_builder.emit(relax.op.cumsum(x, dim, dtype)) - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) + def _index_select(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] + index = self.env[node.args[2]] + return self.block_builder.emit(relax.op.take(x, index, dim)) - def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: + def _masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + return self.block_builder.emit(relax.op.where(mask, values, x)) + + def _inplace_masked_fill(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + mask = self.env[node.args[1]] + value = node.args[2] + rx_value = relax.const(value) + values = self.block_builder.emit(relax.op.full_like(x, rx_value)) + output = self.block_builder.emit(relax.op.where(mask, values, x)) + self.env[node.args[0]] = output + return output + + ########## Search ########## + + def _argmax_argmin(self, op: Callable) -> Callable: from torch import fx - def _impl(node: fx.Node) -> relax.Var: - if is_module: - module = self.named_modules[node.target] - x = self.env[node.args[0]] - output_size = module.output_size - else: - x = self.env[node.args[0]] - output_size = node.args[1] - return self.block_builder.emit( - relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") - ) + def convert(node: fx.Node): + x = self.env[node.args[0]] + dim = None + keepdims = False + + if len(node.args) > 1: + dim = node.args[1] + if len(node.args) > 2: + keepdims = node.args[2] + + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + if "keepdim" in node.kwargs: + keepdims = node.kwargs["keepdim"] + if "keepdims" in node.kwargs: + keepdims = node.kwargs["keepdims"] - return _impl + return self.block_builder.emit(op(x, dim, keepdims)) + + return convert + + ########## Neural Network ########## def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -1169,115 +1282,6 @@ def _batch_norm_2d(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _layer_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - x = self.env[node.args[0]] - - # functional.layer_norm - if node.target not in self.named_modules: - # static or symbolic - arg = node.args[1] - if isinstance(arg, (immutable_list, tuple)): - value = tuple(arg) - else: - try: - value = self.env[arg] - except TypeError: - value = tuple(arg) - normalized_shape = value - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - gamma = node.kwargs["weight"] - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - else: - gamma = self.env[gamma] - beta = node.kwargs["bias"] - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - else: - beta = self.env[beta] - eps = node.kwargs["eps"] - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - module = self.named_modules[node.target] - - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - dim_num = len(module.normalized_shape) - axes = list(range(-dim_num, 0)) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=module.eps, - ) - ) - - def _group_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - - if module.affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) - beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) - - dim = len(self.shape_of(x)) - return self.block_builder.emit( - relax.op.nn.group_norm( - x, - gamma, - beta, - num_groups=module.num_groups, - channel_axis=1, - axes=list(range(2, dim)), - epsilon=module.eps, - ) - ) - - def _embedding(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, @@ -1387,26 +1391,6 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr: ) ) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - assert ( - len(node.args) <= 4 - ), "Dropout is not supported, and is_causal should be called by kwargs." - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - causal_mask = "TopLeft" if node.kwargs.get("is_causal", False) else None - - if len(node.args) == 4: - mask = self.env[node.args[3]] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in mask.struct_info.dtype, msg - attn = relax.op.nn.attention(query, key, value, bias=mask, causal_mask=causal_mask) - else: - attn = relax.op.nn.attention(query, key, value, causal_mask=causal_mask) - - return self.block_builder.emit(attn) - ########## Others ########## def _sym_size_int(self, node: fx.Node) -> relax.Expr: @@ -1538,20 +1522,20 @@ def create_convert_map(self): nn.Softmax: self._softmax_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network - nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), - nn.AvgPool2d: self._avg_pool2d, + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, + nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d, - nn.Conv1d: self._conv1d, - nn.Conv2d: self._conv2d, - nn.Conv3d: self._conv3d, - nn.ConvTranspose1d: self._conv1d_transpose, - nn.ConvTranspose2d: self._conv2d_transpose, + nn.Conv1d: self._conv1d_module, + nn.Conv2d: self._conv2d_module, + nn.Conv3d: self._conv3d_module, + nn.ConvTranspose1d: self._conv1d_transpose_module, + nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, - nn.GroupNorm: self._group_norm, - nn.LayerNorm: self._layer_norm, - nn.Linear: self._linear, - nn.MaxPool2d: self._max_pool2d, - nn.modules.sparse.Embedding: self._embedding, + nn.GroupNorm: self._group_norm_module, + nn.LayerNorm: self._layer_norm_module, + nn.Linear: self._linear_module, + nn.MaxPool2d: self._max_pool2d_module, + nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation nn.Flatten: self._flatten, ## call_function and call_method @@ -1603,23 +1587,23 @@ def create_convert_map(self): "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), # neural network - "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "adaptive_avg_pool2d": self._adaptive_avg_pool2d, "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose_functional, - "conv_transpose2d": self._conv2d_transpose_functional, - "conv1d": self._conv1d_functional, - "conv2d": self._conv2d_functional, - "conv3d": self._conv3d_functional, + "conv_transpose1d": self._conv1d_transpose, + "conv_transpose2d": self._conv2d_transpose, + "conv1d": self._conv1d, + "conv2d": self._conv2d, + "conv3d": self._conv3d, "cross_entropy": self._cross_entropy, "einsum": self._einsum, "interpolate": self._interpolate, "layer_norm": self._layer_norm, - "linear": self._linear_functional, + "linear": self._linear, "max_pool2d": self._max_pool2d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]],