Skip to content

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Sep 3, 2024

With torch==2.4, the normalized_shape argument of the torch.nn.functional.layer_norm can be torch.fx.immutable_collections.immutable_list.
This PR update the layer_norm conveter to cast it to tuple when the normalized_shape is immutable_list.

>       verify_model(model, input_info, binding, expected3)

tests/python/relax/test_frontend_from_fx.py:1323: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/python/relax/test_frontend_from_fx.py:36: in verify_model
    mod = from_fx(graph_model, input_info)
python/tvm/relax/frontend/torch/fx_translator.py:1770: in from_fx
    return TorchFXImporter().from_fx(
python/tvm/relax/frontend/torch/fx_translator.py:1653: in from_fx
    self.env[node] = self.convert_map[func_name](node)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <tvm.relax.frontend.torch.fx_translator.TorchFXImporter object at 0x778d11543da0>, node = layer_norm

    def _layer_norm(self, node: fx.node.Node) -> relax.Var:
        import torch  # type: ignore
        import numpy as np  # type: ignore
    
        x = self.env[node.args[0]]
    
        # functional.layer_norm
        if node.target not in self.named_modules:
            # static or symbolic
            arg = node.args[1]
            if isinstance(arg, tuple):
                value = arg
            else:
                try:
>                   value = self.env[arg]
E                   KeyError: [10, 10]

python/tvm/relax/frontend/torch/fx_translator.py:1084: KeyError

cc @vinx13 @yongwww @Hzfengsy

@mshr-h mshr-h force-pushed the fix-layer-norm-arg branch from eee422f to 29c5a15 Compare September 3, 2024 12:15
@yongwww yongwww merged commit e19541d into apache:main Sep 4, 2024
@mshr-h mshr-h deleted the fix-layer-norm-arg branch September 5, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants