Skip to content

Commit 6802239

Browse files
fix typos, gradients
1 parent bb9b02b commit 6802239

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

pytensor/gradient.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,17 +1815,22 @@ def random_projection(shape, dtype):
18151815
# This sum() is defined above, it's not the builtin sum.
18161816
if sum_outputs:
18171817
t_rs = [
1818-
shared(random_projection(o.shape, o.dtype), borrow=True) for o in o_fn_out
1818+
shared(
1819+
value=random_projection(o.shape, o.dtype),
1820+
borrow=True,
1821+
name=f"random_projection_{i}",
1822+
)
1823+
for i, o in enumerate(o_fn_out)
18191824
]
1820-
for i, x in enumerate(t_rs):
1821-
x.name = "ranom_projection_{i}"
18221825
cost = pytensor.tensor.sum(
18231826
[pytensor.tensor.sum(x * y) for x, y in zip(t_rs, o_output)]
18241827
)
18251828
else:
1826-
t_r = shared(random_projection(o_fn_out.shape, o_fn_out.dtype), borrow=True)
1827-
t_r.name = "random_projection"
1828-
1829+
t_r = shared(
1830+
value=random_projection(o_fn_out.shape, o_fn_out.dtype),
1831+
borrow=True,
1832+
name="random_projection",
1833+
)
18291834
cost = pytensor.tensor.sum(t_r * o_output)
18301835

18311836
if no_debug_ref:

pytensor/tensor/nlinalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from numpy.core.numeric import normalize_axis_tuple # type: ignore
88

9+
import pytensor.printing
910
from pytensor import scalar as ps
1011
from pytensor.gradient import DisconnectedType
1112
from pytensor.graph.basic import Apply
@@ -552,7 +553,7 @@ def __init__(self, full_matrices: bool = True, compute_uv: bool = True):
552553
if self.full_matrices:
553554
self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)"
554555
else:
555-
self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)"
556+
self.gufunc_signature = "(m,n)->(o,k),(k),(k,p)"
556557
else:
557558
self.gufunc_signature = "(m,n)->(k)"
558559

@@ -653,9 +654,10 @@ def h(t):
653654
sign_t = ptb.where(ptm.eq(t, 0), 1, ptm.sign(t))
654655
return ptm.maximum(ptm.abs(t), eps) * sign_t
655656

656-
numer = ptb.ones_like(A) - eye
657+
numer = ptb.ones((k, k)) - eye
657658
denom = h(s[None] - s[:, None]) * h(s[None] + s[:, None])
658659
E = numer / denom
660+
E = pytensor.printing.Print("E")(E)
659661

660662
utgu = U.T @ dU
661663
vtgv = VT @ dV

0 commit comments

Comments
 (0)