Skip to content

Commit 7160e5a

Browse files
Fix NTT issue on odd-numbered butterfly stages
Fixes a bug that caused odd number of total NTT stages to write to incorrect output.
1 parent ae3c257 commit 7160e5a

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

kerngen/pisa_generators/ntt.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,25 +36,21 @@ def butterflies_ops(
3636
"""Helper to return butterflies pisa operations for NTT/INTT"""
3737
ntt_stages_div_by_two = context.ntt_stages % 2
3838

39+
# generate the stages, which depends on the total ntt stages.
40+
stage_dst_srcs = [
41+
(
42+
(stage, outtmp, output)
43+
if ntt_stages_div_by_two == stage % 2
44+
else (stage, output, outtmp)
45+
)
46+
for stage in range(context.ntt_stages)
47+
]
48+
49+
# For INTTs, start with input0 on the first stage destinations
3950
if init_input is True:
40-
# intt
41-
stage_dst_srcs = [
42-
(
43-
(stage, outtmp, output)
44-
if ntt_stages_div_by_two == stage % 2
45-
else (stage, output, outtmp)
46-
)
47-
for stage in range(context.ntt_stages)
48-
]
4951
stage_dst_srcs[0] = (
5052
(0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, output, input0)
5153
)
52-
else:
53-
# ntt
54-
stage_dst_srcs = [
55-
((stage, outtmp, output) if stage % 2 == 0 else (stage, output, outtmp))
56-
for stage in range(context.ntt_stages)
57-
]
5854

5955
return [
6056
op(
@@ -94,15 +90,21 @@ def to_pisa(self) -> list[PIsaOp]:
9490
# TODO We need to decide whether output symbols need to be defined
9591
outtmp = Polys("outtmp", self.output.parts, self.output.rns)
9692

97-
# Essentially a scalar mul since psi 1 part
98-
mul = Mul(self.context, self.output, self.input0, psi)
93+
# control the input for the butterfly method.
94+
if self.context.ntt_stages % 2 == 0:
95+
# Essentially a scalar mul since psi 1 part
96+
mul = Mul(self.context, self.output, self.input0, psi)
97+
butterfly_input = outtmp
98+
else:
99+
mul = Mul(self.context, outtmp, self.input0, psi)
100+
butterfly_input = self.input0
99101

100102
butterflies = butterflies_ops(
101103
pisa_op.NTT,
102104
context=self.context,
103105
output=self.output,
104106
outtmp=outtmp,
105-
input0=self.input0,
107+
input0=butterfly_input,
106108
)
107109

108110
return mixed_to_pisa_ops(mul, butterflies)

0 commit comments

Comments
 (0)