@@ -9239,6 +9239,323 @@ class AdjointGenerator
92399239 return ;
92409240 }
92419241
9242+ if (funcName == " __mulsc3" || funcName == " __muldc3" ||
9243+ funcName == " __multc3" || funcName == " __mulxc3" ) {
9244+ if (gutils->knownRecomputeHeuristic .find (orig) !=
9245+ gutils->knownRecomputeHeuristic .end ()) {
9246+ if (!gutils->knownRecomputeHeuristic [orig]) {
9247+ gutils->cacheForReverse (BuilderZ, newCall,
9248+ getIndex (orig, CacheType::Self));
9249+ }
9250+ }
9251+
9252+ eraseIfUnused (*orig);
9253+ if (gutils->isConstantInstruction (orig))
9254+ return ;
9255+
9256+ Value *orig_op0 = call.getOperand (0 );
9257+ Value *orig_op1 = call.getOperand (1 );
9258+ Value *orig_op2 = call.getOperand (2 );
9259+ Value *orig_op3 = call.getOperand (3 );
9260+
9261+ bool constantval0 = gutils->isConstantValue (orig_op0);
9262+ bool constantval1 = gutils->isConstantValue (orig_op1);
9263+ bool constantval2 = gutils->isConstantValue (orig_op2);
9264+ bool constantval3 = gutils->isConstantValue (orig_op3);
9265+
9266+ Value *prim[4 ] = {gutils->getNewFromOriginal (orig_op0),
9267+ gutils->getNewFromOriginal (orig_op1),
9268+ gutils->getNewFromOriginal (orig_op2),
9269+ gutils->getNewFromOriginal (orig_op3)};
9270+
9271+ auto mul = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9272+ funcName, called->getFunctionType (), called->getAttributes ());
9273+
9274+ switch (Mode) {
9275+ case DerivativeMode::ForwardMode:
9276+ case DerivativeMode::ForwardModeSplit: {
9277+ IRBuilder<> Builder2 (&call);
9278+ getForwardBuilder (Builder2);
9279+
9280+ Value *diff[4 ] = {
9281+ constantval0 ? Constant::getNullValue (orig_op0->getType ())
9282+ : diffe (orig_op0, Builder2),
9283+ constantval1 ? Constant::getNullValue (orig_op1->getType ())
9284+ : diffe (orig_op1, Builder2),
9285+ constantval2 ? Constant::getNullValue (orig_op2->getType ())
9286+ : diffe (orig_op2, Builder2),
9287+ constantval3 ? Constant::getNullValue (orig_op3->getType ())
9288+ : diffe (orig_op3, Builder2)};
9289+
9290+ auto cal1 =
9291+ Builder2.CreateCall (mul, {diff[0 ], diff[1 ], prim[2 ], prim[3 ]});
9292+ auto cal2 =
9293+ Builder2.CreateCall (mul, {prim[0 ], prim[1 ], diff[2 ], diff[3 ]});
9294+
9295+ Value *resReal =
9296+ Builder2.CreateFAdd (Builder2.CreateExtractValue (cal1, {0 }),
9297+ Builder2.CreateExtractValue (cal2, {0 }));
9298+ Value *resImag =
9299+ Builder2.CreateFAdd (Builder2.CreateExtractValue (cal1, {1 }),
9300+ Builder2.CreateExtractValue (cal2, {1 }));
9301+
9302+ Value *res = Builder2.CreateInsertValue (
9303+ UndefValue::get (call.getType ()), resReal, {0 });
9304+ res = Builder2.CreateInsertValue (res, resImag, {1 });
9305+
9306+ setDiffe (&call, res, Builder2);
9307+ return ;
9308+ }
9309+ case DerivativeMode::ReverseModeGradient:
9310+ case DerivativeMode::ReverseModeCombined: {
9311+ IRBuilder<> Builder2 (call.getParent ());
9312+ getReverseBuilder (Builder2);
9313+
9314+ Value *idiff = diffe (&call, Builder2);
9315+ Value *idiffReal = Builder2.CreateExtractValue (idiff, {0 });
9316+ Value *idiffImag = Builder2.CreateExtractValue (idiff, {1 });
9317+
9318+ Value *diff0 = nullptr ;
9319+ Value *diff1 = nullptr ;
9320+
9321+ if (!constantval0 || !constantval1)
9322+ diff0 = Builder2.CreateCall (mul, {idiffReal, idiffImag,
9323+ lookup (prim[2 ], Builder2),
9324+ lookup (prim[3 ], Builder2)});
9325+
9326+ if (!constantval2 || !constantval3)
9327+ diff1 = Builder2.CreateCall (mul, {lookup (prim[0 ], Builder2),
9328+ lookup (prim[1 ], Builder2),
9329+ idiffReal, idiffImag});
9330+
9331+ if (diff0 || diff1)
9332+ setDiffe (&call, Constant::getNullValue (call.getType ()), Builder2);
9333+
9334+ if (diff0) {
9335+ addToDiffe (orig_op0, Builder2.CreateExtractValue (diff0, {0 }),
9336+ Builder2, orig_op0->getType ());
9337+ addToDiffe (orig_op1, Builder2.CreateExtractValue (diff0, {1 }),
9338+ Builder2, orig_op1->getType ());
9339+ }
9340+
9341+ if (diff1) {
9342+ addToDiffe (orig_op2, Builder2.CreateExtractValue (diff1, {0 }),
9343+ Builder2, orig_op2->getType ());
9344+ addToDiffe (orig_op3, Builder2.CreateExtractValue (diff1, {1 }),
9345+ Builder2, orig_op3->getType ());
9346+ }
9347+
9348+ return ;
9349+ }
9350+ case DerivativeMode::ReverseModePrimal:
9351+ return ;
9352+ }
9353+ }
9354+
9355+ if (funcName == " __divsc3" || funcName == " __divdc3" ||
9356+ funcName == " __divtc3" || funcName == " __divxc3" ) {
9357+ if (gutils->knownRecomputeHeuristic .find (orig) !=
9358+ gutils->knownRecomputeHeuristic .end ()) {
9359+ if (!gutils->knownRecomputeHeuristic [orig]) {
9360+ gutils->cacheForReverse (BuilderZ, newCall,
9361+ getIndex (orig, CacheType::Self));
9362+ }
9363+ }
9364+
9365+ if (gutils->isConstantInstruction (orig))
9366+ return ;
9367+
9368+ StringMap<StringRef> map = {
9369+ {" __divsc3" , " __mulsc3" },
9370+ {" __divdc3" , " __muldc3" },
9371+ {" __divtc3" , " __multc3" },
9372+ {" __divxc3" , " __mulxc3" },
9373+ };
9374+
9375+ auto mul = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9376+ map[funcName], called->getFunctionType (), called->getAttributes ());
9377+
9378+ auto div = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9379+ funcName, called->getFunctionType (), called->getAttributes ());
9380+
9381+ Value *orig_op0 = call.getOperand (0 );
9382+ Value *orig_op1 = call.getOperand (1 );
9383+ Value *orig_op2 = call.getOperand (2 );
9384+ Value *orig_op3 = call.getOperand (3 );
9385+
9386+ bool constantval0 = gutils->isConstantValue (orig_op0);
9387+ bool constantval1 = gutils->isConstantValue (orig_op1);
9388+ bool constantval2 = gutils->isConstantValue (orig_op2);
9389+ bool constantval3 = gutils->isConstantValue (orig_op3);
9390+
9391+ Value *prim[4 ] = {gutils->getNewFromOriginal (orig_op0),
9392+ gutils->getNewFromOriginal (orig_op1),
9393+ gutils->getNewFromOriginal (orig_op2),
9394+ gutils->getNewFromOriginal (orig_op3)};
9395+
9396+ switch (Mode) {
9397+ case DerivativeMode::ForwardMode:
9398+ case DerivativeMode::ForwardModeSplit: {
9399+ IRBuilder<> Builder2 (&call);
9400+ getForwardBuilder (Builder2);
9401+
9402+ Value *diff[4 ] = {
9403+ constantval0 ? Constant::getNullValue (orig_op0->getType ())
9404+ : diffe (orig_op0, Builder2),
9405+ constantval1 ? Constant::getNullValue (orig_op1->getType ())
9406+ : diffe (orig_op1, Builder2),
9407+ constantval2 ? Constant::getNullValue (orig_op2->getType ())
9408+ : diffe (orig_op2, Builder2),
9409+ constantval3 ? Constant::getNullValue (orig_op3->getType ())
9410+ : diffe (orig_op3, Builder2)};
9411+
9412+ auto mul1 =
9413+ Builder2.CreateCall (mul, {diff[0 ], diff[1 ], prim[2 ], prim[3 ]});
9414+ auto mul2 =
9415+ Builder2.CreateCall (mul, {prim[0 ], prim[1 ], diff[2 ], diff[3 ]});
9416+ auto sq1 =
9417+ Builder2.CreateCall (mul, {prim[2 ], prim[3 ], prim[2 ], prim[3 ]});
9418+
9419+ Value *subReal =
9420+ Builder2.CreateFSub (Builder2.CreateExtractValue (mul1, {0 }),
9421+ Builder2.CreateExtractValue (mul2, {0 }));
9422+ Value *subImag =
9423+ Builder2.CreateFSub (Builder2.CreateExtractValue (mul1, {1 }),
9424+ Builder2.CreateExtractValue (mul2, {1 }));
9425+
9426+ auto div1 = Builder2.CreateCall (
9427+ div, {subReal, subImag, Builder2.CreateExtractValue (sq1, {0 }),
9428+ Builder2.CreateExtractValue (sq1, {1 })});
9429+
9430+ setDiffe (&call, div1, Builder2);
9431+
9432+ eraseIfUnused (*orig);
9433+
9434+ return ;
9435+ }
9436+ case DerivativeMode::ReverseModeGradient:
9437+ case DerivativeMode::ReverseModeCombined: {
9438+ IRBuilder<> Builder2 (call.getParent ());
9439+ getReverseBuilder (Builder2);
9440+
9441+ Value *idiff = diffe (&call, Builder2);
9442+ Value *idiffReal = Builder2.CreateExtractValue (idiff, {0 });
9443+ Value *idiffImag = Builder2.CreateExtractValue (idiff, {1 });
9444+
9445+ Value *diff0 = nullptr ;
9446+ Value *diff1 = nullptr ;
9447+
9448+ if (!constantval0 || !constantval1)
9449+ diff0 = Builder2.CreateCall (div, {idiffReal, idiffImag,
9450+ lookup (prim[2 ], Builder2),
9451+ lookup (prim[3 ], Builder2)});
9452+
9453+ if (!constantval2 || !constantval3) {
9454+ auto fdiv = Builder2.CreateCall (div, {idiffReal, idiffImag,
9455+ lookup (prim[1 ], Builder2),
9456+ lookup (prim[2 ], Builder2)});
9457+
9458+ Value *newcall = gutils->getNewFromOriginal (&call);
9459+
9460+ diff1 = Builder2.CreateCall (
9461+ mul,
9462+ {Builder2.CreateFNeg (Builder2.CreateExtractValue (newcall, {0 })),
9463+ Builder2.CreateFNeg (Builder2.CreateExtractValue (newcall, {1 })),
9464+ Builder2.CreateExtractValue (fdiv, {0 }),
9465+ Builder2.CreateExtractValue (fdiv, {1 })});
9466+ }
9467+
9468+ if (diff0 || diff1)
9469+ setDiffe (&call, Constant::getNullValue (call.getType ()), Builder2);
9470+
9471+ if (diff0) {
9472+ addToDiffe (orig_op0, Builder2.CreateExtractValue (diff0, {0 }),
9473+ Builder2, orig_op0->getType ());
9474+ addToDiffe (orig_op1, Builder2.CreateExtractValue (diff0, {1 }),
9475+ Builder2, orig_op1->getType ());
9476+ }
9477+
9478+ if (diff1) {
9479+ addToDiffe (orig_op2, Builder2.CreateExtractValue (diff1, {0 }),
9480+ Builder2, orig_op2->getType ());
9481+ addToDiffe (orig_op3, Builder2.CreateExtractValue (diff1, {1 }),
9482+ Builder2, orig_op3->getType ());
9483+ }
9484+
9485+ if (constantval2 && constantval3)
9486+ eraseIfUnused (*orig);
9487+
9488+ return ;
9489+ }
9490+ case DerivativeMode::ReverseModePrimal:;
9491+ return ;
9492+ }
9493+ }
9494+
9495+ if (funcName == " scalbn" || funcName == " scalbnf" ||
9496+ funcName == " scalbnl" || funcName == " scalbln" ||
9497+ funcName == " scalblnf" || funcName == " scalblnl" ) {
9498+ eraseIfUnused (*orig);
9499+
9500+ Value *orig_op0 = call.getOperand (0 );
9501+ Value *orig_op1 = call.getOperand (1 );
9502+
9503+ bool constantval0 = gutils->isConstantValue (orig_op0);
9504+
9505+ if (gutils->isConstantInstruction (orig) || constantval0)
9506+ return ;
9507+
9508+ Value *op0 = gutils->getNewFromOriginal (orig_op0);
9509+ Value *op1 = gutils->getNewFromOriginal (orig_op1);
9510+
9511+ auto scal = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9512+ funcName, called->getFunctionType (), called->getAttributes ());
9513+
9514+ switch (Mode) {
9515+ case DerivativeMode::ForwardMode:
9516+ case DerivativeMode::ForwardModeSplit: {
9517+ IRBuilder<> Builder2 (&call);
9518+ getForwardBuilder (Builder2);
9519+
9520+ Value *diff0 = diffe (orig_op0, Builder2);
9521+
9522+ auto cal1 = Builder2.CreateCall (scal, {op0, op1});
9523+ auto cal2 = Builder2.CreateCall (scal, {diff0, op1});
9524+
9525+ Value *diff = Builder2.CreateFMul (
9526+ cal1, ConstantFP::get (call.getType (), 0.3010299957 ));
9527+ diff = Builder2.CreateFAdd (diff, cal2);
9528+
9529+ setDiffe (&call, diff, Builder2);
9530+ return ;
9531+ }
9532+ case DerivativeMode::ReverseModeGradient:
9533+ case DerivativeMode::ReverseModeCombined: {
9534+ IRBuilder<> Builder2 (call.getParent ());
9535+ getReverseBuilder (Builder2);
9536+
9537+ Value *idiff = diffe (&call, Builder2);
9538+
9539+ if (idiff && !constantval0) {
9540+ op1 = lookup (op1, Builder2);
9541+
9542+ auto cal1 = Builder2.CreateCall (scal, {op0, op1});
9543+ auto cal2 = Builder2.CreateCall (scal, {idiff, op1});
9544+
9545+ Value *diff = Builder2.CreateFMul (
9546+ cal1, ConstantFP::get (call.getType (), 0.3010299957 ));
9547+ diff = Builder2.CreateFAdd (diff, cal2);
9548+
9549+ addToDiffe (orig_op0, diff, Builder2, call.getType ());
9550+ }
9551+
9552+ return ;
9553+ }
9554+ case DerivativeMode::ReverseModePrimal:;
9555+ return ;
9556+ }
9557+ }
9558+
92429559 if (called) {
92439560 if (funcName == " erf" || funcName == " erfi" || funcName == " erfc" ||
92449561 funcName == " Faddeeva_erf" || funcName == " Faddeeva_erfi" ||
0 commit comments