@@ -36,25 +36,21 @@ def butterflies_ops(
3636 """Helper to return butterflies pisa operations for NTT/INTT"""
3737 ntt_stages_div_by_two = context .ntt_stages % 2
3838
39+ # generate the stages, which depends on the total ntt stages.
40+ stage_dst_srcs = [
41+ (
42+ (stage , outtmp , output )
43+ if ntt_stages_div_by_two == stage % 2
44+ else (stage , output , outtmp )
45+ )
46+ for stage in range (context .ntt_stages )
47+ ]
48+
49+ # For INTTs, start with input0 on the first stage destinations
3950 if init_input is True :
40- # intt
41- stage_dst_srcs = [
42- (
43- (stage , outtmp , output )
44- if ntt_stages_div_by_two == stage % 2
45- else (stage , output , outtmp )
46- )
47- for stage in range (context .ntt_stages )
48- ]
4951 stage_dst_srcs [0 ] = (
5052 (0 , outtmp , input0 ) if ntt_stages_div_by_two == 0 else (0 , output , input0 )
5153 )
52- else :
53- # ntt
54- stage_dst_srcs = [
55- ((stage , outtmp , output ) if stage % 2 == 0 else (stage , output , outtmp ))
56- for stage in range (context .ntt_stages )
57- ]
5854
5955 return [
6056 op (
@@ -94,15 +90,21 @@ def to_pisa(self) -> list[PIsaOp]:
9490 # TODO We need to decide whether output symbols need to be defined
9591 outtmp = Polys ("outtmp" , self .output .parts , self .output .rns )
9692
97- # Essentially a scalar mul since psi 1 part
98- mul = Mul (self .context , self .output , self .input0 , psi )
93+ # control the input for the butterfly method.
94+ if self .context .ntt_stages % 2 == 0 :
95+ # Essentially a scalar mul since psi 1 part
96+ mul = Mul (self .context , self .output , self .input0 , psi )
97+ butterfly_input = outtmp
98+ else :
99+ mul = Mul (self .context , outtmp , self .input0 , psi )
100+ butterfly_input = self .input0
99101
100102 butterflies = butterflies_ops (
101103 pisa_op .NTT ,
102104 context = self .context ,
103105 output = self .output ,
104106 outtmp = outtmp ,
105- input0 = self . input0 ,
107+ input0 = butterfly_input ,
106108 )
107109
108110 return mixed_to_pisa_ops (mul , butterflies )
0 commit comments