From 6e56c04bd12c5ad0ffe07ccc600e0c41fbf24056 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Tue, 7 Feb 2023 11:48:26 -0500 Subject: [PATCH] Add pyupgrade to pre-commit and apply it --- .pre-commit-config.yaml | 5 +++ pytensor/compile/mode.py | 4 +-- pytensor/gradient.py | 2 +- pytensor/graph/fg.py | 3 +- pytensor/graph/op.py | 3 +- pytensor/graph/rewriting/basic.py | 6 ++-- pytensor/link/c/cmodule.py | 13 +++++-- .../link/numba/dispatch/elemwise_codegen.py | 36 +++++++++---------- pytensor/printing.py | 14 ++++++-- pytensor/sparse/type.py | 3 +- pytensor/tensor/random/rewriting/basic.py | 2 +- pytensor/tensor/random/utils.py | 3 +- pytensor/tensor/slinalg.py | 3 +- pytensor/tensor/type.py | 3 +- tests/link/numba/test_scan.py | 4 +-- 15 files changed, 58 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 106c42438b..7a87b0e4f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,11 @@ repos: pytensor/tensor/var\.py| )$ - id: check-merge-conflict + - repo: https://github.com/asottile/pyupgrade + rev: v3.3.1 + hooks: + - id: pyupgrade + args: [--py38-plus] - repo: https://github.com/psf/black rev: 22.10.0 hooks: diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8aecf1a902..ac9bc6b83d 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -5,9 +5,7 @@ import logging import warnings -from typing import Optional, Tuple, Union - -from typing_extensions import Literal +from typing import Literal, Optional, Tuple, Union from pytensor.compile.function.types import Supervisor from pytensor.configdefaults import config diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 29106dec70..b85f10c85a 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -8,6 +8,7 @@ Callable, Dict, List, + Literal, Mapping, MutableSequence, Optional, @@ -18,7 +19,6 @@ ) import numpy as np -from typing_extensions import Literal import pytensor from pytensor.compile.ops import ViewOp diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index d815316097..2c89691244 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -7,6 +7,7 @@ Dict, Iterable, List, + Literal, Optional, Sequence, Set, @@ -15,8 +16,6 @@ cast, ) -from typing_extensions import Literal - import pytensor from pytensor.configdefaults import config from pytensor.graph.basic import Apply, AtomicVariable, Variable, applys_between diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index 10f1057e37..b59394eb32 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -9,6 +9,7 @@ Dict, List, Optional, + Protocol, Sequence, Tuple, TypeVar, @@ -16,8 +17,6 @@ cast, ) -from typing_extensions import Protocol - import pytensor from pytensor.configdefaults import config from pytensor.graph.basic import Apply, NoParams, Variable diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 1af3ff743f..9117295a74 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -15,9 +15,7 @@ from itertools import chain from typing import TYPE_CHECKING, Callable, Dict from typing import Iterable as IterableType -from typing import List, Optional, Sequence, Tuple, Union, cast - -from typing_extensions import Literal +from typing import List, Literal, Optional, Sequence, Tuple, Union, cast import pytensor from pytensor.configdefaults import config @@ -1185,7 +1183,7 @@ def _find_impl(self, cls) -> List[NodeRewriter]: matches.extend(match) return matches - @functools.lru_cache() + @functools.lru_cache def get_trackers(self, op: Op) -> List[NodeRewriter]: """Get all the rewrites applicable to `op`.""" return ( diff --git a/pytensor/link/c/cmodule.py b/pytensor/link/c/cmodule.py index 20530803d6..606b48ef16 100644 --- a/pytensor/link/c/cmodule.py +++ b/pytensor/link/c/cmodule.py @@ -19,7 +19,17 @@ import time import warnings from io import BytesIO, StringIO -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, cast +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Optional, + Protocol, + Set, + Tuple, + cast, +) import numpy as np from setuptools._distutils.sysconfig import ( @@ -28,7 +38,6 @@ get_python_inc, get_python_lib, ) -from typing_extensions import Protocol # we will abuse the lockfile mechanism when reading and writing the registry from pytensor.compile.compilelock import lock_ctx diff --git a/pytensor/link/numba/dispatch/elemwise_codegen.py b/pytensor/link/numba/dispatch/elemwise_codegen.py index 0060191ad7..3138110046 100644 --- a/pytensor/link/numba/dispatch/elemwise_codegen.py +++ b/pytensor/link/numba/dispatch/elemwise_codegen.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple +from typing import Any import numba import numpy as np @@ -14,8 +14,8 @@ def compute_itershape( ctx: BaseContext, builder: ir.IRBuilder, - in_shapes: Tuple[ir.Instruction, ...], - broadcast_pattern: Tuple[Tuple[bool, ...], ...], + in_shapes: tuple[ir.Instruction, ...], + broadcast_pattern: tuple[tuple[bool, ...], ...], ): one = ir.IntType(64)(1) ndim = len(in_shapes[0]) @@ -63,12 +63,12 @@ def compute_itershape( def make_outputs( ctx: numba.core.base.BaseContext, builder: ir.IRBuilder, - iter_shape: Tuple[ir.Instruction, ...], - out_bc: Tuple[Tuple[bool, ...], ...], - dtypes: Tuple[Any, ...], - inplace: Tuple[Tuple[int, int], ...], - inputs: Tuple[Any, ...], - input_types: Tuple[Any, ...], + iter_shape: tuple[ir.Instruction, ...], + out_bc: tuple[tuple[bool, ...], ...], + dtypes: tuple[Any, ...], + inplace: tuple[tuple[int, int], ...], + inputs: tuple[Any, ...], + input_types: tuple[Any, ...], ): arrays = [] ar_types: list[types.Array] = [] @@ -106,13 +106,13 @@ def make_loop_call( builder: ir.IRBuilder, scalar_func: Any, scalar_signature: types.FunctionType, - iter_shape: Tuple[ir.Instruction, ...], - inputs: Tuple[ir.Instruction, ...], - outputs: Tuple[ir.Instruction, ...], - input_bc: Tuple[Tuple[bool, ...], ...], - output_bc: Tuple[Tuple[bool, ...], ...], - input_types: Tuple[Any, ...], - output_types: Tuple[Any, ...], + iter_shape: tuple[ir.Instruction, ...], + inputs: tuple[ir.Instruction, ...], + outputs: tuple[ir.Instruction, ...], + input_bc: tuple[tuple[bool, ...], ...], + output_bc: tuple[tuple[bool, ...], ...], + input_types: tuple[Any, ...], + output_types: tuple[Any, ...], ): safe = (False, False) n_outputs = len(outputs) @@ -150,9 +150,7 @@ def extract_array(aryty, obj): # This part corresponds to opening the loops loop_stack = [] loops = [] - output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [ - (None, None) - ] * n_outputs + output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs for dim, length in enumerate(iter_shape): # Find outputs that only have accumulations left for output in range(n_outputs): diff --git a/pytensor/printing.py b/pytensor/printing.py index e7f9738426..8b24884944 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -9,10 +9,20 @@ from copy import copy from functools import reduce, singledispatch from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + TextIO, + Tuple, + Union, +) import numpy as np -from typing_extensions import Literal from pytensor.compile import Function, SharedVariable from pytensor.compile.io import In, Out diff --git a/pytensor/sparse/type.py b/pytensor/sparse/type.py index a38b80179a..c5fd52cced 100644 --- a/pytensor/sparse/type.py +++ b/pytensor/sparse/type.py @@ -1,8 +1,7 @@ -from typing import Iterable, Optional, Union +from typing import Iterable, Literal, Optional, Union import numpy as np import scipy.sparse -from typing_extensions import Literal import pytensor from pytensor import scalar as aes diff --git a/pytensor/tensor/random/rewriting/basic.py b/pytensor/tensor/random/rewriting/basic.py index bef13a9189..8b2cac3e97 100644 --- a/pytensor/tensor/random/rewriting/basic.py +++ b/pytensor/tensor/random/rewriting/basic.py @@ -143,7 +143,7 @@ def local_dimshuffle_rv_lift(fgraph, node): # Check that Dimshuffle does not affect support dims supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim)) shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i} - augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment) + augmented_dims = {d - rv_op.ndim_supp for d in ds_op.augment} if (shuffled_dims | augmented_dims) & supp_dims: return False diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 329581f48b..a74658c21a 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -2,10 +2,9 @@ from functools import wraps from itertools import zip_longest from types import ModuleType -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np -from typing_extensions import Literal from pytensor.compile.sharedvalue import shared from pytensor.graph.basic import Constant, Variable diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index b268c7522e..0ae4a6c976 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -1,10 +1,9 @@ import logging import warnings -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Literal, Union import numpy as np import scipy.linalg -from typing_extensions import Literal import pytensor import pytensor.tensor as pt diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 66cb60a6e3..d1826ff674 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -1,9 +1,8 @@ import logging import warnings -from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union +from typing import TYPE_CHECKING, Iterable, Literal, Optional, Tuple, Union import numpy as np -from typing_extensions import Literal import pytensor from pytensor import scalar as aes diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 0380ae0d92..27558f1d13 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -435,9 +435,9 @@ def test_inner_graph_optimized(): # Disable scan pushout, in which case the whole scan is replaced by an Elemwise f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout")) - (scan_node,) = [ + (scan_node,) = ( node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) - ] + ) inner_scan_nodes = scan_node.op.fgraph.apply_nodes assert len(inner_scan_nodes) == 1 (inner_scan_node,) = scan_node.op.fgraph.apply_nodes