Skip to content

Commit ec7b551

Browse files
Enhancements to the NTT/INTT kernel to support support generic power of 2 polynomials - Tested for 16k, 32k, 64k, and 128k. (#46)
Enhancements to the NTT/INTT kernel to support support generic power of 2 polynomials - Tested for 16k, 32k, 64k, and 128k. Co-authored-by: Flavio Bergamaschi <[email protected]>
1 parent 8739d69 commit ec7b551

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

kerngen/pisa_generators/ntt.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
23

34
"""Module containing conversions or operations from isa to p-isa."""
45

@@ -12,6 +13,15 @@
1213
from .basic import Mul, Muli, mixed_to_pisa_ops
1314

1415

16+
def generate_unit_index(size: int, op: pisa_op.NTT | pisa_op.INTT):
17+
"""Helper to return unit indices for ntt/intt"""
18+
for i in range(int(size / 2)):
19+
if issubclass(op, pisa_op.NTT):
20+
yield (i, int(size / 2) + i, i * 2, i * 2 + 1)
21+
else:
22+
yield (i * 2, i * 2 + 1, i, int(size / 2) + i)
23+
24+
1525
# pylint: disable=too-many-arguments
1626
def butterflies_ops(
1727
op: pisa_op.NTT | pisa_op.INTT,
@@ -21,42 +31,48 @@ def butterflies_ops(
2131
input0: Polys,
2232
*, # only kwargs after
2333
init_input: bool = False,
34+
unit_size: int = 8192
2435
) -> list[PIsaOp]:
2536
"""Helper to return butterflies pisa operations for NTT/INTT"""
26-
ntt_stages = context.ntt_stages
27-
ntt_stages_div_by_two = ntt_stages % 2
28-
29-
stage_dst_srcs = [
30-
(
31-
(stage, outtmp, output)
32-
if ntt_stages_div_by_two == stage % 2
33-
else (stage, output, outtmp)
34-
)
35-
for stage in range(ntt_stages)
36-
]
37+
ntt_stages_div_by_two = context.ntt_stages % 2
3738

3839
if init_input is True:
40+
# intt
41+
stage_dst_srcs = [
42+
(
43+
(stage, outtmp, output)
44+
if ntt_stages_div_by_two == stage % 2
45+
else (stage, output, outtmp)
46+
)
47+
for stage in range(context.ntt_stages)
48+
]
3949
stage_dst_srcs[0] = (
40-
(0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, input0, outtmp)
50+
(0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, output, input0)
4151
)
52+
else:
53+
# ntt
54+
stage_dst_srcs = [
55+
((stage, outtmp, output) if stage % 2 == 0 else (stage, output, outtmp))
56+
for stage in range(context.ntt_stages)
57+
]
4258

4359
return [
4460
op(
4561
context.label,
46-
dst(part, q, unit),
47-
dst(part, q, next_unit),
48-
src(part, q, unit),
49-
src(part, q, next_unit),
62+
dst(part, q, unit[0]),
63+
dst(part, q, unit[1]),
64+
src(part, q, unit[2]),
65+
src(part, q, unit[3]),
5066
stage,
51-
unit,
67+
unit[0] if issubclass(op, pisa_op.NTT) else unit[2],
5268
q,
5369
)
5470
# units for omegas (aka w) taken from 16K onwards
55-
for part, (stage, dst, src), q, (unit, next_unit) in it.product(
71+
for part, (stage, dst, src), q, unit in it.product(
5672
range(input0.start_parts, input0.parts),
5773
stage_dst_srcs,
5874
range(input0.start_rns, input0.rns),
59-
it.pairwise(range(context.units)),
75+
generate_unit_index(int(context.poly_order / unit_size), op),
6076
)
6177
]
6278

0 commit comments

Comments
 (0)