Skip to content

Commit fa1b859

Browse files
committed
[RELAY][PYTORCH]isNan, isinf, isfinite, ceil, clamp, round ops
1 parent 6805d54 commit fa1b859

File tree

7 files changed

+201
-2
lines changed

7 files changed

+201
-2
lines changed

docs/frontend/tensorflow.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ Supported Ops
162162
- Identity
163163
- IsFinite
164164
- IsInf
165+
- IsNan
165166
- LeakyRelu
166167
- LeftShift
167168
- Less

python/tvm/relay/frontend/pytorch.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,12 +1118,45 @@ def _impl(inputs, input_types):
11181118
return _op.tensor.sqrt(data)
11191119
return _impl
11201120

1121+
1122+
def _rsqrt():
1123+
def _impl(inputs, input_types):
1124+
data = inputs[0]
1125+
return _op.tensor.rsqrt(data)
1126+
return _impl
1127+
1128+
1129+
def _ceil():
1130+
def _impl(inputs, input_types):
1131+
data = inputs[0]
1132+
return _op.ceil(data)
1133+
return _impl
1134+
1135+
1136+
def _clamp():
1137+
def _impl(inputs, input_types):
1138+
print(inputs, input_types)
1139+
data = inputs[0]
1140+
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
1141+
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
1142+
return _op.clip(data, amin, amax)
1143+
return _impl
1144+
1145+
11211146
def _floor():
11221147
def _impl(inputs, input_types):
11231148
data = inputs[0]
11241149
return _op.floor(data)
11251150
return _impl
11261151

1152+
1153+
def _round():
1154+
def _impl(inputs, input_types):
1155+
data = inputs[0]
1156+
return _op.round(data)
1157+
return _impl
1158+
1159+
11271160
def _to():
11281161
def _impl(inputs, input_types):
11291162
data = inputs[0]
@@ -1232,6 +1265,18 @@ def _impl(inputs, input_types):
12321265
return _impl
12331266

12341267

1268+
def _isfinite():
1269+
def _impl(inputs, input_types):
1270+
return _op.isfinite(inputs[0])
1271+
return _impl
1272+
1273+
1274+
def _isnan():
1275+
def _impl(inputs, input_types):
1276+
return _op.isnan(inputs[0])
1277+
return _impl
1278+
1279+
12351280
def _list_getitem(prelude):
12361281
def _impl(inputs, input_types):
12371282
return prelude.nth(inputs[0], _wrap_const(inputs[1]))
@@ -1429,7 +1474,11 @@ def _get_convert_map(prelude):
14291474
"aten::std" : _std(),
14301475
"aten::var" : _variance(),
14311476
"aten::sqrt" : _sqrt(),
1432-
'aten::floor' : _floor(),
1477+
"aten::rsqrt" : _rsqrt(),
1478+
"aten::ceil" : _ceil(),
1479+
"aten::clamp" : _clamp(),
1480+
"aten::floor" : _floor(),
1481+
"aten::round" : _round(),
14331482
"aten::detach" : _identity(),
14341483
"aten::upsample_bilinear2d" : _upsample("bilinear"),
14351484
"aten::upsample_nearest2d" : _upsample("nearest_neighbor"),
@@ -1439,6 +1488,9 @@ def _get_convert_map(prelude):
14391488
"aten::le" : _elemwise("less_equal"),
14401489
"aten::ge" : _elemwise("greater_equal"),
14411490
"aten::ne" : _elemwise("not_equal"),
1491+
"aten::eq" : _elemwise("equal"),
1492+
"aten::isfinite" : _isfinite(),
1493+
"aten::isnan" : _isnan(),
14421494
"aten::Bool" : _Bool(),
14431495
"aten::Float" : _Float(),
14441496
"aten::neg" : _neg(),

python/tvm/relay/op/_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
register_broadcast_schedule("less_equal")
6767
register_broadcast_schedule("greater")
6868
register_broadcast_schedule("greater_equal")
69+
register_broadcast_schedule("isnan")
6970
register_broadcast_schedule("isfinite")
7071
register_broadcast_schedule("isinf")
7172
register_injective_schedule("maximum")

python/tvm/relay/op/tensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,22 @@ def ndarray_size(data, dtype="int32"):
10101010
return _make.ndarray_size(data, dtype)
10111011

10121012

1013+
def isnan(data):
1014+
"""Check nan in input data element-wise.
1015+
1016+
Parameters
1017+
----------
1018+
data : relay.Expr
1019+
The input data
1020+
1021+
Returns
1022+
-------
1023+
result : relay.Expr
1024+
The computed result.
1025+
"""
1026+
return _make.isnan(data)
1027+
1028+
10131029
def isfinite(data):
10141030
"""Compute element-wise finiteness of data.
10151031

src/relay/op/tensor/unary.cc

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,15 @@ ElemwiseArbitraryLayout)
426426
.set_support_level(10)
427427
.set_attr<FTVMCompute>("FTVMCompute", NdarraySizeCompute);
428428

429+
RELAY_REGISTER_UNARY_OP("isnan")
430+
.describe(R"code(Returns whether the input contains any NaN, computed element-wise.
431+
.. math::
432+
isnan(x)
433+
)code" TVM_ADD_FILELINE)
434+
.set_support_level(3)
435+
.add_type_rel("IdentityCompRel", IdentityCompRel)
436+
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::isnan));
437+
429438
RELAY_REGISTER_UNARY_OP("isfinite")
430439
.describe(R"code(Returns the finiteness of input, computed element-wise.
431440
.. math::
@@ -438,7 +447,7 @@ RELAY_REGISTER_UNARY_OP("isfinite")
438447
RELAY_REGISTER_UNARY_OP("isinf")
439448
.describe(R"code(Returns the infiniteness of input, computed element-wise.
440449
.. math::
441-
isfinite(x)
450+
isinf(x)
442451
)code" TVM_ADD_FILELINE)
443452
.set_support_level(3)
444453
.add_type_rel("IdentityCompRel", IdentityCompRel)

src/target/intrin_rule.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sigmoid")
9696
*rv = one / (one + exp(-call->args[0]));
9797
});
9898

99+
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.nan")
100+
.set_body([](const TVMArgs& args, TVMRetValue* rv){
101+
PrimExpr e = args[0];
102+
const CallNode* call = e.as<CallNode>();
103+
CHECK(call != nullptr);
104+
*rv = isnan(call->args[0]);
105+
});
106+
99107
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.isfinite")
100108
.set_body([](const TVMArgs& args, TVMRetValue* rv){
101109
PrimExpr e = args[0];

tests/python/frontend/pytorch/test_forward.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,110 @@ def forward(self, *args):
14411441
verify_model(Variance5().float().eval(), input_data=input_data)
14421442

14431443

1444+
1445+
def test_forward_isfinite():
1446+
torch.set_grad_enabled(False)
1447+
1448+
class IsFinite1(Module):
1449+
def forward(self, *args):
1450+
return torch.isfinite(args[0])
1451+
1452+
input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
1453+
verify_model(IsFinite1().float().eval(), input_data=input_data)
1454+
1455+
1456+
def test_forward_isnan():
1457+
torch.set_grad_enabled(False)
1458+
1459+
class IsNan1(Module):
1460+
def forward(self, *args):
1461+
return torch.isnan(args[0])
1462+
1463+
input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
1464+
verify_model(IsNan1().float().eval(), input_data=input_data)
1465+
1466+
1467+
def test_forward_isinf():
1468+
torch.set_grad_enabled(False)
1469+
1470+
class IsInf1(Module):
1471+
def forward(self, *args):
1472+
return torch.isinf(args[0])
1473+
1474+
input_data = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).float()
1475+
verify_model(IsInf1().float().eval(), input_data=input_data)
1476+
1477+
1478+
def test_forward_rsqrt():
1479+
torch.set_grad_enabled(False)
1480+
input_shape = [1, 3, 10, 10]
1481+
1482+
class Rsqrt1(Module):
1483+
def forward(self, *args):
1484+
return torch.rsqrt(args[0])
1485+
1486+
input_data = torch.rand(input_shape).float()
1487+
verify_model(Rsqrt1().float().eval(), input_data=input_data)
1488+
1489+
1490+
def test_forward_ceil():
1491+
torch.set_grad_enabled(False)
1492+
input_shape = [1, 3, 10, 10]
1493+
1494+
class Ceil1(Module):
1495+
def forward(self, *args):
1496+
return torch.ceil(args[0])
1497+
1498+
input_data = torch.rand(input_shape).float()
1499+
verify_model(Ceil1().float().eval(), input_data=input_data)
1500+
1501+
1502+
def test_forward_clamp():
1503+
torch.set_grad_enabled(False)
1504+
input_shape = [1, 3, 10, 10]
1505+
1506+
class Clamp1(Module):
1507+
def forward(self, *args):
1508+
return torch.clamp(args[0], min=-0.5, max=0.5)
1509+
1510+
class Clamp2(Module):
1511+
def forward(self, *args):
1512+
return torch.clamp(args[0], min=-0.3)
1513+
1514+
class Clamp3(Module):
1515+
def forward(self, *args):
1516+
return torch.clamp(args[0], max=1.0)
1517+
1518+
input_data = torch.rand(input_shape).float()
1519+
verify_model(Clamp1().float().eval(), input_data=input_data)
1520+
verify_model(Clamp2().float().eval(), input_data=input_data)
1521+
verify_model(Clamp3().float().eval(), input_data=input_data)
1522+
1523+
1524+
def test_forward_floor():
1525+
torch.set_grad_enabled(False)
1526+
input_shape = [1, 3, 10, 10]
1527+
1528+
class Floor1(Module):
1529+
def forward(self, *args):
1530+
return torch.floor(args[0])
1531+
1532+
input_data = torch.rand(input_shape).float()
1533+
verify_model(Floor1().float().eval(), input_data=input_data)
1534+
1535+
1536+
def test_forward_round():
1537+
torch.set_grad_enabled(False)
1538+
input_shape = [1, 3, 10, 10]
1539+
1540+
class Round1(Module):
1541+
def forward(self, *args):
1542+
return torch.round(args[0])
1543+
1544+
input_data = torch.rand(input_shape).float()
1545+
verify_model(Round1().float().eval(), input_data=input_data)
1546+
1547+
14441548
if __name__ == "__main__":
14451549
# Single operator tests
14461550
test_forward_add()
@@ -1497,6 +1601,14 @@ def forward(self, *args):
14971601
test_forward_expand()
14981602
test_forward_pow()
14991603
test_forward_abs()
1604+
test_forward_rsqrt()
1605+
test_forward_ceil()
1606+
test_forward_clamp()
1607+
test_forward_floor()
1608+
test_forward_round()
1609+
test_forward_isfinite()
1610+
test_forward_isnan()
1611+
test_forward_isinf()
15001612
test_forward_arange()
15011613
test_forward_chunk()
15021614
test_forward_split()

0 commit comments

Comments
 (0)