diff --git a/kerngen/pisa_generators/manifest.json b/kerngen/pisa_generators/manifest.json index 5a833a02..00cf2f91 100644 --- a/kerngen/pisa_generators/manifest.json +++ b/kerngen/pisa_generators/manifest.json @@ -24,6 +24,8 @@ "INTT": ["INTT", "ntt.py"], "MOD": ["Mod", "mod.py"], "MODUP": ["Modup", "mod.py"], - "RESCALE": ["Rescale", "rescale.py"] + "RELIN": ["Relin", "relin.py"], + "RESCALE": ["Rescale", "rescale.py"], + "ROTATE": ["Rotate", "rotate.py"] } } diff --git a/kerngen/pisa_generators/mod.py b/kerngen/pisa_generators/mod.py index 8383d250..f99bf20d 100644 --- a/kerngen/pisa_generators/mod.py +++ b/kerngen/pisa_generators/mod.py @@ -10,12 +10,15 @@ from .basic import ( Add, + Sub, Muli, mixed_to_pisa_ops, split_last_rns_polys, duplicate_polys, common_immediates, muli_last_half, + add_last_half, + sub_last_half, ) from .ntt import INTT, NTT @@ -43,6 +46,122 @@ def to_pisa(self) -> list[PIsaOp]: temp_input_last_rns = duplicate_polys(input_last_rns, "y") temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x") + # ckks variable + p_half = Polys("pHalf", 1, rns=last_q) + + class Stage: + """Stage hold a list of PISA ops, which are used by generate_mod_stages""" + + pisa_ops: list[PIsaOp] + + def __init__(self, pisa_ops: list[PIsaOp]): + self.pisa_ops = pisa_ops + + def generate_mod_stages() -> list[Stage]: + """Generate stages based on the current scheme (e.g. BGV or CKKS)""" + # init empty stages + stages = [] + if self.context.scheme == "BGV": + stages.append( + Stage( + [ + Comment("Mod Stage 0"), + Muli( + self.context, + temp_input_last_rns, + temp_input_last_rns, + it, + ), + ] + ) + ) + stages.append( + Stage( + [ + Comment("Mod Stage 1"), + muli_last_half( + self.context, + temp_input_remaining_rns, + temp_input_last_rns, + r2, + input_remaining_rns, + last_q, + ), + ] + ) + ) + stages.append( + Stage( + [ + Comment("Mod Stage 2"), + Muli( + self.context, + temp_input_remaining_rns, + temp_input_remaining_rns, + t, + ), + Comment("Add the delta correction to mod down polys"), + Add( + self.context, + temp_input_remaining_rns, + temp_input_remaining_rns, + input_remaining_rns, + ), + ] + ) + ) + elif self.context.scheme == "CKKS": + # stage 1 is empty for CKKS + stages.append(Stage([Comment("Mod Stage 0 - Empty")])) + stages.append( + Stage( + [ + Comment("Mod Stage 1 - add/sub half and muli"), + add_last_half( + self.context, + temp_input_last_rns, + temp_input_last_rns, + p_half, + Polys.from_polys( + input_remaining_rns, mode="drop_last_rns" + ), + last_q, + ), + sub_last_half( + self.context, + temp_input_remaining_rns, + temp_input_last_rns, + p_half, + input_remaining_rns, + last_q, + ), + Muli( + self.context, + temp_input_remaining_rns, + temp_input_remaining_rns, + r2, + ), + ] + ) + ) + stages.append( + Stage( + [ + Comment("Mod Stage 2 - Sub"), + Sub( + self.context, + temp_input_remaining_rns, + input_remaining_rns, + temp_input_remaining_rns, + ), + ] + ) + ) + + return stages + + stages = generate_mod_stages() + # Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts # The `one` acts as a select flag as whether or not R2 the Montgomery # factor should be applied @@ -51,29 +170,18 @@ def to_pisa(self) -> list[PIsaOp]: Comment("Start of mod kernel"), Comment("Compute the delta from last rns"), INTT(self.context, temp_input_last_rns, input_last_rns), - Muli(self.context, temp_input_last_rns, temp_input_last_rns, it), + ] + + stages[0].pisa_ops + + [ Muli(self.context, temp_input_last_rns, temp_input_last_rns, one), Comment("Compute the remaining rns"), - # drop down to pisa ops to use correct rns q - muli_last_half( - self.context, - temp_input_remaining_rns, - temp_input_last_rns, - r2, - input_remaining_rns, - last_q, - ), + ] + + stages[1].pisa_ops + + [ NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns), - Muli( - self.context, temp_input_remaining_rns, temp_input_remaining_rns, t - ), - Comment("Add the delta correction to mod down polys"), - Add( - self.context, - temp_input_remaining_rns, - temp_input_remaining_rns, - input_remaining_rns, - ), + ] + + stages[2].pisa_ops + + [ Muli(self.context, self.output, temp_input_remaining_rns, iq), Comment("End of mod kernel"), ]