diff --git a/kerngen/pisa_generators/ntt.py b/kerngen/pisa_generators/ntt.py index 10ac840b..383041dc 100644 --- a/kerngen/pisa_generators/ntt.py +++ b/kerngen/pisa_generators/ntt.py @@ -1,4 +1,5 @@ # Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 """Module containing conversions or operations from isa to p-isa.""" @@ -12,6 +13,15 @@ from .basic import Mul, Muli, mixed_to_pisa_ops +def generate_unit_index(size: int, op: pisa_op.NTT | pisa_op.INTT): + """Helper to return unit indices for ntt/intt""" + for i in range(int(size / 2)): + if issubclass(op, pisa_op.NTT): + yield (i, int(size / 2) + i, i * 2, i * 2 + 1) + else: + yield (i * 2, i * 2 + 1, i, int(size / 2) + i) + + # pylint: disable=too-many-arguments def butterflies_ops( op: pisa_op.NTT | pisa_op.INTT, @@ -21,42 +31,48 @@ def butterflies_ops( input0: Polys, *, # only kwargs after init_input: bool = False, + unit_size: int = 8192 ) -> list[PIsaOp]: """Helper to return butterflies pisa operations for NTT/INTT""" - ntt_stages = context.ntt_stages - ntt_stages_div_by_two = ntt_stages % 2 - - stage_dst_srcs = [ - ( - (stage, outtmp, output) - if ntt_stages_div_by_two == stage % 2 - else (stage, output, outtmp) - ) - for stage in range(ntt_stages) - ] + ntt_stages_div_by_two = context.ntt_stages % 2 if init_input is True: + # intt + stage_dst_srcs = [ + ( + (stage, outtmp, output) + if ntt_stages_div_by_two == stage % 2 + else (stage, output, outtmp) + ) + for stage in range(context.ntt_stages) + ] stage_dst_srcs[0] = ( - (0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, input0, outtmp) + (0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, output, input0) ) + else: + # ntt + stage_dst_srcs = [ + ((stage, outtmp, output) if stage % 2 == 0 else (stage, output, outtmp)) + for stage in range(context.ntt_stages) + ] return [ op( context.label, - dst(part, q, unit), - dst(part, q, next_unit), - src(part, q, unit), - src(part, q, next_unit), + dst(part, q, unit[0]), + dst(part, q, unit[1]), + src(part, q, unit[2]), + src(part, q, unit[3]), stage, - unit, + unit[0] if issubclass(op, pisa_op.NTT) else unit[2], q, ) # units for omegas (aka w) taken from 16K onwards - for part, (stage, dst, src), q, (unit, next_unit) in it.product( + for part, (stage, dst, src), q, unit in it.product( range(input0.start_parts, input0.parts), stage_dst_srcs, range(input0.start_rns, input0.rns), - it.pairwise(range(context.units)), + generate_unit_index(int(context.poly_order / unit_size), op), ) ]