diff --git a/kerngen/pisa_generators/basic.py b/kerngen/pisa_generators/basic.py index e96901b4..17118113 100644 --- a/kerngen/pisa_generators/basic.py +++ b/kerngen/pisa_generators/basic.py @@ -1,8 +1,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -# Copyright (C) 2024 Intel Corporation - """Module containing conversions or operations from isa to p-isa.""" import itertools as it @@ -296,3 +294,139 @@ def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Poly upto_last_coeffs.start_parts = 0 return input_last_part, last_coeff, upto_last_coeffs + + +def split_last_rns_polys(input0: Polys) -> Tuple[Polys, Polys]: + """Split and extract last RNS of input0""" + return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys( + input0, mode="drop_last_rns" + ) + + +def duplicate_polys(input0: Polys, name: str) -> Polys: + """Creates a duplicate of input0 with new name""" + return Polys(name, input0.parts, input0.rns, input0.start_parts, input0.start_rns) + + +def common_immediates( + r2_rns=None, iq_rns=None +) -> Tuple[Immediate, Immediate, Immediate]: + """Generate commonly used immediates""" + return ( + Immediate(name="one"), + Immediate(name="R2", rns=r2_rns), + Immediate(name="iq", rns=iq_rns), + ) + + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-positional-arguments + + +@dataclass +class PartialOpOptions: + """Optional arguments for partial_op helper function""" + + output_last_q: bool = False + input0_last_q: bool = False + input1_last_q: bool = False + input1_first_part: bool = False + op_last_q: bool = False + + +@dataclass +class PartialOpPolys: + """Polynomials used in partial ops""" + + output: Polys + input0: Polys + input1: Polys + input_remaining_rns: Polys + + +def partial_op( + context: KernelContext, + op, + polys: PartialOpPolys, + options: PartialOpOptions, + last_q: int, +): + """ "A helper function to perform partial operation, such as add/sub on last half (input1) to all of input0""" + return [ + op( + context.label, + polys.output(part, last_q if options.output_last_q else q, unit), + polys.input0(part, last_q if options.input0_last_q else q, unit), + polys.input1( + 0 if options.input1_first_part else part, + last_q if options.input1_last_q else q, + unit, + ), + last_q if options.op_last_q else q, + ) + for part, q, unit in it.product( + range(polys.input_remaining_rns.parts), + range(polys.input_remaining_rns.rns), + range(context.units), + ) + ] + + +def add_last_half( + context: KernelContext, + output: Polys, + input0: Polys, + input1: Polys, + input_remaining_rns: Polys, + last_q: int, +): + """Add input0 to input1 (first part)""" + return partial_op( + context, + pisa_op.Add, + PartialOpPolys(output, input0, input1, input_remaining_rns), + PartialOpOptions( + output_last_q=True, + input0_last_q=True, + input1_last_q=True, + input1_first_part=True, + op_last_q=True, + ), + last_q, + ) + + +def sub_last_half( + context: KernelContext, + output: Polys, + input0: Polys, + input1: Polys, + input_remaining_rns: Polys, + last_q: int, +): + """Subtract input1 (first part) with input0 (last RNS)""" + return partial_op( + context, + pisa_op.Sub, + PartialOpPolys(output, input0, input1, input_remaining_rns), + PartialOpOptions(input0_last_q=True, input1_first_part=True), + last_q, + ) + + +def muli_last_half( + context: KernelContext, + output: Polys, + input0: Polys, + input1: Polys, + input_remaining_rns: Polys, + last_q: int, +): + """Muli input0/1 w/input0 last RNS""" + return partial_op( + context, + pisa_op.Muli, + PartialOpPolys(output, input0, input1, input_remaining_rns), + PartialOpOptions(input0_last_q=True), + last_q, + ) diff --git a/kerngen/pisa_generators/manifest.json b/kerngen/pisa_generators/manifest.json index 79c8012a..5a833a02 100644 --- a/kerngen/pisa_generators/manifest.json +++ b/kerngen/pisa_generators/manifest.json @@ -14,5 +14,16 @@ "ROTATE": ["Rotate", "rotate.py"] }, "CKKS": { + "ADD": ["Add", "basic.py"], + "MUL": ["Mul", "basic.py"], + "MULI": ["Muli", "basic.py"], + "COPY": ["Copy", "basic.py"], + "SUB": ["Sub", "basic.py"], + "SQUARE": ["Square", "square.py"], + "NTT": ["NTT", "ntt.py"], + "INTT": ["INTT", "ntt.py"], + "MOD": ["Mod", "mod.py"], + "MODUP": ["Modup", "mod.py"], + "RESCALE": ["Rescale", "rescale.py"] } } diff --git a/kerngen/pisa_generators/mod.py b/kerngen/pisa_generators/mod.py index 4ec75a94..8383d250 100644 --- a/kerngen/pisa_generators/mod.py +++ b/kerngen/pisa_generators/mod.py @@ -1,14 +1,22 @@ # Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 """Module containing conversions or operations from isa to p-isa.""" from dataclasses import dataclass -from itertools import product -from high_parser.pisa_operations import PIsaOp, Comment, Muli as pisa_op_muli +from high_parser.pisa_operations import PIsaOp, Comment from high_parser import KernelContext, Immediate, HighOp, Polys -from .basic import Add, Muli, mixed_to_pisa_ops +from .basic import ( + Add, + Muli, + mixed_to_pisa_ops, + split_last_rns_polys, + duplicate_polys, + common_immediates, + muli_last_half, +) from .ntt import INTT, NTT @@ -22,27 +30,18 @@ class Mod(HighOp): def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform an mod switch down""" - # Convenience and Immediates - context = self.context + # Immediates last_q = self.input0.rns - 1 it = Immediate(name="it") - one = Immediate(name="one") - r2 = Immediate(name="R2", rns=last_q) - iq = Immediate(name="iq", rns=last_q) t = Immediate(name="t", rns=last_q) + one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q) # Drop down input rns - input_last_rns = Polys.from_polys(self.input0, mode="last_rns") - input_remaining_rns = Polys.from_polys(self.input0, mode="drop_last_rns") + input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0) # Temp. - y = Polys( - "y", - input_last_rns.parts, - input_last_rns.rns, - start_rns=input_last_rns.start_rns, - ) - x = Polys("x", input_remaining_rns.parts, input_remaining_rns.rns) + temp_input_last_rns = duplicate_polys(input_last_rns, "y") + temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x") # Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts # The `one` acts as a select flag as whether or not R2 the Montgomery @@ -51,30 +50,31 @@ def to_pisa(self) -> list[PIsaOp]: [ Comment("Start of mod kernel"), Comment("Compute the delta from last rns"), - INTT(context, y, input_last_rns), - Muli(context, y, y, it), - Muli(context, y, y, one), + INTT(self.context, temp_input_last_rns, input_last_rns), + Muli(self.context, temp_input_last_rns, temp_input_last_rns, it), + Muli(self.context, temp_input_last_rns, temp_input_last_rns, one), Comment("Compute the remaining rns"), # drop down to pisa ops to use correct rns q - [ - pisa_op_muli( - self.context.label, - x(part, q, unit), - y(part, last_q, unit), - r2(part, q, unit), - q, - ) - for part, q, unit in product( - range(input_remaining_rns.parts), - range(input_remaining_rns.rns), - range(context.units), - ) - ], - NTT(context, x, x), - Muli(context, x, x, t), + muli_last_half( + self.context, + temp_input_remaining_rns, + temp_input_last_rns, + r2, + input_remaining_rns, + last_q, + ), + NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns), + Muli( + self.context, temp_input_remaining_rns, temp_input_remaining_rns, t + ), Comment("Add the delta correction to mod down polys"), - Add(context, x, x, input_remaining_rns), - Muli(context, self.output, x, iq), + Add( + self.context, + temp_input_remaining_rns, + temp_input_remaining_rns, + input_remaining_rns, + ), + Muli(self.context, self.output, temp_input_remaining_rns, iq), Comment("End of mod kernel"), ] ) diff --git a/kerngen/pisa_generators/rescale.py b/kerngen/pisa_generators/rescale.py new file mode 100644 index 00000000..951de892 --- /dev/null +++ b/kerngen/pisa_generators/rescale.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (C) 2024 Intel Corporation + +"""Module containing conversions or operations from isa to p-isa.""" + +from dataclasses import dataclass + +from high_parser.pisa_operations import PIsaOp, Comment + +from high_parser import KernelContext, HighOp, Polys + +from .basic import ( + Muli, + mixed_to_pisa_ops, + Sub, + split_last_rns_polys, + duplicate_polys, + common_immediates, + add_last_half, + sub_last_half, +) +from .ntt import INTT, NTT + + +@dataclass +class Rescale(HighOp): + """Class representing mod down operation""" + + context: KernelContext + output: Polys + input0: Polys + + def to_pisa(self) -> list[PIsaOp]: + """Return the p-isa code to perform an mod switch down""" + # Immediates + last_q = self.input0.rns - 1 + one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q) + + q_last_half = Polys("qLastHalf", 1, self.input0.rns) + q_i_last_half = Polys("qiLastHalf", 1, rns=last_q) + + # split input + input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0) + + # Create temp vars for input_last/remaining + temp_input_last_rns = duplicate_polys(input_last_rns, "y") + temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x") + + # Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts + # The `one` acts as a select flag as whether or not R2 the Montgomery + # factor should be applied + return mixed_to_pisa_ops( + [ + Comment("Start of Rescale kernel."), + INTT(self.context, temp_input_last_rns, input_last_rns), + Muli(self.context, temp_input_last_rns, temp_input_last_rns, one), + Comment("Add the last part of the input to y"), + add_last_half( + self.context, + temp_input_last_rns, + temp_input_last_rns, + q_last_half, + input_remaining_rns, + last_q, + ), + Comment("Subtract q_i (last half/last rns) from y"), + sub_last_half( + self.context, + temp_input_remaining_rns, + temp_input_last_rns, + q_i_last_half, + input_remaining_rns, + last_q, + ), + Muli( + self.context, temp_input_remaining_rns, temp_input_remaining_rns, r2 + ), + NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns), + Sub( + self.context, + temp_input_remaining_rns, + Polys.from_polys(self.input0, mode="drop_last_rns"), + temp_input_remaining_rns, + ), + Muli(self.context, self.output, temp_input_remaining_rns, iq), + Comment("End of Rescale kernel."), + ] + )