File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change 44import numpy as np
55import pytest
66
7+ import pytensor .tensor as pt
78import pytensor .tensor .basic as ptb
89from pytensor .compile .builders import OpFromGraph
910from pytensor .compile .function import function
@@ -343,3 +344,17 @@ def test_pytorch_OpFromGraph():
343344
344345 f = FunctionGraph ([x , y , z ], [out ])
345346 compare_pytorch_and_py (f , [xv , yv , zv ])
347+
348+
349+ def test_pytorch_scipy ():
350+ x = vector ("a" , shape = (3 ,))
351+ out = pt .expit (x )
352+ f = FunctionGraph ([x ], [out ])
353+ compare_pytorch_and_py (f , [np .random .rand (3 )])
354+
355+
356+ def test_pytorch_softplus ():
357+ x = vector ("a" , shape = (3 ,))
358+ out = pt .softplus (x )
359+ f = FunctionGraph ([x ], [out ])
360+ compare_pytorch_and_py (f , [np .random .rand (3 )])
You can’t perform that action at this time.
0 commit comments