diff --git a/p-isa_tools/kerngen/const/options.py b/p-isa_tools/kerngen/const/options.py new file mode 100644 index 00000000..e5d403e1 --- /dev/null +++ b/p-isa_tools/kerngen/const/options.py @@ -0,0 +1,23 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for defining constants and enums used in the kernel generator""" +from enum import Enum + + +class LoopKey(Enum): + """Sort keys for PIsaOp instructions""" + + RNS = "rns" + PART = "part" + UNIT = "unit" + + @classmethod + def from_str(cls, value: str) -> "LoopKey": + """Convert a string to a LoopKey enum""" + if value is None: + raise ValueError("LoopKey cannot be None") + try: + return cls[value.upper()] + except KeyError: + raise ValueError(f"Invalid LoopKey: {value}") from None diff --git a/p-isa_tools/kerngen/high_parser/types.py b/p-isa_tools/kerngen/high_parser/types.py index d618741c..f6a6ead1 100644 --- a/p-isa_tools/kerngen/high_parser/types.py +++ b/p-isa_tools/kerngen/high_parser/types.py @@ -46,7 +46,7 @@ def __call__(self, *args) -> str: return self.expand(*args) def __repr__(self) -> str: - return self.name + return f"Polys(name={self.name}, parts={self.parts}, rns={self.rns})" @classmethod def from_polys(cls, poly: "Polys", *, mode: str | None = None) -> "Polys": diff --git a/p-isa_tools/kerngen/kernel_optimization/__init__.py b/p-isa_tools/kerngen/kernel_optimization/__init__.py new file mode 100644 index 00000000..916f3a44 --- /dev/null +++ b/p-isa_tools/kerngen/kernel_optimization/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/kerngen/kernel_optimization/loops.py b/p-isa_tools/kerngen/kernel_optimization/loops.py new file mode 100644 index 00000000..7db78607 --- /dev/null +++ b/p-isa_tools/kerngen/kernel_optimization/loops.py @@ -0,0 +1,56 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for loop interchange optimization in P-ISA operations""" + +import re +from const.options import LoopKey +from high_parser.pisa_operations import PIsaOp, Comment + + +def loop_interchange( + pisa_list: list[PIsaOp], + primary_key: LoopKey | None = LoopKey.PART, + secondary_key: LoopKey | None = LoopKey.RNS, +) -> list[PIsaOp]: + """Batch pisa_list into groups and sort them by primary and optional secondary keys. + + Args: + pisa_list: List of PIsaOp instructions + primary_key: Primary sort criterion from SortKey enum + secondary_key: Optional secondary sort criterion from SortKey enum + + Returns: + List of processed PIsaOp instructions + + Raises: + ValueError: If invalid sort key values provided + """ + if primary_key is None and secondary_key is None: + return pisa_list + + def get_sort_value(pisa: PIsaOp, key: LoopKey) -> int: + match key: + case LoopKey.RNS: + return pisa.q + case LoopKey.PART: + match = re.search(r"_(\d+)_", str(pisa)) + return int(match[1]) if match else 0 + case LoopKey.UNIT: + match = re.search(r"_(\d+),", str(pisa)) + return int(match[1]) if match else 0 + case _: + raise ValueError(f"Invalid sort key value: {key}") + + def get_sort_key(pisa: PIsaOp) -> tuple: + primary_value = get_sort_value(pisa, primary_key) + if secondary_key: + secondary_value = get_sort_value(pisa, secondary_key) + return (primary_value, secondary_value) + return (primary_value,) + + # Filter out comments + pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)] + # Sort based on primary and optional secondary keys + pisa_list_wo_comments.sort(key=get_sort_key) + return pisa_list_wo_comments diff --git a/p-isa_tools/kerngen/kernel_parser/__init__.py b/p-isa_tools/kerngen/kernel_parser/__init__.py new file mode 100644 index 00000000..916f3a44 --- /dev/null +++ b/p-isa_tools/kerngen/kernel_parser/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/p-isa_tools/kerngen/kernel_parser/parser.py b/p-isa_tools/kerngen/kernel_parser/parser.py new file mode 100644 index 00000000..074a3653 --- /dev/null +++ b/p-isa_tools/kerngen/kernel_parser/parser.py @@ -0,0 +1,136 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Module for parsing kernel commands from Kerngen""" + +import re +from high_parser.types import Immediate, KernelContext, Polys, Context +from pisa_generators.basic import Copy, HighOp, Add, Sub, Mul, Muli +from pisa_generators.ntt import NTT, INTT +from pisa_generators.square import Square +from pisa_generators.relin import Relin +from pisa_generators.rotate import Rotate +from pisa_generators.mod import Mod, ModUp +from pisa_generators.rescale import Rescale + + +class KernelParser: + """Parser for kernel operations.""" + + high_op_map = { + "Add": Add, + "Mul": Mul, + "Muli": Muli, + "Copy": Copy, + "Sub": Sub, + "Square": Square, + "NTT": NTT, + "INTT": INTT, + "Mod": Mod, + "ModUp": ModUp, + "Relin": Relin, + "Rotate": Rotate, + "Rescale": Rescale, + } + + @staticmethod + def parse_context(context_str: str) -> KernelContext: + """Parse the context string and return a KernelContext object.""" + context_match = re.search( + r"KernelContext\(scheme='(?P\w+)', " + + r"poly_order=(?P\w+), key_rns=(?P\w+), " + r"current_rns=(?P\w+), .*? label='(?P