diff --git a/kerngen/high_parser/options_handler.py b/kerngen/high_parser/options_handler.py new file mode 100644 index 00000000..af627535 --- /dev/null +++ b/kerngen/high_parser/options_handler.py @@ -0,0 +1,139 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""A module to process optional key/value dictionary parameters""" + +from abc import ABC, abstractmethod + + +class OptionsDict(ABC): + """Abstract class to hold the options key/value pairs""" + + op_name: str = "" + op_value = None + + @abstractmethod + def validate(self, value): + """Abstract method, which defines how to valudate a value""" + + +class OptionsIntDict(OptionsDict): + """Holds a key/value pair for options of type Int""" + + def __init__(self, name: str, min_val: int, max_val: int): + self.min_val = min_val + self.max_val = max_val + self._op_name = name + + def validate(self, value: int): + """Validate numeric options with min/max range""" + if self.min_val < value < self.max_val: + return True + return False + + @property + def op_value(self): + """Get op_value""" + return self._op_value + + @op_value.setter + def op_value(self, value: int): + """Set op_value""" + if self.validate(value): + self._op_value = int(value) + else: + raise ValueError( + "{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}" + ) + + +class OptionsIntBounds: + """Holds min/max/default values for options of type Int""" + + int_min: int + int_max: int + default: int | None + + def __init__(self, int_min: int, int_max: int, default: int | None): + self.int_min = int_min + self.int_max = int_max + self.default = default + + +class OptionsDictFactory(ABC): + """Abstract class that creates OptionsDict objects""" + + MAX_KRNS_DELTA = 128 + MAX_DIGIT = 3 + MIN_KRNS_DELTA = MIN_DIGIT = 0 + options = { + "krns_delta": OptionsIntBounds(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0), + "num_digits": OptionsIntBounds(MIN_DIGIT, MAX_DIGIT, None), + } + + @staticmethod + @abstractmethod + def create(name: str, value) -> OptionsDict: + """Abstract method, to define how to create an OptionsDict""" + + +class OptionsIntDictFactory(OptionsDictFactory): + """OptionsDict parameter factory for Int types""" + + @staticmethod + def create(name: str, value: int) -> OptionsIntDict: + """Create a OptionsInt object based on key/value pair""" + if name in OptionsIntDictFactory.options: + if isinstance(OptionsIntDictFactory.options[name], OptionsIntBounds): + options_int = OptionsIntDict( + name, + OptionsIntDictFactory.options[name].int_min, + OptionsIntDictFactory.options[name].int_max, + ) + options_int.op_value = value + # add other options types here + else: + raise KeyError(f"Invalid options name: '{name}'") + return options_int + + +class OptionsDictFactoryDispatcher: + """An object dispatcher based on key/value pair""" + + @staticmethod + def create(name: str, value) -> OptionsDict: + """Creat an OptionsDict object based on the type of value passed in""" + if value.isnumeric(): + value = int(value) + match value: + case int(): + return OptionsIntDictFactory.create(name, value) + case _: + raise ValueError(f"Current type '{type(value)}' is not supported.") + + +class OptionsDictParser: + """Parses key/value pairs and returns a dictionary of options""" + + @staticmethod + def __default_values(): + default_dict = {} + for key, val in OptionsDictFactory.options.items(): + default_dict[key] = val.default + return default_dict + + @staticmethod + def parse(options: list[str]): + """Parse the options list and return a dictionary with values""" + output_dict = OptionsDictParser.__default_values() + for option in options: + try: + key, value = option.split("=") + output_dict[key] = OptionsDictFactoryDispatcher.create( + key, value + ).op_value + except ValueError as err: + raise ValueError( + f"Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'" + ) from err + return output_dict diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index 5cca8d7c..ade46c5d 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -13,6 +13,8 @@ from .pisa_operations import PIsaOp +from .options_handler import OptionsDictParser + class PolyOutOfBoundsError(Exception): """Exception for Poly attributes being out of bounds""" @@ -149,18 +151,15 @@ class Context(BaseModel): scheme: str poly_order: int # the N max_rns: int + # optional vars for context key_rns: int | None + num_digits: int | None @classmethod def from_string(cls, line: str): """Construct context from a string""" - scheme, poly_order, max_rns, *optional = line.split() - try: - krns, *rest = optional - except ValueError: - krns = None - if optional != [] and rest != []: - raise ValueError(f"too many parameters for context given: {line}") + scheme, poly_order, max_rns, *optionals = line.split() + optional_dict = OptionsDictParser.parse(optionals) int_poly_order = int(poly_order) if ( int_poly_order < MIN_POLY_SIZE @@ -172,12 +171,15 @@ def from_string(cls, line: str): ) int_max_rns = int(max_rns) - int_key_rns = int_max_rns + int(krns) if krns else None + int_key_rns = int_max_rns + int_key_rns += optional_dict.pop("krns_delta") + return cls( scheme=scheme.upper(), poly_order=int_poly_order, max_rns=int_max_rns, key_rns=int_key_rns, + **optional_dict, ) @property diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index b5b113f7..19dc69c9 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -69,6 +69,49 @@ def test_multiple_contexts(kerngen_path): assert result.returncode != 0 +def test_context_options_without_key(kerngen_path): + """Test kerngen raises an exception when more than one context is given""" + input_string = "CONTEXT BGV 16384 4 1\nData a 2\n" + result = execute_process( + [kerngen_path], + data_in=input_string, + ) + assert not result.stdout + assert ( + "ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'" + in result.stderr + ) + assert result.returncode != 0 + + +def test_context_unsupported_options_variable(kerngen_path): + """Test kerngen raises an exception when more than one context is given""" + input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n" + result = execute_process( + [kerngen_path], + data_in=input_string, + ) + assert not result.stdout + assert "Invalid options name: 'test'" in result.stderr + assert result.returncode != 0 + + +@pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"]) +def test_context_option_invalid_values(kerngen_path, invalid): + """Test kerngen raises an exception if value is out of range for correct key""" + input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n" + result = execute_process( + [kerngen_path], + data_in=input_string, + ) + assert not result.stdout + assert ( + f"ValueError: Options must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'" + in result.stderr + ) + assert result.returncode != 0 + + def test_unrecognised_opname(kerngen_path): """Test kerngen raises an exception when receiving an unrecognised opname""" @@ -99,7 +142,7 @@ def test_invalid_scheme(kerngen_path): @pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18]) def test_invalid_poly_order(kerngen_path, invalid_poly): """Poly order should be powers of two >= 2^14 and <= 2^17""" - input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 2\nADD a b c\n" + input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n" result = execute_process( [kerngen_path], data_in=input_string,