From 2426bc5be5ab4b98d296273292b02f58e1f8da82 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 23 May 2023 15:17:53 -0700 Subject: [PATCH 1/4] fix and test get_imports for multiline try blocks, and excepts with specific errors --- src/transformers/dynamic_module_utils.py | 2 +- tests/utils/test_dynamic_module_utils.py | 81 ++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_dynamic_module_utils.py diff --git a/src/transformers/dynamic_module_utils.py b/src/transformers/dynamic_module_utils.py index e7ee18a278fa..8492b4c2d56c 100644 --- a/src/transformers/dynamic_module_utils.py +++ b/src/transformers/dynamic_module_utils.py @@ -123,7 +123,7 @@ def get_imports(filename): content = f.read() # filter out try/except block so in custom code we can have try/except imports - content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*:", "", content, flags=re.MULTILINE) + content = re.sub(r"\s*try\s*:\s*.*?\s*except\s*.*?:", "", content, flags=re.MULTILINE | re.DOTALL) # Imports of the form `import xxx` imports = re.findall(r"^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE) diff --git a/tests/utils/test_dynamic_module_utils.py b/tests/utils/test_dynamic_module_utils.py new file mode 100644 index 000000000000..04beecdd6ec9 --- /dev/null +++ b/tests/utils/test_dynamic_module_utils.py @@ -0,0 +1,81 @@ +import os +import pytest + +from transformers.dynamic_module_utils import get_imports + +TOP_LEVEL_IMPORT = """ +import os +""" + +IMPORT_IN_FUNCTION = """ +def foo(): + import os + return False +""" + +DEEPLY_NESTED_IMPORT = """ +def foo(): + def bar(): + if True: + import os + return False + return bar() +""" + +TOP_LEVEL_TRY_IMPORT = """ +import os + +try: + import bar +except ImportError: + raise ValueError() +""" + +GENERIC_EXCEPT_IMPORT = """ +import os + +try: + import bar +except: + raise ValueError() +""" + +MULTILINE_TRY_IMPORT = """ +import os + +try: + import bar + import baz +except ImportError: + raise ValueError() +""" + +MULTILINE_BOTH_IMPORT = """ +import os + +try: + import bar + import baz +except ImportError: + x = 1 + raise ValueError() +""" + +CASES = [ + TOP_LEVEL_IMPORT, + IMPORT_IN_FUNCTION, + DEEPLY_NESTED_IMPORT, + TOP_LEVEL_TRY_IMPORT, + GENERIC_EXCEPT_IMPORT, + MULTILINE_TRY_IMPORT, + MULTILINE_BOTH_IMPORT +] + +@pytest.mark.parametrize('case', CASES) +def test_import_parsing(tmp_path, case): + tmp_file_path = os.path.join(tmp_path, 'test_file.py') + with open(tmp_file_path, 'w') as _tmp_file: + _tmp_file.write(case) + + parsed_imports = get_imports(tmp_file_path) + assert parsed_imports == ['os'] \ No newline at end of file From a07a38bef2a6c5af4833d18dd785a7e4e78b77d0 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 23 May 2023 15:19:03 -0700 Subject: [PATCH 2/4] fixup --- tests/utils/test_dynamic_module_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/utils/test_dynamic_module_utils.py b/tests/utils/test_dynamic_module_utils.py index 04beecdd6ec9..7c15a12f91f6 100644 --- a/tests/utils/test_dynamic_module_utils.py +++ b/tests/utils/test_dynamic_module_utils.py @@ -1,8 +1,10 @@ import os + import pytest from transformers.dynamic_module_utils import get_imports + TOP_LEVEL_IMPORT = """ import os """ @@ -68,14 +70,15 @@ def bar(): TOP_LEVEL_TRY_IMPORT, GENERIC_EXCEPT_IMPORT, MULTILINE_TRY_IMPORT, - MULTILINE_BOTH_IMPORT + MULTILINE_BOTH_IMPORT, ] -@pytest.mark.parametrize('case', CASES) + +@pytest.mark.parametrize("case", CASES) def test_import_parsing(tmp_path, case): - tmp_file_path = os.path.join(tmp_path, 'test_file.py') - with open(tmp_file_path, 'w') as _tmp_file: + tmp_file_path = os.path.join(tmp_path, "test_file.py") + with open(tmp_file_path, "w") as _tmp_file: _tmp_file.write(case) parsed_imports = get_imports(tmp_file_path) - assert parsed_imports == ['os'] \ No newline at end of file + assert parsed_imports == ["os"] From 44a89fc86fc2d8489d88093533c10d5b78f14125 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 23 May 2023 15:26:25 -0700 Subject: [PATCH 3/4] add some more tests --- tests/utils/test_dynamic_module_utils.py | 31 ++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/utils/test_dynamic_module_utils.py b/tests/utils/test_dynamic_module_utils.py index 7c15a12f91f6..3d76791c734a 100644 --- a/tests/utils/test_dynamic_module_utils.py +++ b/tests/utils/test_dynamic_module_utils.py @@ -33,6 +33,34 @@ def bar(): raise ValueError() """ +TRY_IMPORT_IN_FUNCTION = """ +import os + +def foo(): + try: + import bar + except ImportError: + raise ValueError() +""" + +MULTIPLE_EXCEPTS_IMPORT = """ +import os + +try: + import bar +except (ImportError, AttributeError): + raise ValueError() +""" + +EXCEPT_AS_IMPORT = """ +import os + +try: + import bar +except ImportError as e: + raise ValueError() +""" + GENERIC_EXCEPT_IMPORT = """ import os @@ -71,6 +99,9 @@ def bar(): GENERIC_EXCEPT_IMPORT, MULTILINE_TRY_IMPORT, MULTILINE_BOTH_IMPORT, + MULTIPLE_EXCEPTS_IMPORT, + EXCEPT_AS_IMPORT, + TRY_IMPORT_IN_FUNCTION, ] From 93085a7803e97eb52bd49408f6e20c85eaf71054 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Wed, 24 May 2023 10:03:09 -0700 Subject: [PATCH 4/4] add license --- tests/utils/test_dynamic_module_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/utils/test_dynamic_module_utils.py b/tests/utils/test_dynamic_module_utils.py index 3d76791c734a..dfdc63460cd3 100644 --- a/tests/utils/test_dynamic_module_utils.py +++ b/tests/utils/test_dynamic_module_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import pytest