diff --git a/kerngen/pisa_generators/ntt.py b/kerngen/pisa_generators/ntt.py index 383041dc..316de6cc 100644 --- a/kerngen/pisa_generators/ntt.py +++ b/kerngen/pisa_generators/ntt.py @@ -36,25 +36,21 @@ def butterflies_ops( """Helper to return butterflies pisa operations for NTT/INTT""" ntt_stages_div_by_two = context.ntt_stages % 2 + # generate the stages, which depends on the total ntt stages. + stage_dst_srcs = [ + ( + (stage, outtmp, output) + if ntt_stages_div_by_two == stage % 2 + else (stage, output, outtmp) + ) + for stage in range(context.ntt_stages) + ] + + # For INTTs, start with input0 on the first stage destinations if init_input is True: - # intt - stage_dst_srcs = [ - ( - (stage, outtmp, output) - if ntt_stages_div_by_two == stage % 2 - else (stage, output, outtmp) - ) - for stage in range(context.ntt_stages) - ] stage_dst_srcs[0] = ( (0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, output, input0) ) - else: - # ntt - stage_dst_srcs = [ - ((stage, outtmp, output) if stage % 2 == 0 else (stage, output, outtmp)) - for stage in range(context.ntt_stages) - ] return [ op( @@ -94,15 +90,21 @@ def to_pisa(self) -> list[PIsaOp]: # TODO We need to decide whether output symbols need to be defined outtmp = Polys("outtmp", self.output.parts, self.output.rns) - # Essentially a scalar mul since psi 1 part - mul = Mul(self.context, self.output, self.input0, psi) + # control the input for the butterfly method. + if self.context.ntt_stages % 2 == 0: + # Essentially a scalar mul since psi 1 part + mul = Mul(self.context, self.output, self.input0, psi) + butterfly_input = outtmp + else: + mul = Mul(self.context, outtmp, self.input0, psi) + butterfly_input = self.input0 butterflies = butterflies_ops( pisa_op.NTT, context=self.context, output=self.output, outtmp=outtmp, - input0=self.input0, + input0=butterfly_input, ) return mixed_to_pisa_ops(mul, butterflies)