Skip to content

Commit ad33005

Browse files
renamed optional.py to options_handler.py
1 parent c7c8268 commit ad33005

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""A module to process optional context parameters"""
5+
6+
from abc import ABC, abstractmethod
7+
8+
9+
class OptionalContext(ABC):
10+
"""Abstract class to hold optional parameters for context"""
11+
12+
op_name: str = ""
13+
op_value = None
14+
15+
@abstractmethod
16+
def validate(self, value):
17+
"""Abstract method, which defines how to valudate a value"""
18+
19+
20+
class OptionalInt(OptionalContext):
21+
"""Holds a key/value pair for optional context parameters of type Int"""
22+
23+
def __init__(self, name: str, min_val: int, max_val: int):
24+
self.min_val = min_val
25+
self.max_val = max_val
26+
self._op_name = name
27+
28+
def validate(self, value: int):
29+
"""Validate numeric options with min/max range"""
30+
if self.min_val < value < self.max_val:
31+
return True
32+
return False
33+
34+
@property
35+
def op_value(self):
36+
"""Get op_value"""
37+
return self._op_value
38+
39+
@op_value.setter
40+
def op_value(self, value: int):
41+
"""Set op_value"""
42+
if self.validate(value):
43+
self._op_value = int(value)
44+
else:
45+
raise ValueError(
46+
"{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}"
47+
)
48+
49+
50+
class OptionalIntMinMax:
51+
"""Holds min/max values for optional context parameters for type Int"""
52+
53+
int_min: int
54+
int_max: int
55+
default: int | None
56+
57+
def __init__(self, int_min: int, int_max: int, default: int | None):
58+
self.int_min = int_min
59+
self.int_max = int_max
60+
self.default = default
61+
62+
63+
class OptionalFactory(ABC):
64+
"""Abstract class that creates OptionaContext objects"""
65+
66+
MAX_KRNS_DELTA = 128
67+
MAX_DIGIT = 3
68+
MIN_KRNS_DELTA = MIN_DIGIT = 0
69+
optionals = {
70+
"krns_delta": OptionalIntMinMax(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0),
71+
"num_digits": OptionalIntMinMax(MIN_DIGIT, MAX_DIGIT, None),
72+
}
73+
74+
@staticmethod
75+
@abstractmethod
76+
def create(name: str, value) -> OptionalContext:
77+
"""Abstract method, to define how to create an OptionalContext"""
78+
79+
80+
class OptionalIntFactory(OptionalFactory):
81+
"""Optional context parameter factory for Int types"""
82+
83+
@staticmethod
84+
def create(name: str, value: int) -> OptionalInt:
85+
"""Create a OptionalInt object based on key/value pair"""
86+
if name in OptionalIntFactory.optionals:
87+
if isinstance(OptionalIntFactory.optionals[name], OptionalIntMinMax):
88+
optional_int = OptionalInt(
89+
name,
90+
OptionalIntFactory.optionals[name].int_min,
91+
OptionalIntFactory.optionals[name].int_max,
92+
)
93+
optional_int.op_value = value
94+
# add other optional types here
95+
else:
96+
raise KeyError(f"Invalid optional name for Context: '{name}'")
97+
return optional_int
98+
99+
100+
class OptionalFactoryDispatcher:
101+
"""An object dispatcher based on key/value pair for comptional context parameters"""
102+
103+
@staticmethod
104+
def create(name: str, value) -> OptionalContext:
105+
"""Creat an OptionalContext object based on the type of value passed in"""
106+
if value.isnumeric():
107+
value = int(value)
108+
match value:
109+
case int():
110+
return OptionalIntFactory.create(name, value)
111+
case _:
112+
raise ValueError(f"Current type '{type(value)}' is not supported.")
113+
114+
115+
class OptionalsParser:
116+
"""Parses key/value pairs and returns a dictionary of optiona parameters"""
117+
118+
@staticmethod
119+
def __default_values():
120+
default_dict = {}
121+
for key, val in OptionalFactory.optionals.items():
122+
default_dict[key] = val.default
123+
return default_dict
124+
125+
@staticmethod
126+
def parse(optionals: list[str]):
127+
"""Parse the optional parameter list and return a dictionary with values"""
128+
output_dict = OptionalsParser.__default_values()
129+
for option in optionals:
130+
try:
131+
key, value = option.split("=")
132+
output_dict[key] = OptionalFactoryDispatcher.create(key, value).op_value
133+
except ValueError as err:
134+
raise ValueError(
135+
f"Optional variables must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'"
136+
) from err
137+
return output_dict

kerngen/high_parser/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .pisa_operations import PIsaOp
1515

16-
from .optional import OptionalsParser
16+
from .options_handler import OptionalsParser
1717

1818

1919
class PolyOutOfBoundsError(Exception):

0 commit comments

Comments
 (0)