Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
32e984d
Modified default manifest to use MOD_SWITCH key instead of MOD. The f…
christopherngutierrez Jul 24, 2024
3761595
Initial port of rotate kernel. Currently untested as there is a known…
christopherngutierrez Jul 24, 2024
4f977d1
WIP for relin code, currently failing tests but graph structure looks…
christopherngutierrez Aug 26, 2024
62defde
Merge branch 'main' of github.com:IntelLabs/hec-p-isa-tools into chri…
christopherngutierrez Aug 26, 2024
c41e237
added comments '<hack></hack>' for solutions that should be removed o…
christopherngutierrez Sep 3, 2024
7c7a5a0
Merge branch 'main' into christopherngutierrez/relin
christopherngutierrez Sep 3, 2024
5e32b1b
Merge remote-tracking branch 'origin/christopherngutierrez/relin' int…
christopherngutierrez Sep 3, 2024
0e74871
modified KeyPoly to match argument order found in legacy implementation
christopherngutierrez Sep 5, 2024
9a030f1
added execute permissions to kerngen.py
christopherngutierrez Sep 5, 2024
d1915a2
fixed bug with KeyPolys. Modified to match legacy output
christopherngutierrez Sep 9, 2024
cefe295
Merge branch 'christopherngutierrez/relin' into christopherngutierrez…
christopherngutierrez Sep 9, 2024
69eb30e
modified rotate to closely match legacy output, WIP, untested.
christopherngutierrez Sep 10, 2024
ff50a0a
Add support for BGV rotate. Tested and working.
christopherngutierrez Sep 12, 2024
70a7a80
eliminated duplicated code. Created a helper function in relin to ini…
christopherngutierrez Sep 13, 2024
78951fe
renamed variable for KeyMul
christopherngutierrez Sep 13, 2024
46783db
modified comments for rotate
christopherngutierrez Sep 13, 2024
8cbccdb
Fixed bug in init_common_polys, variable name typo
christopherngutierrez Sep 16, 2024
424c1a5
Merge branch 'main' into christopherngutierrez/rotate
christopherngutierrez Sep 18, 2024
741edb6
Fixed a merge issue. Removes opt temp variable and returns mixed_to_p…
christopherngutierrez Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions kerngen/pisa_generators/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
"NTT": ["NTT", "ntt.py"],
"INTT": ["INTT", "ntt.py"],
"MOD": ["Mod", "mod.py"],
"MODUP": ["ModUp", "mod.py"],
"RELIN": ["Relin", "relin.py"]
"MODUP": ["Modup", "mod.py"],
"RELIN": ["Relin", "relin.py"],
"ROTATE": ["Rotate", "rotate.py"]
},
"CKKS": {
}
Expand Down
33 changes: 23 additions & 10 deletions kerngen/pisa_generators/relin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from itertools import product
from string import ascii_letters
from typing import Tuple

import high_parser.pisa_operations as pisa_op
from high_parser.pisa_operations import PIsaOp, Comment
Expand All @@ -16,6 +17,22 @@
from .ntt import INTT, NTT


def init_common_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]:
"""Initialize commonly used polys in both relin and rotate kernels"""
input_last_part = Polys.from_polys(input0, mode="last_part")
input_last_part.name = input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = rns

upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0

return input_last_part, last_coeff, upto_last_coeffs


@dataclass
class KeyMul(HighOp):
"""Class representing a key multiplication operation"""
Expand All @@ -24,6 +41,7 @@ class KeyMul(HighOp):
output: Polys
input0: Polys
input1: KeyPolys
input0_fixed_part: int

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform a key multiplication"""
Expand All @@ -40,7 +58,7 @@ def get_pisa_op(num):
op(
self.context.label,
self.output(part, q, unit),
input0_tmp(2, q, unit),
input0_tmp(self.input0_fixed_part, q, unit),
self.input1(digit, part, q, unit),
q,
)
Expand Down Expand Up @@ -120,15 +138,10 @@ def to_pisa(self) -> list[PIsaOp]:
mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns)
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns
input_last_part = Polys.from_polys(self.input0, mode="last_part")
input_last_part.name = self.input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = self.context.key_rns
upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0
input_last_part, last_coeff, upto_last_coeffs = init_common_polys(
self.input0, self.context.key_rns
)

add_original = Polys.from_polys(mul_by_rlk_modded_down)
add_original.name = self.input0.name
Expand All @@ -138,7 +151,7 @@ def to_pisa(self) -> list[PIsaOp]:
Comment("Digit decomposition and extend base from Q to PQ"),
DigitDecompExtend(self.context, last_coeff, input_last_part),
Comment("Multiply by relin key"),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 2),
Comment("Mod switch down to Q"),
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk),
Comment("Add to original poly"),
Expand Down
68 changes: 68 additions & 0 deletions kerngen/pisa_generators/rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module containing conversions or operations from isa to p-isa."""

from dataclasses import dataclass

from high_parser.pisa_operations import PIsaOp, Comment
from high_parser import KernelContext, HighOp, Polys, KeyPolys

from .basic import Add, mixed_to_pisa_ops
from .relin import KeyMul, DigitDecompExtend, init_common_polys
from .mod import Mod
from .ntt import INTT, NTT


@dataclass
class Rotate(HighOp):
"""Class representing rotate operation"""

context: KernelContext
output: Polys
input0: Polys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform a rotate. Note:
currently only supports polynomials with two parts. Currently only
supports number of digits equal to the RNS size"""
self.output.parts = 2
self.input0.parts = 2
relin_key = KeyPolys(
"gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns
)
mul_by_rlk = Polys("c2_gk", parts=2, rns=self.context.key_rns)
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns
mul_by_rlk_modded_down.name = self.output.name

input_last_part, last_coeff, upto_last_coeffs = init_common_polys(
self.input0, self.context.key_rns
)

cd = Polys.from_polys(self.input0)
cd.name = "cd"
cd.parts = 1
cd.start_parts = 0

start_input = Polys.from_polys(self.input0)
start_input.parts = 1
start_input.start_parts = 0

first_part_rlk = Polys.from_polys(mul_by_rlk_modded_down)
first_part_rlk.parts = 1
first_part_rlk.start_parts = 0

return mixed_to_pisa_ops(
Comment("Start of rotate kernel"),
Comment("Digit Decomp"),
DigitDecompExtend(self.context, last_coeff, input_last_part),
Comment("Multiply by rotate key"),
KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1),
Comment("Mod switch down to Q"),
Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk),
INTT(self.context, cd, start_input),
NTT(self.context, cd, cd),
Add(self.context, self.output, cd, first_part_rlk),
Comment("End of rotate kernel"),
)