Skip to content

Commit 7f6098e

Browse files
dakinggggojiteji
authored andcommitted
Fix the regex in get_imports to support multiline try blocks and excepts with specific exception types (huggingface#23725)
* fix and test get_imports for multiline try blocks, and excepts with specific errors * fixup * add some more tests * add license
1 parent 121ca59 commit 7f6098e

File tree

2 files changed

+130
-1
lines changed

2 files changed

+130
-1
lines changed

src/transformers/dynamic_module_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def get_imports(filename):
123123
content = f.read()
124124

125125
# filter out try/except block so in custom code we can have try/except imports
126-
content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*:", "", content, flags=re.MULTILINE)
126+
content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL)
127127

128128
# Imports of the form `import xxx`
129129
imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2023 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import pytest
18+
19+
from transformers.dynamic_module_utils import get_imports
20+
21+
22+
TOP_LEVEL_IMPORT = """
23+
import os
24+
"""
25+
26+
IMPORT_IN_FUNCTION = """
27+
def foo():
28+
import os
29+
return False
30+
"""
31+
32+
DEEPLY_NESTED_IMPORT = """
33+
def foo():
34+
def bar():
35+
if True:
36+
import os
37+
return False
38+
return bar()
39+
"""
40+
41+
TOP_LEVEL_TRY_IMPORT = """
42+
import os
43+
44+
try:
45+
import bar
46+
except ImportError:
47+
raise ValueError()
48+
"""
49+
50+
TRY_IMPORT_IN_FUNCTION = """
51+
import os
52+
53+
def foo():
54+
try:
55+
import bar
56+
except ImportError:
57+
raise ValueError()
58+
"""
59+
60+
MULTIPLE_EXCEPTS_IMPORT = """
61+
import os
62+
63+
try:
64+
import bar
65+
except (ImportError, AttributeError):
66+
raise ValueError()
67+
"""
68+
69+
EXCEPT_AS_IMPORT = """
70+
import os
71+
72+
try:
73+
import bar
74+
except ImportError as e:
75+
raise ValueError()
76+
"""
77+
78+
GENERIC_EXCEPT_IMPORT = """
79+
import os
80+
81+
try:
82+
import bar
83+
except:
84+
raise ValueError()
85+
"""
86+
87+
MULTILINE_TRY_IMPORT = """
88+
import os
89+
90+
try:
91+
import bar
92+
import baz
93+
except ImportError:
94+
raise ValueError()
95+
"""
96+
97+
MULTILINE_BOTH_IMPORT = """
98+
import os
99+
100+
try:
101+
import bar
102+
import baz
103+
except ImportError:
104+
x = 1
105+
raise ValueError()
106+
"""
107+
108+
CASES = [
109+
TOP_LEVEL_IMPORT,
110+
IMPORT_IN_FUNCTION,
111+
DEEPLY_NESTED_IMPORT,
112+
TOP_LEVEL_TRY_IMPORT,
113+
GENERIC_EXCEPT_IMPORT,
114+
MULTILINE_TRY_IMPORT,
115+
MULTILINE_BOTH_IMPORT,
116+
MULTIPLE_EXCEPTS_IMPORT,
117+
EXCEPT_AS_IMPORT,
118+
TRY_IMPORT_IN_FUNCTION,
119+
]
120+
121+
122+
@pytest.mark.parametrize("case", CASES)
123+
def test_import_parsing(tmp_path, case):
124+
tmp_file_path = os.path.join(tmp_path, "test_file.py")
125+
with open(tmp_file_path, "w") as _tmp_file:
126+
_tmp_file.write(case)
127+
128+
parsed_imports = get_imports(tmp_file_path)
129+
assert parsed_imports == ["os"]

0 commit comments

Comments
 (0)