Skip to content

Commit d17d4a9

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add tests
1 parent d14d152 commit d17d4a9

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pytensor.ifelse import ifelse
1818
from pytensor.link.pytorch.linker import PytorchLinker
1919
from pytensor.raise_op import CheckAndRaise
20-
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
20+
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
2121
from pytensor.tensor.type import matrices, matrix, scalar, vector
2222

2323

@@ -343,3 +343,17 @@ def test_pytorch_OpFromGraph():
343343

344344
f = FunctionGraph([x, y, z], [out])
345345
compare_pytorch_and_py(f, [xv, yv, zv])
346+
347+
348+
def test_pytorch_scipy():
349+
x = vector("a", shape=(3,))
350+
out = expit(x)
351+
f = FunctionGraph([x], [out])
352+
compare_pytorch_and_py(f, [np.random.rand(3)])
353+
354+
355+
def test_pytorch_softplus():
356+
x = vector("a", shape=(3,))
357+
out = softplus(x)
358+
f = FunctionGraph([x], [out])
359+
compare_pytorch_and_py(f, [np.random.rand(3)])

0 commit comments

Comments
 (0)