|
103 | 103 | from pytensor.link.c.params_type import ParamsType |
104 | 104 | from pytensor.printing import FunctionPrinter, pprint |
105 | 105 | from pytensor.scalar import bool as bool_t |
106 | | -from pytensor.tensor import basic as ptb |
| 106 | +from pytensor.tensor.basic import as_tensor_variable, cast |
107 | 107 | from pytensor.tensor.blas_headers import blas_header_text, blas_header_version |
108 | 108 | from pytensor.tensor.math import dot, tensordot |
109 | 109 | from pytensor.tensor.shape import specify_broadcastable |
@@ -157,11 +157,11 @@ def __str__(self): |
157 | 157 | return f"{self.__class__.__name__}{{no_inplace}}" |
158 | 158 |
|
159 | 159 | def make_node(self, y, alpha, A, x, beta): |
160 | | - y = ptb.as_tensor_variable(y) |
161 | | - x = ptb.as_tensor_variable(x) |
162 | | - A = ptb.as_tensor_variable(A) |
163 | | - alpha = ptb.as_tensor_variable(alpha) |
164 | | - beta = ptb.as_tensor_variable(beta) |
| 160 | + y = as_tensor_variable(y) |
| 161 | + x = as_tensor_variable(x) |
| 162 | + A = as_tensor_variable(A) |
| 163 | + alpha = as_tensor_variable(alpha) |
| 164 | + beta = as_tensor_variable(beta) |
165 | 165 | if y.dtype != A.dtype or y.dtype != x.dtype: |
166 | 166 | raise TypeError( |
167 | 167 | "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) |
@@ -257,10 +257,10 @@ def __str__(self): |
257 | 257 | return f"{self.__class__.__name__}{{non-destructive}}" |
258 | 258 |
|
259 | 259 | def make_node(self, A, alpha, x, y): |
260 | | - A = ptb.as_tensor_variable(A) |
261 | | - y = ptb.as_tensor_variable(y) |
262 | | - x = ptb.as_tensor_variable(x) |
263 | | - alpha = ptb.as_tensor_variable(alpha) |
| 260 | + A = as_tensor_variable(A) |
| 261 | + y = as_tensor_variable(y) |
| 262 | + x = as_tensor_variable(x) |
| 263 | + alpha = as_tensor_variable(alpha) |
264 | 264 | if not (A.dtype == x.dtype == y.dtype == alpha.dtype): |
265 | 265 | raise TypeError( |
266 | 266 | "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) |
@@ -859,7 +859,7 @@ def __getstate__(self): |
859 | 859 | return rval |
860 | 860 |
|
861 | 861 | def make_node(self, *inputs): |
862 | | - inputs = list(map(ptb.as_tensor_variable, inputs)) |
| 862 | + inputs = list(map(as_tensor_variable, inputs)) |
863 | 863 |
|
864 | 864 | if any(not isinstance(i.type, DenseTensorType) for i in inputs): |
865 | 865 | raise NotImplementedError("Only dense tensor types are supported") |
@@ -1129,8 +1129,8 @@ class Dot22(GemmRelated): |
1129 | 1129 | check_input = False |
1130 | 1130 |
|
1131 | 1131 | def make_node(self, x, y): |
1132 | | - x = ptb.as_tensor_variable(x) |
1133 | | - y = ptb.as_tensor_variable(y) |
| 1132 | + x = as_tensor_variable(x) |
| 1133 | + y = as_tensor_variable(y) |
1134 | 1134 |
|
1135 | 1135 | if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): |
1136 | 1136 | raise NotImplementedError("Only dense tensor types are supported") |
@@ -1322,8 +1322,8 @@ class BatchedDot(COp): |
1322 | 1322 | gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" |
1323 | 1323 |
|
1324 | 1324 | def make_node(self, x, y): |
1325 | | - x = ptb.as_tensor_variable(x) |
1326 | | - y = ptb.as_tensor_variable(y) |
| 1325 | + x = as_tensor_variable(x) |
| 1326 | + y = as_tensor_variable(y) |
1327 | 1327 |
|
1328 | 1328 | if not ( |
1329 | 1329 | isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) |
@@ -1357,7 +1357,7 @@ def extract_static_dim(dim_x, dim_y): |
1357 | 1357 |
|
1358 | 1358 | # Change dtype if needed |
1359 | 1359 | dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) |
1360 | | - x, y = ptb.cast(x, dtype), ptb.cast(y, dtype) |
| 1360 | + x, y = cast(x, dtype), cast(y, dtype) |
1361 | 1361 | out = tensor(dtype=dtype, shape=out_shape) |
1362 | 1362 | return Apply(self, [x, y], [out]) |
1363 | 1363 |
|
@@ -1738,7 +1738,7 @@ def batched_dot(a, b): |
1738 | 1738 | "Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`", |
1739 | 1739 | FutureWarning, |
1740 | 1740 | ) |
1741 | | - a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b) |
| 1741 | + a, b = as_tensor_variable(a), as_tensor_variable(b) |
1742 | 1742 |
|
1743 | 1743 | if a.ndim == 0: |
1744 | 1744 | raise TypeError("a must have at least one (batch) axis") |
|
0 commit comments