1- from pytensor import scalar as aes
21from pytensor .graph .rewriting .basic import copy_stack_trace , node_rewriter
3- from pytensor .tensor .elemwise import DimShuffle , Elemwise
4- from pytensor .tensor .math import Sum , exp
2+ from pytensor .tensor .elemwise import DimShuffle
3+ from pytensor .tensor .math import Sum , exp , log
54from pytensor .tensor .math import sum as at_sum
65from pytensor .tensor .math import true_div
7- from pytensor .tensor .rewriting .basic import register_specialize
6+ from pytensor .tensor .rewriting .basic import register_stabilize
87from pytensor .tensor .rewriting .math import local_mul_canonizer
9- from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
10- from pytensor .tensor .subtensor import AdvancedIncSubtensor
8+ from pytensor .tensor .special import Softmax , SoftmaxGrad , log_softmax
9+ from pytensor .tensor .subtensor import (
10+ AdvancedIncSubtensor ,
11+ AdvancedSubtensor ,
12+ AdvancedSubtensor1 ,
13+ Subtensor ,
14+ )
1115from pytensor .tensor .type import (
1216 values_eq_approx_remove_inf ,
1317 values_eq_approx_remove_nan ,
1418)
1519
1620
17- # This is not registered in stabilize, as it cause some crossentropy
18- # optimization to not be inserted.
19- @register_specialize ("stabilize" , "fast_compile" )
20- @node_rewriter ([Elemwise ])
21+ subtensor_ops = (
22+ Subtensor ,
23+ AdvancedSubtensor ,
24+ AdvancedSubtensor1 ,
25+ )
26+
27+
28+ @register_stabilize
29+ @node_rewriter ([log ])
2130def local_logsoftmax (fgraph , node ):
2231 """
2332 Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
2433
34+ This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
35+
2536 Note: only forward pass is affected
2637 """
27- if (
28- isinstance (node .op , Elemwise )
29- and isinstance (node .op .scalar_op , aes .Log )
30- and len (node .inputs ) == 1
31- and node .inputs [0 ].owner is not None
32- and isinstance (node .inputs [0 ].owner .op , Softmax )
33- ):
34- inVars = node .inputs [0 ].owner .inputs [0 ]
35- new_op = LogSoftmax (axis = node .inputs [0 ].owner .op .axis )
36- ret = new_op (inVars )
37- ret .tag .values_eq_approx = values_eq_approx_remove_inf
38- copy_stack_trace ([node .inputs [0 ], node .outputs [0 ]], ret )
39- return [ret ]
38+
39+ def find_softmax_under_lifteable_ops (inp_node , ops_to_lift ):
40+ if inp_node is None :
41+ return
42+
43+ if isinstance (inp_node .op , Softmax ):
44+ return inp_node
45+
46+ if isinstance (inp_node .op , subtensor_ops ):
47+ ops_to_lift .append ((inp_node .op , inp_node .inputs [1 :]))
48+ return find_softmax_under_lifteable_ops (
49+ inp_node .inputs [0 ].owner , ops_to_lift
50+ )
51+
52+ if isinstance (inp_node .op , DimShuffle ):
53+ ops_to_lift .append ((inp_node .op , ()))
54+ return find_softmax_under_lifteable_ops (
55+ inp_node .inputs [0 ].owner , ops_to_lift
56+ )
57+
58+ ops_to_lift = []
59+ softmax_node = find_softmax_under_lifteable_ops (node .inputs [0 ].owner , ops_to_lift )
60+
61+ if softmax_node is None :
62+ return
63+
64+ ret = log_softmax (softmax_node .inputs [0 ], axis = softmax_node .op .axis )
65+ ret .tag .values_eq_approx = values_eq_approx_remove_inf
66+
67+ # Lift ops that used to be between log and softmax
68+ for op_to_lift , parameters in reversed (ops_to_lift ):
69+ ret = op_to_lift (ret , * parameters )
70+
71+ copy_stack_trace (node .outputs , ret )
72+ return [ret ]
4073
4174
42- # This is not registered in stabilize, as it cause some crossentropy
43- # optimization to not be inserted.
44- @register_specialize ("stabilize" , "fast_compile" )
75+ @register_stabilize
4576@node_rewriter ([SoftmaxGrad ])
4677def local_logsoftmax_grad (fgraph , node ):
4778 """
@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
5081 Note: only grad is affected
5182 """
5283 if (
53- isinstance (node .op , SoftmaxGrad )
54- and len (node .inputs ) == 2
55- and node .inputs [0 ].owner is not None
84+ node .inputs [0 ].owner is not None
5685 and node .inputs [0 ].owner .op == true_div
5786 and len (node .inputs [0 ].owner .inputs ) >= 2
5887 and node .inputs [0 ].owner .inputs [1 ].owner is not None
0 commit comments