From c33d436f814b2b74a3d65c3857fb8e81fd575fc0 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:23:39 +0900 Subject: [PATCH 01/18] cleanup `_adaptive_avg_pool2d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 39 ++++++++++--------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 7efc2412eaf7..5a5466b9c274 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -256,6 +256,24 @@ def call_binary_op(op, lhs, rhs): return convert + ########## Neural Network ########## + + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: + + module = self.named_modules[node.target] + x = self.env[node.args[0]] + output_size = module.output_size + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1118,23 +1136,6 @@ def _avg_pool2d(self, node: fx.Node) -> relax.Var: ) ) - def _adaptive_avg_pool2d(self, is_module: bool) -> Callable: - from torch import fx - - def _impl(node: fx.Node) -> relax.Var: - if is_module: - module = self.named_modules[node.target] - x = self.env[node.args[0]] - output_size = module.output_size - else: - x = self.env[node.args[0]] - output_size = node.args[1] - return self.block_builder.emit( - relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") - ) - - return _impl - def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1538,7 +1539,7 @@ def create_convert_map(self): nn.Softmax: self._softmax_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network - nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True), + nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AvgPool2d: self._avg_pool2d, nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d, @@ -1603,7 +1604,7 @@ def create_convert_map(self): "sub": self._binary_op(relax.op.subtract, operator.sub), "truediv": self._binary_op(relax.op.divide, operator.truediv), # neural network - "adaptive_avg_pool2d": self._adaptive_avg_pool2d(is_module=False), + "adaptive_avg_pool2d": self._adaptive_avg_pool2d, "addmm": self._addmm, "avg_pool2d": self._avg_pool2d, "baddbmm": self._baddbmm, From 87e40b76680a0ccc5e84c0b4445620c7b87e07c9 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:26:35 +0900 Subject: [PATCH 02/18] cleanup `addmm()` --- .../tvm/relax/frontend/torch/fx_translator.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5a5466b9c274..99de1fcd67fc 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -274,6 +274,28 @@ def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") ) + def _addmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + y = self.env[node.args[1]] + z = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) + return res + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -459,28 +481,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _addmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - y = self.env[node.args[1]] - z = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.linear_algebra.matmul(y, z, out_dtype="float32")) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) - return res - def _baddbmm(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] a = self.env[node.args[1]] From e14ed29b315f9ee0656e99c704cd4299245995e3 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:29:30 +0900 Subject: [PATCH 03/18] cleanup `_avg_pool2d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 83 +++++++++---------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 99de1fcd67fc..9d372cfaef22 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -296,6 +296,44 @@ def _addmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(bias, res)) return res + def _avg_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None or stride == [] else stride + return self.block_builder.emit( + relax.op.nn.avg_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _avg_pool2d(self, node: fx.Node) -> relax.Var: + args, kwargs = node.normalized_arguments(node) + x = self.env[args[0]] + kernel_size = args[1] if len(args) > 1 else kwargs["kernel_size"] + stride = args[2] if len(args) > 2 else kwargs.get("stride", None) + padding = args[3] if len(args) > 3 else kwargs.get("padding", 0) + ceil_mode = args[4] if len(args) > 4 else kwargs.get("ceil_mode", False) + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + + def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + ceil_mode = module.ceil_mode + return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1093,49 +1131,6 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var: ) ) - def _avg_pool2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - ceil_mode = module.ceil_mode - else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - if nargs > 2: - stride = node.args[2] - elif "stride" in node.kwargs.keys(): - stride = node.kwargs["stride"] - else: - stride = None - if nargs > 3: - padding = node.args[3] - elif "padding" in node.kwargs.keys(): - padding = node.kwargs["padding"] - else: - padding = 0 - if nargs > 4: - ceil_mode = node.args[4] - elif "ceil_mode" in node.kwargs.keys(): - ceil_mode = node.kwargs["ceil_mode"] - else: - ceil_mode = False - - stride = kernel if stride is None else stride - - return self.block_builder.emit( - relax.op.nn.avg_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) - def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1540,7 +1535,7 @@ def create_convert_map(self): nn.Tanh: self._unary_op(relax.op.tanh), # neural network nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, - nn.AvgPool2d: self._avg_pool2d, + nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, From 10737f4ad335fa422a86074954ef03a3bdf4d959 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:30:04 +0900 Subject: [PATCH 04/18] cleanup `_baddbmm()` --- .../tvm/relax/frontend/torch/fx_translator.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 9d372cfaef22..2b64f30df2a2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -334,6 +334,28 @@ def _avg_pool2d_module(self, node: fx.Node) -> relax.Var: ceil_mode = module.ceil_mode return self._avg_pool2d_impl(x, kernel_size, stride, padding, ceil_mode) + def _baddbmm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + a = self.env[node.args[1]] + b = self.env[node.args[2]] + alpha = node.kwargs.get("alpha", 1) + beta = node.kwargs.get("beta", 1) + + res = None + if alpha != 0: + res = self.block_builder.emit(relax.op.matmul(a, b)) + if alpha != 1: + dtype = res.struct_info.dtype + res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) + if beta != 0: + dtype = x.struct_info.dtype + if beta != 1: + bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) + else: + bias = x + res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) + return res + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -519,28 +541,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _baddbmm(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - a = self.env[node.args[1]] - b = self.env[node.args[2]] - alpha = node.kwargs["alpha"] if "alpha" in node.kwargs else 1 - beta = node.kwargs["beta"] if "beta" in node.kwargs else 1 - - res = None - if alpha != 0: - res = self.block_builder.emit(relax.op.matmul(a, b)) - if alpha != 1: - dtype = res.struct_info.dtype - res = self.block_builder.emit(relax.op.multiply(res, relax.const(alpha, dtype))) - if beta != 0: - dtype = x.struct_info.dtype - if beta != 1: - bias = self.block_builder.emit(relax.op.multiply(x, relax.const(beta, dtype))) - else: - bias = x - res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) - return res - def _einsum(self, node: fx.Node) -> relax.Var: import torch # type: ignore From 8d2d98363e2772225807e9ffbba6bac54ceeda92 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:32:33 +0900 Subject: [PATCH 05/18] cleanup `_conv1d_transpose()` --- .../tvm/relax/frontend/torch/fx_translator.py | 138 +++++++++--------- 1 file changed, 68 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 2b64f30df2a2..69f52818171f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -356,6 +356,72 @@ def _baddbmm(self, node: fx.Node) -> relax.Var: res = bias if res is None else self.block_builder.emit(relax.op.add(res, bias)) return res + def _conv1d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d_transpose = self.block_builder.emit( + relax.op.nn.conv1d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) + + def _conv1d_transpose(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv1d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -830,74 +896,6 @@ def _conv1d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv1d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d_transpose = self.block_builder.emit( - relax.op.nn.conv1d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d_transpose, bias)) - - def _conv1d_transpose(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv1d_transpose_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv2d_impl( self, x: relax.Expr, @@ -1540,7 +1538,7 @@ def create_convert_map(self): nn.Conv1d: self._conv1d, nn.Conv2d: self._conv2d, nn.Conv3d: self._conv3d, - nn.ConvTranspose1d: self._conv1d_transpose, + nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose, nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm, @@ -1606,7 +1604,7 @@ def create_convert_map(self): "bmm": self._binary_op( partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), - "conv_transpose1d": self._conv1d_transpose_functional, + "conv_transpose1d": self._conv1d_transpose, "conv_transpose2d": self._conv2d_transpose_functional, "conv1d": self._conv1d_functional, "conv2d": self._conv2d_functional, From 35c43ac70428886933d732971ffb95ae05480e8f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:34:58 +0900 Subject: [PATCH 06/18] cleanup `_conv2d_transpose()` --- .../tvm/relax/frontend/torch/fx_translator.py | 138 +++++++++--------- 1 file changed, 68 insertions(+), 70 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 69f52818171f..05746110ced6 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -422,6 +422,72 @@ def _conv1d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _conv2d_transpose_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv2d_transpose = self.block_builder.emit( + relax.op.nn.conv2d_transpose( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d_transpose + + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) + + def _conv2d_transpose(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_transpose_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -963,74 +1029,6 @@ def _conv2d_functional(self, node: fx.Node) -> relax.Var: groups=groups, ) - def _conv2d_transpose_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv2d_transpose = self.block_builder.emit( - relax.op.nn.conv2d_transpose( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d_transpose - - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d_transpose, bias)) - - def _conv2d_transpose(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv2d_transpose_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_transpose_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv3d_impl( self, x: relax.Expr, @@ -1539,7 +1537,7 @@ def create_convert_map(self): nn.Conv2d: self._conv2d, nn.Conv3d: self._conv3d, nn.ConvTranspose1d: self._conv1d_transpose_module, - nn.ConvTranspose2d: self._conv2d_transpose, + nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm, nn.LayerNorm: self._layer_norm, @@ -1605,7 +1603,7 @@ def create_convert_map(self): partial(relax.op.linear_algebra.matmul, out_dtype="float32"), operator.matmul ), "conv_transpose1d": self._conv1d_transpose, - "conv_transpose2d": self._conv2d_transpose_functional, + "conv_transpose2d": self._conv2d_transpose, "conv1d": self._conv1d_functional, "conv2d": self._conv2d_functional, "conv3d": self._conv3d_functional, From 5001b3d5a754099d965e4c2f757280c6fd33b4c3 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:38:32 +0900 Subject: [PATCH 07/18] cleanup `_conv1d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 136 +++++++++--------- 1 file changed, 67 insertions(+), 69 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 05746110ced6..67a6ad423eb2 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -488,6 +488,71 @@ def _conv2d_transpose_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _conv1d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ) -> relax.Var: + conv1d = self.block_builder.emit( + relax.op.nn.conv1d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCW", + kernel_layout="OIW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv1d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1)) + return self.block_builder.emit(relax.op.add(conv1d, bias)) + + def _conv1d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv1d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv1d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -895,73 +960,6 @@ def _linear_functional(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv1d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ) -> relax.Var: - conv1d = self.block_builder.emit( - relax.op.nn.conv1d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCW", - kernel_layout="OIW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv1d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1)) - return self.block_builder.emit(relax.op.add(conv1d, bias)) - - def _conv1d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv1d_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv1d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv2d_impl( self, x: relax.Expr, @@ -1533,7 +1531,7 @@ def create_convert_map(self): nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d, - nn.Conv1d: self._conv1d, + nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d, nn.Conv3d: self._conv3d, nn.ConvTranspose1d: self._conv1d_transpose_module, @@ -1604,7 +1602,7 @@ def create_convert_map(self): ), "conv_transpose1d": self._conv1d_transpose, "conv_transpose2d": self._conv2d_transpose, - "conv1d": self._conv1d_functional, + "conv1d": self._conv1d, "conv2d": self._conv2d_functional, "conv3d": self._conv3d_functional, "cross_entropy": self._cross_entropy, From 6ada5fac6844df937edec3520864cf92bf13fdb0 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:39:59 +0900 Subject: [PATCH 08/18] cleanup `_conv2d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 136 +++++++++--------- 1 file changed, 67 insertions(+), 69 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 67a6ad423eb2..b17e9aef6d0b 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -553,6 +553,71 @@ def _conv1d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _conv2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _conv2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -960,73 +1025,6 @@ def _linear_functional(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv2d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv2d = self.block_builder.emit( - relax.op.nn.conv2d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d, bias)) - - def _conv2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv2d_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv3d_impl( self, x: relax.Expr, @@ -1532,7 +1530,7 @@ def create_convert_map(self): nn.AvgPool2d: self._avg_pool2d_module, nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d_module, - nn.Conv2d: self._conv2d, + nn.Conv2d: self._conv2d_module, nn.Conv3d: self._conv3d, nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose_module, @@ -1603,7 +1601,7 @@ def create_convert_map(self): "conv_transpose1d": self._conv1d_transpose, "conv_transpose2d": self._conv2d_transpose, "conv1d": self._conv1d, - "conv2d": self._conv2d_functional, + "conv2d": self._conv2d, "conv3d": self._conv3d_functional, "cross_entropy": self._cross_entropy, "einsum": self._einsum, From 76d63bdcb57bf896b0973b788489b9535b71bbba Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:41:07 +0900 Subject: [PATCH 09/18] cleanup `_conv3d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 136 +++++++++--------- 1 file changed, 67 insertions(+), 69 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index b17e9aef6d0b..d08301a135a5 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -618,6 +618,71 @@ def _conv2d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _conv3d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv3d = self.block_builder.emit( + relax.op.nn.conv3d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv3d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv3d, bias)) + + def _conv3d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _conv3d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + + return self._conv3d_impl( + x, + weight, + bias=bias, + strides=module.stride, + padding=module.padding, + dilation=module.dilation, + groups=module.groups, + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1025,73 +1090,6 @@ def _linear_functional(self, node: fx.Node) -> relax.Var: bias = args[2] if len(args) > 2 else None return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _conv3d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv3d = self.block_builder.emit( - relax.op.nn.conv3d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCDHW", - kernel_layout="OIDHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv3d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv3d, bias)) - - def _conv3d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None - if module.bias is not None: - bias = self.params[module.bias] - - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - def _conv3d_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv3d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _max_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1531,7 +1529,7 @@ def create_convert_map(self): nn.BatchNorm2d: self._batch_norm_2d, nn.Conv1d: self._conv1d_module, nn.Conv2d: self._conv2d_module, - nn.Conv3d: self._conv3d, + nn.Conv3d: self._conv3d_module, nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, @@ -1602,7 +1600,7 @@ def create_convert_map(self): "conv_transpose2d": self._conv2d_transpose, "conv1d": self._conv1d, "conv2d": self._conv2d, - "conv3d": self._conv3d_functional, + "conv3d": self._conv3d, "cross_entropy": self._cross_entropy, "einsum": self._einsum, "interpolate": self._interpolate, From c121c5f09570e6cfa70376f465abe0cb7fbd7389 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:41:55 +0900 Subject: [PATCH 10/18] cleanup `_einsum()` --- python/tvm/relax/frontend/torch/fx_translator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d08301a135a5..8717466a7c79 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -683,6 +683,13 @@ def _conv3d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) + def _einsum(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.einsum(operands, args[0])) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -868,14 +875,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _einsum(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - if isinstance(args[1], (torch.Size, tuple, list)): - return self.block_builder.emit(relax.op.einsum(tuple(args[1]), args[0])) - return self.block_builder.emit(relax.op.einsum(args[1:], args[0])) - def _unbind(self, node: fx.Node) -> relax.Var: if len(node.args) == 2: assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" From ae2ea5ad506692d7c5ecead9a18928920a39af88 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:43:21 +0900 Subject: [PATCH 11/18] cleanup `_embedding()` --- .../tvm/relax/frontend/torch/fx_translator.py | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8717466a7c79..a7edc96063a0 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -690,6 +690,29 @@ def _einsum(self, node: fx.Node) -> relax.Var: operands = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.einsum(operands, args[0])) + def _embedding_impl( + self, + x, + weight, + ) -> relax.Var: + x = self.block_builder.emit(relax.op.astype(x, "int32")) + + ndim = x.struct_info.ndim + if ndim == 1: + return self.block_builder.emit(relax.op.take(weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = weight.struct_info.shape.values[-1] + x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) + embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) + return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) + + def _embedding_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + return self._embedding_impl(x, weight) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1247,22 +1270,6 @@ def _group_norm(self, node: fx.Node) -> relax.Var: ) ) - def _embedding(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - x = self.block_builder.emit(relax.op.astype(x, "int32")) - - ndim = x.struct_info.ndim - if ndim == 1: - return self.block_builder.emit(relax.op.take(weight, x, axis=0)) - else: - x_shape = x.struct_info.shape.values - emb_size = weight.struct_info.shape.values[-1] - x = self.block_builder.emit(relax.op.reshape(x, shape=[-1])) - embedding = self.block_builder.emit(relax.op.take(weight, x, axis=0)) - return self.block_builder.emit(relax.op.reshape(embedding, [*x_shape, emb_size])) - def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, @@ -1536,7 +1543,7 @@ def create_convert_map(self): nn.LayerNorm: self._layer_norm, nn.Linear: self._linear, nn.MaxPool2d: self._max_pool2d, - nn.modules.sparse.Embedding: self._embedding, + nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation nn.Flatten: self._flatten, ## call_function and call_method From e5a1e6d8163806fc179d56734cd0373bf936f64d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:46:03 +0900 Subject: [PATCH 12/18] cleanup `_group_norm()` --- .../tvm/relax/frontend/torch/fx_translator.py | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index a7edc96063a0..5789956b919e 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -713,6 +713,33 @@ def _embedding_module(self, node: fx.Node) -> relax.Var: weight = self.params[module.weight] return self._embedding_impl(x, weight) + def _group_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + num_groups = module.num_groups + if module.affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) + beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) + eps = module.eps + + dim = len(self.shape_of(x)) + return self.block_builder.emit( + relax.op.nn.group_norm( + x, + gamma, + beta, + num_groups=num_groups, + channel_axis=1, + axes=list(range(2, dim)), + epsilon=eps, + ) + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1244,32 +1271,6 @@ def _layer_norm(self, node: fx.Node) -> relax.Var: ) ) - def _group_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - x = self.env[node.args[0]] - module = self.named_modules[node.target] - - if module.affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.num_channels), x.checked_type) - beta = relax.const(torch.zeros_like(module.num_channels), x.checked_type) - - dim = len(self.shape_of(x)) - return self.block_builder.emit( - relax.op.nn.group_norm( - x, - gamma, - beta, - num_groups=module.num_groups, - channel_axis=1, - axes=list(range(2, dim)), - epsilon=module.eps, - ) - ) - def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, @@ -1539,7 +1540,7 @@ def create_convert_map(self): nn.ConvTranspose1d: self._conv1d_transpose_module, nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, - nn.GroupNorm: self._group_norm, + nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm, nn.Linear: self._linear, nn.MaxPool2d: self._max_pool2d, From 41748aed4cd149df8c4cc68021519dd3e0d6ae3a Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:52:38 +0900 Subject: [PATCH 13/18] cleanup `_layer_norm()` --- .../tvm/relax/frontend/torch/fx_translator.py | 124 ++++++++---------- 1 file changed, 56 insertions(+), 68 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5789956b919e..1ca4e6a5dc9c 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -740,6 +740,61 @@ def _group_norm_module(self, node: fx.Node) -> relax.Var: ) ) + def _layer_norm_impl(self, x, gamma, beta, eps, normalized_shape) -> relax.Var: + from torch.fx.immutable_collections import immutable_list + import numpy as np # type: ignore + + if isinstance(normalized_shape, (immutable_list, tuple)): + normalized_shape = tuple(normalized_shape) + else: + try: + normalized_shape = self.env[normalized_shape] + except TypeError: + normalized_shape = tuple(normalized_shape) + + dim_num = len(normalized_shape) + axes = list(range(-dim_num, 0)) + + if gamma is None: + shape_tuple = [int(s) for s in normalized_shape] + gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) + if beta is None: + shape_tuple = [int(s) for s in normalized_shape] + beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) + + return self.block_builder.emit( + relax.op.nn.layer_norm( + x, + gamma, + beta, + axes=axes, + epsilon=eps, + ) + ) + + def _layer_norm(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + normalized_shape = node.args[1] + gamma = self.env[node.args[2]] if len(node.args) > 2 else None + beta = self.env[node.args[3]] if len(node.args) > 3 else None + eps = node.args[4] if len(node.args) > 4 else 1e-05 + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + + def _layer_norm_module(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + x = self.env[node.args[0]] + module = self.named_modules[node.target] + normalized_shape = module.normalized_shape + if module.elementwise_affine: + gamma = self.params[module.weight] + beta = self.params[module.bias] + else: + gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) + beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) + eps = module.eps + return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1204,73 +1259,6 @@ def _batch_norm_2d(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.TupleGetItem(res_tuple, 0)) - def _layer_norm(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - from torch.fx.immutable_collections import immutable_list - import numpy as np # type: ignore - - x = self.env[node.args[0]] - - # functional.layer_norm - if node.target not in self.named_modules: - # static or symbolic - arg = node.args[1] - if isinstance(arg, (immutable_list, tuple)): - value = tuple(arg) - else: - try: - value = self.env[arg] - except TypeError: - value = tuple(arg) - normalized_shape = value - dim_num = len(normalized_shape) - axes = list(range(-dim_num, 0)) - - gamma = node.kwargs["weight"] - if gamma is None: - shape_tuple = [int(s) for s in normalized_shape] - gamma = relax.const(np.ones(shape_tuple), x.struct_info.dtype) - else: - gamma = self.env[gamma] - beta = node.kwargs["bias"] - if beta is None: - shape_tuple = [int(s) for s in normalized_shape] - beta = relax.const(np.zeros(shape_tuple), x.struct_info.dtype) - else: - beta = self.env[beta] - eps = node.kwargs["eps"] - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=eps, - ) - ) - - module = self.named_modules[node.target] - - if module.elementwise_affine: - gamma = self.params[module.weight] - beta = self.params[module.bias] - else: - gamma = relax.const(torch.ones_like(module.normalized_shape), x.struct_info.dtype) - beta = relax.const(torch.zeros_like(module.normalized_shape), x.struct_info.dtype) - dim_num = len(module.normalized_shape) - axes = list(range(-dim_num, 0)) - - return self.block_builder.emit( - relax.op.nn.layer_norm( - x, - gamma, - beta, - axes=axes, - epsilon=module.eps, - ) - ) - def _interpolate(self, node: fx.Node) -> relax.Var: # torch.nn.functional.interpolate( # input, size=None, scale_factor=None, mode='nearest', align_corners=None, @@ -1541,7 +1529,7 @@ def create_convert_map(self): nn.ConvTranspose2d: self._conv2d_transpose_module, nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm_module, - nn.LayerNorm: self._layer_norm, + nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear, nn.MaxPool2d: self._max_pool2d, nn.modules.sparse.Embedding: self._embedding_module, From 4c5a15662048db1fe68f8886db545b41ed728b6b Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:53:54 +0900 Subject: [PATCH 14/18] cleanup `_linear()` --- .../tvm/relax/frontend/torch/fx_translator.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 1ca4e6a5dc9c..d23ccaafbcdc 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -795,6 +795,20 @@ def _layer_norm_module(self, node: fx.Node) -> relax.Var: eps = module.eps return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) + def _linear(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _linear_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + weight = self.params[module.weight] + bias = self.params.get(module.bias, None) + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1180,20 +1194,6 @@ def convert(node: fx.Node): ########## Neural Network ########## - def _linear(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - module = self.named_modules[node.target] - weight = self.params[module.weight] - bias = None if module.bias is None else self.params[module.bias] - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - - def _linear_functional(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _max_pool2d(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1530,7 +1530,7 @@ def create_convert_map(self): nn.CrossEntropyLoss: self._cross_entropy, nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, - nn.Linear: self._linear, + nn.Linear: self._linear_module, nn.MaxPool2d: self._max_pool2d, nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation @@ -1600,7 +1600,7 @@ def create_convert_map(self): "einsum": self._einsum, "interpolate": self._interpolate, "layer_norm": self._layer_norm, - "linear": self._linear_functional, + "linear": self._linear, "max_pool2d": self._max_pool2d, "scaled_dot_product_attention": self._scaled_dot_product_attention, "stochastic_depth": lambda node: self.env[node.args[0]], From 1dc0f86db3b8782b4b79610dead19da6124cda39 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:54:56 +0900 Subject: [PATCH 15/18] cleanup `_max_pool2d()` --- .../tvm/relax/frontend/torch/fx_translator.py | 77 +++++++++++-------- 1 file changed, 45 insertions(+), 32 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d23ccaafbcdc..502780254891 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -809,6 +809,50 @@ def _linear_module(self, node: fx.Node) -> relax.Var: bias = self.params.get(module.bias, None) return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + def _max_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None else stride + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _max_pool2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + def _max_pool2d_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + dilation = module.dilation + ceil_mode = module.ceil_mode + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1194,37 +1238,6 @@ def convert(node: fx.Node): ########## Neural Network ########## - def _max_pool2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - if node.target in self.named_modules: - module = self.named_modules[node.target] - kernel = module.kernel_size - stride = module.stride - padding = module.padding - dilation = module.dilation - ceil_mode = module.ceil_mode - else: - nargs = len(node.args) - kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - stride = node.args[2] if nargs > 2 else node.kwargs["stride"] - padding = node.args[3] if nargs > 3 else node.kwargs["padding"] - dilation = node.args[4] if nargs > 4 else node.kwargs["dilation"] - ceil_mode = node.args[5] if nargs > 5 else node.kwargs["ceil_mode"] - - stride = kernel if stride is None else stride - - return self.block_builder.emit( - relax.op.nn.max_pool2d( - x, - pool_size=kernel, - strides=stride, - padding=padding, - dilation=dilation, - layout="NCHW", - ceil_mode=ceil_mode, - ) - ) - def _softmax(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] if node.target in self.named_modules: @@ -1531,7 +1544,7 @@ def create_convert_map(self): nn.GroupNorm: self._group_norm_module, nn.LayerNorm: self._layer_norm_module, nn.Linear: self._linear_module, - nn.MaxPool2d: self._max_pool2d, + nn.MaxPool2d: self._max_pool2d_module, nn.modules.sparse.Embedding: self._embedding_module, # tensor manipulation nn.Flatten: self._flatten, From 78c50a6acca34523034587c6c15ad708477ea943 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:55:58 +0900 Subject: [PATCH 16/18] cleanup `_scaled_dot_product_attention()` --- .../tvm/relax/frontend/torch/fx_translator.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 502780254891..ef6a5469f0b8 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -853,6 +853,26 @@ def _max_pool2d_module(self, node: fx.Node) -> relax.Var: return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = transpose_S_H(self.env[node.args[0]]) + key = transpose_S_H(self.env[node.args[1]]) + value = transpose_S_H(self.env[node.args[2]]) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) + dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) + assert dropout_p == 0.0, "Dropout is not supported" + is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) + causal_mask = "TopLeft" if is_causal else None + + if attn_mask is not None: + attn_mask = self.env[attn_mask] + msg = "Only a float mask is supported for the attn_mask input." + assert "float" in attn_mask.struct_info.dtype, msg + + return self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) + ) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1381,26 +1401,6 @@ def _cross_entropy(self, node: fx.Node) -> relax.Expr: ) ) - def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - assert ( - len(node.args) <= 4 - ), "Dropout is not supported, and is_causal should be called by kwargs." - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) - causal_mask = "TopLeft" if node.kwargs.get("is_causal", False) else None - - if len(node.args) == 4: - mask = self.env[node.args[3]] - msg = "Only a float mask is supported for the attn_mask input." - assert "float" in mask.struct_info.dtype, msg - attn = relax.op.nn.attention(query, key, value, bias=mask, causal_mask=causal_mask) - else: - attn = relax.op.nn.attention(query, key, value, causal_mask=causal_mask) - - return self.block_builder.emit(attn) - ########## Others ########## def _sym_size_int(self, node: fx.Node) -> relax.Expr: From 544a902ecb9e18d4de221d6eea00adaed9a1aadb Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:57:19 +0900 Subject: [PATCH 17/18] cleanup `_unbind()` --- .../tvm/relax/frontend/torch/fx_translator.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ef6a5469f0b8..e12753d92eba 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -873,6 +873,17 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) ) + def _unbind(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) + assert isinstance(dim, int), "Expected 2nd argument of unbind as int" + selections = self.shape_of(x)[dim].value + n_section = list(range(1, selections + 1)) + ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) + for i in range(selections): + ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) + return self.block_builder.emit(relax.Tuple(ret)) + ########## Creation ########## def _arange(self, node: fx.Node) -> relax.Var: @@ -1058,22 +1069,6 @@ def _to(self, node: fx.Node) -> relax.Var: def _matmul_impl(self, a: relax.Expr, b: relax.Expr): return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - def _unbind(self, node: fx.Node) -> relax.Var: - if len(node.args) == 2: - assert isinstance(node.args[1], int), "Expected 2nd argument of unbind as int" - dim = node.args[1] - elif "dim" in node.kwargs: - dim = node.kwargs["dim"] - else: - dim = 0 - x = self.env[node.args[0]] - selections = self.shape_of(x)[dim].value - n_section = list(range(1, selections + 1)) - ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, dim)) - for i in range(selections): - ret.append(self.block_builder.emit(relax.op.squeeze(split[i], axis=dim))) - return self.block_builder.emit(relax.Tuple(ret)) - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: From ef3df9db78abeebe644060d93f1ff36d5a5aa1d9 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 12 Sep 2024 22:57:56 +0900 Subject: [PATCH 18/18] remove `_matmul_impl()` since we don't use it anymore --- python/tvm/relax/frontend/torch/fx_translator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index e12753d92eba..1c4796a533a4 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1064,11 +1064,6 @@ def _to(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.astype(x, dtype)) return x - ########## Linear Algebra ########## - - def _matmul_impl(self, a: relax.Expr, b: relax.Expr): - return self.block_builder.emit(relax.op.linear_algebra.matmul(a, b, out_dtype="float32")) - ########## Manipulation ########## def _cat(self, node: fx.Node) -> relax.Var: