diff --git a/kerngen/pisa_generators/manifest.json b/kerngen/pisa_generators/manifest.json index 87836ceb..17adbc2f 100644 --- a/kerngen/pisa_generators/manifest.json +++ b/kerngen/pisa_generators/manifest.json @@ -9,8 +9,9 @@ "NTT": ["NTT", "ntt.py"], "INTT": ["INTT", "ntt.py"], "MOD": ["Mod", "mod.py"], - "MODUP": ["ModUp", "mod.py"], - "RELIN": ["Relin", "relin.py"] + "MODUP": ["Modup", "mod.py"], + "RELIN": ["Relin", "relin.py"], + "ROTATE": ["Rotate", "rotate.py"] }, "CKKS": { } diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index 693fbd4d..52f4c711 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from itertools import product from string import ascii_letters +from typing import Tuple import high_parser.pisa_operations as pisa_op from high_parser.pisa_operations import PIsaOp, Comment @@ -16,6 +17,22 @@ from .ntt import INTT, NTT +def init_common_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]: + """Initialize commonly used polys in both relin and rotate kernels""" + input_last_part = Polys.from_polys(input0, mode="last_part") + input_last_part.name = input0.name + + last_coeff = Polys.from_polys(input_last_part) + last_coeff.name = "coeffs" + last_coeff.rns = rns + + upto_last_coeffs = Polys.from_polys(last_coeff) + upto_last_coeffs.parts = 1 + upto_last_coeffs.start_parts = 0 + + return input_last_part, last_coeff, upto_last_coeffs + + @dataclass class KeyMul(HighOp): """Class representing a key multiplication operation""" @@ -24,6 +41,7 @@ class KeyMul(HighOp): output: Polys input0: Polys input1: KeyPolys + input0_fixed_part: int def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform a key multiplication""" @@ -40,7 +58,7 @@ def get_pisa_op(num): op( self.context.label, self.output(part, q, unit), - input0_tmp(2, q, unit), + input0_tmp(self.input0_fixed_part, q, unit), self.input1(digit, part, q, unit), q, ) @@ -120,15 +138,10 @@ def to_pisa(self) -> list[PIsaOp]: mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns) mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns - input_last_part = Polys.from_polys(self.input0, mode="last_part") - input_last_part.name = self.input0.name - last_coeff = Polys.from_polys(input_last_part) - last_coeff.name = "coeffs" - last_coeff.rns = self.context.key_rns - upto_last_coeffs = Polys.from_polys(last_coeff) - upto_last_coeffs.parts = 1 - upto_last_coeffs.start_parts = 0 + input_last_part, last_coeff, upto_last_coeffs = init_common_polys( + self.input0, self.context.key_rns + ) add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name @@ -138,7 +151,7 @@ def to_pisa(self) -> list[PIsaOp]: Comment("Digit decomposition and extend base from Q to PQ"), DigitDecompExtend(self.context, last_coeff, input_last_part), Comment("Multiply by relin key"), - KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key), + KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 2), Comment("Mod switch down to Q"), Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), Comment("Add to original poly"), diff --git a/kerngen/pisa_generators/rotate.py b/kerngen/pisa_generators/rotate.py new file mode 100644 index 00000000..f51054ce --- /dev/null +++ b/kerngen/pisa_generators/rotate.py @@ -0,0 +1,68 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module containing conversions or operations from isa to p-isa.""" + +from dataclasses import dataclass + +from high_parser.pisa_operations import PIsaOp, Comment +from high_parser import KernelContext, HighOp, Polys, KeyPolys + +from .basic import Add, mixed_to_pisa_ops +from .relin import KeyMul, DigitDecompExtend, init_common_polys +from .mod import Mod +from .ntt import INTT, NTT + + +@dataclass +class Rotate(HighOp): + """Class representing rotate operation""" + + context: KernelContext + output: Polys + input0: Polys + + def to_pisa(self) -> list[PIsaOp]: + """Return the p-isa code to perform a rotate. Note: + currently only supports polynomials with two parts. Currently only + supports number of digits equal to the RNS size""" + self.output.parts = 2 + self.input0.parts = 2 + relin_key = KeyPolys( + "gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns + ) + mul_by_rlk = Polys("c2_gk", parts=2, rns=self.context.key_rns) + mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) + mul_by_rlk_modded_down.rns = self.input0.rns + mul_by_rlk_modded_down.name = self.output.name + + input_last_part, last_coeff, upto_last_coeffs = init_common_polys( + self.input0, self.context.key_rns + ) + + cd = Polys.from_polys(self.input0) + cd.name = "cd" + cd.parts = 1 + cd.start_parts = 0 + + start_input = Polys.from_polys(self.input0) + start_input.parts = 1 + start_input.start_parts = 0 + + first_part_rlk = Polys.from_polys(mul_by_rlk_modded_down) + first_part_rlk.parts = 1 + first_part_rlk.start_parts = 0 + + return mixed_to_pisa_ops( + Comment("Start of rotate kernel"), + Comment("Digit Decomp"), + DigitDecompExtend(self.context, last_coeff, input_last_part), + Comment("Multiply by rotate key"), + KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1), + Comment("Mod switch down to Q"), + Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), + INTT(self.context, cd, start_input), + NTT(self.context, cd, cd), + Add(self.context, self.output, cd, first_part_rlk), + Comment("End of rotate kernel"), + )