| 
 | 1 | +# Copyright (C) 2024 Intel Corporation  | 
 | 2 | +# SPDX-License-Identifier: Apache-2.0  | 
 | 3 | + | 
 | 4 | +"""Module for parsing kernel commands from Kerngen"""  | 
 | 5 | + | 
 | 6 | +import re  | 
 | 7 | +from high_parser.types import Immediate, KernelContext, Polys, Context  | 
 | 8 | +from pisa_generators.basic import Copy, HighOp, Add, Sub, Mul, Muli  | 
 | 9 | +from pisa_generators.ntt import NTT, INTT  | 
 | 10 | +from pisa_generators.square import Square  | 
 | 11 | +from pisa_generators.relin import Relin  | 
 | 12 | +from pisa_generators.rotate import Rotate  | 
 | 13 | +from pisa_generators.mod import Mod, ModUp  | 
 | 14 | +from pisa_generators.rescale import Rescale  | 
 | 15 | + | 
 | 16 | + | 
 | 17 | +class KernelParser:  | 
 | 18 | +    """Parser for kernel operations."""  | 
 | 19 | + | 
 | 20 | +    high_op_map = {  | 
 | 21 | +        "Add": Add,  | 
 | 22 | +        "Mul": Mul,  | 
 | 23 | +        "Muli": Muli,  | 
 | 24 | +        "Copy": Copy,  | 
 | 25 | +        "Sub": Sub,  | 
 | 26 | +        "Square": Square,  | 
 | 27 | +        "NTT": NTT,  | 
 | 28 | +        "INTT": INTT,  | 
 | 29 | +        "Mod": Mod,  | 
 | 30 | +        "ModUp": ModUp,  | 
 | 31 | +        "Relin": Relin,  | 
 | 32 | +        "Rotate": Rotate,  | 
 | 33 | +        "Rescale": Rescale,  | 
 | 34 | +    }  | 
 | 35 | + | 
 | 36 | +    @staticmethod  | 
 | 37 | +    def parse_context(context_str: str) -> KernelContext:  | 
 | 38 | +        """Parse the context string and return a KernelContext object."""  | 
 | 39 | +        context_match = re.search(  | 
 | 40 | +            r"KernelContext\(scheme='(?P<scheme>\w+)', "  | 
 | 41 | +            + r"poly_order=(?P<poly_order>\w+), key_rns=(?P<key_rns>\w+), "  | 
 | 42 | +            r"current_rns=(?P<current_rns>\w+), .*? label='(?P<label>\w+)'\)",  | 
 | 43 | +            context_str,  | 
 | 44 | +        )  | 
 | 45 | +        if not context_match:  | 
 | 46 | +            raise ValueError("Invalid context string format.")  | 
 | 47 | +        return KernelContext.from_context(  | 
 | 48 | +            Context(  | 
 | 49 | +                scheme=context_match.group("scheme"),  | 
 | 50 | +                poly_order=int(context_match.group("poly_order")),  | 
 | 51 | +                key_rns=int(context_match.group("key_rns")),  | 
 | 52 | +                current_rns=int(context_match.group("current_rns")),  | 
 | 53 | +                max_rns=int(context_match.group("key_rns")) - 1,  | 
 | 54 | +            ),  | 
 | 55 | +            label=context_match.group("label"),  | 
 | 56 | +        )  | 
 | 57 | + | 
 | 58 | +    @staticmethod  | 
 | 59 | +    def parse_polys(polys_str: str) -> Polys:  | 
 | 60 | +        """Parse the Polys string and return a Polys object."""  | 
 | 61 | +        polys_match = re.search(  | 
 | 62 | +            r"Polys\(name=(.*?), parts=(\d+), rns=(\d+)\)", polys_str  | 
 | 63 | +        )  | 
 | 64 | +        if not polys_match:  | 
 | 65 | +            raise ValueError("Invalid Polys string format.")  | 
 | 66 | +        name, parts, rns = polys_match.groups()  | 
 | 67 | +        return Polys(name=name, parts=int(parts), rns=int(rns))  | 
 | 68 | + | 
 | 69 | +    @staticmethod  | 
 | 70 | +    def parse_immediate(immediate_str: str) -> Immediate:  | 
 | 71 | +        """Parse the Immediate string and return an Immediate object."""  | 
 | 72 | +        immediate_match = re.search(  | 
 | 73 | +            r"Immediate\(name='(?P<name>\w+)', rns=(?P<rns>\w+)\)", immediate_str  | 
 | 74 | +        )  | 
 | 75 | +        if not immediate_match:  | 
 | 76 | +            raise ValueError("Invalid Immediate string format.")  | 
 | 77 | +        name, rns = immediate_match.group("name"), immediate_match.group("rns")  | 
 | 78 | +        rns = None if rns == "None" else int(rns)  | 
 | 79 | +        return Immediate(name=name, rns=rns)  | 
 | 80 | + | 
 | 81 | +    @staticmethod  | 
 | 82 | +    def parse_high_op(kernel_str: str) -> HighOp:  | 
 | 83 | +        """Parse a HighOp kernel string and return the corresponding object."""  | 
 | 84 | +        pattern = (  | 
 | 85 | +            r"### Kernel \(\d+\): (?P<op_type>\w+)\(context=(KernelContext\(.*?\)), "  | 
 | 86 | +            r"output=(Polys\(.*?\)), input0=(Polys\(.*?\))"  | 
 | 87 | +        )  | 
 | 88 | +        has_second_input = False  | 
 | 89 | +        # Check if the kernel string contains "input1" or not  | 
 | 90 | +        if "input1" not in kernel_str:  | 
 | 91 | +            # Match the operation type and its arguments  | 
 | 92 | +            high_op_match = re.search(pattern, kernel_str)  | 
 | 93 | +        else:  | 
 | 94 | +            # Adjust the pattern to include input1  | 
 | 95 | +            pattern += r", input1=(Polys\(.*?\)\)|Immediate\(.*?\)\))"  | 
 | 96 | +            # Match the operation type and its arguments  | 
 | 97 | +            high_op_match = re.search(pattern, kernel_str)  | 
 | 98 | +            has_second_input = True  | 
 | 99 | + | 
 | 100 | +        if not high_op_match:  | 
 | 101 | +            raise ValueError(f"Invalid kernel string format: {kernel_str}.")  | 
 | 102 | + | 
 | 103 | +        op_type = high_op_match.group("op_type")  | 
 | 104 | +        context_str, output_str, input0_str = high_op_match.groups()[1:4]  | 
 | 105 | + | 
 | 106 | +        if has_second_input:  | 
 | 107 | +            input1_str = high_op_match.group(5)  | 
 | 108 | + | 
 | 109 | +        # Parse the components  | 
 | 110 | +        context = KernelParser.parse_context(context_str)  | 
 | 111 | +        output = KernelParser.parse_polys(output_str)  | 
 | 112 | +        input0 = KernelParser.parse_polys(input0_str)  | 
 | 113 | +        if has_second_input:  | 
 | 114 | +            if op_type == "Muli":  | 
 | 115 | +                input1 = KernelParser.parse_immediate(input1_str)  | 
 | 116 | +            else:  | 
 | 117 | +                # For other operations, parse as Polys  | 
 | 118 | +                input1 = KernelParser.parse_polys(input1_str)  | 
 | 119 | + | 
 | 120 | +        if op_type not in KernelParser.high_op_map:  | 
 | 121 | +            raise ValueError(f"Unsupported HighOp type: {op_type}")  | 
 | 122 | + | 
 | 123 | +        # Instantiate the HighOp object  | 
 | 124 | +        if has_second_input:  | 
 | 125 | +            return KernelParser.high_op_map[op_type](  | 
 | 126 | +                context=context, output=output, input0=input0, input1=input1  | 
 | 127 | +            )  | 
 | 128 | +        # For operations without a second input, we can ignore the input1 parameter  | 
 | 129 | +        return KernelParser.high_op_map[op_type](  | 
 | 130 | +            context=context, output=output, input0=input0  | 
 | 131 | +        )  | 
 | 132 | + | 
 | 133 | +    @staticmethod  | 
 | 134 | +    def parse_kernel(kernel_str: str) -> HighOp:  | 
 | 135 | +        """Parse a kernel string and return the corresponding HighOp object."""  | 
 | 136 | +        return KernelParser.parse_high_op(kernel_str)  | 
0 commit comments