Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 136 additions & 2 deletions kerngen/pisa_generators/basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Copyright (C) 2024 Intel Corporation

"""Module containing conversions or operations from isa to p-isa."""

import itertools as it
Expand Down Expand Up @@ -296,3 +294,139 @@ def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Poly
upto_last_coeffs.start_parts = 0

return input_last_part, last_coeff, upto_last_coeffs


def split_last_rns_polys(input0: Polys) -> Tuple[Polys, Polys]:
"""Split and extract last RNS of input0"""
return Polys.from_polys(input0, mode="last_rns"), Polys.from_polys(
input0, mode="drop_last_rns"
)


def duplicate_polys(input0: Polys, name: str) -> Polys:
"""Creates a duplicate of input0 with new name"""
return Polys(name, input0.parts, input0.rns, input0.start_parts, input0.start_rns)


def common_immediates(
r2_rns=None, iq_rns=None
) -> Tuple[Immediate, Immediate, Immediate]:
"""Generate commonly used immediates"""
return (
Immediate(name="one"),
Immediate(name="R2", rns=r2_rns),
Immediate(name="iq", rns=iq_rns),
)


# pylint: disable=too-many-arguments
# pylint: disable=too-many-positional-arguments


@dataclass
class PartialOpOptions:
"""Optional arguments for partial_op helper function"""

output_last_q: bool = False
input0_last_q: bool = False
input1_last_q: bool = False
input1_first_part: bool = False
op_last_q: bool = False


@dataclass
class PartialOpPolys:
"""Polynomials used in partial ops"""

output: Polys
input0: Polys
input1: Polys
input_remaining_rns: Polys


def partial_op(
context: KernelContext,
op,
polys: PartialOpPolys,
options: PartialOpOptions,
last_q: int,
):
""" "A helper function to perform partial operation, such as add/sub on last half (input1) to all of input0"""
return [
op(
context.label,
polys.output(part, last_q if options.output_last_q else q, unit),
polys.input0(part, last_q if options.input0_last_q else q, unit),
polys.input1(
0 if options.input1_first_part else part,
last_q if options.input1_last_q else q,
unit,
),
last_q if options.op_last_q else q,
)
for part, q, unit in it.product(
range(polys.input_remaining_rns.parts),
range(polys.input_remaining_rns.rns),
range(context.units),
)
]


def add_last_half(
context: KernelContext,
output: Polys,
input0: Polys,
input1: Polys,
input_remaining_rns: Polys,
last_q: int,
):
"""Add input0 to input1 (first part)"""
return partial_op(
context,
pisa_op.Add,
PartialOpPolys(output, input0, input1, input_remaining_rns),
PartialOpOptions(
output_last_q=True,
input0_last_q=True,
input1_last_q=True,
input1_first_part=True,
op_last_q=True,
),
last_q,
)


def sub_last_half(
context: KernelContext,
output: Polys,
input0: Polys,
input1: Polys,
input_remaining_rns: Polys,
last_q: int,
):
"""Subtract input1 (first part) with input0 (last RNS)"""
return partial_op(
context,
pisa_op.Sub,
PartialOpPolys(output, input0, input1, input_remaining_rns),
PartialOpOptions(input0_last_q=True, input1_first_part=True),
last_q,
)


def muli_last_half(
context: KernelContext,
output: Polys,
input0: Polys,
input1: Polys,
input_remaining_rns: Polys,
last_q: int,
):
"""Muli input0/1 w/input0 last RNS"""
return partial_op(
context,
pisa_op.Muli,
PartialOpPolys(output, input0, input1, input_remaining_rns),
PartialOpOptions(input0_last_q=True),
last_q,
)
11 changes: 11 additions & 0 deletions kerngen/pisa_generators/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,16 @@
"ROTATE": ["Rotate", "rotate.py"]
},
"CKKS": {
"ADD": ["Add", "basic.py"],
"MUL": ["Mul", "basic.py"],
"MULI": ["Muli", "basic.py"],
"COPY": ["Copy", "basic.py"],
"SUB": ["Sub", "basic.py"],
"SQUARE": ["Square", "square.py"],
"NTT": ["NTT", "ntt.py"],
"INTT": ["INTT", "ntt.py"],
"MOD": ["Mod", "mod.py"],
"MODUP": ["Modup", "mod.py"],
"RESCALE": ["Rescale", "rescale.py"]
}
}
76 changes: 38 additions & 38 deletions kerngen/pisa_generators/mod.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module containing conversions or operations from isa to p-isa."""

from dataclasses import dataclass
from itertools import product

from high_parser.pisa_operations import PIsaOp, Comment, Muli as pisa_op_muli
from high_parser.pisa_operations import PIsaOp, Comment
from high_parser import KernelContext, Immediate, HighOp, Polys

from .basic import Add, Muli, mixed_to_pisa_ops
from .basic import (
Add,
Muli,
mixed_to_pisa_ops,
split_last_rns_polys,
duplicate_polys,
common_immediates,
muli_last_half,
)
from .ntt import INTT, NTT


Expand All @@ -22,27 +30,18 @@ class Mod(HighOp):

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform an mod switch down"""
# Convenience and Immediates
context = self.context
# Immediates
last_q = self.input0.rns - 1
it = Immediate(name="it")
one = Immediate(name="one")
r2 = Immediate(name="R2", rns=last_q)
iq = Immediate(name="iq", rns=last_q)
t = Immediate(name="t", rns=last_q)
one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q)

# Drop down input rns
input_last_rns = Polys.from_polys(self.input0, mode="last_rns")
input_remaining_rns = Polys.from_polys(self.input0, mode="drop_last_rns")
input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0)

# Temp.
y = Polys(
"y",
input_last_rns.parts,
input_last_rns.rns,
start_rns=input_last_rns.start_rns,
)
x = Polys("x", input_remaining_rns.parts, input_remaining_rns.rns)
temp_input_last_rns = duplicate_polys(input_last_rns, "y")
temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x")

# Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts
# The `one` acts as a select flag as whether or not R2 the Montgomery
Expand All @@ -51,30 +50,31 @@ def to_pisa(self) -> list[PIsaOp]:
[
Comment("Start of mod kernel"),
Comment("Compute the delta from last rns"),
INTT(context, y, input_last_rns),
Muli(context, y, y, it),
Muli(context, y, y, one),
INTT(self.context, temp_input_last_rns, input_last_rns),
Muli(self.context, temp_input_last_rns, temp_input_last_rns, it),
Muli(self.context, temp_input_last_rns, temp_input_last_rns, one),
Comment("Compute the remaining rns"),
# drop down to pisa ops to use correct rns q
[
pisa_op_muli(
self.context.label,
x(part, q, unit),
y(part, last_q, unit),
r2(part, q, unit),
q,
)
for part, q, unit in product(
range(input_remaining_rns.parts),
range(input_remaining_rns.rns),
range(context.units),
)
],
NTT(context, x, x),
Muli(context, x, x, t),
muli_last_half(
self.context,
temp_input_remaining_rns,
temp_input_last_rns,
r2,
input_remaining_rns,
last_q,
),
NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns),
Muli(
self.context, temp_input_remaining_rns, temp_input_remaining_rns, t
),
Comment("Add the delta correction to mod down polys"),
Add(context, x, x, input_remaining_rns),
Muli(context, self.output, x, iq),
Add(
self.context,
temp_input_remaining_rns,
temp_input_remaining_rns,
input_remaining_rns,
),
Muli(self.context, self.output, temp_input_remaining_rns, iq),
Comment("End of mod kernel"),
]
)
Expand Down
90 changes: 90 additions & 0 deletions kerngen/pisa_generators/rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# Copyright (C) 2024 Intel Corporation

"""Module containing conversions or operations from isa to p-isa."""

from dataclasses import dataclass

from high_parser.pisa_operations import PIsaOp, Comment

from high_parser import KernelContext, HighOp, Polys

from .basic import (
Muli,
mixed_to_pisa_ops,
Sub,
split_last_rns_polys,
duplicate_polys,
common_immediates,
add_last_half,
sub_last_half,
)
from .ntt import INTT, NTT


@dataclass
class Rescale(HighOp):
"""Class representing mod down operation"""

context: KernelContext
output: Polys
input0: Polys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform an mod switch down"""
# Immediates
last_q = self.input0.rns - 1
one, r2, iq = common_immediates(r2_rns=last_q, iq_rns=last_q)

q_last_half = Polys("qLastHalf", 1, self.input0.rns)
q_i_last_half = Polys("qiLastHalf", 1, rns=last_q)

# split input
input_last_rns, input_remaining_rns = split_last_rns_polys(self.input0)

# Create temp vars for input_last/remaining
temp_input_last_rns = duplicate_polys(input_last_rns, "y")
temp_input_remaining_rns = duplicate_polys(input_remaining_rns, "x")

# Compute the `delta_i = t * [-t^-1 * c_i] mod ql` where `i` are the parts
# The `one` acts as a select flag as whether or not R2 the Montgomery
# factor should be applied
return mixed_to_pisa_ops(
[
Comment("Start of Rescale kernel."),
INTT(self.context, temp_input_last_rns, input_last_rns),
Muli(self.context, temp_input_last_rns, temp_input_last_rns, one),
Comment("Add the last part of the input to y"),
add_last_half(
self.context,
temp_input_last_rns,
temp_input_last_rns,
q_last_half,
input_remaining_rns,
last_q,
),
Comment("Subtract q_i (last half/last rns) from y"),
sub_last_half(
self.context,
temp_input_remaining_rns,
temp_input_last_rns,
q_i_last_half,
input_remaining_rns,
last_q,
),
Muli(
self.context, temp_input_remaining_rns, temp_input_remaining_rns, r2
),
NTT(self.context, temp_input_remaining_rns, temp_input_remaining_rns),
Sub(
self.context,
temp_input_remaining_rns,
Polys.from_polys(self.input0, mode="drop_last_rns"),
temp_input_remaining_rns,
),
Muli(self.context, self.output, temp_input_remaining_rns, iq),
Comment("End of Rescale kernel."),
]
)