@@ -151,6 +151,8 @@ class OneUse2<SDPatternOperator operator>
151151class fpimm_pos_inf<ValueType vt>
152152 : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
153153
154+ class zeroinitializer<ValueType vt> :
155+ PatLeaf<(vt (bitconvert (!cast<ValueType>("i" # vt.Size) 0)))>;
154156
155157
156158// Operands which can hold a Register or an Immediate.
@@ -789,6 +791,23 @@ def UMAX16x2 : I16x2<"max.u", umax>;
789791def SMIN16x2 : I16x2<"min.s", smin>;
790792def UMIN16x2 : I16x2<"min.u", umin>;
791793
794+ let Predicates = [hasPTX<80>, hasSM<90>] in {
795+
796+ def MIN_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
797+ "min.relu.s32",
798+ [(set i32:$dst, (smax (smin i32:$a, i32:$b), 0))]>;
799+ def MAX_RELU_S32 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
800+ "max.relu.s32",
801+ [(set i32:$dst, (smax (smax i32:$a, i32:$b), 0))]>;
802+ def MIN_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
803+ "min.relu.s16x2",
804+ [(set v2i16:$dst, (smax (smin v2i16:$a, v2i16:$b),
805+ zeroinitializer<v2i16>))]>;
806+ def MAX_RELU_S16x2 : BasicNVPTXInst<(outs B32:$dst), (ins B32:$a, B32:$b),
807+ "max.relu.s16x2",
808+ [(set v2i16:$dst, (smax (smax v2i16:$a, v2i16:$b),
809+ zeroinitializer<v2i16>))]>;
810+ }
792811
793812//
794813// Wide multiplication
@@ -2379,9 +2398,6 @@ def fpimm_any_zero : FPImmLeaf<fAny, [{
23792398 return Imm.isZero();
23802399}]>;
23812400
2382- def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
2383- def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;
2384-
23852401// Perform substitution if fma only has one use, and also if instruction has
23862402// nnan instruction flag or if the TM has NoNaNsFPMath
23872403def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
@@ -2404,10 +2420,10 @@ class FMARELUInst<RegTyInfo t, bit allow_ftz, PatFrag zero_pat>
24042420
24052421let Predicates = [useFP16Math, hasPTX<70>, hasSM<80>] in {
24062422 def FMARELU_F16 : FMARELUInst<F16RT, true, fpimm_any_zero>;
2407- def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, fpimm_positive_zero_v2f16 >;
2423+ def FMARELU_F16X2 : FMARELUInst<F16X2RT, true, zeroinitializer<v2f16> >;
24082424}
24092425
24102426let Predicates = [hasBF16Math, hasPTX<70>, hasSM<80>] in {
24112427 def FMARELU_BF16 : FMARELUInst<BF16RT, false, fpimm_any_zero>;
2412- def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, fpimm_positive_zero_v2bf16 >;
2428+ def FMARELU_BF16X2 : FMARELUInst<BF16X2RT, false, zeroinitializer<v2bf16> >;
24132429}
0 commit comments