File tree Expand file tree Collapse file tree 2 files changed +17
-2
lines changed
pytensor/tensor/rewriting Expand file tree Collapse file tree 2 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
20712071 rval = [reciprocal (sqr (xsym ))]
20722072 if rval :
20732073 rval [0 ] = cast (rval [0 ], odtype )
2074- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2074+ assert rval [0 ].type .is_super (node .outputs [0 ].type ), (
2075+ rval [0 ].type ,
2076+ node .outputs [0 ].type ,
2077+ )
20752078 return rval
20762079 else :
20772080 return False
Original file line number Diff line number Diff line change 9696 perform_sigm_times_exp ,
9797 simplify_mul ,
9898)
99- from pytensor .tensor .shape import Reshape , Shape_i
99+ from pytensor .tensor .shape import Reshape , Shape_i , SpecifyShape
100100from pytensor .tensor .type import (
101101 TensorType ,
102102 cmatrix ,
@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
16711671 assert isinstance (nodes [1 ].scalar_op , aes .basic .Reciprocal )
16721672 utt .assert_allclose (f (val_no0 ), val_no0 ** (- 0.5 ))
16731673
1674+ twos = np .full (shape = (10 ,), fill_value = 2.0 ).astype (config .floatX )
1675+ f = function ([v ], v ** twos , mode = mode )
1676+ topo = f .maker .fgraph .toposort ()
1677+ assert len (topo ) == 2
1678+ # Depending on the mode the SpecifyShape is lifted or not
1679+ if topo [0 ].op == sqr :
1680+ assert isinstance (topo [1 ].op , SpecifyShape )
1681+ else :
1682+ assert isinstance (topo [0 ].op , SpecifyShape )
1683+ assert topo [1 ].op == sqr
1684+ utt .assert_allclose (f (val ), val ** twos )
1685+
16741686
16751687def test_local_pow_to_nested_squaring ():
16761688 mode = config .mode
You can’t perform that action at this time.
0 commit comments