@@ -1441,6 +1441,110 @@ def forward(self, *args):
14411441 verify_model (Variance5 ().float ().eval (), input_data = input_data )
14421442
14431443
1444+
1445+ def test_forward_isfinite ():
1446+ torch .set_grad_enabled (False )
1447+
1448+ class IsFinite1 (Module ):
1449+ def forward (self , * args ):
1450+ return torch .isfinite (args [0 ])
1451+
1452+ input_data = torch .tensor ([1 , float ('inf' ), 2 , float ('-inf' ), float ('nan' )]).float ()
1453+ verify_model (IsFinite1 ().float ().eval (), input_data = input_data )
1454+
1455+
1456+ def test_forward_isnan ():
1457+ torch .set_grad_enabled (False )
1458+
1459+ class IsNan1 (Module ):
1460+ def forward (self , * args ):
1461+ return torch .isnan (args [0 ])
1462+
1463+ input_data = torch .tensor ([1 , float ('inf' ), 2 , float ('-inf' ), float ('nan' )]).float ()
1464+ verify_model (IsNan1 ().float ().eval (), input_data = input_data )
1465+
1466+
1467+ def test_forward_isinf ():
1468+ torch .set_grad_enabled (False )
1469+
1470+ class IsInf1 (Module ):
1471+ def forward (self , * args ):
1472+ return torch .isinf (args [0 ])
1473+
1474+ input_data = torch .tensor ([1 , float ('inf' ), 2 , float ('-inf' ), float ('nan' )]).float ()
1475+ verify_model (IsInf1 ().float ().eval (), input_data = input_data )
1476+
1477+
1478+ def test_forward_rsqrt ():
1479+ torch .set_grad_enabled (False )
1480+ input_shape = [1 , 3 , 10 , 10 ]
1481+
1482+ class Rsqrt1 (Module ):
1483+ def forward (self , * args ):
1484+ return torch .rsqrt (args [0 ])
1485+
1486+ input_data = torch .rand (input_shape ).float ()
1487+ verify_model (Rsqrt1 ().float ().eval (), input_data = input_data )
1488+
1489+
1490+ def test_forward_ceil ():
1491+ torch .set_grad_enabled (False )
1492+ input_shape = [1 , 3 , 10 , 10 ]
1493+
1494+ class Ceil1 (Module ):
1495+ def forward (self , * args ):
1496+ return torch .ceil (args [0 ])
1497+
1498+ input_data = torch .rand (input_shape ).float ()
1499+ verify_model (Ceil1 ().float ().eval (), input_data = input_data )
1500+
1501+
1502+ def test_forward_clamp ():
1503+ torch .set_grad_enabled (False )
1504+ input_shape = [1 , 3 , 10 , 10 ]
1505+
1506+ class Clamp1 (Module ):
1507+ def forward (self , * args ):
1508+ return torch .clamp (args [0 ], min = - 0.5 , max = 0.5 )
1509+
1510+ class Clamp2 (Module ):
1511+ def forward (self , * args ):
1512+ return torch .clamp (args [0 ], min = - 0.3 )
1513+
1514+ class Clamp3 (Module ):
1515+ def forward (self , * args ):
1516+ return torch .clamp (args [0 ], max = 1.0 )
1517+
1518+ input_data = torch .rand (input_shape ).float ()
1519+ verify_model (Clamp1 ().float ().eval (), input_data = input_data )
1520+ verify_model (Clamp2 ().float ().eval (), input_data = input_data )
1521+ verify_model (Clamp3 ().float ().eval (), input_data = input_data )
1522+
1523+
1524+ def test_forward_floor ():
1525+ torch .set_grad_enabled (False )
1526+ input_shape = [1 , 3 , 10 , 10 ]
1527+
1528+ class Floor1 (Module ):
1529+ def forward (self , * args ):
1530+ return torch .floor (args [0 ])
1531+
1532+ input_data = torch .rand (input_shape ).float ()
1533+ verify_model (Floor1 ().float ().eval (), input_data = input_data )
1534+
1535+
1536+ def test_forward_round ():
1537+ torch .set_grad_enabled (False )
1538+ input_shape = [1 , 3 , 10 , 10 ]
1539+
1540+ class Round1 (Module ):
1541+ def forward (self , * args ):
1542+ return torch .round (args [0 ])
1543+
1544+ input_data = torch .rand (input_shape ).float ()
1545+ verify_model (Round1 ().float ().eval (), input_data = input_data )
1546+
1547+
14441548if __name__ == "__main__" :
14451549 # Single operator tests
14461550 test_forward_add ()
@@ -1497,6 +1601,14 @@ def forward(self, *args):
14971601 test_forward_expand ()
14981602 test_forward_pow ()
14991603 test_forward_abs ()
1604+ test_forward_rsqrt ()
1605+ test_forward_ceil ()
1606+ test_forward_clamp ()
1607+ test_forward_floor ()
1608+ test_forward_round ()
1609+ test_forward_isfinite ()
1610+ test_forward_isnan ()
1611+ test_forward_isinf ()
15001612 test_forward_arange ()
15011613 test_forward_chunk ()
15021614 test_forward_split ()
0 commit comments