Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion kerngen/pisa_generators/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
"INTT": ["INTT", "ntt.py"],
"MOD": ["Mod", "mod.py"],
"MODUP": ["Modup", "mod.py"],
"RESCALE": ["Rescale", "rescale.py"]
"RELIN": ["Relin", "relin.py"],
"RESCALE": ["Rescale", "rescale.py"],
"ROTATE": ["Rotate", "rotate.py"]
}
}
148 changes: 128 additions & 20 deletions kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@

from .basic import (
Add,
Sub,
Muli,
mixed_to_pisa_ops,
split_last_rns_polys,
duplicate_polys,
common_immediates,
muli_last_half,
add_last_half,
sub_last_half,
)
from .ntt import INTT, NTT

Expand Down Expand Up @@ -43,6 +46,122 @@ def to_pisa(self) -> list[PIsaOp]:
temp_input_last_rns = duplicate_polys(input_last_rns, "y")
temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x")

# ckks variable
p_half = Polys("pHalf", 1, rns=last_q)

class Stage:
"""Stage hold a list of PISA ops, which are used by generate_mod_stages"""

pisa_ops: list[PIsaOp]

def __init__(self, pisa_ops: list[PIsaOp]):
self.pisa_ops = pisa_ops

def generate_mod_stages() -> list[Stage]:
"""Generate stages based on the current scheme (e.g. BGV or CKKS)"""
# init empty stages
stages = []
if self.context.scheme == "BGV":
stages.append(
Stage(
[
Comment("Mod Stage 0"),
Muli(
self.context,
temp_input_last_rns,
temp_input_last_rns,
it,
),
]
)
)
stages.append(
Stage(
[
Comment("Mod Stage 1"),
muli_last_half(
self.context,
temp_input_remaining_rns,
temp_input_last_rns,
r2,
input_remaining_rns,
last_q,
),
]
)
)
stages.append(
Stage(
[
Comment("Mod Stage 2"),
Muli(
self.context,
temp_input_remaining_rns,
temp_input_remaining_rns,
t,
),
Comment("Add the delta correction to mod down polys"),
Add(
self.context,
temp_input_remaining_rns,
temp_input_remaining_rns,
input_remaining_rns,
),
]
)
)
elif self.context.scheme == "CKKS":
# stage 1 is empty for CKKS
stages.append(Stage([Comment("Mod Stage 0 - Empty")]))
stages.append(
Stage(
[
Comment("Mod Stage 1 - add/sub half and muli"),
add_last_half(
self.context,
temp_input_last_rns,
temp_input_last_rns,
p_half,
Polys.from_polys(
input_remaining_rns, mode="drop_last_rns"
),
last_q,
),
sub_last_half(
self.context,
temp_input_remaining_rns,
temp_input_last_rns,
p_half,
input_remaining_rns,
last_q,
),
Muli(
self.context,
temp_input_remaining_rns,
temp_input_remaining_rns,
r2,
),
]
)
)
stages.append(
Stage(
[
Comment("Mod Stage 2 - Sub"),
Sub(
self.context,
temp_input_remaining_rns,
input_remaining_rns,
temp_input_remaining_rns,
),
]
)
)

return stages

stages = generate_mod_stages()

# Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts
# The `one` acts as a select flag as whether or not R2 the Montgomery
# factor should be applied
Expand All @@ -51,29 +170,18 @@ def to_pisa(self) -> list[PIsaOp]:
Comment("Start of mod kernel"),
Comment("Compute the delta from last rns"),
INTT(self.context, temp_input_last_rns, input_last_rns),
Muli(self.context, temp_input_last_rns, temp_input_last_rns, it),
]
+ stages[0].pisa_ops
+ [
Muli(self.context, temp_input_last_rns, temp_input_last_rns, one),
Comment("Compute the remaining rns"),
# drop down to pisa ops to use correct rns q
muli_last_half(
self.context,
temp_input_remaining_rns,
temp_input_last_rns,
r2,
input_remaining_rns,
last_q,
),
]
+ stages[1].pisa_ops
+ [
NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns),
Muli(
self.context, temp_input_remaining_rns, temp_input_remaining_rns, t
),
Comment("Add the delta correction to mod down polys"),
Add(
self.context,
temp_input_remaining_rns,
temp_input_remaining_rns,
input_remaining_rns,
),
]
+ stages[2].pisa_ops
+ [
Muli(self.context, self.output, temp_input_remaining_rns, iq),
Comment("End of mod kernel"),
]
Expand Down