diff --git a/kerngen/pisa_generators/ntt.py b/kerngen/pisa_generators/ntt.py index 316de6cc..dd0ab555 100644 --- a/kerngen/pisa_generators/ntt.py +++ b/kerngen/pisa_generators/ntt.py @@ -90,21 +90,20 @@ def to_pisa(self) -> list[PIsaOp]: # TODO We need to decide whether output symbols need to be defined outtmp = Polys("outtmp", self.output.parts, self.output.rns) - # control the input for the butterfly method. + # Essentially a scalar mul since psi 1 part if self.context.ntt_stages % 2 == 0: - # Essentially a scalar mul since psi 1 part + # Even case: butterfly input starts "coeff" mul = Mul(self.context, self.output, self.input0, psi) - butterfly_input = outtmp else: + # Odd case: butterfly input stats with "outtmp" mul = Mul(self.context, outtmp, self.input0, psi) - butterfly_input = self.input0 butterflies = butterflies_ops( pisa_op.NTT, context=self.context, output=self.output, outtmp=outtmp, - input0=butterfly_input, + input0=self.input0, ) return mixed_to_pisa_ops(mul, butterflies)