diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 676f63b5c359..28d9a9f52373 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1069,6 +1069,7 @@ def _batch_norm_2d(self, node: fx.node.Node) -> relax.Var: def _layer_norm(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore + from torch.fx.immutable_collections import immutable_list import numpy as np # type: ignore x = self.env[node.args[0]] @@ -1077,8 +1078,8 @@ def _layer_norm(self, node: fx.node.Node) -> relax.Var: if node.target not in self.named_modules: # static or symbolic arg = node.args[1] - if isinstance(arg, tuple): - value = arg + if isinstance(arg, (immutable_list, tuple)): + value = tuple(arg) else: try: value = self.env[arg]