Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0623bca
modified Poly.__repr__ to print out more details.
christopherngutierrez May 23, 2025
8e28dc9
Working loop interchange, initial checking
christopherngutierrez Jun 5, 2025
d1dd9b8
added a check to make sure that the high op that is target is support…
christopherngutierrez Jun 5, 2025
9dcd6ad
add test cases for kerngraph command
christopherngutierrez Jun 6, 2025
1f0c3d2
Modified to allow secondary loops to be None. Check if debug mode is …
christopherngutierrez Jun 6, 2025
911e2bf
Update pydantic requirement from ~=1.10.13 to ~=1.10.22 (#72)
dependabot[bot] May 20, 2025
0730427
Repository name change (#73)
faberga May 21, 2025
ec0913b
New repository structure (#74)
faberga May 22, 2025
e3b8dc9
Reference implementation of HERACLES Assembler tools (#75)
faberga May 23, 2025
d82596d
Bump the pip group across 1 directory with 3 updates (#76)
dependabot[bot] May 26, 2025
f190167
Update dependabot.yml (#79)
faberga May 28, 2025
8e3c06d
[hec-assembler]: Refactoring no_hbm option for Configurable ISA_spec …
joserochh Jun 4, 2025
a9b1ef1
[hec-assembler]: Enable Configurable Memory Specification (#80)
joserochh Jun 5, 2025
da4aea2
Bug Fix: Updates source-roots and pythonpath in config files (#82)
christopherngutierrez Jun 10, 2025
efaa0d1
[hec-assembler]: Fix suggested bank error (#78)
joserochh Jun 10, 2025
9d493c3
modified Poly.__repr__ to print out more details.
christopherngutierrez May 23, 2025
de62e23
Working loop interchange, initial checking
christopherngutierrez Jun 5, 2025
b17271a
Merge branch 'main' into experimental/christopherngutierrez/kerngraph…
christopherngutierrez Jun 10, 2025
0abd18a
removed old code from kerngen directory (from rebase)
christopherngutierrez Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions p-isa_tools/kerngen/const/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for defining constants and enums used in the kernel generator"""
from enum import Enum


class LoopKey(Enum):
"""Sort keys for PIsaOp instructions"""

RNS = "rns"
PART = "part"
UNIT = "unit"

@classmethod
def from_str(cls, value: str) -> "LoopKey":
"""Convert a string to a LoopKey enum"""
if value is None:
raise ValueError("LoopKey cannot be None")
try:
return cls[value.upper()]
except KeyError:
raise ValueError(f"Invalid LoopKey: {value}") from None
2 changes: 1 addition & 1 deletion p-isa_tools/kerngen/high_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __call__(self, *args) -> str:
return self.expand(*args)

def __repr__(self) -> str:
return self.name
return f"Polys(name={self.name}, parts={self.parts}, rns={self.rns})"

@classmethod
def from_polys(cls, poly: "Polys", *, mode: str | None = None) -> "Polys":
Expand Down
2 changes: 2 additions & 0 deletions p-isa_tools/kerngen/kernel_optimization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
56 changes: 56 additions & 0 deletions p-isa_tools/kerngen/kernel_optimization/loops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for loop interchange optimization in P-ISA operations"""

import re
from const.options import LoopKey
from high_parser.pisa_operations import PIsaOp, Comment


def loop_interchange(
pisa_list: list[PIsaOp],
primary_key: LoopKey | None = LoopKey.PART,
secondary_key: LoopKey | None = LoopKey.RNS,
) -> list[PIsaOp]:
"""Batch pisa_list into groups and sort them by primary and optional secondary keys.

Args:
pisa_list: List of PIsaOp instructions
primary_key: Primary sort criterion from SortKey enum
secondary_key: Optional secondary sort criterion from SortKey enum

Returns:
List of processed PIsaOp instructions

Raises:
ValueError: If invalid sort key values provided
"""
if primary_key is None and secondary_key is None:
return pisa_list

def get_sort_value(pisa: PIsaOp, key: LoopKey) -> int:
match key:
case LoopKey.RNS:
return pisa.q
case LoopKey.PART:
match = re.search(r"_(\d+)_", str(pisa))
return int(match[1]) if match else 0
case LoopKey.UNIT:
match = re.search(r"_(\d+),", str(pisa))
return int(match[1]) if match else 0
case _:
raise ValueError(f"Invalid sort key value: {key}")

def get_sort_key(pisa: PIsaOp) -> tuple:
primary_value = get_sort_value(pisa, primary_key)
if secondary_key:
secondary_value = get_sort_value(pisa, secondary_key)
return (primary_value, secondary_value)
return (primary_value,)

# Filter out comments
pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
# Sort based on primary and optional secondary keys
pisa_list_wo_comments.sort(key=get_sort_key)
return pisa_list_wo_comments
2 changes: 2 additions & 0 deletions p-isa_tools/kerngen/kernel_parser/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
136 changes: 136 additions & 0 deletions p-isa_tools/kerngen/kernel_parser/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module for parsing kernel commands from Kerngen"""

import re
from high_parser.types import Immediate, KernelContext, Polys, Context
from pisa_generators.basic import Copy, HighOp, Add, Sub, Mul, Muli
from pisa_generators.ntt import NTT, INTT
from pisa_generators.square import Square
from pisa_generators.relin import Relin
from pisa_generators.rotate import Rotate
from pisa_generators.mod import Mod, ModUp
from pisa_generators.rescale import Rescale


class KernelParser:
"""Parser for kernel operations."""

high_op_map = {
"Add": Add,
"Mul": Mul,
"Muli": Muli,
"Copy": Copy,
"Sub": Sub,
"Square": Square,
"NTT": NTT,
"INTT": INTT,
"Mod": Mod,
"ModUp": ModUp,
"Relin": Relin,
"Rotate": Rotate,
"Rescale": Rescale,
}

@staticmethod
def parse_context(context_str: str) -> KernelContext:
"""Parse the context string and return a KernelContext object."""
context_match = re.search(
r"KernelContext\(scheme='(?P<scheme>\w+)', "
+ r"poly_order=(?P<poly_order>\w+), key_rns=(?P<key_rns>\w+), "
r"current_rns=(?P<current_rns>\w+), .*? label='(?P<label>\w+)'\)",
context_str,
)
if not context_match:
raise ValueError("Invalid context string format.")
return KernelContext.from_context(
Context(
scheme=context_match.group("scheme"),
poly_order=int(context_match.group("poly_order")),
key_rns=int(context_match.group("key_rns")),
current_rns=int(context_match.group("current_rns")),
max_rns=int(context_match.group("key_rns")) - 1,
),
label=context_match.group("label"),
)

@staticmethod
def parse_polys(polys_str: str) -> Polys:
"""Parse the Polys string and return a Polys object."""
polys_match = re.search(
r"Polys\(name=(.*?), parts=(\d+), rns=(\d+)\)", polys_str
)
if not polys_match:
raise ValueError("Invalid Polys string format.")
name, parts, rns = polys_match.groups()
return Polys(name=name, parts=int(parts), rns=int(rns))

@staticmethod
def parse_immediate(immediate_str: str) -> Immediate:
"""Parse the Immediate string and return an Immediate object."""
immediate_match = re.search(
r"Immediate\(name='(?P<name>\w+)', rns=(?P<rns>\w+)\)", immediate_str
)
if not immediate_match:
raise ValueError("Invalid Immediate string format.")
name, rns = immediate_match.group("name"), immediate_match.group("rns")
rns = None if rns == "None" else int(rns)
return Immediate(name=name, rns=rns)

@staticmethod
def parse_high_op(kernel_str: str) -> HighOp:
"""Parse a HighOp kernel string and return the corresponding object."""
pattern = (
r"### Kernel \(\d+\): (?P<op_type>\w+)\(context=(KernelContext\(.*?\)), "
r"output=(Polys\(.*?\)), input0=(Polys\(.*?\))"
)
has_second_input = False
# Check if the kernel string contains "input1" or not
if "input1" not in kernel_str:
# Match the operation type and its arguments
high_op_match = re.search(pattern, kernel_str)
else:
# Adjust the pattern to include input1
pattern += r", input1=(Polys\(.*?\)\)|Immediate\(.*?\)\))"
# Match the operation type and its arguments
high_op_match = re.search(pattern, kernel_str)
has_second_input = True

if not high_op_match:
raise ValueError(f"Invalid kernel string format: {kernel_str}.")

op_type = high_op_match.group("op_type")
context_str, output_str, input0_str = high_op_match.groups()[1:4]

if has_second_input:
input1_str = high_op_match.group(5)

# Parse the components
context = KernelParser.parse_context(context_str)
output = KernelParser.parse_polys(output_str)
input0 = KernelParser.parse_polys(input0_str)
if has_second_input:
if op_type == "Muli":
input1 = KernelParser.parse_immediate(input1_str)
else:
# For other operations, parse as Polys
input1 = KernelParser.parse_polys(input1_str)

if op_type not in KernelParser.high_op_map:
raise ValueError(f"Unsupported HighOp type: {op_type}")

# Instantiate the HighOp object
if has_second_input:
return KernelParser.high_op_map[op_type](
context=context, output=output, input0=input0, input1=input1
)
# For operations without a second input, we can ignore the input1 parameter
return KernelParser.high_op_map[op_type](
context=context, output=output, input0=input0
)

@staticmethod
def parse_kernel(kernel_str: str) -> HighOp:
"""Parse a kernel string and return the corresponding HighOp object."""
return KernelParser.parse_high_op(kernel_str)
114 changes: 114 additions & 0 deletions p-isa_tools/kerngen/kerngraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#! /usr/bin/env python3
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""
kerngraph.py

This script provides a command-line tool for parsing kernel strings from standard input using the KernelParser class.
Future improvements may include graph representation of the parsed kernels and optimization.

Functions:
parse_args():
Parses command-line arguments.
Returns:
argparse.Namespace: Parsed arguments including debug flag.

main(args):
Reads lines from standard input, parses each line as a kernel string using KernelParser,
and prints the successfully parsed kernel objects. If parsing fails for a line, an error
message is printed if debug mode is enabled.

Usage:
Run the script and provide kernel strings via standard input. Use the '-d' or '--debug' flag
to enable debug output for parsing errors.

Example:
$ cat bgv.add.high | ./kerngen.py | ./kerngraph.py
"""


import argparse
import sys
from kernel_parser.parser import KernelParser
from kernel_optimization.loops import loop_interchange
from const.options import LoopKey
from pisa_generators.basic import mixed_to_pisa_ops


def parse_args():
"""Parse arguments from the commandline"""
parser = argparse.ArgumentParser(description="Kernel Graph Parser")
parser.add_argument("-d", "--debug", action="store_true", help="Enable Debug Print")
parser.add_argument(
"-t",
"--target",
nargs="*",
default=[],
# Composition high ops such are ntt, mod, and relin are not currently supported
choices=["add", "sub", "mul", "muli", "copy"], # currently supports single ops
help="List of high_op names",
)
parser.add_argument(
"-p",
"--primary",
type=LoopKey,
default=LoopKey.PART,
choices=list(LoopKey),
help="Primary key for loop interchange (default: PART, options: RNS, PART))",
)
parser.add_argument(
"-s",
"--secondary",
type=LoopKey,
default=None,
choices=list(LoopKey) + list([None]),
help="Secondary key for loop interchange (default: None, Options: RNS, PART)",
)
parsed_args = parser.parse_args()
# verify that primary and secondary keys are not the same
if parsed_args.primary == parsed_args.secondary:
raise ValueError("Primary and secondary keys cannot be the same.")
return parser.parse_args()


def main(args):
"""Main function to read input and parse each line with KernelParser."""
input_lines = sys.stdin.read().strip().splitlines()
valid_kernels = []

for line in input_lines:
try:
kernel = KernelParser.parse_kernel(line)
valid_kernels.append(kernel)
except ValueError as e:
if args.debug:
print(f"Error parsing line: {line}\nReason: {e}")
continue # Skip invalid lines

if not valid_kernels:
print("No valid kernel strings were parsed.")
else:
if args.debug:
print(
f"# Reordered targets {args.target} with primary key {args.primary} and secondary key {args.secondary}"
)
for kernel in valid_kernels:
if args.target and any(
target.capitalize() in str(kernel) for target in args.target
):
kernel = loop_interchange(
kernel.to_pisa(),
primary_key=args.primary,
secondary_key=args.secondary,
)
for pisa in mixed_to_pisa_ops(kernel):
print(pisa)
else:
for pisa in kernel.to_pisa():
print(pisa)


if __name__ == "__main__":
cmdline_args = parse_args()
main(cmdline_args)
Loading