Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ repos:
pytensor/tensor/var\.py|
)$
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
Expand Down
4 changes: 1 addition & 3 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import logging
import warnings
from typing import Optional, Tuple, Union

from typing_extensions import Literal
from typing import Literal, Optional, Tuple, Union

from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
Expand Down
2 changes: 1 addition & 1 deletion pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Dict,
List,
Literal,
Mapping,
MutableSequence,
Optional,
Expand All @@ -18,7 +19,6 @@
)

import numpy as np
from typing_extensions import Literal

import pytensor
from pytensor.compile.ops import ViewOp
Expand Down
3 changes: 1 addition & 2 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
Expand All @@ -15,8 +16,6 @@
cast,
)

from typing_extensions import Literal

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, AtomicVariable, Variable, applys_between
Expand Down
3 changes: 1 addition & 2 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
Dict,
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
Union,
cast,
)

from typing_extensions import Protocol

import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, NoParams, Variable
Expand Down
6 changes: 2 additions & 4 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
from itertools import chain
from typing import TYPE_CHECKING, Callable, Dict
from typing import Iterable as IterableType
from typing import List, Optional, Sequence, Tuple, Union, cast

from typing_extensions import Literal
from typing import List, Literal, Optional, Sequence, Tuple, Union, cast

import pytensor
from pytensor.configdefaults import config
Expand Down Expand Up @@ -1185,7 +1183,7 @@ def _find_impl(self, cls) -> List[NodeRewriter]:
matches.extend(match)
return matches

@functools.lru_cache()
@functools.lru_cache
def get_trackers(self, op: Op) -> List[NodeRewriter]:
"""Get all the rewrites applicable to `op`."""
return (
Expand Down
13 changes: 11 additions & 2 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@
import time
import warnings
from io import BytesIO, StringIO
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Protocol,
Set,
Tuple,
cast,
)

import numpy as np
from setuptools._distutils.sysconfig import (
Expand All @@ -28,7 +38,6 @@
get_python_inc,
get_python_lib,
)
from typing_extensions import Protocol

# we will abuse the lockfile mechanism when reading and writing the registry
from pytensor.compile.compilelock import lock_ctx
Expand Down
36 changes: 17 additions & 19 deletions pytensor/link/numba/dispatch/elemwise_codegen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, List, Optional, Tuple
from typing import Any

import numba
import numpy as np
Expand All @@ -14,8 +14,8 @@
def compute_itershape(
ctx: BaseContext,
builder: ir.IRBuilder,
in_shapes: Tuple[ir.Instruction, ...],
broadcast_pattern: Tuple[Tuple[bool, ...], ...],
in_shapes: tuple[ir.Instruction, ...],
broadcast_pattern: tuple[tuple[bool, ...], ...],
):
one = ir.IntType(64)(1)
ndim = len(in_shapes[0])
Expand Down Expand Up @@ -63,12 +63,12 @@ def compute_itershape(
def make_outputs(
ctx: numba.core.base.BaseContext,
builder: ir.IRBuilder,
iter_shape: Tuple[ir.Instruction, ...],
out_bc: Tuple[Tuple[bool, ...], ...],
dtypes: Tuple[Any, ...],
inplace: Tuple[Tuple[int, int], ...],
inputs: Tuple[Any, ...],
input_types: Tuple[Any, ...],
iter_shape: tuple[ir.Instruction, ...],
out_bc: tuple[tuple[bool, ...], ...],
dtypes: tuple[Any, ...],
inplace: tuple[tuple[int, int], ...],
inputs: tuple[Any, ...],
input_types: tuple[Any, ...],
):
arrays = []
ar_types: list[types.Array] = []
Expand Down Expand Up @@ -106,13 +106,13 @@ def make_loop_call(
builder: ir.IRBuilder,
scalar_func: Any,
scalar_signature: types.FunctionType,
iter_shape: Tuple[ir.Instruction, ...],
inputs: Tuple[ir.Instruction, ...],
outputs: Tuple[ir.Instruction, ...],
input_bc: Tuple[Tuple[bool, ...], ...],
output_bc: Tuple[Tuple[bool, ...], ...],
input_types: Tuple[Any, ...],
output_types: Tuple[Any, ...],
iter_shape: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...],
outputs: tuple[ir.Instruction, ...],
input_bc: tuple[tuple[bool, ...], ...],
output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...],
output_types: tuple[Any, ...],
):
safe = (False, False)
n_outputs = len(outputs)
Expand Down Expand Up @@ -150,9 +150,7 @@ def extract_array(aryty, obj):
# This part corresponds to opening the loops
loop_stack = []
loops = []
output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [
(None, None)
] * n_outputs
output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs
for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left
for output in range(n_outputs):
Expand Down
14 changes: 12 additions & 2 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@
from copy import copy
from functools import reduce, singledispatch
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
TextIO,
Tuple,
Union,
)

import numpy as np
from typing_extensions import Literal

from pytensor.compile import Function, SharedVariable
from pytensor.compile.io import In, Out
Expand Down
3 changes: 1 addition & 2 deletions pytensor/sparse/type.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Iterable, Optional, Union
from typing import Iterable, Literal, Optional, Union

import numpy as np
import scipy.sparse
from typing_extensions import Literal

import pytensor
from pytensor import scalar as aes
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
augmented_dims = {d - rv_op.ndim_supp for d in ds_op.augment}
if (shuffled_dims | augmented_dims) & supp_dims:
return False

Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union

import numpy as np
from typing_extensions import Literal

from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import warnings
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Literal, Union

import numpy as np
import scipy.linalg
from typing_extensions import Literal

import pytensor
import pytensor.tensor as pt
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import logging
import warnings
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Literal, Optional, Tuple, Union

import numpy as np
from typing_extensions import Literal

import pytensor
from pytensor import scalar as aes
Expand Down
4 changes: 2 additions & 2 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,9 @@ def test_inner_graph_optimized():

# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout"))
(scan_node,) = [
(scan_node,) = (
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
)
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
assert len(inner_scan_nodes) == 1
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
Expand Down