11# Copyright (C) 2024 Intel Corporation
2+ # SPDX-License-Identifier: Apache-2.0
23
34"""Module containing conversions or operations from isa to p-isa."""
45
1213from .basic import Mul , Muli , mixed_to_pisa_ops
1314
1415
16+ def generate_unit_index (size : int , op : pisa_op .NTT | pisa_op .INTT ):
17+ """Helper to return unit indices for ntt/intt"""
18+ for i in range (int (size / 2 )):
19+ if issubclass (op , pisa_op .NTT ):
20+ yield (i , int (size / 2 ) + i , i * 2 , i * 2 + 1 )
21+ else :
22+ yield (i * 2 , i * 2 + 1 , i , int (size / 2 ) + i )
23+
24+
1525# pylint: disable=too-many-arguments
1626def butterflies_ops (
1727 op : pisa_op .NTT | pisa_op .INTT ,
@@ -21,42 +31,48 @@ def butterflies_ops(
2131 input0 : Polys ,
2232 * , # only kwargs after
2333 init_input : bool = False ,
34+ unit_size : int = 8192
2435) -> list [PIsaOp ]:
2536 """Helper to return butterflies pisa operations for NTT/INTT"""
26- ntt_stages = context .ntt_stages
27- ntt_stages_div_by_two = ntt_stages % 2
28-
29- stage_dst_srcs = [
30- (
31- (stage , outtmp , output )
32- if ntt_stages_div_by_two == stage % 2
33- else (stage , output , outtmp )
34- )
35- for stage in range (ntt_stages )
36- ]
37+ ntt_stages_div_by_two = context .ntt_stages % 2
3738
3839 if init_input is True :
40+ # intt
41+ stage_dst_srcs = [
42+ (
43+ (stage , outtmp , output )
44+ if ntt_stages_div_by_two == stage % 2
45+ else (stage , output , outtmp )
46+ )
47+ for stage in range (context .ntt_stages )
48+ ]
3949 stage_dst_srcs [0 ] = (
40- (0 , outtmp , input0 ) if ntt_stages_div_by_two == 0 else (0 , input0 , outtmp )
50+ (0 , outtmp , input0 ) if ntt_stages_div_by_two == 0 else (0 , output , input0 )
4151 )
52+ else :
53+ # ntt
54+ stage_dst_srcs = [
55+ ((stage , outtmp , output ) if stage % 2 == 0 else (stage , output , outtmp ))
56+ for stage in range (context .ntt_stages )
57+ ]
4258
4359 return [
4460 op (
4561 context .label ,
46- dst (part , q , unit ),
47- dst (part , q , next_unit ),
48- src (part , q , unit ),
49- src (part , q , next_unit ),
62+ dst (part , q , unit [ 0 ] ),
63+ dst (part , q , unit [ 1 ] ),
64+ src (part , q , unit [ 2 ] ),
65+ src (part , q , unit [ 3 ] ),
5066 stage ,
51- unit ,
67+ unit [ 0 ] if issubclass ( op , pisa_op . NTT ) else unit [ 2 ] ,
5268 q ,
5369 )
5470 # units for omegas (aka w) taken from 16K onwards
55- for part , (stage , dst , src ), q , ( unit , next_unit ) in it .product (
71+ for part , (stage , dst , src ), q , unit in it .product (
5672 range (input0 .start_parts , input0 .parts ),
5773 stage_dst_srcs ,
5874 range (input0 .start_rns , input0 .rns ),
59- it . pairwise ( range (context .units ) ),
75+ generate_unit_index ( int (context .poly_order / unit_size ), op ),
6076 )
6177 ]
6278
0 commit comments