Skip to content
139 changes: 139 additions & 0 deletions kerngen/high_parser/options_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""A module to process optional key/value dictionary parameters"""

from abc import ABC, abstractmethod


class OptionsDict(ABC):
"""Abstract class to hold the options key/value pairs"""

op_name: str = ""
op_value = None

@abstractmethod
def validate(self, value):
"""Abstract method, which defines how to valudate a value"""


class OptionsIntDict(OptionsDict):
"""Holds a key/value pair for options of type Int"""

def __init__(self, name: str, min_val: int, max_val: int):
self.min_val = min_val
self.max_val = max_val
self._op_name = name

def validate(self, value: int):
"""Validate numeric options with min/max range"""
if self.min_val < value < self.max_val:
return True
return False

@property
def op_value(self):
"""Get op_value"""
return self._op_value

@op_value.setter
def op_value(self, value: int):
"""Set op_value"""
if self.validate(value):
self._op_value = int(value)
else:
raise ValueError(
"{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}"
)


class OptionsIntBounds:
"""Holds min/max/default values for options of type Int"""

int_min: int
int_max: int
default: int | None

def __init__(self, int_min: int, int_max: int, default: int | None):
self.int_min = int_min
self.int_max = int_max
self.default = default


class OptionsDictFactory(ABC):
"""Abstract class that creates OptionsDict objects"""

MAX_KRNS_DELTA = 128
MAX_DIGIT = 3
MIN_KRNS_DELTA = MIN_DIGIT = 0
options = {
"krns_delta": OptionsIntBounds(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0),
"num_digits": OptionsIntBounds(MIN_DIGIT, MAX_DIGIT, None),
}

@staticmethod
@abstractmethod
def create(name: str, value) -> OptionsDict:
"""Abstract method, to define how to create an OptionsDict"""


class OptionsIntDictFactory(OptionsDictFactory):
"""OptionsDict parameter factory for Int types"""

@staticmethod
def create(name: str, value: int) -> OptionsIntDict:
"""Create a OptionsInt object based on key/value pair"""
if name in OptionsIntDictFactory.options:
if isinstance(OptionsIntDictFactory.options[name], OptionsIntBounds):
options_int = OptionsIntDict(
name,
OptionsIntDictFactory.options[name].int_min,
OptionsIntDictFactory.options[name].int_max,
)
options_int.op_value = value
# add other options types here
else:
raise KeyError(f"Invalid options name: '{name}'")
return options_int


class OptionsDictFactoryDispatcher:
"""An object dispatcher based on key/value pair"""

@staticmethod
def create(name: str, value) -> OptionsDict:
"""Creat an OptionsDict object based on the type of value passed in"""
if value.isnumeric():
value = int(value)
match value:
case int():
return OptionsIntDictFactory.create(name, value)
case _:
raise ValueError(f"Current type '{type(value)}' is not supported.")


class OptionsDictParser:
"""Parses key/value pairs and returns a dictionary of options"""

@staticmethod
def __default_values():
default_dict = {}
for key, val in OptionsDictFactory.options.items():
default_dict[key] = val.default
return default_dict

@staticmethod
def parse(options: list[str]):
"""Parse the options list and return a dictionary with values"""
output_dict = OptionsDictParser.__default_values()
for option in options:
try:
key, value = option.split("=")
output_dict[key] = OptionsDictFactoryDispatcher.create(
key, value
).op_value
except ValueError as err:
raise ValueError(
f"Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'"
) from err
return output_dict
18 changes: 10 additions & 8 deletions kerngen/high_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .pisa_operations import PIsaOp

from .options_handler import OptionsDictParser


class PolyOutOfBoundsError(Exception):
"""Exception for Poly attributes being out of bounds"""
Expand Down Expand Up @@ -149,18 +151,15 @@ class Context(BaseModel):
scheme: str
poly_order: int # the N
max_rns: int
# optional vars for context
key_rns: int | None
num_digits: int | None

@classmethod
def from_string(cls, line: str):
"""Construct context from a string"""
scheme, poly_order, max_rns, *optional = line.split()
try:
krns, *rest = optional
except ValueError:
krns = None
if optional != [] and rest != []:
raise ValueError(f"too many parameters for context given: {line}")
scheme, poly_order, max_rns, *optionals = line.split()
optional_dict = OptionsDictParser.parse(optionals)
int_poly_order = int(poly_order)
if (
int_poly_order < MIN_POLY_SIZE
Expand All @@ -172,12 +171,15 @@ def from_string(cls, line: str):
)

int_max_rns = int(max_rns)
int_key_rns = int_max_rns + int(krns) if krns else None
int_key_rns = int_max_rns
int_key_rns += optional_dict.pop("krns_delta")

return cls(
scheme=scheme.upper(),
poly_order=int_poly_order,
max_rns=int_max_rns,
key_rns=int_key_rns,
**optional_dict,
)

@property
Expand Down
45 changes: 44 additions & 1 deletion kerngen/tests/test_kerngen.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,49 @@ def test_multiple_contexts(kerngen_path):
assert result.returncode != 0


def test_context_options_without_key(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 1\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'"
in result.stderr
)
assert result.returncode != 0


def test_context_unsupported_options_variable(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert "Invalid options name: 'test'" in result.stderr
assert result.returncode != 0


@pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"])
def test_context_option_invalid_values(kerngen_path, invalid):
"""Test kerngen raises an exception if value is out of range for correct key"""
input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
f"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'"
in result.stderr
)
assert result.returncode != 0


def test_unrecognised_opname(kerngen_path):
"""Test kerngen raises an exception when receiving an unrecognised
opname"""
Expand Down Expand Up @@ -99,7 +142,7 @@ def test_invalid_scheme(kerngen_path):
@pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18])
def test_invalid_poly_order(kerngen_path, invalid_poly):
"""Poly order should be powers of two >= 2^14 and <= 2^17"""
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 2\nADD a b c\n"
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand Down