@@ -1647,6 +1647,120 @@ def aten_ops_logical_xor(
1647
1647
)
1648
1648
1649
1649
1650
+ def bitwise_type_validator (node : Node ) -> bool :
1651
+ targets = [
1652
+ torch .ops .aten .bitwise_and .Tensor ,
1653
+ torch .ops .aten .bitwise_or .Tensor ,
1654
+ torch .ops .aten .bitwise_xor .Tensor ,
1655
+ ]
1656
+ if node .target not in targets :
1657
+ return False
1658
+
1659
+ lhs_val = node .args [0 ]
1660
+ rhs_val = node .args [1 ]
1661
+ lhs_meta = lhs_val .meta .get ("tensor_meta" )
1662
+ rhs_meta = rhs_val .meta .get ("tensor_meta" )
1663
+
1664
+ if lhs_meta is None or rhs_meta is None :
1665
+ return False
1666
+
1667
+ supported_type = [torch .bool , bool ]
1668
+ return lhs_meta .dtype in supported_type and rhs_meta .dtype in supported_type
1669
+
1670
+
1671
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1672
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar ) # type: ignore[misc]
1673
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_and .Scalar_Tensor ) # type: ignore[misc]
1674
+ def aten_ops_bitwise_and (
1675
+ ctx : ConversionContext ,
1676
+ target : Target ,
1677
+ args : Tuple [Argument , ...],
1678
+ kwargs : Dict [str , Argument ],
1679
+ name : str ,
1680
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1681
+ return impl .elementwise .bitwise_and (
1682
+ ctx ,
1683
+ target ,
1684
+ SourceIR .ATEN ,
1685
+ name ,
1686
+ args [0 ],
1687
+ args [1 ],
1688
+ )
1689
+
1690
+
1691
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1692
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar ) # type: ignore[misc]
1693
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_or .Scalar_Tensor ) # type: ignore[misc]
1694
+ def aten_ops_bitwise_or (
1695
+ ctx : ConversionContext ,
1696
+ target : Target ,
1697
+ args : Tuple [Argument , ...],
1698
+ kwargs : Dict [str , Argument ],
1699
+ name : str ,
1700
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1701
+ return impl .elementwise .bitwise_or (
1702
+ ctx ,
1703
+ target ,
1704
+ SourceIR .ATEN ,
1705
+ name ,
1706
+ args [0 ],
1707
+ args [1 ],
1708
+ )
1709
+
1710
+
1711
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Tensor , capability_validator = bitwise_type_validator ) # type: ignore[misc]
1712
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar ) # type: ignore[misc]
1713
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_xor .Scalar_Tensor ) # type: ignore[misc]
1714
+ def aten_ops_bitwise_xor (
1715
+ ctx : ConversionContext ,
1716
+ target : Target ,
1717
+ args : Tuple [Argument , ...],
1718
+ kwargs : Dict [str , Argument ],
1719
+ name : str ,
1720
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1721
+ return impl .elementwise .bitwise_xor (
1722
+ ctx ,
1723
+ target ,
1724
+ SourceIR .ATEN ,
1725
+ name ,
1726
+ args [0 ],
1727
+ args [1 ],
1728
+ )
1729
+
1730
+
1731
+ def bitwise_not_type_validator (node : Node ) -> bool :
1732
+ val = node .args [0 ]
1733
+ val_meta = val .meta .get ("tensor_meta" )
1734
+
1735
+ if val_meta is None :
1736
+ return False
1737
+
1738
+ supported_type = [torch .bool , bool ]
1739
+ return val_meta .dtype in supported_type
1740
+
1741
+
1742
+ @dynamo_tensorrt_converter (torch .ops .aten .bitwise_not .default , capability_validator = bitwise_not_type_validator ) # type: ignore[misc]
1743
+ @enforce_tensor_types (
1744
+ {
1745
+ 0 : (TRTTensor ,),
1746
+ }
1747
+ ) # type: ignore[misc]
1748
+ def aten_ops_bitwise_not (
1749
+ ctx : ConversionContext ,
1750
+ target : Target ,
1751
+ args : Tuple [Argument , ...],
1752
+ kwargs : Dict [str , Argument ],
1753
+ name : str ,
1754
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1755
+ return impl .unary .bitwise_not (
1756
+ ctx ,
1757
+ target ,
1758
+ SourceIR .ATEN ,
1759
+ name ,
1760
+ args [0 ],
1761
+ )
1762
+
1763
+
1650
1764
@dynamo_tensorrt_converter (torch .ops .aten .eq .Tensor ) # type: ignore[misc]
1651
1765
@dynamo_tensorrt_converter (torch .ops .aten .eq .Scalar ) # type: ignore[misc]
1652
1766
@enforce_tensor_types (
0 commit comments