diff --git a/assembler_tools/hec-assembler-tools/assembler/__init__.py b/assembler_tools/hec-assembler-tools/assembler/__init__.py new file mode 100644 index 00000000..4057dc01 --- /dev/null +++ b/assembler_tools/hec-assembler-tools/assembler/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/assembler_tools/hec-assembler-tools/assembler/common/counter.py b/assembler_tools/hec-assembler-tools/assembler/common/counter.py index 65ca61a7..0f16ec8d 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/counter.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/counter.py @@ -1,5 +1,18 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Counter module providing sequential value generation with reset capability. + +This module implements a Counter class that allows for the creation and management of +counters that generate sequential values. These counters can be individually reset +or collectively reset through the class interface, making them useful for various +numbering and sequence generation tasks in the assembler. +""" import itertools +from typing import Set, Optional + class Counter: """ @@ -18,7 +31,8 @@ class CounterIter: This iterator starts at a specified value and increments by a specified step. It can be reset to start over from its initial start value. """ - def __init__(self, start = 0, step = 1): + + def __init__(self, start=0, step=1): """ Initializes a new CounterIter object. @@ -28,7 +42,7 @@ def __init__(self, start = 0, step = 1): """ self.__start = start self.__step = step - self.__counter = None # itertools.counter + self.__counter = None # itertools.counter self.reset() def __next__(self): @@ -66,10 +80,10 @@ def reset(self): """ self.__counter = itertools.count(self.start, self.step) - __counters = set() + __counters: Set["Counter.CounterIter"] = set() @classmethod - def count(cls, start = 0, step = 1) -> CounterIter: + def count(cls, start=0, step=1) -> "Counter.CounterIter": """ Creates a new counter iterator that returns evenly spaced values. @@ -85,7 +99,7 @@ def count(cls, start = 0, step = 1) -> CounterIter: return retval @classmethod - def reset(cls, counter: CounterIter = None): + def reset(cls, counter: Optional["Counter.CounterIter"] = None): """ Reset the specified counter, or all counters if none is specified. @@ -93,9 +107,9 @@ def reset(cls, counter: CounterIter = None): over from their respective `start` values. Args: - counter (CounterIter, optional): The counter to reset. + counter (Optional[CounterIter], optional): The counter to reset. If None, all counters are reset. """ - counters_to_reset = cls.__counters if counter is None else { counter } + counters_to_reset = cls.__counters if counter is None else {counter} for c in counters_to_reset: c.reset() diff --git a/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py index 9d845d22..26f8bf71 100644 --- a/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py +++ b/assembler_tools/hec-assembler-tools/assembler/common/priority_queue.py @@ -1,8 +1,21 @@ -import heapq +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +""" +Priority queue implementation with support for task prioritization, ordering, and updating. + +This module provides a priority queue that allows tasks to be added with priorities, +and supports operations to update, remove, and retrieve tasks based on their priorities. +""" + +import heapq import bisect import itertools +from typing import List, Dict, Optional, Tuple, Any + class PriorityQueue: + # pylint: disable=invalid-name """ A priority queue implementation that supports task prioritization and ordering. @@ -17,6 +30,7 @@ class __PriorityQueueIter: This iterator allows for iterating over the tasks in the priority queue while ensuring that the queue's size does not change during iteration. """ + def __init__(self, pq, removed): """ Initializes the iterator with the priority queue and removed marker. @@ -45,16 +59,17 @@ def __next__(self): raise RuntimeError("PriorityQueue changed size during iteration.") # Skip all removed tasks - while self.__current < len(self.__pq) \ - and self.__pq[self.__current][-1] is self.__removed: + while ( + self.__current < len(self.__pq) + and self.__pq[self.__current][-1] is self.__removed + ): self.__current += 1 if self.__current >= len(self.__pq): raise StopIteration priority, _, task = self.__pq[self.__current] - self.__current += 1 # point to nex element + self.__current += 1 # point to nex element return (priority, task) - class __PriorityTracker: """ A helper class to track tasks by their priority. @@ -62,12 +77,17 @@ class __PriorityTracker: This class maintains a mapping of priorities to tasks and supports operations to add, find, and remove tasks based on their priority. """ + def __init__(self): """ Initializes the priority tracker with empty mappings. """ - self.__priority_dict = {} # dict(int, SortedList(task)): maps priority to unordered set of tasks with same priority - self.__priority_dict_set = {} # dict(int, set(task)): maps priority to unordered set of tasks with same priority + self.__priority_dict = ( + {} + ) # dict(int, SortedList(task)): maps priority to unordered set of tasks with same priority + self.__priority_dict_set = ( + {} + ) # dict(int, set(task)): maps priority to unordered set of tasks with same priority def find(self, priority: int) -> object: """ @@ -79,7 +99,11 @@ def find(self, priority: int) -> object: Returns: object: A task with the specified priority, or None if not found. """ - return next(iter(self.__priority_dict[priority]))[1] if priority in self.__priority_dict else None + return ( + next(iter(self.__priority_dict[priority]))[1] + if priority in self.__priority_dict + else None + ) def push(self, priority: int, tie_breaker: tuple, task: object): """ @@ -94,7 +118,7 @@ def push(self, priority: int, tie_breaker: tuple, task: object): ValueError: If the task is None. """ if task is None: - raise ValueError('`task` cannot be `None`.') + raise ValueError("`task` cannot be `None`.") if priority not in self.__priority_dict: self.__priority_dict[priority] = [] @@ -104,7 +128,7 @@ def push(self, priority: int, tie_breaker: tuple, task: object): bisect.insort_right(self.__priority_dict[priority], (tie_breaker, task)) self.__priority_dict_set[priority].add(task) - def pop(self, priority: int, task = None) -> object: + def pop(self, priority: int, task=None) -> object: """ Removes a task with the specified priority. @@ -126,12 +150,20 @@ def pop(self, priority: int, task = None) -> object: assert priority in self.__priority_dict_set if task: # Find index for task - idx = next((i for i, (_, contained_task) in enumerate(self.__priority_dict[priority]) if contained_task == task), - len(self.__priority_dict[priority])) + idx = next( + ( + i + for i, (_, contained_task) in enumerate( + self.__priority_dict[priority] + ) + if contained_task == task + ), + len(self.__priority_dict[priority]), + ) if idx >= len(self.__priority_dict[priority]): - raise ValueError('`task` not found in priority.') + raise ValueError("`task` not found in priority.") _, retval = self.__priority_dict[priority].pop(idx) - assert(retval == task) + assert retval == task else: # Remove first task _, retval = self.__priority_dict[priority].pop(0) @@ -144,38 +176,44 @@ def pop(self, priority: int, task = None) -> object: self.__priority_dict_set.pop(priority) return retval - __REMOVED = object() # Placeholder for a removed task + __REMOVED = object() # Placeholder for a removed task - def __init__(self, queue: list = None): + def __init__(self, queue: Optional[List[Tuple[int, Any]]] = None): """ Creates a new PriorityQueue object. Args: - queue (list, optional): A list of (priority, task) tuples to initialize the queue. + queue (Optional[List[Tuple[int, Any]]], optional): A list of (priority, task) tuples to initialize the queue. This is an O(len(queue)) operation. Raises: ValueError: If any task in the queue is None. """ # entry: [priority: int, nonce: int, task: hashable_object] - self.__pq = [] # list(entry) - List of entries arranged in a heap - self.__entry_finder = {} # dictionary(task: Hashable_object, entry) - mapping of tasks to entries - self.__priority_tracker = PriorityQueue.__PriorityTracker() # Tracks tasks by priority - self.__counter: int = itertools.count(1) # Unique sequence count + self.__pq: List[List[Any]] = ( + [] + ) # list(entry) - List of entries arranged in a heap + self.__entry_finder: Dict[Any, List[Any]] = ( + {} + ) # dictionary(task: Hashable_object, entry) - mapping of tasks to entries + self.__priority_tracker = ( + PriorityQueue.__PriorityTracker() + ) # Tracks tasks by priority + self.__counter = itertools.count(1) # Unique sequence count if queue: for priority, task in queue: if task is None: - raise ValueError('`queue`: tasks cannot be `None`.') + raise ValueError("`queue`: tasks cannot be `None`.") count = next(self.__counter) - entry = [priority, ((0, ), count), task] + entry = [priority, ((0,), count), task] self.__entry_finder[task] = entry - self.__priority_tracker.push(*entry)#priority, task) - self.__pq.append() + self.__priority_tracker.push(priority, ((0,), count), task) + heapq.heappush(self.__pq, entry) heapq.heapify(self.__pq) def __bool__(self): - """ + """ Returns True if the priority queue is not empty, False otherwise. Returns: @@ -220,19 +258,18 @@ def __repr__(self): Returns: str: A string representation of the queue. """ - return '<{} object at {}>(len={}, pq={})'.format(type(self).__name__, - hex(id(self)), - len(self), - self.__pq) + return f"<{type(self).__name__} object at {hex(id(self))}>(len={len(self)}, pq={self.__pq})" - def push(self, priority: int, task: object, tie_breaker: tuple = None): #ahead: bool = None): + def push( + self, priority: int, task: object, tie_breaker: Optional[Tuple[int, ...]] = None + ): """ Adds a new task or update the priority of an existing task. Args: priority (int): The priority of the task. task (object): The task to add or update. - tie_breaker (tuple, optional): A tuple of ints to use as a tie breaker for tasks + tie_breaker (Optional[Tuple[int, ...]], optional): A tuple of ints to use as a tie breaker for tasks of the same priority. Defaults to (0,) if None. Raises: @@ -240,17 +277,15 @@ def push(self, priority: int, task: object, tie_breaker: tuple = None): #ahead: TypeError: If the tie_breaker is not a tuple of ints or None. """ if task is None: - raise ValueError('`task` cannot be `None`.') - if tie_breaker is not None \ - and not all(isinstance(x, int) for x in tie_breaker): - raise TypeError('`tie_breaker` expected tuple of `int`s, or `None`.') + raise ValueError("`task` cannot be `None`.") + if tie_breaker is not None and not all(isinstance(x, int) for x in tie_breaker): + raise TypeError("`tie_breaker` expected tuple of `int`s, or `None`.") b_add_needed = True if task in self.__entry_finder: old_priority, (old_tie_breaker, _), _ = self.__entry_finder[task] if tie_breaker is None: tie_breaker = old_tie_breaker - if old_priority != priority \ - or tie_breaker != old_tie_breaker: + if old_priority != priority or tie_breaker != old_tie_breaker: self.remove(task) else: # same task without priority change detected: no need to add @@ -261,11 +296,13 @@ def push(self, priority: int, task: object, tie_breaker: tuple = None): #ahead: if b_add_needed: if len(self.__pq) == 0: - self.__counter: int = itertools.count(1) # restart sequence count when queue is empty + self.__counter = itertools.count( + 1 + ) # restart sequence count when queue is empty count = next(self.__counter) entry = [priority, (tie_breaker, count), task] self.__entry_finder[task] = entry - self.__priority_tracker.push(*entry)#priority, task) + self.__priority_tracker.push(priority, (tie_breaker, count), task) heapq.heappush(self.__pq, entry) def remove(self, task: object): @@ -281,15 +318,17 @@ def remove(self, task: object): # mark an existing task as PriorityQueue.__REMOVED. entry = self.__entry_finder.pop(task) priority, *_ = entry - self.__priority_tracker.pop(priority, task) # remove it from the priority tracker + self.__priority_tracker.pop( + priority, task + ) # remove it from the priority tracker entry[-1] = PriorityQueue.__REMOVED - def peek(self) -> tuple: + def peek(self) -> Optional[Tuple[int, Any]]: """ Returns the task with the lowest priority without removing it from the queue. Returns: - tuple: The (priority, task) pair of the task with the lowest priority, + Optional[Tuple[int, Any]]: The (priority, task) pair of the task with the lowest priority, or None if the queue is empty. """ # make sure head is not a removed task @@ -326,7 +365,7 @@ def pop(self) -> tuple: IndexError: If the queue is empty. """ task = PriorityQueue.__REMOVED - while task is PriorityQueue.__REMOVED: # make sure head is not a removed task + while task is PriorityQueue.__REMOVED: # make sure head is not a removed task priority, _, task = heapq.heappop(self.__pq) self.__entry_finder.pop(task) self.__priority_tracker.pop(priority, task) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py index a401ed56..7ac6bffe 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/__init__.py @@ -1,28 +1,35 @@ - -def tokenizeFromLine(line: str) -> list: - """ - Tokenizes a line of text and extracts any comment present. +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 - This function processes a line of text, removing line breaks and splitting the line - into tokens based on commas. It also identifies and extracts comments, which are - denoted by the pound symbol `#`. +"""Instruction module initialization and utilities.""" - Args: - line (str): Line of text to tokenize. +from typing import Tuple - Returns: - tuple: A tuple containing the tokens and the comment. The `tokens` are a tuple of strings, - and `comment` is a string. The `comment` is an empty string if no comment is found in the line. - """ - tokens = tuple() - comment = "" - if line: - line = ''.join(line.splitlines()) # remove line breaks - comment_idx = line.find('#') - if comment_idx >= 0: - # Found a comment - comment = line[comment_idx + 1:] - line = line[:comment_idx] - tokens = tuple(map(lambda s: s.strip(), line.split(','))) - retval = (tokens, comment) - return retval + +def tokenize_from_line(line: str) -> Tuple[Tuple[str, ...], str]: + """ + Tokenizes a line of text and extracts any comment present. + + This function processes a line of text, removing line breaks and splitting the line + into tokens based on commas. It also identifies and extracts comments, which are + denoted by the pound symbol `#`. + + Args: + line (str): Line of text to tokenize. + + Returns: + tuple: A tuple containing the tokens and the comment. The `tokens` are a tuple of strings, + and `comment` is a string. The `comment` is an empty string if no comment is found in the line. + """ + tokens: Tuple[str, ...] = tuple() + comment = "" + if line: + line = "".join(line.splitlines()) # remove line breaks + comment_idx = line.find("#") + if comment_idx >= 0: + # Found a comment + comment = line[comment_idx + 1 :] + line = line[:comment_idx] + tokens = tuple(map(lambda s: s.strip(), line.split(","))) + retval = (tokens, comment) + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py index 44526533..9d2b0837 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bload.py @@ -1,13 +1,17 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable + class Instruction(CInstruction): """ Encapsulates the `bload` CInstruction. The `bload` instruction loads metadata from the scratchpad to special registers in the register file. - + For more information, check the `bload` Specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bload.md @@ -18,7 +22,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name for the operation. @@ -27,15 +31,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "bload" - def __init__(self, - id: int, - col_num: int, - m_idx: int, - src: Variable, - mem_model, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + col_num: int, + m_idx: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `bload` CInstruction. @@ -53,7 +59,7 @@ def __init__(self, ValueError: If `mem_model` is None. """ if not mem_model: - raise ValueError('`mem_model` cannot be `None`.') + raise ValueError("`mem_model` cannot be `None`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -62,7 +68,7 @@ def __init__(self, self.col_num = col_num self.m_idx = m_idx self.__mem_model = mem_model - self._set_sources( [ src ] ) + self._set_sources([src]) def __repr__(self): """ @@ -71,19 +77,23 @@ def __repr__(self): Returns: str: A string representation. """ - assert(len(self.sources) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'col_num={}, m_idx={}, src={}, ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.col_num, - self.m_idx, - self.sources[0], - self.throughput, - self.latency) + assert len(self.sources) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "col_num={}, m_idx={}, src={}, " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.col_num, + self.m_idx, + self.sources[0], + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -96,7 +106,9 @@ def _set_dests(self, value): Raises: RuntimeError: Always, as `bload` does not have destination parameters. """ - raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") + raise RuntimeError( + f"Instruction `{self.name}` does not have destination parameters." + ) def _set_sources(self, value): """ @@ -109,9 +121,14 @@ def _set_sources(self, value): ValueError: If the value is not a list of the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -131,25 +148,34 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) - variable: Variable = self.sources[0] # expected sources to contain a Variable + variable: Variable = self.sources[0] # expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') + raise RuntimeError( + f'Null Access Violation: Variable "{variable}" not allocated in SPAD.' + ) if self.m_idx < 0: raise RuntimeError(f"Invalid negative index `m_idx`.") if self.col_num not in range(4): - raise RuntimeError(f"Invalid `col_num`: {self.col_num}. Must be in range [0, 4).") + raise RuntimeError( + f"Invalid `col_num`: {self.col_num}. Must be in range [0, 4)." + ) retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking = self.__mem_model.spad.getAccessTracking( + variable.spad_address + ) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bload spad_access_tracking.last_mload = None return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -162,20 +188,18 @@ def _toCASMISAFormat(self, *extra_args) -> str: Returns: str: The ASM format string of the instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # `op, target_idx, spad_src [# comment]` preamble = [] # Instruction sources - extra_args = (self.col_num, ) - extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + extra_args = (self.col_num,) + extra_args = tuple(src.to_casmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - extra_args = (self.m_idx, ) + extra_args - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + extra_args = (self.m_idx,) + extra_args + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py index e958ac60..31fa5108 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/bones.py @@ -1,22 +1,26 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable + class Instruction(CInstruction): """ Encapsulates a `bones` CInstruction. The `bones` instruction loads metadata of identity (one) from the scratchpad to the register file. - + For more information, check the `bones` Specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_bones.md - + Attributes: spad_src (int): SPAD address of the metadata variable to load. """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name for the operation. @@ -25,14 +29,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "bones" - def __init__(self, - id: int, - src_col_num: int, - src: Variable, - mem_model, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + src_col_num: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `bones` CInstruction. @@ -49,7 +55,7 @@ def __init__(self, ValueError: If `mem_model` is None. """ if not mem_model: - raise ValueError('`mem_model` cannot be `None`.') + raise ValueError("`mem_model` cannot be `None`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -57,7 +63,7 @@ def __init__(self, super().__init__(id, throughput, latency, comment=comment) self.src_col_num = src_col_num self.__mem_model = mem_model - self._set_sources( [ src ] ) + self._set_sources([src]) def __repr__(self): """ @@ -66,18 +72,22 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - assert(len(self.sources) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'src_col_num={}, src={}, ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.src_col_num, - self.sources[0], - self.throughput, - self.latency) + assert len(self.sources) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "src_col_num={}, src={}, " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.src_col_num, + self.sources[0], + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -103,9 +113,14 @@ def _set_sources(self, value): ValueError: If the value is not a list of the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -125,23 +140,30 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) - variable: Variable = self.sources[0] # Expected sources to contain a Variable. + variable: Variable = self.sources[0] # Expected sources to contain a Variable. if variable.spad_address < 0: - raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + raise RuntimeError( + f"Null Access Violation: Variable `{variable}` not allocated in SPAD." + ) if self.src_col_num < 0: raise RuntimeError("Invalid `src_col_num` negative `Ones` target index.") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address. - spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking = self.__mem_model.spad.getAccessTracking( + variable.spad_address + ) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bones. spad_access_tracking.last_mload = None return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -154,11 +176,11 @@ def _toCASMISAFormat(self, *extra_args) -> str: Returns: str: The ASM format string of the instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # `op, spad_src, src_col_num [# comment]` - return super()._toCASMISAFormat(self.src_col_num) + return super()._to_casmisa_format(self.src_col_num) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py index 529f94cd..124542b8 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cexit.py @@ -1,4 +1,8 @@ -from .cinstruction import CInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .cinstruction import CInstruction + class Instruction(CInstruction): """ @@ -11,7 +15,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name for the operation. @@ -20,11 +24,9 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "cexit" - def __init__(self, - id: int, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, id: int, throughput: int = None, latency: int = None, comment: str = "" + ): """ Constructs a new `cexit` CInstruction. @@ -47,13 +49,16 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -80,7 +85,7 @@ def _set_sources(self, value): """ raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -93,10 +98,10 @@ def _toCASMISAFormat(self, *extra_args) -> str: Returns: str: The ASM format string of the instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toCASMISAFormat() + return super()._to_casmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py index 79eed14b..0dc8bfe9 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cinstruction.py @@ -1,6 +1,13 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""CInstruction base class for C-instructions in the assembler.""" + +from typing import List, Any +from assembler.common.cycle_tracking import CycleType from ..instruction import BaseInstruction + class CInstruction(BaseInstruction): """ Represents a CInstruction, which is a type of BaseInstruction. @@ -18,22 +25,33 @@ class CInstruction(BaseInstruction): # Constructor # ----------- - def __init__(self, - id: int, - throughput : int, - latency : int, - comment: str = ""): + def __init__( + self, instruction_id: int, throughput: int, latency: int, comment: str = "" + ): """ Constructs a new CInstruction. Parameters: - id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. + instruction_id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. throughput (int): The throughput of the instruction. latency (int): The latency of the instruction. comment (str, optional): An optional comment for the instruction. """ - super().__init__(id, throughput, latency, comment=comment) + super().__init__(instruction_id, throughput, latency, comment=comment) + + @classmethod + def _get_op_name_asm(cls) -> str: + """ + Returns the ASM name for the operation. + + This method must be implemented by derived CInstruction classes. + Returns: + str: The ASM name for the operation. + """ + raise NotImplementedError( + "Derived CInstruction must implement _get_op_name_asm." + ) # Methods and properties # ---------------------- @@ -47,9 +65,9 @@ def _get_cycle_ready(self): Returns: CycleType: A CycleType object with bundle and cycle set to 0. """ - return CycleType(bundle = 0, cycle = 0) + return CycleType(bundle=0, cycle=0) - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to CInst ASM-ISA format. @@ -63,11 +81,9 @@ def _toCASMISAFormat(self, *extra_args) -> str: str: The CInst ASM-ISA format string of the instruction. """ - preamble = [] + preamble: List[Any] = [] # instruction sources - extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_casmisa_format() for src in self.sources) + extra_args # instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py index d5d2aa7c..34c27c74 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cload.py @@ -1,23 +1,26 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model import MemoryModel from assembler.memory_model.variable import Variable from assembler.memory_model.register_file import Register + class Instruction(CInstruction): """ Encapsulates a `cload` CInstruction. A `cload` instruction loads a word, corresponding to a single polynomial residue, from scratchpad memory into the register file memory. - + For more information, check the `cload` Specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cload.md """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name for the operation. @@ -26,14 +29,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "cload" - def __init__(self, - id: int, - dst: Register, - src: list, - mem_model: MemoryModel, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + dst: Register, + src: list, + mem_model: MemoryModel, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `cload` CInstruction. @@ -50,14 +55,16 @@ def __init__(self, Raises: AssertionError: If the destination register bank index is not 0. """ - assert(dst.bank.bank_index == 0) # We must be following convention of loading from SPAD into bank 0 + assert ( + dst.bank.bank_index == 0 + ) # We must be following convention of loading from SPAD into bank 0 if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) self.__mem_model = mem_model - self._set_dests([ dst ]) + self._set_dests([dst]) self._set_sources(src) def __repr__(self): @@ -67,17 +74,21 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - assert(len(self.dests) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={},' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests[0], - self.sources, - self.throughput, - self.latency) + assert len(self.dests) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "dst={}, src={}," + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests[0], + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -92,9 +103,14 @@ def _set_dests(self, value): TypeError: If the value is not a list of `Register` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Register) for x in value): raise TypeError("`value`: Expected list of `Register` objects.") super()._set_dests(value) @@ -110,9 +126,14 @@ def _set_sources(self, value): ValueError: If the value is not a list of the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -134,32 +155,46 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) - - variable: Variable = self.sources[0] # Expected sources to contain a Variable + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable target_register: Register = self.dests[0] if variable.spad_address < 0: - raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + raise RuntimeError( + f"Null Access Violation: Variable `{variable}` not allocated in SPAD." + ) # Cannot allocate variable to more than one register (memory coherence) - # and must not overrite a register that already contains a variable. + # and must not overwrite a register that already contains a variable. if variable.register: - raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") + raise RuntimeError( + f"Variable `{variable}` already allocated in register `{variable.register}`." + ) if target_register.contained_variable: - raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") + raise RuntimeError( + f"Register `{target_register}` already contains a Variable object." + ) retval = super()._schedule(cycle_count, schedule_id) # Perform the load target_register.allocateVariable(variable) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking = self.__mem_model.spad.getAccessTracking( + variable.spad_address + ) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after cload spad_access_tracking.last_mload = None if self.comment: - self.comment += ';' - self.comment += f' {variable.name}' + self.comment += ";" + self.comment += f" {variable.name}" return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py index 61e05f0d..26b42be6 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cnop.py @@ -1,13 +1,16 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .cinstruction import CInstruction + class Instruction(CInstruction): """ Represents a 'cnop' CInstruction from the ASM ISA specification. - This class is used to create a 'cnop' instruction, which is a type of - no-operation (NOP) instruction that inserts a specified number of idle - cycles during its execution. The instruction does not have any destination + This class is used to create a 'cnop' instruction, which is a type of + no-operation (NOP) instruction that inserts a specified number of idle + cycles during its execution. The instruction does not have any destination or source operands. For more information, check the `cnop` Specification: @@ -15,7 +18,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -24,15 +27,12 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "cnop" - def __init__(self, - id: int, - idle_cycles: int, - comment: str = ""): + def __init__(self, id: int, idle_cycles: int, comment: str = ""): """ Constructs a new 'cnop' CInstruction. Parameters: - id (int): User-defined ID for the instruction. It will be bundled + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. idle_cycles (int): Number of idle cycles to insert in the CInst execution. comment (str, optional): A comment for the instruction. Defaults to an empty string. @@ -45,15 +45,12 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, and throughput. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'idle_cycles={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput) + retval = ("<{}({}) object at {}>(id={}[0], " "idle_cycles={})").format( + type(self).__name__, self.name, hex(id(self)), self.id, self.throughput + ) return retval def _set_dests(self, value): @@ -80,7 +77,7 @@ def _set_sources(self, value): """ raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -94,12 +91,12 @@ def _toCASMISAFormat(self, *extra_args) -> str: AssertionError: If the number of destinations or sources is incorrect. ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # The idle cycles in the ASM ISA for 'nop' must be one less because decoding/scheduling # the instruction counts as a cycle. - return super()._toCASMISAFormat(self.throughput - 1) \ No newline at end of file + return super()._to_casmisa_format(self.throughput - 1) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py index a6908427..21161e67 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/cstore.py @@ -1,3 +1,5 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 from assembler.common.config import GlobalConfig from assembler.common.cycle_tracking import CycleType @@ -6,6 +8,7 @@ from assembler.memory_model.variable import Variable, DummyVariable from assembler.memory_model.register_file import Register + class Instruction(CInstruction): """ Encapsulates a `cstore` CInstruction. @@ -14,13 +17,13 @@ class Instruction(CInstruction): and stores it in SPAD. To accomplish this in scheduling, a `cstore` should be scheduled immediately after the `ifetch` for the bundle containing the matching `xstore`. - + For more information, check the `cstore` Specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/cinst/cinst_cstore.md """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -29,12 +32,14 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "cstore" - def __init__(self, - id: int, - mem_model: MemoryModel, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + mem_model: MemoryModel, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `cstore` CInstruction. @@ -53,7 +58,7 @@ def __init__(self, ValueError: If `mem_model` is not an instance of `MemoryModel`. """ if not isinstance(mem_model, MemoryModel): - raise ValueError('`mem_model` must be an instance of `MemoryModel`.') + raise ValueError("`mem_model` must be an instance of `MemoryModel`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -67,17 +72,21 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -91,9 +100,14 @@ def _set_dests(self, value): ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_dests(value) @@ -110,9 +124,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Register` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Register) for x in value): raise TypeError("`value`: Expected list of `Register` objects.") super()._set_sources(value) @@ -142,36 +161,41 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ spad = self.__mem_model.spad - var_name, (variable, self.__spad_addr) = self.__mem_model.store_buffer.pop() # Will raise IndexError if popping from empty queue - assert(var_name == variable.name) - assert self.__spad_addr >= 0 and (variable.spad_address < 0 or variable.spad_address == self.__spad_addr), \ - f'self.__spad_addr = {self.__spad_addr}; {variable.name}.spad_address = {variable.spad_address}' + var_name, (variable, self.__spad_addr) = ( + self.__mem_model.store_buffer.pop() + ) # Will raise IndexError if popping from empty queue + assert var_name == variable.name + assert self.__spad_addr >= 0 and ( + variable.spad_address < 0 or variable.spad_address == self.__spad_addr + ), f"self.__spad_addr = {self.__spad_addr}; {variable.name}.spad_address = {variable.spad_address}" retval = super()._schedule(cycle_count, schedule_id) # Perform the cstore if spad.buffer[self.__spad_addr] and spad.buffer[self.__spad_addr] != variable: if not isinstance(spad.buffer[self.__spad_addr], DummyVariable): - raise RuntimeError(f'SPAD location {self.__spad_addr} for instruction (`{self.name}`, id {self.id}) is occupied by variable {spad.buffer[self.__spad_addr]}.') + raise RuntimeError( + f"SPAD location {self.__spad_addr} for instruction (`{self.name}`, id {self.id}) is occupied by variable {spad.buffer[self.__spad_addr]}." + ) spad.deallocate(self.__spad_addr) - spad.allocateForce(self.__spad_addr, variable) # Allocate in SPAD + spad.allocateForce(self.__spad_addr, variable) # Allocate in SPAD # Track last access to SPAD address spad_access_tracking = spad.getAccessTracking(self.__spad_addr) spad_access_tracking.last_cstore = self - spad_access_tracking.last_mload = None # Last mload is now obsolete - variable.spad_dirty = True # Variable has new value in SPAD - + spad_access_tracking.last_mload = None # Last mload is now obsolete + variable.spad_dirty = True # Variable has new value in SPAD + if not GlobalConfig.hasHBM: # Used to track the variable name going into spad at the moment of cstore. # This is used to output var name instead of spad address when requested. # remove when we have spad and HBM back - self.__spad_addr = variable.toCASMISAFormat() + self.__spad_addr = variable.to_casmisa_format() if self.comment: - self.comment += ';' - self.comment += f' {variable.name}' + self.comment += ";" + self.comment += f" {variable.name}" return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to CInst ASM-ISA format. @@ -186,10 +210,10 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toCASMISAFormat(self.__spad_addr) \ No newline at end of file + return super()._to_casmisa_format(self.__spad_addr) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py index ce57054c..0268689a 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/csyncm.py @@ -1,8 +1,12 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction + class Instruction(CInstruction): """ Encapsulates a `csyncm` CInstruction. @@ -16,9 +20,9 @@ class Instruction(CInstruction): @classmethod def get_throughput(cls) -> int: return cls._OP_DEFAULT_THROUGHPUT - + @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -27,12 +31,14 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "csyncm" - def __init__(self, - id: int, - minstr, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + minstr, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `csyncm` CInstruction. @@ -53,25 +59,29 @@ def __init__(self, if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) - self.minstr = minstr # Instruction from the MINST queue for which to wait + self.minstr = minstr # Instruction from the MINST queue for which to wait def __repr__(self): """ Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, minstr, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'minstr={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - repr(self.minstr), - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "minstr={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + repr(self.minstr), + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -123,7 +133,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: retval = super()._schedule(cycle_count, schedule_id) return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -136,12 +146,12 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) - assert(self.minstr.is_scheduled) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES + assert self.minstr.is_scheduled if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # warnings.warn("`csyncm` instruction requires second pass to set correct instruction number.") - return super()._toCASMISAFormat(self.minstr.schedule_timing.index) \ No newline at end of file + return super()._to_casmisa_format(self.minstr.schedule_timing.index) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py index c4919186..1072221d 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/ifetch.py @@ -1,7 +1,10 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction + class Instruction(CInstruction): """ Encapsulates an `ifetch` CInstruction. @@ -17,7 +20,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -26,12 +29,14 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "ifetch" - def __init__(self, - id: int, - bundle_id: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + bundle_id: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `ifetch` CInstruction. @@ -51,25 +56,31 @@ def __init__(self, if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) - self.bundle_id = bundle_id # Instruction number from the MINST queue for which to wait + self.bundle_id = ( + bundle_id # Instruction number from the MINST queue for which to wait + ) def __repr__(self): """ Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, bundle_id, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'bundle_id={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.bundle_id, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "bundle_id={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.bundle_id, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -119,7 +130,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: retval = super()._schedule(cycle_count, schedule_id) return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -132,10 +143,10 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toCASMISAFormat(self.bundle_id) \ No newline at end of file + return super()._to_casmisa_format(self.bundle_id) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py index 3bf97a36..ad9fb69e 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgload.py @@ -1,8 +1,12 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable from assembler.memory_model.register_file import Register + class Instruction(CInstruction): """ Encapsulates `kg_load` CInstruction. @@ -23,11 +27,13 @@ class Instruction(CInstruction): """ @classmethod - def SetNumSources(cls, val): - cls._OP_NUM_SOURCES = val + 1 # Adding the keygen variable (since the actual instruction needs no sources) + def set_num_sources(cls, val): + cls._OP_NUM_SOURCES = ( + val + 1 + ) # Adding the keygen variable (since the actual instruction needs no sources) @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -36,13 +42,15 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "kg_load" - def __init__(self, - id: int, - dst: Register, - src: list, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + dst: Register, + src: list, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `kg_load` CInstruction. @@ -74,22 +82,26 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, column number, memory index, source, throughput, and latency. """ - assert(len(self.sources) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'col_num={}, m_idx={}, src={}, ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.col_num, - self.m_idx, - self.sources[0], - self.throughput, - self.latency) + assert len(self.sources) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "col_num={}, m_idx={}, src={}, " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.col_num, + self.m_idx, + self.sources[0], + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -104,9 +116,14 @@ def _set_dests(self, value): TypeError: If the list does not contain `Register` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Register) for x in value): raise TypeError("`value`: Expected list of `Register` objects.") super()._set_dests(value) @@ -123,9 +140,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -148,10 +170,16 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) - - variable: Variable = self.sources[0] # Expected sources to contain a Variable + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) + + variable: Variable = self.sources[0] # Expected sources to contain a Variable target_register: Register = self.dests[0] if variable.spad_address >= 0 or variable.hbm_address >= 0: @@ -159,21 +187,25 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: # Cannot allocate variable to more than one register (memory coherence) # and must not overwrite a register that already contains a variable. if variable.register: - raise RuntimeError(f"Variable `{variable}` already allocated in register `{variable.register}`.") + raise RuntimeError( + f"Variable `{variable}` already allocated in register `{variable.register}`." + ) if target_register.contained_variable: - raise RuntimeError(f"Register `{target_register}` already contains a Variable object.") + raise RuntimeError( + f"Register `{target_register}` already contains a Variable object." + ) retval = super()._schedule(cycle_count, schedule_id) # Variable generated, reflect the load target_register.allocateVariable(variable) if self.comment: - self.comment += ';' - self.comment += f' {variable.name}' + self.comment += ";" + self.comment += f" {variable.name}" return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -186,11 +218,11 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # `op, dest_reg [# comment]` preamble = [] @@ -198,7 +230,5 @@ def _toCASMISAFormat(self, *extra_args) -> str: # kg_load has no sources # Instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py index 1d2478cb..92f16661 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgseed.py @@ -1,7 +1,11 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable + class Instruction(CInstruction): """ Encapsulates `kg_seed` CInstruction. @@ -24,7 +28,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -33,14 +37,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "kg_seed" - def __init__(self, - id: int, - block_index: int, - src: Variable, - mem_model, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + block_index: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `kg_seed` CInstruction. @@ -63,7 +69,7 @@ def __init__(self, ValueError: If `mem_model` is `None`. """ if not mem_model: - raise ValueError('`mem_model` cannot be `None`.') + raise ValueError("`mem_model` cannot be `None`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -78,21 +84,25 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, block index, source, throughput, and latency. """ - assert(len(self.sources) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'col_num={}, m_idx={}, src={}, ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.block_index, - self.sources[0], - self.throughput, - self.latency) + assert len(self.sources) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "col_num={}, m_idx={}, src={}, " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.block_index, + self.sources[0], + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -105,7 +115,9 @@ def _set_dests(self, value): Raises: RuntimeError: Always raised as the instruction does not have destination parameters. """ - raise RuntimeError(f"Instruction `{self.name}` does not have destination parameters.") + raise RuntimeError( + f"Instruction `{self.name}` does not have destination parameters." + ) def _set_sources(self, value): """ @@ -119,9 +131,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -143,23 +160,32 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) - variable: Variable = self.sources[0] # Expected sources to contain a Variable + variable: Variable = self.sources[0] # Expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError(f'Null Access Violation: Variable "{variable}" not allocated in SPAD.') + raise RuntimeError( + f'Null Access Violation: Variable "{variable}" not allocated in SPAD.' + ) if self.block_index not in range(4): - raise RuntimeError(f"Invalid `block_index`: {self.block_index}. Must be in range [0, 4).") + raise RuntimeError( + f"Invalid `block_index`: {self.block_index}. Must be in range [0, 4)." + ) retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking = self.__mem_model.spad.getAccessTracking( + variable.spad_address + ) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after kg_seed spad_access_tracking.last_mload = None return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -172,19 +198,17 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # `op, spad_src, block_index [# comment]` preamble = [] # Instruction sources - extra_args = (self.block_index, ) - extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + extra_args = (self.block_index,) + extra_args = tuple(src.to_casmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py index 7de5ae12..a8eafca8 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/kgstart.py @@ -1,7 +1,11 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable + class Instruction(CInstruction): """ Encapsulates `kg_start` CInstruction. @@ -17,7 +21,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -26,11 +30,9 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "kg_start" - def __init__(self, - id: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, id: int, throughput: int = None, latency: int = None, comment: str = "" + ): """ Constructs a new `kg_start` CInstruction. @@ -54,16 +56,19 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -90,7 +95,7 @@ def _set_sources(self, value): """ raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -103,10 +108,10 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toCASMISAFormat() \ No newline at end of file + return super()._to_casmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py index ffe1111e..794b3d64 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/nload.py @@ -1,8 +1,11 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction from assembler.memory_model.variable import Variable + class Instruction(CInstruction): """ Encapsulates an `nload` CInstruction. @@ -19,7 +22,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -28,14 +31,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "nload" - def __init__(self, - id: int, - table_idx: int, - src: Variable, - mem_model, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + table_idx: int, + src: Variable, + mem_model, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `nload` CInstruction. @@ -58,7 +63,7 @@ def __init__(self, ValueError: If `mem_model` is `None`. """ if not mem_model: - raise ValueError('`mem_model` cannot be `None`.') + raise ValueError("`mem_model` cannot be `None`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -73,21 +78,25 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, table index, source, throughput, and latency. """ - assert(len(self.sources) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'table_idx={}, src={}, ' - 'mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.table_idx, - self.sources[0], - self.throughput, - self.latency) + assert len(self.sources) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "table_idx={}, src={}, " + "mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.table_idx, + self.sources[0], + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -114,9 +123,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -138,23 +152,30 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) - variable: Variable = self.sources[0] # Expected sources to contain a Variable + variable: Variable = self.sources[0] # Expected sources to contain a Variable if variable.spad_address < 0: - raise RuntimeError(f"Null Access Violation: Variable `{variable}` not allocated in SPAD.") + raise RuntimeError( + f"Null Access Violation: Variable `{variable}` not allocated in SPAD." + ) if self.table_idx < 0: raise RuntimeError("Invalid `table_idx` negative routing table index.") retval = super()._schedule(cycle_count, schedule_id) # Track last access to SPAD address - spad_access_tracking = self.__mem_model.spad.getAccessTracking(variable.spad_address) + spad_access_tracking = self.__mem_model.spad.getAccessTracking( + variable.spad_address + ) spad_access_tracking.last_cload = self # No need to sync to any previous MLoads after bones spad_access_tracking.last_mload = None return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -167,19 +188,17 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # `op, target_idx, spad_src [# comment]` preamble = [] # Instruction sources - extra_args = tuple(src.toCASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_casmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - extra_args = (self.table_idx, ) + extra_args - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + extra_args = (self.table_idx,) + extra_args + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py index 3f9d1804..a5085233 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/cinst/xinstfetch.py @@ -1,7 +1,11 @@ -from assembler.common import constants +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common import constants from assembler.common.cycle_tracking import CycleType from .cinstruction import CInstruction + class Instruction(CInstruction): """ Encapsulates an `xinstfetch` CInstruction. @@ -20,7 +24,7 @@ class Instruction(CInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -29,13 +33,15 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "xinstfetch" - def __init__(self, - id: int, - xq_dst: int, - hbm_src: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + xq_dst: int, + hbm_src: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `xinstfetch` CInstruction. @@ -66,20 +72,24 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, xq_dst, hbm_src, throughput, and latency. """ - assert(len(self.dests) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'xq_dst={}, hbm_src={},' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.xq_dst, - self.hbm_src, - self.throughput, - self.latency) + assert len(self.dests) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "xq_dst={}, hbm_src={}," + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.xq_dst, + self.hbm_src, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -123,17 +133,26 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - if self.xq_dst < 0 or self.xq_dst >= constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS: - raise RuntimeError(('Invalid `xq_dst` XINST queue destination address. Expected value in range ' - '[0, {}), but received {}.'. format(constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS, - self.xq_dst))) + if ( + self.xq_dst < 0 + or self.xq_dst >= constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS + ): + raise RuntimeError( + ( + "Invalid `xq_dst` XINST queue destination address. Expected value in range " + "[0, {}), but received {}.".format( + constants.MemoryModel.XINST_QUEUE_MAX_CAPACITY_WORDS, + self.xq_dst, + ) + ) + ) if self.hbm_src < 0: raise RuntimeError("Invalid `hbm_src` negative HBM address.") retval = super()._schedule(cycle_count, schedule_id) return retval - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -146,10 +165,10 @@ def _toCASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toCASMISAFormat(self.xq_dst, self.hbm_src) \ No newline at end of file + return super()._to_casmisa_format(self.xq_dst, self.hbm_src) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py index 5ed888c7..ecbf9e9b 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/instruction.py @@ -1,10 +1,16 @@ -from typing import final -from typing import NamedTuple +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""BaseInstruction and related classes for assembler instructions.""" +from typing import final, NamedTuple, List, Optional + +# pylint: disable=too-many-instance-attributes, too-many-public-methods from assembler.common.config import GlobalConfig from assembler.common.counter import Counter from assembler.common.cycle_tracking import CycleTracker, CycleType -from assembler.common.decorators import * +from assembler.common.decorators import classproperty + class ScheduleTiming(NamedTuple): """ @@ -14,9 +20,11 @@ class ScheduleTiming(NamedTuple): cycle (CycleType): The cycle in which the instruction was scheduled. index (int): The index for the instruction in its schedule listing. """ + cycle: CycleType index: int + class BaseInstruction(CycleTracker): """ The base class for all instructions. @@ -26,29 +34,29 @@ class BaseInstruction(CycleTracker): Class Properties: name (str): Returns the name of the represented operation. - OP_NAME_ASM (str): ASM-ISA name for the instruction. - OP_NAME_PISA (str): P-ISA name for the instruction. + op_name_asm (str): ASM-ISA name for the instruction. + op_name_pisa (str): P-ISA name for the instruction. Class Methods: - _get_name(cls) -> str: Derived classes should implement this method and return the correct + _get_name(self) -> str: Derived classes should implement this method and return the correct name for the instruction. Defaults to the ASM-ISA name. - _get_OP_NAME_ASM(cls) -> str: Derived classes should implement this method and return the correct + _get_op_name_asm(self) -> str: Derived classes should implement this method and return the correct ASM name for the operation. Default throws not implemented. - _get_OP_NAME_PISA(cls) -> str: Derived classes should implement this method and return the correct + _get_op_name_pisa(self) -> str: Derived classes should implement this method and return the correct P-ISA name for the operation. Defaults to the ASM-ISA name. Constructors: - __init__(self, id: int, throughput: int, latency: int, comment: str = ""): + __init__(self, id: int, throughput: int, latency: int, comment: str = ""): Initializes a new BaseInstruction object. Attributes: - _dests (list[CycleTracker]): List of destination objects. Derived classes can override + _dests (list[CycleTracker]): List of destination objects. Derived classes can override _set_dests to validate this attribute. _frozen_cisa (str): Contains frozen CInst in ASM ISA format after scheduling. Empty string if not frozen. _frozen_misa (str): Contains frozen MInst in ASM ISA format after scheduling. Empty string if not frozen. _frozen_pisa (str): Contains frozen P-ISA format after scheduling. Empty string if not frozen. _frozen_xisa (str): Contains frozen XInst in ASM ISA format after scheduling. Empty string if not frozen. - _sources (list[CycleTracker]): List of source objects. Derived classes can override + _sources (list[CycleTracker]): List of source objects. Derived classes can override _set_sources to validate this attribute. comment (str): Comment for the instruction. @@ -60,7 +68,7 @@ class BaseInstruction(CycleTracker): is_scheduled (bool): Returns whether the instruction has been scheduled (True) or not (False). latency (int): Returns the latency of the represented operation. This is the number of clock cycles before the results of the operation are ready in the destination. - schedule_timing (ScheduleTiming): Gets the cycle and index in which this instruction was scheduled or + schedule_timing (ScheduleTiming): Gets the cycle and index in which this instruction was scheduled or None if not scheduled yet. Index is subject to change and it is not final until the second pass of scheduling. sources (list): Gets or sets the list of source objects. The elements of the list are derived dependent. Calls _set_sources to set value. @@ -74,121 +82,109 @@ class BaseInstruction(CycleTracker): __str__(self): Returns a string representation of the BaseInstruction object. Methods: - _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: Schedules the instruction, simulating timings of executing this instruction. Derived classes should override with their scheduling functionality. - _toCASMISAFormat(self, *extra_args) -> str: Converts the instruction to CInst ASM-ISA format. + _to_casmisa_format(self, *extra_args) -> str: Converts the instruction to CInst ASM-ISA format. Derived classes should override with their functionality. - _toMASMISAFormat(self, *extra_args) -> str: Converts the instruction to MInst ASM-ISA format. + _to_masmisa_format(self, *extra_args) -> str: Converts the instruction to MInst ASM-ISA format. Derived classes should override with their functionality. - _toPISAFormat(self, *extra_args) -> str: Converts the instruction to P-ISA kernel format. + _to_pisa_format(self, *extra_args) -> str: Converts the instruction to P-ISA kernel format. Derived classes should override with their functionality. - _toXASMISAFormat(self, *extra_args) -> str: Converts the instruction to XInst ASM-ISA format. + _to_xasmisa_format(self, *extra_args) -> str: Converts the instruction to XInst ASM-ISA format. Derived classes should override with their functionality. freeze(self): Called immediately after _schedule() to freeze the instruction after scheduling to preserve the instruction string representation to output into the listing. Changes made to the instruction and its components after freezing are ignored. - schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: + schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: Schedules and freezes the instruction, simulating timings of executing this instruction. - toStringFormat(self, preamble, op_name: str, *extra_args) -> str: + to_string_format(self, preamble, op_name: str, *extra_args) -> str: Converts the instruction to a string format. - toPISAFormat(self) -> str: Converts the instruction to P-ISA kernel format. - toXASMISAFormat(self) -> str: Converts the instruction to ASM-ISA format. - toCASMISAFormat(self) -> str: Converts the instruction to CInst ASM-ISA format. - toMASMISAFormat(self) -> str: Converts the instruction to MInst ASM-ISA format. + to_pisa_format(self) -> str: Converts the instruction to P-ISA kernel format. + to_xasmisa_format(self) -> str: Converts the instruction to ASM-ISA format. + to_casmisa_format(self) -> str: Converts the instruction to CInst ASM-ISA format. + to_masmisa_format(self) -> str: Converts the instruction to MInst ASM-ISA format. """ + # To be initialized from ASM ISA spec - _OP_NUM_DESTS : int - _OP_NUM_SOURCES : int - _OP_DEFAULT_THROUGHPUT : int - _OP_DEFAULT_LATENCY : int + _OP_NUM_DESTS: int + _OP_NUM_SOURCES: int + _OP_DEFAULT_THROUGHPUT: int + _OP_DEFAULT_LATENCY: int - __id_count = Counter.count(0) # internal unique sequence counter to generate unique IDs + __id_count = Counter.count( + 0 + ) # internal unique sequence counter to generate unique IDs # Class methods and properties # ---------------------------- @classmethod def isa_spec_as_dict(cls) -> dict: - """ - Returns attributes as dictionary. - """ - dict = {"num_dests": cls._OP_NUM_DESTS, - "num_sources": cls._OP_NUM_SOURCES, - "default_throughput": cls._OP_DEFAULT_THROUGHPUT, - "default_latency": cls._OP_DEFAULT_LATENCY} - return dict - + """Returns attributes as dictionary.""" + spec = { + "num_dests": cls._OP_NUM_DESTS, + "num_sources": cls._OP_NUM_SOURCES, + "default_throughput": cls._OP_DEFAULT_THROUGHPUT, + "default_latency": cls._OP_DEFAULT_LATENCY, + } + return spec + @classmethod - def SetNumDests(cls, val): + def set_num_dests(cls, val): + """Set the number of destination operands.""" cls._OP_NUM_DESTS = val @classmethod - def SetNumSources(cls, val): + def set_num_sources(cls, val): + """Set the number of source operands.""" cls._OP_NUM_SOURCES = val @classmethod - def SetDefaultThroughput(cls, val): + def set_default_throughput(cls, val): + """Set the default throughput.""" cls._OP_DEFAULT_THROUGHPUT = val @classmethod - def SetDefaultLatency(cls, val): + def set_default_latency(cls, val): + """Set the default latency.""" cls._OP_DEFAULT_LATENCY = val @classproperty - def name(cls) -> str: - """ - Name for the instruction. - """ - return cls._get_name() + def name(self) -> str: + """Name for the instruction.""" + return self._get_name() @classmethod def _get_name(cls) -> str: - """ - Derived classes should implement this method and return correct - name for the instruction. Defaults to the ASM-ISA name. - """ - return cls.OP_NAME_ASM + """Derived classes should implement this method and return correct name for the instruction.""" + return cls.op_name_asm @classproperty - def OP_NAME_PISA(cls) -> str: - """ - P-ISA name for the instruction. - """ - return cls._get_OP_NAME_PISA() + def op_name_pisa(self) -> str: + """P-ISA name for the instruction.""" + return self._get_op_name_pisa() @classmethod - def _get_OP_NAME_PISA(cls) -> str: - """ - Derived classes should implement this method and return correct - P-ISA name for the operation. Defaults to the ASM-ISA name. - """ - return cls.OP_NAME_ASM + def _get_op_name_pisa(cls) -> str: + """Derived classes should implement this method and return correct P-ISA name for the operation.""" + return cls.op_name_asm @classproperty - def OP_NAME_ASM(cls) -> str: - """ - ASM-ISA name for instruction. - - Will throw if no ASM-ISA name for instruction. - """ - return cls._get_OP_NAME_ASM() + def op_name_asm(self) -> str: + """ASM-ISA name for instruction.""" + return self._get_op_name_asm() @classmethod - def _get_OP_NAME_ASM(cls) -> str: - """ - Derived classes should implement this method and return correct - ASM name for the operation. - """ - raise NotImplementedError('Abstract method not implemented.') + def _get_op_name_asm(cls) -> str: + """Derived classes should implement this method and return correct ASM name for the operation.""" + raise NotImplementedError("Abstract method not implemented.") # Constructor # ----------- - def __init__(self, - id: int, - throughput : int, - latency : int, - comment: str = ""): + def __init__( + self, instruction_id: int, throughput: int, latency: int, comment: str = "" + ): """ Initializes a new BaseInstruction object. @@ -208,62 +204,53 @@ def __init__(self, """ # validate inputs if throughput < 1: - raise ValueError(("`throughput`: must be a positive number, " - "but {} received.".format(throughput))) + raise ValueError( + ( + f"`throughput`: must be a positive number, " + f"but {throughput} received." + ) + ) if latency < throughput: - raise ValueError(("`latency`: cannot be less than throughput. " - "Expected, at least, {}, but {} received.".format(throughput, latency))) - - super().__init__((0, 0)) - - self.__id = (id, next(BaseInstruction.__id_count)) # Mix with unique sequence counter - self.__throughput = throughput # read_only throughput of the operation - self.__latency = latency # read_only latency of the operation - self._dests = [] - self._sources = [] - self.comment = " id: {}{}{}".format(self.__id, - "; " if comment.strip() else "", - comment) - self.__schedule_timing: ScheduleTiming = None # Tracks when was this instruction scheduled, or None if not scheduled yet - - self._frozen_pisa = "" # To contain frozen P-ISA format after scheduling - self._frozen_xisa = "" # To contain frozen XInst in ASM ISA format after scheduling - self._frozen_cisa = "" # To contain frozen CInst in ASM ISA format after scheduling - self._frozen_misa = "" # To contain frozen MInst in ASM ISA format after scheduling + raise ValueError( + ( + f"`latency`: cannot be less than throughput. " + f"Expected, at least, {throughput}, but {latency} received." + ) + ) + + super().__init__(CycleType(0, 0)) + self.__id = (instruction_id, next(BaseInstruction.__id_count)) + self.__throughput = throughput + self.__latency = latency + self._dests: List[CycleTracker] = [] + self._sources: List[CycleTracker] = [] + self.comment = f" id: {self.__id}{'; ' if comment.strip() else ''}{comment}" + self.__schedule_timing: Optional[ScheduleTiming] = None + self._frozen_pisa = "" + self._frozen_xisa = "" + self._frozen_cisa = "" + self._frozen_misa = "" def __repr__(self): - """ - Returns a string representation of the BaseInstruction object. - """ - retval = ('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.OP_NAME_PISA, - hex(id(self)), - self.id, - self.dests, - self.sources, - self.throughput, - self.latency) + """Returns a string representation of the BaseInstruction object.""" + retval = ( + f"<{type(self).__name__}({self.op_name_pisa}) object at {hex(id(self))}>(id={self.id}[0], " + f"dst={self.dests}, src={self.sources}, " + f"throughput={self.throughput}, latency={self.latency})" + ) return retval def __eq__(self, other): - """ - Checks equality between two BaseInstruction objects. - """ - return self is other #other.id == self.id + """Checks equality between two BaseInstruction objects.""" + return self is other def __hash__(self): - """ - Returns the hash of the BaseInstruction object. - """ + """Returns the hash of the BaseInstruction object.""" return hash(self.id) def __str__(self): - """ - Returns a string representation of the BaseInstruction object. - """ - return f'{self.name} {self.id}' + """Returns a string representation of the BaseInstruction object.""" + return f"{self.name} {self.id}" # Methods and properties # ---------------------------- @@ -299,9 +286,12 @@ def set_schedule_timing_index(self, value: int): ValueError: If the value is less than 0. """ if value < 0: - raise ValueError("`value`: expected a value of `0` or greater for `schedule_timing.index`.") - self.__schedule_timing = ScheduleTiming(cycle = self.__schedule_timing.cycle, - index=value) + raise ValueError( + "`value`: expected a value of `0` or greater for `schedule_timing.index`." + ) + self.__schedule_timing = ScheduleTiming( + cycle=self.__schedule_timing.cycle, index=value + ) @property def is_scheduled(self) -> bool: @@ -311,7 +301,7 @@ def is_scheduled(self) -> bool: Returns: bool: True if the instruction is scheduled, False otherwise. """ - return True if self.schedule_timing else False + return bool(self.schedule_timing) @property def throughput(self) -> int: @@ -365,7 +355,7 @@ def _set_dests(self, value): """ if not all(isinstance(x, CycleTracker) for x in value): raise ValueError("`value`: Expected list of `CycleTracker` objects.") - self._dests = [ x for x in value ] + self._dests = list(value) @property def sources(self) -> list: @@ -399,7 +389,7 @@ def _set_sources(self, value): """ if not all(isinstance(x, CycleTracker) for x in value): raise ValueError("`value`: Expected list of `CycleTracker` objects.") - self._sources = [ x for x in value ] + self._sources = list(value) def _get_cycle_ready(self): """ @@ -423,9 +413,9 @@ def _get_cycle_ready(self): retval = max(retval, *(src.cycle_ready for src in self.sources)) if self.dests: # dests cycle ready is a special case: - # dests are ready to be read or writen to at their cycle_ready, but instructions can + # dests are ready to be read or written to at their cycle_ready, but instructions can # start the following cycle when their dests are ready minus the latency of - # the instruction because the dests will be writen to in the last cycle of + # the instruction because the dests will be written to in the last cycle of # the instruction: # Cycle decode_phase write_phase dests_ready latency # 1 INST1 5 @@ -435,10 +425,12 @@ def _get_cycle_ready(self): # 5 INST6 INST1 5 # 6 INST7 INST2 INST1 5 # 7 INST8 INST3 INST2 5 - # INST1's dests are ready in cycle 6 and they are writen to in cycle 5. + # INST1's dests are ready in cycle 6 and they are written to in cycle 5. # If INST2 uses any INST1 dest as its dest, INST2 can start the cycle # following INST1, 2, because INST2 will write to the same dest in cycle 6. - retval = max(retval, *(dst.cycle_ready - self.latency + 1 for dst in self.dests)) + retval = max( + retval, *(dst.cycle_ready - self.latency + 1 for dst in self.dests) + ) return retval def freeze(self): @@ -463,12 +455,14 @@ def freeze(self): RuntimeError: If the instruction has not been scheduled yet. """ if not self.is_scheduled: - raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is not yet scheduled.") + raise RuntimeError( + f"Instruction `{self.name}` (id = {self.id}) is not yet scheduled." + ) - self._frozen_pisa = self._toPISAFormat() - self._frozen_xisa = self._toXASMISAFormat() - self._frozen_cisa = self._toCASMISAFormat() - self._frozen_misa = self._toMASMISAFormat() + self._frozen_pisa = self._to_pisa_format() + self._frozen_xisa = self._to_xasmisa_format() + self._frozen_cisa = self._to_casmisa_format() + self._frozen_misa = self._to_masmisa_format() def _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: """ @@ -493,17 +487,20 @@ def _schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: the current cycle counter. """ if self.is_scheduled: - raise RuntimeError(f"Instruction `{self.name}` (id = {self.id}) is already scheduled.") + raise RuntimeError( + f"Instruction `{self.name}` (id = {self.id}) is already scheduled." + ) if schedule_idx < 1: raise ValueError("`schedule_idx`: expected a value of `1` or greater.") if len(cycle_count) < 2: - raise ValueError("`cycle_count`: expected a pair/tuple with two components.") + raise ValueError( + "`cycle_count`: expected a pair/tuple with two components." + ) if cycle_count < self.cycle_ready: - raise RuntimeError(("Instruction {}, id: {}, not ready to schedule. " - "Ready cycle is {}, but current cycle is {}.").format(self.name, - self.id, - self.cycle_ready, - cycle_count)) + raise RuntimeError( + f"Instruction {self.name}, id: {self.id}, not ready to schedule. " + f"Ready cycle is {self.cycle_ready}, but current cycle is {cycle_count}." + ) self.__schedule_timing = ScheduleTiming(cycle_count, schedule_idx) return self.throughput @@ -534,10 +531,7 @@ def schedule(self, cycle_count: CycleType, schedule_idx: int) -> int: self.freeze() return retval - def toStringFormat(self, - preamble, - op_name: str, - *extra_args) -> str: + def to_string_format(self, preamble, op_name: str, *extra_args) -> str: """ Converts the instruction to a string format. @@ -555,16 +549,16 @@ def toStringFormat(self, raise ValueError("`op_name` cannot be empty.") retval = op_name if preamble: - retval = ('{}, '.format(', '.join(str(x) for x in preamble))) + retval + retval = f'{", ".join(str(x) for x in preamble)}, {retval}' if extra_args: - retval += ', {}'.format(', '.join([str(extra) for extra in extra_args])) + retval += f', {", ".join([str(extra) for extra in extra_args])}' if not GlobalConfig.suppressComments: if self.comment: - retval += ' #{}'.format(self.comment) + retval += f" #{self.comment}" return retval @final - def toPISAFormat(self) -> str: + def to_pisa_format(self) -> str: """ Converts the instruction to P-ISA kernel format. @@ -573,19 +567,19 @@ def toPISAFormat(self) -> str: `N, op, dst0 (bank), dst1 (bank), ..., dst_d (bank), src0 (bank), src1 (bank), ..., src_s (bank) [, extra0, extra1, ..., extra_e] [, res] [# comment]` where `extra_e` are instruction specific extra arguments. """ - return self._frozen_pisa if self._frozen_pisa else self._toPISAFormat() + return self._frozen_pisa if self._frozen_pisa else self._to_pisa_format() @final - def toXASMISAFormat(self) -> str: + def to_xasmisa_format(self) -> str: """ Converts the instruction to ASM-ISA format. If instruction is frozen, this returns the frozen result, otherwise, it attempts to generate the string representation on the fly. - Internally calls method `_toXASMISAFormat()`. + Internally calls method `_to_xasmisa_format()`. - Derived classes can override method `_toXASMISAFormat()` to provide their own conversion. + Derived classes can override method `_to_xasmisa_format()` to provide their own conversion. Returns: str: A string representation of the instruction in ASM-ISA format. The string has the form: @@ -594,19 +588,19 @@ def toXASMISAFormat(self) -> str: Since the residual is mandatory in the format, it is set to `0` in the output if the instruction does not support residual. """ - return self._frozen_xisa if self._frozen_xisa else self._toXASMISAFormat() + return self._frozen_xisa if self._frozen_xisa else self._to_xasmisa_format() @final - def toCASMISAFormat(self) -> str: + def to_casmisa_format(self) -> str: """ Converts the instruction to CInst ASM-ISA format. If instruction is frozen, this returns the frozen result, otherwise, it attempts to generate the string representation on the fly. - Internally calls method `_toCASMISAFormat()`. + Internally calls method `__to_casmisa_format()`. - Derived classes can override method `_toCASMISAFormat()` to provide their own conversion. + Derived classes can override method `__to_casmisa_format()` to provide their own conversion. Returns: str: A string representation of the instruction in ASM-ISA format. The string has the form: @@ -615,28 +609,28 @@ def toCASMISAFormat(self) -> str: Since the ring size is mandatory in the format, it is set to `0` in the output if the instruction does not support it. """ - return self._frozen_cisa if self._frozen_cisa else self._toCASMISAFormat() + return self._frozen_cisa if self._frozen_cisa else self._to_casmisa_format() @final - def toMASMISAFormat(self) -> str: + def to_masmisa_format(self) -> str: """ Converts the instruction to MInst ASM-ISA format. If instruction is frozen, this returns the frozen result, otherwise, it attempts to generate the string representation on the fly. - Internally calls method `_toMASMISAFormat()`. + Internally calls method `_to_masmisa_format()`. - Derived classes can override method `_toMASMISAFormat()` to provide their own conversion. + Derived classes can override method `_to_masmisa_format()` to provide their own conversion. Returns: str: A string representation of the instruction in ASM-ISA format. The string has the form: `op, dst0, dst1, ..., dst_d, src0, src1, ..., src_s [, extra0, extra1, ..., extra_e], [# comment]` where `extra_e` are instruction specific extra arguments. """ - return self._frozen_misa if self._frozen_misa else self._toMASMISAFormat() + return self._frozen_misa if self._frozen_misa else self._to_masmisa_format() - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: # pylint: disable=unused-argument """ Converts the instruction to P-ISA kernel format. @@ -648,7 +642,7 @@ def _toPISAFormat(self, *extra_args) -> str: """ return "" - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: # pylint: disable=unused-argument """ Converts the instruction to XInst ASM-ISA format. @@ -662,7 +656,7 @@ def _toXASMISAFormat(self, *extra_args) -> str: """ return "" - def _toCASMISAFormat(self, *extra_args) -> str: + def _to_casmisa_format(self, *extra_args) -> str: # pylint: disable=unused-argument """ Converts the instruction to CInst ASM-ISA format. @@ -674,7 +668,7 @@ def _toCASMISAFormat(self, *extra_args) -> str: """ return "" - def _toMASMISAFormat(self, *extra_args) -> str: + def _to_masmisa_format(self, *extra_args) -> str: # pylint: disable=unused-argument """ Converts the instruction to MInst ASM-ISA format. diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py index 6a2afe4b..23fe385c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/minstruction.py @@ -1,6 +1,10 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from ..instruction import BaseInstruction + class MInstruction(BaseInstruction): """ Represents a memory-level instruction (MInstruction). @@ -13,13 +17,9 @@ class MInstruction(BaseInstruction): count: Returns the MInstruction counter value for this instruction. """ - __minst_count = 0 # Internal Minst counter + __minst_count = 0 # Internal Minst counter - def __init__(self, - id: int, - throughput: int, - latency: int, - comment: str = ""): + def __init__(self, id: int, throughput: int, latency: int, comment: str = ""): """ Constructs a new MInstruction. @@ -56,7 +56,7 @@ def _get_cycle_ready(self): """ return CycleType(bundle=0, cycle=0) - def _toMASMISAFormat(self, *extra_args) -> str: + def _to_masmisa_format(self, *extra_args) -> str: """ Converts the instruction to MInst ASM-ISA format. @@ -69,9 +69,7 @@ def _toMASMISAFormat(self, *extra_args) -> str: str: The instruction in MInst ASM-ISA format. """ # Instruction sources - extra_args = tuple(src.toMASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_masmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toMASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(None, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_masmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(None, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py index 2618403a..f7075eab 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mload.py @@ -1,10 +1,13 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.config import GlobalConfig from assembler.common.cycle_tracking import CycleType from .minstruction import MInstruction from assembler.memory_model import MemoryModel from assembler.memory_model.variable import Variable + class Instruction(MInstruction): """ Encapsulates an `mload` MInstruction. @@ -22,7 +25,7 @@ class Instruction(MInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -31,14 +34,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "mload" - def __init__(self, - id: int, - src: list, - mem_model: MemoryModel, - dst_spad_addr: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + src: list, + mem_model: MemoryModel, + dst_spad_addr: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `mload` MInstruction. @@ -80,20 +85,24 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, source, destination SPAD address, throughput, and latency. """ - assert(len(self.dests) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'src={}, dst_spad_addr={}, mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.sources, - self.dst_spad_addr, - self.throughput, - self.latency) + assert len(self.dests) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "src={}, dst_spad_addr={}, mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.sources, + self.dst_spad_addr, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -120,9 +129,14 @@ def __internal_set_dests(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_dests(value) @@ -139,9 +153,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -167,9 +186,15 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert all(src == dst for src, dst in zip(self.sources, self.dests)) hbm = self.__mem_model.hbm spad = self.__mem_model.spad @@ -177,9 +202,13 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: variable: Variable = self.sources[0] if variable.spad_address >= 0: - raise RuntimeError("Source variable is already in SPAD. Cannot load a variable into SPAD more than once.") + raise RuntimeError( + "Source variable is already in SPAD. Cannot load a variable into SPAD more than once." + ) if variable.hbm_address < 0: - raise RuntimeError("Null reference exception: source variable is not in HBM.") + raise RuntimeError( + "Null reference exception: source variable is not in HBM." + ) retval = super()._schedule(cycle_count, schedule_id) # Perform the load @@ -188,7 +217,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: spad.getAccessTracking(self.dst_spad_addr).last_mload = self return retval - def _toMASMISAFormat(self, *extra_args) -> str: + def _to_masmisa_format(self, *extra_args) -> str: """ Converts the instruction to MInst ASM-ISA format. @@ -201,9 +230,7 @@ def _toMASMISAFormat(self, *extra_args) -> str: str: The instruction in MInst ASM-ISA format. """ # Instruction sources - extra_args = tuple(src.toMASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_masmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(None, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(None, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py index a5228193..4c521adc 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/mstore.py @@ -1,8 +1,12 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common.cycle_tracking import CycleType from .minstruction import MInstruction from assembler.memory_model import MemoryModel from assembler.memory_model.variable import Variable + class Instruction(MInstruction): """ Encapsulates an `mstore` MInstruction. @@ -21,7 +25,7 @@ class Instruction(MInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -30,14 +34,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "mstore" - def __init__(self, - id: int, - src: list, - mem_model: MemoryModel, - dst_hbm_addr: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + src: list, + mem_model: MemoryModel, + dst_hbm_addr: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `mstore` MInstruction. @@ -63,7 +69,7 @@ def __init__(self, ValueError: If `dst_hbm_addr` is negative. """ if dst_hbm_addr < 0: - raise ValueError('`dst_hbm_addr`: cannot be null address (negative).') + raise ValueError("`dst_hbm_addr`: cannot be null address (negative).") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: @@ -81,21 +87,25 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, source, destination HBM address, throughput, and latency. """ - assert(len(self.dests) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'src={}, dst_hbm_addr={}, mem_model, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.sources, - self.dst_hbm_addr, - # repr(self.__mem_model), - self.throughput, - self.latency) + assert len(self.dests) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "src={}, dst_hbm_addr={}, mem_model, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.sources, + self.dst_hbm_addr, + # repr(self.__mem_model), + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -122,9 +132,14 @@ def __internal_set_dests(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_dests(value) @@ -141,9 +156,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -169,9 +189,15 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert all(src == dst for src, dst in zip(self.sources, self.dests)) hbm = self.__mem_model.hbm spad = self.__mem_model.spad @@ -182,24 +208,29 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: if variable.hbm_address >= 0: if self.dst_hbm_addr != variable.hbm_address: - raise RuntimeError("Source variable is already in different HBM location. Cannot store a variable into HBM more than once.") - assert(hbm.buffer[variable.hbm_address] == variable) + raise RuntimeError( + "Source variable is already in different HBM location. Cannot store a variable into HBM more than once." + ) + assert hbm.buffer[variable.hbm_address] == variable if self.__source_spad_address < 0: - raise RuntimeError("Null reference exception: source variable is not in SPAD.") + raise RuntimeError( + "Null reference exception: source variable is not in SPAD." + ) if self.comment: - self.comment += ';' + self.comment += ";" # self.comment += ' variable "{}": HBM({}) <- SPAD({})'.format(variable.name, # self.dst_hbm_addr, # variable.spad_address) - self.comment += ' variable "{}" <- SPAD({})'.format(variable.name, - variable.spad_address) + self.comment += ' variable "{}" <- SPAD({})'.format( + variable.name, variable.spad_address + ) retval = super()._schedule(cycle_count, schedule_id) # Perform the store - if variable.hbm_address < 0: # Variable new to HBM + if variable.hbm_address < 0: # Variable new to HBM hbm.allocateForce(self.dst_hbm_addr, variable) - spad.deallocate(self.__source_spad_address) # Deallocate variable from SPAD + spad.deallocate(self.__source_spad_address) # Deallocate variable from SPAD # Track SPAD access spad_access_tracking = spad.getAccessTracking(self.__source_spad_address) spad_access_tracking.last_mstore = self @@ -209,7 +240,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: return retval - def _toMASMISAFormat(self, *extra_args) -> str: + def _to_masmisa_format(self, *extra_args) -> str: """ Converts the instruction to MInst ASM-ISA format. @@ -222,9 +253,7 @@ def _toMASMISAFormat(self, *extra_args) -> str: str: The instruction in MInst ASM-ISA format. """ # Instruction sources - extra_args = (self.__source_spad_address, ) + extra_args + extra_args = (self.__source_spad_address,) + extra_args # Instruction destinations - extra_args = tuple(dst.toMASMISAFormat() for dst in self.dests) + extra_args - return self.toStringFormat(None, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + extra_args = tuple(dst.to_masmisa_format() for dst in self.dests) + extra_args + return self.to_string_format(None, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py index d8b3718d..a45e6160 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/minst/msyncc.py @@ -1,6 +1,10 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .minstruction import MInstruction + class Instruction(MInstruction): """ Encapsulates an `msyncc` MInstruction. @@ -13,9 +17,9 @@ class Instruction(MInstruction): Attributes: cinstr: The instruction from the CINST queue for which to wait. """ - + @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -24,12 +28,14 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "msyncc" - def __init__(self, - id: int, - cinstr, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + cinstr, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `msyncc` CInstruction. @@ -50,26 +56,32 @@ def __init__(self, if not latency: latency = Instruction._OP_DEFAULT_LATENCY super().__init__(id, throughput, latency, comment=comment) - self.cinstr = cinstr # Instruction number from the MINST queue for which to wait + self.cinstr = ( + cinstr # Instruction number from the MINST queue for which to wait + ) def __repr__(self): """ Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, cinstr, throughput, and latency. """ - assert(len(self.dests) > 0) - retval=('<{}({}) object at {}>(id={}[0], ' - 'cinstr={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.OP_NAME_PISA, - hex(id(self)), - self.id, - repr(self.cinstr), - self.throughput, - self.latency) + assert len(self.dests) > 0 + retval = ( + "<{}({}) object at {}>(id={}[0], " + "cinstr={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.op_name_pisa, + hex(id(self)), + self.id, + repr(self.cinstr), + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -121,7 +133,7 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: retval = super()._schedule(cycle_count, schedule_id) return retval - def _toMASMISAFormat(self, *extra_args) -> str: + def _to_masmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -134,12 +146,12 @@ def _toMASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) - assert(self.cinstr.is_scheduled) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES + assert self.cinstr.is_scheduled if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # warnings.warn("`msyncc` instruction requires second pass to set correct instruction number.") - return super()._toMASMISAFormat(self.cinstr.schedule_timing.index) \ No newline at end of file + return super()._to_masmisa_format(self.cinstr.schedule_timing.index) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py index 02f91124..d9abaeab 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/__init__.py @@ -1,6 +1,25 @@ -from assembler.memory_model import MemoryModel +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.memory_model import MemoryModel from .xinstruction import XInstruction -from . import add, sub, mul, muli, mac, maci, ntt, intt, twntt, twintt, rshuffle, irshuffle, move, xstore, nop +from . import ( + add, + sub, + mul, + muli, + mac, + maci, + ntt, + intt, + twntt, + twintt, + rshuffle, + irshuffle, + move, + xstore, + nop, +) from . import exit as exit_mod from . import copy as copy_mod @@ -27,15 +46,29 @@ Nop = nop.Instruction # Collection of XInstructions with P-ISA or intermediate P-ISA equivalents -__PISA_INSTRUCTIONS = ( Add, Sub, Mul, Muli, Mac, Maci, NTT, iNTT, twNTT, twiNTT, rShuffle, irShuffle, Copy ) +__PISA_INSTRUCTIONS = ( + Add, + Sub, + Mul, + Muli, + Mac, + Maci, + NTT, + iNTT, + twNTT, + twiNTT, + rShuffle, + irShuffle, + Copy, +) # Collection of XInstructions with global cycle tracking -GLOBAL_CYCLE_TRACKING_INSTRUCTIONS = ( rShuffle, irShuffle, XStore ) +GLOBAL_CYCLE_TRACKING_INSTRUCTIONS = (rShuffle, irShuffle, XStore) + -def createFromParsedObj(mem_model: MemoryModel, - inst_type, - parsed_op, - new_id: int = 0) -> XInstruction: +def createFromParsedObj( + mem_model: MemoryModel, inst_type, parsed_op, new_id: int = 0 +) -> XInstruction: """ Creates an XInstruction object XInst from the specified namespace data. @@ -63,7 +96,7 @@ def createFromParsedObj(mem_model: MemoryModel, """ if not issubclass(inst_type, XInstruction): - raise ValueError('`inst_type`: expected a class derived from `XInstruction`.') + raise ValueError("`inst_type`: expected a class derived from `XInstruction`.") # Convert variable names into actual variable objects. @@ -84,14 +117,15 @@ def createFromParsedObj(mem_model: MemoryModel, # Prepare parsed object to add as arguments to instruction constructor. parsed_op.dst = dsts parsed_op.src = srcs - assert(parsed_op.op_name == inst_type.OP_NAME_PISA) + assert parsed_op.op_name == inst_type.op_name_pisa parsed_op = vars(parsed_op) - parsed_op.pop("op_name") # op name not needed: inst_type knows its name already + parsed_op.pop("op_name") # op name not needed: inst_type knows its name already return inst_type(new_id, **parsed_op) -def createFromPISALine(mem_model: MemoryModel, - line: str, - line_no: int = 0) -> XInstruction: + +def createFromPISALine( + mem_model: MemoryModel, line: str, line_no: int = 0 +) -> XInstruction: """ Parses an XInst from the specified string (in P-ISA kernel input format) and returns a XInstruction object encapsulating the resulting instruction. @@ -126,7 +160,7 @@ def createFromPISALine(mem_model: MemoryModel, for inst_type in __PISA_INSTRUCTIONS: parsed_op = inst_type.parseFromPISALine(line) if parsed_op: - assert(inst_type.OP_NAME_PISA == parsed_op.op_name) + assert inst_type.op_name_pisa == parsed_op.op_name # Convert parsed instruction into an actual instruction object. retval = createFromParsedObj(mem_model, inst_type, parsed_op, line_no) @@ -135,6 +169,6 @@ def createFromPISALine(mem_model: MemoryModel, break except Exception as ex: - raise Exception(f'line {line_no}: {line}.') from ex + raise Exception(f"line {line_no}: {line}.") from ex return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py index fb06d226..84af953b 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/add.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents an `add` instruction in the assembler with specific properties and methods for parsing, @@ -62,30 +66,32 @@ def parseFromPISALine(cls, line: str) -> list: comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: - retval = { "comment": tokens[1] } + retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -94,15 +100,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "add" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - res: int, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object with the given parameters. @@ -133,17 +141,21 @@ def __repr__(self): Returns: str: A string representation of object. """ - retval=('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -157,9 +169,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -175,14 +192,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -195,15 +217,15 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: Kernel format instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toPISAFormat() + return super()._to_pisa_format() - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -216,10 +238,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: ASM format instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py index 74022b32..7946e57f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/copy.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Encapsulates a `move` instruction when used to copy @@ -35,10 +39,10 @@ def _get_name(cls) -> str: Returns: str: PISA operation name. """ - return cls.OP_NAME_PISA + return cls.op_name_pisa @classmethod - def _get_OP_NAME_PISA(cls) -> str: + def _get_op_name_pisa(cls) -> str: """ Returns the operation name in PISA format. @@ -48,7 +52,7 @@ def _get_OP_NAME_PISA(cls) -> str: return "copy" @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -82,44 +86,51 @@ def parseFromPISALine(cls, line: str) -> list: comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: - retval = { "comment": tokens[1] } + retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) if len(instr_tokens) < cls._OP_NUM_TOKENS: # temporary warning to avoid syntax error during testing # REMOVE WARNING AND TURN IT TO ERROR DURING PRODUCTION - #--------------------------- - warnings.warn(f'Not enough tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + # --------------------------- + warnings.warn( + f'Not enough tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) pass else: # ignore "res", but make sure it exists (syntax) - assert(instr_tokens[params_end] is not None) + assert instr_tokens[params_end] is not None retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object with the given parameters. @@ -139,10 +150,12 @@ def __init__(self, throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: latency = Instruction._OP_DEFAULT_LATENCY - N = 0 # does not require ring-size + N = 0 # does not require ring-size super().__init__(id, N, throughput, latency, comment=comment) if dst[0].name == src[0].name: - raise ValueError(f'`dst`: Source and destination cannot be the same for instruction "{self.name}".') + raise ValueError( + f'`dst`: Source and destination cannot be the same for instruction "{self.name}".' + ) self._set_dests(dst) self._set_sources(src) @@ -153,16 +166,20 @@ def __repr__(self): Returns: str: A string representation of object. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -176,9 +193,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_dests(value) @@ -194,14 +216,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to P-ISA kernel format. @@ -214,15 +241,15 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: P-ISA kernel format instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toPISAFormat() + return super()._to_pisa_format() - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -235,10 +262,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: ASM format instruction. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py index 7bce54c1..f21334c7 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/exit.py @@ -1,4 +1,8 @@ -from .xinstruction import XInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xinstruction import XInstruction + class Instruction(XInstruction): """ @@ -12,7 +16,7 @@ class Instruction(XInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -21,11 +25,9 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "bexit" - def __init__(self, - id: int, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, id: int, throughput: int = None, latency: int = None, comment: str = "" + ): """ Initializes an Instruction object with the given parameters. @@ -49,13 +51,16 @@ def __repr__(self): Returns: str: A string representation. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -68,7 +73,9 @@ def _set_dests(self, value): Raises: RuntimeError: Always raised as `bexit` does not have parameters. """ - raise RuntimeError(f"Instruction `{self.OP_NAME_PISA}` does not have parameters.") + raise RuntimeError( + f"Instruction `{self.op_name_pisa}` does not have parameters." + ) def _set_sources(self, value): """ @@ -80,9 +87,11 @@ def _set_sources(self, value): Raises: RuntimeError: Always raised as `bexit` does not have parameters. """ - raise RuntimeError(f"Instruction `{self.OP_NAME_PISA}` does not have parameters.") + raise RuntimeError( + f"Instruction `{self.op_name_pisa}` does not have parameters." + ) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ This instruction has no PISA equivalent. @@ -94,7 +103,7 @@ def _toPISAFormat(self, *extra_args) -> str: """ return None - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -107,10 +116,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py index c5bae42a..03d6e7ef 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/intt.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents an `intt` instruction in the assembler with specific properties and methods for parsing, @@ -70,31 +74,33 @@ def parseFromPISALine(cls, line: str) -> object: Returns None if an `intt` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: - retval = { "comment": tokens[1] } + retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["stage"] = int(instr_tokens[params_end]) retval["res"] = int(instr_tokens[params_end + 1]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -103,16 +109,18 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "intt" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - stage: int, - res: int, - comment: str = "", - throughput : int = None, - latency : int = None): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + stage: int, + res: int, + comment: str = "", + throughput: int = None, + latency: int = None, + ): """ Initializes an Instruction object with the given parameters. @@ -134,7 +142,7 @@ def __init__(self, super().__init__(id, N, throughput, latency, res=res, comment=comment) - self.__stage = stage # (read-only) stage + self.__stage = stage # (read-only) stage self._set_dests(dst) self._set_sources(src) @@ -145,17 +153,21 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval=('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -179,9 +191,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -197,14 +214,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -218,16 +240,16 @@ def _toPISAFormat(self, *extra_args) -> str: str: The instruction in kernel format. """ if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES # N, intt, dst_top (bank), dest_bot (bank), src_top (bank), src_bot (bank), src_tw (bank), stage, res # comment - retval = super()._toPISAFormat(self.stage) + retval = super()._to_pisa_format(self.stage) return retval - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -240,10 +262,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.stage) \ No newline at end of file + return super()._to_xasmisa_format(self.stage) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py index 7176a5cb..2c27d914 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/irshuffle.py @@ -1,4 +1,7 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace @@ -8,10 +11,11 @@ from assembler.memory_model.variable import Variable from . import rshuffle + class Instruction(XInstruction): """ Represents an instruction in the assembler with specific properties and methods for parsing, - scheduling, and formatting. This class is specifically designed to handle `irshuffle` + scheduling, and formatting. This class is specifically designed to handle `irshuffle` instruction within the assembler's instruction set architecture (ISA). Attributes: @@ -29,13 +33,17 @@ class Instruction(XInstruction): """ # To be initialized from ASM ISA spec - _OP_NUM_TOKENS : int - _OP_IRMOVE_LATENCY : int + _OP_NUM_TOKENS: int + _OP_IRMOVE_LATENCY: int _OP_IRMOVE_LATENCY_MAX: int _OP_IRMOVE_LATENCY_INC: int - __irshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among irshuffles - __rshuffle_global_cycle_ready = CycleType(0, 0) # private class attribute to track the cycle ready based on last rshuffle + __irshuffle_global_cycle_ready = CycleType( + 0, 0 + ) # private class attribute to track cycle ready among irshuffles + __rshuffle_global_cycle_ready = CycleType( + 0, 0 + ) # private class attribute to track the cycle ready based on last rshuffle @classmethod def isa_spec_as_dict(cls) -> dict: @@ -43,9 +51,13 @@ def isa_spec_as_dict(cls) -> dict: Returns isa_spec attributes as dictionary. """ dict = super().isa_spec_as_dict() - dict.update({"num_tokens": cls._OP_NUM_TOKENS, - "special_latency_max": cls._OP_IRMOVE_LATENCY_MAX, - "special_latency_increment": cls._OP_IRMOVE_LATENCY_INC}) + dict.update( + { + "num_tokens": cls._OP_NUM_TOKENS, + "special_latency_max": cls._OP_IRMOVE_LATENCY_MAX, + "special_latency_increment": cls._OP_IRMOVE_LATENCY_INC, + } + ) return dict @classmethod @@ -133,27 +145,29 @@ def parseFromPISALine(cls, line: str) -> object: Returns None if an `irshuffle` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: - retval = { "comment": tokens[1] } + retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) # ignore "res", but make sure it exists (syntax) - assert(instr_tokens[params_end] is not None) + assert instr_tokens[params_end] is not None retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod @@ -164,10 +178,10 @@ def _get_name(cls) -> str: Returns: str: The operation name in PISA format. """ - return cls.OP_NAME_PISA + return cls.op_name_pisa @classmethod - def _get_OP_NAME_PISA(cls) -> str: + def _get_op_name_pisa(cls) -> str: """ Returns the operation name in PISA format. @@ -177,7 +191,7 @@ def _get_OP_NAME_PISA(cls) -> str: return "irshuffle" @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -186,15 +200,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "rshuffle" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - wait_cyc: int = 0, - throughput : int = None, - latency : int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + wait_cyc: int = 0, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object with the given parameters. @@ -216,8 +232,12 @@ def __init__(self, if not latency: latency = Instruction._OP_DEFAULT_LATENCY if latency < Instruction._OP_IRMOVE_LATENCY: - raise ValueError((f'`latency`: expected a value greater than or equal to ' - '{Instruction._OP_IRMOVE_LATENCY}, but {latency} received.')) + raise ValueError( + ( + f"`latency`: expected a value greater than or equal to " + "{Instruction._OP_IRMOVE_LATENCY}, but {latency} received." + ) + ) super().__init__(id, N, throughput, latency, comment=comment) @@ -232,16 +252,18 @@ def __repr__(self): Returns: str: A string representation of the Instruction object. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={}, ' - 'wait_cyc={}, res={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests, - self.sources, - self.wait_cyc, - self.res) + retval = ( + "<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={}, res={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.wait_cyc, + self.res, + ) return retval @classmethod @@ -252,7 +274,7 @@ def __set_irshuffleGlobalCycleReady(cls, value: CycleType): Parameters: value (CycleType): The cycle type value to set. """ - if (value > cls.__irshuffle_global_cycle_ready): + if value > cls.__irshuffle_global_cycle_ready: cls.__irshuffle_global_cycle_ready = value @classmethod @@ -263,11 +285,11 @@ def set_rshuffleGlobalCycleReady(cls, value: CycleType): Parameters: value (CycleType): The cycle type value to set. """ - if (value > cls.__rshuffle_global_cycle_ready): + if value > cls.__rshuffle_global_cycle_ready: cls.__rshuffle_global_cycle_ready = value @classmethod - def reset_GlobalCycleReady(cls, value = CycleType(0, 0)): + def reset_GlobalCycleReady(cls, value=CycleType(0, 0)): """ Resets the global cycle ready for both irshuffle and rshuffle instructions. @@ -288,8 +310,12 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " - "but list with {len(value)} elements received.")) + raise ValueError( + ( + f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " + "but list with {len(value)} elements received." + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -305,8 +331,12 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " - "but list with {len(value)} elements received.")) + raise ValueError( + ( + f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " + "but list with {len(value)} elements received." + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) @@ -324,9 +354,11 @@ def _get_cycle_ready(self): # sources and the global cycles-ready for other rshuffles and other irshuffles. # An irshuffle cannot be within _OP_IRMOVE_LATENCY cycles from another irshuffle, # nor within _OP_DEFAULT_LATENCY cycles from another rshuffle. - return max(super()._get_cycle_ready(), - Instruction.__irshuffle_global_cycle_ready, - Instruction.__rshuffle_global_cycle_ready) + return max( + super()._get_cycle_ready(), + Instruction.__irshuffle_global_cycle_ready, + Instruction.__rshuffle_global_cycle_ready, + ) def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ @@ -350,13 +382,19 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ original_throughput = super()._schedule(cycle_count, schedule_id) retval = self.throughput + self.wait_cyc - assert(original_throughput <= retval) - Instruction.__set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_IRMOVE_LATENCY)) + assert original_throughput <= retval + Instruction.__set_irshuffleGlobalCycleReady( + CycleType( + cycle_count.bundle, cycle_count.cycle + Instruction._OP_IRMOVE_LATENCY + ) + ) # Avoid rshuffles and irshuffles in the same bundle - rshuffle.Instruction.set_irshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) + rshuffle.Instruction.set_irshuffleGlobalCycleReady( + CycleType(cycle_count.bundle + 1, 0) + ) return retval - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -369,16 +407,16 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: The instruction in kernel format. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, irshuffle, dst0, dst1, src0, src1, res=0 # comment - return super()._toPISAFormat(0) + return super()._to_pisa_format(0) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -391,11 +429,11 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # id[0], N, op, dst_register0, dst_register1, src_register0, src_register1, wait_cycle, data_type="intt", res=0 [# comment] - return super()._toXASMISAFormat(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) \ No newline at end of file + return super()._to_xasmisa_format(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py index 1d6d817a..82d7e531 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mac.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents a `mac` (multiply-accumulate) instruction in an assembly language. @@ -18,7 +22,7 @@ class Instruction(XInstruction): Methods: parseFromPISALine: Parses a `mac` instruction from a Kernel instruction string. """ - + # To be initialized from ASM ISA spec _OP_NUM_TOKENS: int @@ -36,18 +40,18 @@ def SetNumTokens(cls, val): cls._OP_NUM_TOKENS = val @classmethod - def SetNumSources(cls, val): + def set_num_sources(cls, val): cls._OP_NUM_SOURCES = val # In ASM ISA spec there are 3 sources, but src[0] == dst cls._OP_NUM_PISA_SOURCES = cls._OP_NUM_SOURCES - 1 - + @classmethod def parseFromPISALine(cls, line: str) -> list: """ Parses a 'mac' instruction from a Kernel instruction string. Parameters: - line (str): + line (str): String containing the instruction to parse. Instruction format: N, mac, dst (bank), src0 (bank), src1 (bank), res # comment Comment is optional. @@ -56,7 +60,7 @@ def parseFromPISALine(cls, line: str) -> list: "13, mac , c2_rlk_0_10_0 (3), coeff_0_0_0 (2), rlk_0_2_10_0 (0), 10" Returns: - list: + list: A list of tuples with a single element representing the parsed information, or an empty list if a 'mac' could not be parsed from the input. @@ -73,30 +77,32 @@ def parseFromPISALine(cls, line: str) -> list: - comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_PISA_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start + ) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -105,34 +111,36 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "mac" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object for a 'mac' operation. Parameters: - id (int): + id (int): The unique identifier for the instruction. - N (int): + N (int): The ring size. - dst (list): + dst (list): List of destination variables. - src (list): + src (list): List of source variables. - res (int): + res (int): The residual for the operation. - throughput (int, optional): + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. - latency (int, optional): + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. - comment (str, optional): + comment (str, optional): An optional comment for the instruction. """ if not throughput: @@ -152,17 +160,21 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -176,9 +188,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -194,14 +211,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -214,22 +236,22 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: The instruction in kernel format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") preamble = (self.N,) - extra_args = tuple(src.toPISAFormat() for src in self.sources[1:]) + extra_args - extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + extra_args = ( + tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args + ) + extra_args = tuple(dst.to_pisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res,) - return self.toStringFormat(preamble, - self.OP_NAME_PISA, - *extra_args) + return self.to_string_format(preamble, self.op_name_pisa, *extra_args) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -242,10 +264,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py index 455387f2..f52d41f4 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/maci.py @@ -1,14 +1,18 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents a `maci` (multiply-accumulate immediate) instruction in an assembly language. - + This class is responsible for parsing, representing, and converting 'maci' instructions according to a specific instruction set architecture (ISA) specification. @@ -33,18 +37,18 @@ def SetNumTokens(cls, val): cls._OP_NUM_TOKENS = val @classmethod - def SetNumSources(cls, val): + def set_num_sources(cls, val): cls._OP_NUM_SOURCES = val # In ASM ISA spec there are 2 sources, but src[0] == dst cls._OP_NUM_PISA_SOURCES = cls._OP_NUM_SOURCES - 1 - + @classmethod def parseFromPISALine(cls, line: str) -> list: """ Parses a 'maci' instruction from a Kernel instruction string. Parameters: - line (str): + line (str): String containing the instruction to parse. Instruction format: N, maci, dst (bank), src (bank), imm, res # comment Comment is optional. @@ -53,7 +57,7 @@ def parseFromPISALine(cls, line: str) -> list: "13, maci, coeff_0_1_3 (2), c2_4_3 (3), Qqr_extend_2_13_4, 13" Returns: - list: + list: A list of tuples with a single element representing the parsed information, or an empty list if a 'maci' could not be parsed from the input. @@ -71,31 +75,33 @@ def parseFromPISALine(cls, line: str) -> list: - comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_PISA_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_PISA_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_PISA_SOURCES, params_start + ) retval.update(dst_src) retval["imm"] = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -104,37 +110,39 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "maci" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - imm: str, - res: int, - comment: str = "", - throughput: int = None, - latency: int = None): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + imm: str, + res: int, + comment: str = "", + throughput: int = None, + latency: int = None, + ): """ Initializes an Instruction object for a 'maci' operation. Parameters: - id (int): + id (int): The unique identifier for the instruction. - N (int): + N (int): The ring size. - dst (list): + dst (list): List of destination variables. - src (list): + src (list): List of source variables. - imm (str): + imm (str): The immediate value identifier. - res (int): + res (int): The residual for the operation. - comment (str, optional): + comment (str, optional): An optional comment for the instruction. - throughput (int, optional): + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. - latency (int, optional): + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. """ if not throughput: @@ -155,18 +163,22 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], res={}, imm={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.imm, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, imm={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.imm, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -190,9 +202,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -208,14 +225,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -228,24 +250,24 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: The instruction in kernel format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, muli, dst (bank), src0 (bank), imm, res # comment preamble = (self.N,) extra_args = (self.imm,) - extra_args = tuple(src.toPISAFormat() for src in self.sources[1:]) + extra_args - extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + extra_args = ( + tuple(src.to_pisa_format() for src in self.sources[1:]) + extra_args + ) + extra_args = tuple(dst.to_pisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res,) - return self.toStringFormat(preamble, - self.OP_NAME_PISA, - *extra_args) + return self.to_string_format(preamble, self.op_name_pisa, *extra_args) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -258,10 +280,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.imm) \ No newline at end of file + return super()._to_xasmisa_format(self.imm) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py index 80a47f3c..bf3409c6 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/move.py @@ -1,8 +1,12 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .xinstruction import XInstruction from assembler.memory_model.variable import Variable, DummyVariable from assembler.memory_model.register_file import Register + class Instruction(XInstruction): """ Encapsulates a `move` instruction used to copy data from one register to a different one. @@ -15,7 +19,7 @@ class Instruction(XInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -24,32 +28,34 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "move" - def __init__(self, - id: int, - dst: Register, - src: list, - dummy_var: DummyVariable = None, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + dst: Register, + src: list, + dummy_var: DummyVariable = None, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `move` CInstruction. Parameters: - id (int): + id (int): User-defined ID for the instruction. It will be bundled with a nonce to form a unique ID. - dst (Register): + dst (Register): The destination register where to load the variable in `src`. - src (list of Variable): + src (list of Variable): A list containing a single Variable object indicating the source variable to move from its current register to `dst` register. - dummy_var (DummyVariable, optional): + dummy_var (DummyVariable, optional): A dummy variable used for marking registers as free. - throughput (int, optional): + throughput (int, optional): The throughput of the instruction. Defaults to the class-level default if not provided. - latency (int, optional): + latency (int, optional): The latency of the instruction. Defaults to the class-level default if not provided. - comment (str, optional): + comment (str, optional): An optional comment for the instruction. Raises: @@ -60,12 +66,19 @@ def __init__(self, if not latency: latency = Instruction._OP_DEFAULT_LATENCY if any(isinstance(v, DummyVariable) or not v.name for v in src): - raise ValueError(f"{Instruction.OP_NAME_ASM} cannot have dummy variable as source.") - if dst.contained_variable \ - and not isinstance(dst.contained_variable, DummyVariable): - raise ValueError("{}: destination register must be empty, but variable {}.{} found.".format(Instruction.OP_NAME_ASM, - dst.contained_variable.name, - dst.contained_variable.tag)) + raise ValueError( + f"{Instruction.op_name_asm} cannot have dummy variable as source." + ) + if dst.contained_variable and not isinstance( + dst.contained_variable, DummyVariable + ): + raise ValueError( + "{}: destination register must be empty, but variable {}.{} found.".format( + Instruction.op_name_asm, + dst.contained_variable.name, + dst.contained_variable.tag, + ) + ) N = 0 # Does not require ring-size super().__init__(id, N, throughput, latency, comment=comment) self.__dummy_var = dummy_var @@ -79,16 +92,20 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -103,9 +120,14 @@ def _set_dests(self, value): TypeError: If the list contains non-Register objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Register` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Register` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Register) for x in value): raise TypeError("`value`: Expected list of `Register` objects.") super()._set_dests(value) @@ -121,9 +143,14 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -147,32 +174,52 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction, i.e., the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) variable = self.sources[0] # Expected sources to contain a Variable target_register = self.dests[0] if isinstance(variable, Register): # Source and target types are swapped after scheduling # Instruction already scheduled: can only schedule once - assert(isinstance(target_register, Variable)) - raise RuntimeError(f'Instruction `{self.name}` (id = {self.id}) already scheduled.') - - if target_register.contained_variable \ - and not isinstance(target_register.contained_variable, DummyVariable): - raise RuntimeError(('Instruction `{}` (id = {}) ' - 'cannot be scheduled because target register `{}` is not empty: ' - 'contains variable "{}".').format(self.name, - self.id, - target_register.name, - target_register.contained_variable.name)) - - assert not target_register.contained_variable or self.__dummy_var == target_register.contained_variable + assert isinstance(target_register, Variable) + raise RuntimeError( + f"Instruction `{self.name}` (id = {self.id}) already scheduled." + ) + + if target_register.contained_variable and not isinstance( + target_register.contained_variable, DummyVariable + ): + raise RuntimeError( + ( + "Instruction `{}` (id = {}) " + "cannot be scheduled because target register `{}` is not empty: " + 'contains variable "{}".' + ).format( + self.name, + self.id, + target_register.name, + target_register.contained_variable.name, + ) + ) + + assert ( + not target_register.contained_variable + or self.__dummy_var == target_register.contained_variable + ) # Perform the move register_dirty = variable.register_dirty source_register = variable.register target_register.allocateVariable(variable) - source_register.allocateVariable(self.__dummy_var) # Mark source register as free for next bundle + source_register.allocateVariable( + self.__dummy_var + ) # Mark source register as free for next bundle assert source_register.bank.bank_index == 0 # Swap source and dest to keep the output format of the string instruction consistent self.sources[0] = source_register @@ -183,12 +230,12 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: variable.register_dirty = register_dirty # Preserve register dirty state if self.comment: - self.comment += ';' + self.comment += ";" self.comment += ' variable "{}"'.format(variable.name) return retval - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ This instruction has no PISA equivalent. @@ -200,7 +247,7 @@ def _toPISAFormat(self, *extra_args) -> str: """ return None - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -213,10 +260,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py index a17e2673..2698b78e 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/mul.py @@ -1,14 +1,18 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents a `mul` (multiply) instruction in an assembly language. - + This class is responsible for parsing, representing, and converting `mul` instructions according to a specific instruction set architecture (ISA) specification. @@ -63,30 +67,32 @@ def parseFromPISALine(cls, line: str) -> Namespace: - comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -95,15 +101,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "mul" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object for a 'mul' operation. @@ -133,17 +141,21 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -157,9 +169,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -175,14 +192,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -195,15 +217,15 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: The instruction in kernel format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toPISAFormat() + return super()._to_pisa_format() - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -216,10 +238,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py index eb3ce83c..3d1852aa 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/muli.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents a `muli` (multiply immediate) instruction in an assembly language. @@ -67,31 +71,33 @@ def parseFromPISALine(cls, line: str) -> list: - comment (str): String with the comment attached to the line (empty string if no comment). """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["imm"] = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -100,16 +106,18 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "muli" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - imm: str, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + imm: str, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object for a 'muli' operation. @@ -142,18 +150,22 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], res={}, imm={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.imm, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, imm={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.imm, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -177,9 +189,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -195,14 +212,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -215,16 +237,16 @@ def _toPISAFormat(self, *extra_args) -> str: Returns: str: The instruction in kernel format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, muli, dst (bank), src0 (bank), imm, res # comment - return super()._toPISAFormat(self.imm) + return super()._to_pisa_format(self.imm) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -237,10 +259,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.imm) \ No newline at end of file + return super()._to_xasmisa_format(self.imm) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py index 5c55c7f3..4c50de53 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/nop.py @@ -1,4 +1,8 @@ -from .xinstruction import XInstruction +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xinstruction import XInstruction + class Instruction(XInstruction): """ @@ -12,7 +16,7 @@ class Instruction(XInstruction): """ @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -21,10 +25,7 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "nop" - def __init__(self, - id: int, - idle_cycles: int, - comment: str = ""): + def __init__(self, id: int, idle_cycles: int, comment: str = ""): """ Initializes an Instruction object for a 'nop' operation. @@ -44,12 +45,9 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], ' - 'idle_cycles={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.throughput) + retval = ("<{}({}) object at {}>(id={}[0], " "idle_cycles={})").format( + type(self).__name__, self.name, hex(id(self)), self.id, self.throughput + ) return retval def _set_dests(self, value): @@ -76,7 +74,7 @@ def _set_sources(self, value): """ raise RuntimeError(f"Instruction `{self.name}` does not have parameters.") - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Indicates that this instruction has no PISA equivalent. @@ -88,7 +86,7 @@ def _toPISAFormat(self, *extra_args) -> str: """ return None - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -101,12 +99,12 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # The idle cycles in the ASM ISA for `nop` must be one less because decoding/scheduling # the instruction counts as a cycle. - return super()._toXASMISAFormat(self.throughput - 1) \ No newline at end of file + return super()._to_xasmisa_format(self.throughput - 1) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py index 5b68c9d8..22c81924 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/ntt.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Represents an `ntt` (Number Theoretic Transform) instruction in an assembly language. @@ -64,31 +68,33 @@ def parseFromPISALine(cls, line: str) -> object: Returns None if an 'ntt' could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["stage"] = int(instr_tokens[params_end]) retval["res"] = int(instr_tokens[params_end + 1]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. @@ -97,16 +103,18 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "ntt" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - stage: int, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + stage: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Initializes an Instruction object. @@ -139,17 +147,21 @@ def __repr__(self): Returns: str: A string representation of the object. """ - retval = ('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -173,9 +185,14 @@ def _set_dests(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -191,14 +208,19 @@ def _set_sources(self, value): ValueError: If the list does not contain the expected number of Variable objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -212,16 +234,16 @@ def _toPISAFormat(self, *extra_args) -> str: str: The instruction in kernel format as a string. """ if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES # N, ntt, dst_top (bank), dest_bot (bank), src_top (bank), src_bot (bank), src_tw (bank), stage, res # comment - retval = super()._toPISAFormat(self.stage) + retval = super()._to_pisa_format(self.stage) return retval - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -234,10 +256,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Returns: str: The instruction in ASM format as a string. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.stage) \ No newline at end of file + return super()._to_xasmisa_format(self.stage) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py index 89e53617..c455e79f 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/parse_xntt.py @@ -1,4 +1,7 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace @@ -8,9 +11,8 @@ __xntt_id = 0 -def parseXNTTKernelLine(line: str, - op_name: str, - tw_separator: str) -> Namespace: + +def parseXNTTKernelLine(line: str, op_name: str, tw_separator: str) -> Namespace: """ Parses an `xntt` instruction from a P-ISA kernel instruction string. @@ -37,9 +39,9 @@ def parseXNTTKernelLine(line: str, None: If an `xntt` could not be parsed from the input. """ - OP_NUM_DESTS = 2 + OP_NUM_DESTS = 2 OP_NUM_SOURCES = 2 - OP_NUM_TOKENS = 8 + OP_NUM_TOKENS = 8 retval = None tokens = xinst.XInstruction.tokenizeFromPISALine(op_name, line) @@ -48,16 +50,17 @@ def parseXNTTKernelLine(line: str, instr_tokens = tokens[0] if len(instr_tokens) > OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{op_name}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{op_name}"', SyntaxWarning + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + OP_NUM_DESTS + OP_NUM_SOURCES - dst_src = xinst.XInstruction.parsePISASourceDestsFromTokens(instr_tokens, - OP_NUM_DESTS, - OP_NUM_SOURCES, - params_start) + dst_src = xinst.XInstruction.parsePISASourceDestsFromTokens( + instr_tokens, OP_NUM_DESTS, OP_NUM_SOURCES, params_start + ) retval.update(dst_src) twiddle = instr_tokens[params_end] retval["res"] = int(instr_tokens[params_end + 1]) @@ -65,18 +68,25 @@ def parseXNTTKernelLine(line: str, # Parse twiddle (w___, where "_" is the `tw_separator`) twiddle_tokens = list(map(lambda s: s.strip(), twiddle.split(tw_separator))) if len(twiddle_tokens) != 4: - raise ValueError(f'Error parsing twiddle information for "{op_name}" in line "{line}".') + raise ValueError( + f'Error parsing twiddle information for "{op_name}" in line "{line}".' + ) if twiddle_tokens[0] != "w": - raise ValueError(f'Invalid twiddle detected for "{op_name}" in line "{line}".') + raise ValueError( + f'Invalid twiddle detected for "{op_name}" in line "{line}".' + ) if int(twiddle_tokens[1]) != retval["res"]: - raise ValueError(f'Invalid "residual" component detected in twiddle information for "{op_name}" in line "{line}".') + raise ValueError( + f'Invalid "residual" component detected in twiddle information for "{op_name}" in line "{line}".' + ) retval["stage"] = int(twiddle_tokens[2]) retval["block"] = int(twiddle_tokens[3]) retval = Namespace(**retval) - assert(retval.op_name == op_name) + assert retval.op_name == op_name return retval + def __generateRMoveParsedOp(kntt_parsed_op: Namespace) -> (type, Namespace): """ Generates a namespace compatible with xrshuffle XInst constructor. @@ -96,23 +106,28 @@ def __generateRMoveParsedOp(kntt_parsed_op: Namespace) -> (type, Namespace): parsed_op["src"] = [] parsed_op["comment"] = "" - if kntt_parsed_op.op_name == xinst.NTT.OP_NAME_PISA: + if kntt_parsed_op.op_name == xinst.NTT.op_name_pisa: xrshuffle_type = xinst.rShuffle parsed_op["dst"] = [d for d in kntt_parsed_op.dst] - elif kntt_parsed_op.op_name == xinst.iNTT.OP_NAME_PISA: + elif kntt_parsed_op.op_name == xinst.iNTT.op_name_pisa: xrshuffle_type = xinst.irShuffle parsed_op["dst"] = [s for s in kntt_parsed_op.src] else: - raise ValueError('`kntt_parsed_op`: cannot process operation with name "{}".'.format(kntt_parsed_op.op_name)) + raise ValueError( + '`kntt_parsed_op`: cannot process operation with name "{}".'.format( + kntt_parsed_op.op_name + ) + ) - assert(xrshuffle_type) + assert xrshuffle_type parsed_op["src"] = parsed_op["dst"] - parsed_op["op_name"] = xrshuffle_type.OP_NAME_PISA + parsed_op["op_name"] = xrshuffle_type.op_name_pisa # rshuffle goes above corresponding intt or below corresponding ntt return xrshuffle_type, Namespace(**parsed_op) + def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: """ Generates a namespace compatible with twxntt XInst constructor. @@ -124,13 +139,13 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: tuple: A tuple containing the twxntt type, a Namespace with the parsed operation, and a tuple with the twiddle variable name and suggested bank. The twxntt type is None if a twxntt is not needed for the specified xntt. """ - global __xntt_id # TODO: replace by unique ID once it gets integrated into the P-ISA kernel. + global __xntt_id # TODO: replace by unique ID once it gets integrated into the P-ISA kernel. retval = None parsed_op = {} parsed_op["N"] = xntt_parsed_op.N - parsed_op["op_name"] = 'tw' + str(xntt_parsed_op.op_name) + parsed_op["op_name"] = "tw" + str(xntt_parsed_op.op_name) parsed_op["res"] = xntt_parsed_op.res parsed_op["stage"] = xntt_parsed_op.stage parsed_op["block"] = xntt_parsed_op.block @@ -140,11 +155,18 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: parsed_op["comment"] = "" # Find types depending on whether we are doing ntt or intt - twxntt_type = next((t for t in (xinst.twNTT, xinst.twiNTT) if t.OP_NAME_PISA == parsed_op["op_name"]), None) - assert(twxntt_type) + twxntt_type = next( + ( + t + for t in (xinst.twNTT, xinst.twiNTT) + if t.op_name_pisa == parsed_op["op_name"] + ), + None, + ) + assert twxntt_type # Adapted from legacy code add_tw_xntt - #------------------------------------- + # ------------------------------------- ringsize = int(parsed_op["N"]) rminustwo = ringsize - 2 @@ -153,16 +175,16 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: # Generate meta data look-up meta_rns_term = rns_term % constants.MemoryModel.MAX_RESIDUALS - mdata_word_sel = meta_rns_term >> 1 # 5bit word select + mdata_word_sel = meta_rns_term >> 1 # 5bit word select mdata_inword_res_sel = meta_rns_term & 1 mdata_inword_stage_sel = rminustwo - stage if twxntt_type == xinst.twiNTT: - mdata_inword_ntt_sel = 1 # Select intt field - else: # xinst.twNTT - mdata_inword_ntt_sel = 0 # Select ntt field - mdata_ptr = (mdata_word_sel << 6) - mdata_ptr |= (mdata_inword_res_sel << 5) - mdata_ptr |= (mdata_inword_ntt_sel << 4) + mdata_inword_ntt_sel = 1 # Select intt field + else: # xinst.twNTT + mdata_inword_ntt_sel = 0 # Select ntt field + mdata_ptr = mdata_word_sel << 6 + mdata_ptr |= mdata_inword_res_sel << 5 + mdata_ptr |= mdata_inword_ntt_sel << 4 mdata_ptr |= mdata_inword_stage_sel block = int(parsed_op["block"]) @@ -171,12 +193,20 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: __xntt_id += 1 # Generate twiddle variable name - tw_var_name_bank = ("w_gen_{}_{}_{}_{}".format(mdata_inword_ntt_sel, __xntt_id, rns_term, block), 1) - - meta_data_comment = "{} {} ".format(mdata_word_sel, mdata_inword_res_sel) - meta_data_comment += "{} {} w_{}_{}_{}".format(mdata_inword_ntt_sel, mdata_inword_stage_sel, - # hop_list[6] - parsed_op["res"], parsed_op["stage"], parsed_op["block"]) + tw_var_name_bank = ( + "w_gen_{}_{}_{}_{}".format(mdata_inword_ntt_sel, __xntt_id, rns_term, block), + 1, + ) + + meta_data_comment = "{} {} ".format(mdata_word_sel, mdata_inword_res_sel) + meta_data_comment += "{} {} w_{}_{}_{}".format( + mdata_inword_ntt_sel, + mdata_inword_stage_sel, + # hop_list[6] + parsed_op["res"], + parsed_op["stage"], + parsed_op["block"], + ) parsed_op["dst"] = [tw_var_name_bank] parsed_op["src"] = [tw_var_name_bank] @@ -194,9 +224,10 @@ def __generateTWNTTParsedOp(xntt_parsed_op: Namespace) -> Namespace: return retval, Namespace(**parsed_op), tw_var_name_bank -def generateXNTT(mem_model: MemoryModel, - xntt_parsed_op: Namespace, - new_id: int = 0) -> list: + +def generateXNTT( + mem_model: MemoryModel, xntt_parsed_op: Namespace, new_id: int = 0 +) -> list: """ Parses an `xntt` instruction from a P-ISA kernel instruction string. @@ -216,32 +247,51 @@ def generateXNTT(mem_model: MemoryModel, retval = [] # Find xntt type depending on whether we are doing ntt or intt - xntt_type = next((t for t in (xinst.NTT, xinst.iNTT) if t.OP_NAME_PISA == xntt_parsed_op.op_name), None) + xntt_type = next( + ( + t + for t in (xinst.NTT, xinst.iNTT) + if t.op_name_pisa == xntt_parsed_op.op_name + ), + None, + ) if not xntt_type: - raise ValueError('`xntt_parsed_op`: cannot process parsed kernel operation with name "{}".'.format(xntt_parsed_op.op_name)) + raise ValueError( + '`xntt_parsed_op`: cannot process parsed kernel operation with name "{}".'.format( + xntt_parsed_op.op_name + ) + ) # Generate twiddle instruction - #----------------------------- + # ----------------------------- - twxntt_type, twxntt_parsed_op, last_twxinput_name = __generateTWNTTParsedOp(xntt_parsed_op) + twxntt_type, twxntt_parsed_op, last_twxinput_name = __generateTWNTTParsedOp( + xntt_parsed_op + ) # print(twxntt_parsed_op) twxntt_inst = None if twxntt_type: - twxntt_inst = xinst.createFromParsedObj(mem_model, twxntt_type, twxntt_parsed_op, new_id) + twxntt_inst = xinst.createFromParsedObj( + mem_model, twxntt_type, twxntt_parsed_op, new_id + ) # Generate corresponding rshuffle - #----------------------------- + # ----------------------------- rshuffle_type, rshuffle_parsed_op = __generateRMoveParsedOp(xntt_parsed_op) - rshuffle_parsed_op.comment += (" " + twxntt_parsed_op.comment) if twxntt_parsed_op else "" - rshuffle_inst = xinst.createFromParsedObj(mem_model, rshuffle_type, rshuffle_parsed_op, new_id) + rshuffle_parsed_op.comment += ( + (" " + twxntt_parsed_op.comment) if twxntt_parsed_op else "" + ) + rshuffle_inst = xinst.createFromParsedObj( + mem_model, rshuffle_type, rshuffle_parsed_op, new_id + ) # Generate xntt instruction - #-------------------------- + # -------------------------- # Prepare arguments for ASM ntt instruction object construction if twxntt_parsed_op: - assert(twxntt_parsed_op.stage == xntt_parsed_op.stage) + assert twxntt_parsed_op.stage == xntt_parsed_op.stage delattr(xntt_parsed_op, "block") xntt_parsed_op.src.append(last_twxinput_name) xntt_parsed_op.comment += twxntt_parsed_op.comment if twxntt_parsed_op else "" @@ -250,18 +300,18 @@ def generateXNTT(mem_model: MemoryModel, xntt_inst = xinst.createFromParsedObj(mem_model, xntt_type, xntt_parsed_op, new_id) # Add instructions to return list - #-------------------------------- + # -------------------------------- - retval = [xntt_inst] # xntt + retval = [xntt_inst] # xntt - if xntt_type == xinst.iNTT: # rshuffle + if xntt_type == xinst.iNTT: # rshuffle # rshuffle goes above corresponding intt retval = [rshuffle_inst] + retval else: # rshuffle goes below corresponding ntt retval.append(rshuffle_inst) - if twxntt_inst: # twiddle + if twxntt_inst: # twiddle retval.append(twxntt_inst) - return retval \ No newline at end of file + return retval diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py index ff3b7090..378d6c70 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/rshuffle.py @@ -1,4 +1,7 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from assembler.common.cycle_tracking import CycleType @@ -7,6 +10,7 @@ from assembler.memory_model.variable import Variable from . import irshuffle + class Instruction(XInstruction): """ Encapsulates an `rshuffle` XInstruction. @@ -22,13 +26,17 @@ class Instruction(XInstruction): """ # To be initialized from ASM ISA spec - _OP_NUM_TOKENS : int - _OP_RMOVE_LATENCY : int - _OP_RMOVE_LATENCY_MAX: int - _OP_RMOVE_LATENCY_INC: int - - __rshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track cycle ready among rshuffles - __irshuffle_global_cycle_ready = CycleType(0, 0) # Private class attribute to track the cycle ready based on last irshuffle + _OP_NUM_TOKENS: int + _OP_REMOVE_LATENCY: int + _OP_REMOVE_LATENCY_MAX: int + _OP_REMOVE_LATENCY_INC: int + + __rshuffle_global_cycle_ready = CycleType( + 0, 0 + ) # Private class attribute to track cycle ready among rshuffles + __irshuffle_global_cycle_ready = CycleType( + 0, 0 + ) # Private class attribute to track the cycle ready based on last irshuffle @classmethod def isa_spec_as_dict(cls) -> dict: @@ -36,23 +44,27 @@ def isa_spec_as_dict(cls) -> dict: Returns isa_spec attributes as dictionary. """ dict = super().isa_spec_as_dict() - dict.update({"num_tokens": cls._OP_NUM_TOKENS, - "special_latency_max": cls._OP_RMOVE_LATENCY_MAX, - "special_latency_increment": cls._OP_RMOVE_LATENCY_INC}) + dict.update( + { + "num_tokens": cls._OP_NUM_TOKENS, + "special_latency_max": cls._OP_REMOVE_LATENCY_MAX, + "special_latency_increment": cls._OP_REMOVE_LATENCY_INC, + } + ) return dict - + @classmethod def SetNumTokens(cls, val): cls._OP_NUM_TOKENS = val @classmethod def SetSpecialLatencyMax(cls, val): - cls._OP_RMOVE_LATENCY_MAX = val + cls._OP_REMOVE_LATENCY_MAX = val @classmethod def SetSpecialLatencyIncrement(cls, val): - cls._OP_RMOVE_LATENCY_INC = val - cls._OP_RMOVE_LATENCY = cls._OP_RMOVE_LATENCY_INC + cls._OP_REMOVE_LATENCY_INC = val + cls._OP_REMOVE_LATENCY = cls._OP_REMOVE_LATENCY_INC @classproperty def SpecialLatency(cls): @@ -63,7 +75,7 @@ def SpecialLatency(cls): Returns: int: The special latency for rshuffle instructions. """ - return cls._OP_RMOVE_LATENCY + return cls._OP_REMOVE_LATENCY @classproperty def SpecialLatencyMax(cls): @@ -74,7 +86,7 @@ def SpecialLatencyMax(cls): Returns: int: The maximum special latency for rshuffle instructions. """ - return cls._OP_RMOVE_LATENCY_MAX + return cls._OP_REMOVE_LATENCY_MAX @classproperty def SpecialLatencyIncrement(cls): @@ -85,7 +97,7 @@ def SpecialLatencyIncrement(cls): Returns: int: The increment for special latency for rshuffle instructions. """ - return cls._OP_RMOVE_LATENCY_INC + return cls._OP_REMOVE_LATENCY_INC @classproperty def RSHUFFLE_DATA_TYPE(cls): @@ -126,31 +138,33 @@ def parseFromPISALine(cls, line: str) -> object: None: If an `rshuffle` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) # Ignore "res", but make sure it exists (syntax) - assert(instr_tokens[params_end] is not None) + assert instr_tokens[params_end] is not None retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -159,15 +173,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "rshuffle" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - wait_cyc: int = 0, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + wait_cyc: int = 0, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `rshuffle` XInstruction. @@ -188,9 +204,13 @@ def __init__(self, throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: latency = Instruction._OP_DEFAULT_LATENCY - if latency < Instruction._OP_RMOVE_LATENCY: - raise ValueError((f'`latency`: expected a value greater than or equal to ' - '{Instruction._OP_RMOVE_LATENCY}, but {latency} received.')) + if latency < Instruction._OP_REMOVE_LATENCY: + raise ValueError( + ( + f"`latency`: expected a value greater than or equal to " + "{Instruction._OP_REMOVE_LATENCY}, but {latency} received." + ) + ) super().__init__(id, N, throughput, latency, comment=comment) @@ -203,18 +223,20 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, destinations, sources, and wait cycles. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'dst={}, src={}, ' - 'wait_cyc={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests, - self.sources, - self.wait_cyc) + retval = ( + "<{}({}) object at {}>(id={}[0], " "dst={}, src={}, " "wait_cyc={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.sources, + self.wait_cyc, + ) return retval @classmethod @@ -225,7 +247,7 @@ def __set_rshuffleGlobalCycleReady(cls, value: CycleType): Parameters: value (CycleType): The cycle type value to set. """ - if (value > cls.__rshuffle_global_cycle_ready): + if value > cls.__rshuffle_global_cycle_ready: cls.__rshuffle_global_cycle_ready = value @classmethod @@ -236,7 +258,7 @@ def set_irshuffleGlobalCycleReady(cls, value: CycleType): Parameters: value (CycleType): The cycle type value to set. """ - if (value > cls.__irshuffle_global_cycle_ready): + if value > cls.__irshuffle_global_cycle_ready: cls.__irshuffle_global_cycle_ready = value @classmethod @@ -261,8 +283,12 @@ def _set_dests(self, value): ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " - "but list with {len(value)} elements received.")) + raise ValueError( + ( + f"`value`: Expected list of {Instruction._OP_NUM_DESTS} Variable objects, " + "but list with {len(value)} elements received." + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -278,8 +304,12 @@ def _set_sources(self, value): ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError((f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " - "but list with {len(value)} elements received.")) + raise ValueError( + ( + f"`value`: Expected list of {Instruction._OP_NUM_SOURCES} Variable objects, " + "but list with {len(value)} elements received." + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) @@ -295,11 +325,13 @@ def _get_cycle_ready(self): """ # This will return the maximum cycle ready among this instruction # sources and the global cycles-ready for other rshuffles and other irshuffles. - # An rshuffle cannot be within _OP_RMOVE_LATENCY cycles from another rshuffle, + # An rshuffle cannot be within _OP_REMOVE_LATENCY cycles from another rshuffle, # nor within _OP_DEFAULT_LATENCY cycles from another irshuffle. - return max(super()._get_cycle_ready(), - Instruction.__irshuffle_global_cycle_ready, - Instruction.__rshuffle_global_cycle_ready) + return max( + super()._get_cycle_ready(), + Instruction.__irshuffle_global_cycle_ready, + Instruction.__rshuffle_global_cycle_ready, + ) def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ @@ -324,13 +356,19 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ original_throughput = super()._schedule(cycle_count, schedule_id) retval = self.throughput + self.wait_cyc - assert(original_throughput <= retval) - Instruction.__set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + Instruction._OP_RMOVE_LATENCY)) + assert original_throughput <= retval + Instruction.__set_rshuffleGlobalCycleReady( + CycleType( + cycle_count.bundle, cycle_count.cycle + Instruction._OP_REMOVE_LATENCY + ) + ) # Avoid rshuffles and irshuffles in the same bundle - irshuffle.Instruction.set_rshuffleGlobalCycleReady(CycleType(cycle_count.bundle + 1, 0)) + irshuffle.Instruction.set_rshuffleGlobalCycleReady( + CycleType(cycle_count.bundle + 1, 0) + ) return retval - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -343,16 +381,16 @@ def _toPISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, rshuffle, dst0, dst1, src0, src1, res=0 # comment - return super()._toPISAFormat(0) + return super()._to_pisa_format(0) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -365,11 +403,11 @@ def _toXASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # id[0], N, op, dst_register0, dst_register1, src_register0, src_register1, wait_cycle, data_type="ntt", res=0 [# comment] - return super()._toXASMISAFormat(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) \ No newline at end of file + return super()._to_xasmisa_format(self.wait_cyc, self.RSHUFFLE_DATA_TYPE) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py index 3b7bcce3..d28febab 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/sub.py @@ -1,14 +1,18 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Encapsulates a `sub` XInstruction. - + This instruction performs element-wise polynomial subtraction. For more information, check the specification: @@ -61,30 +65,32 @@ def parseFromPISALine(cls, line: str) -> list: None: If a `sub` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["res"] = int(instr_tokens[params_end]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -93,15 +99,17 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "sub" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `sub` XInstruction. @@ -137,20 +145,24 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, residual, destinations, sources, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], res={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval def _set_dests(self, value): @@ -164,9 +176,14 @@ def _set_dests(self, value): ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -182,14 +199,19 @@ def _set_sources(self, value): ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -202,15 +224,15 @@ def _toPISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toPISAFormat() + return super()._to_pisa_format() - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -223,10 +245,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat() \ No newline at end of file + return super()._to_xasmisa_format() diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py index 6b14a5fe..4c5904af 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twintt.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Encapsulates a `twintt` XInstruction. @@ -72,21 +76,23 @@ def parseFromPISALine(cls, line: str) -> object: None: If a `twintt` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["tw_meta"] = int(instr_tokens[params_end]) retval["stage"] = int(instr_tokens[params_end + 1]) @@ -94,11 +100,11 @@ def parseFromPISALine(cls, line: str) -> object: retval["res"] = int(instr_tokens[params_end + 3]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -107,18 +113,20 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "twintt" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - tw_meta: int, - stage: int, - block: int, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + tw_meta: int, + stage: int, + block: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `twintt` XInstruction. @@ -142,9 +150,9 @@ def __init__(self, super().__init__(id, N, throughput, latency, res=res, comment=comment) - self.__tw_meta = tw_meta # (Read-only) tw_meta - self.__stage = stage # (Read-only) stage - self.__block = block # (Read-only) block + self.__tw_meta = tw_meta # (Read-only) tw_meta + self.__stage = stage # (Read-only) stage + self.__block = block # (Read-only) block self._set_dests(dst) self._set_sources(src) @@ -153,23 +161,27 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.tw_meta, - self.stage, - self.block, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.tw_meta, + self.stage, + self.block, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -213,9 +225,14 @@ def _set_dests(self, value): ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -231,14 +248,19 @@ def _set_sources(self, value): ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -251,19 +273,17 @@ def _toPISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, twintt, dst_tw, src_tw, tw_meta, stage, block, res # comment - retval = super()._toPISAFormat(self.tw_meta, - self.stage, - self.block) + retval = super()._to_pisa_format(self.tw_meta, self.stage, self.block) return retval - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -276,13 +296,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.tw_meta, - self.stage, - self.block, - self.N) \ No newline at end of file + return super()._to_xasmisa_format(self.tw_meta, self.stage, self.block, self.N) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py index 3494e5b5..33928f16 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/twntt.py @@ -1,10 +1,14 @@ -import warnings +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import warnings from argparse import Namespace from .xinstruction import XInstruction from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Encapsulates a `twntt` XInstruction. @@ -72,21 +76,23 @@ def parseFromPISALine(cls, line: str) -> object: None: If a `twntt` could not be parsed from the input. """ retval = None - tokens = XInstruction.tokenizeFromPISALine(cls.OP_NAME_PISA, line) + tokens = XInstruction.tokenizeFromPISALine(cls.op_name_pisa, line) if tokens: retval = {"comment": tokens[1]} instr_tokens = tokens[0] if len(instr_tokens) > cls._OP_NUM_TOKENS: - warnings.warn(f'Extra tokens detected for instruction "{cls.OP_NAME_PISA}"', SyntaxWarning) + warnings.warn( + f'Extra tokens detected for instruction "{cls.op_name_pisa}"', + SyntaxWarning, + ) retval["N"] = int(instr_tokens[0]) retval["op_name"] = instr_tokens[1] params_start = 2 params_end = params_start + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES - dst_src = cls.parsePISASourceDestsFromTokens(instr_tokens, - cls._OP_NUM_DESTS, - cls._OP_NUM_SOURCES, - params_start) + dst_src = cls.parsePISASourceDestsFromTokens( + instr_tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, params_start + ) retval.update(dst_src) retval["tw_meta"] = int(instr_tokens[params_end]) retval["stage"] = int(instr_tokens[params_end + 1]) @@ -94,11 +100,11 @@ def parseFromPISALine(cls, line: str) -> object: retval["res"] = int(instr_tokens[params_end + 3]) retval = Namespace(**retval) - assert(retval.op_name == cls.OP_NAME_PISA) + assert retval.op_name == cls.op_name_pisa return retval @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -107,18 +113,20 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "twntt" - def __init__(self, - id: int, - N: int, - dst: list, - src: list, - tw_meta: int, - stage: int, - block: int, - res: int, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + dst: list, + src: list, + tw_meta: int, + stage: int, + block: int, + res: int, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `twntt` XInstruction. @@ -152,9 +160,9 @@ def __init__(self, super().__init__(id, N, throughput, latency, res=res, comment=comment) - self.__tw_meta = tw_meta # (Read-only) tw_meta - self.__stage = stage # (Read-only) stage - self.__block = block # (Read-only) block + self.__tw_meta = tw_meta # (Read-only) tw_meta + self.__stage = stage # (Read-only) stage + self.__block = block # (Read-only) block self._set_dests(dst) self._set_sources(src) @@ -163,23 +171,27 @@ def __repr__(self): Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, residual, tw_meta, stage, block, destinations, sources, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, ' - 'dst={}, src={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.res, - self.tw_meta, - self.stage, - self.block, - self.dests, - self.sources, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], res={}, tw_meta={}, stage={}, block={}, " + "dst={}, src={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.res, + self.tw_meta, + self.stage, + self.block, + self.dests, + self.sources, + self.throughput, + self.latency, + ) return retval @property @@ -223,9 +235,14 @@ def _set_dests(self, value): ValueError: If the number of destinations is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_DESTS, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_DESTS, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_dests(value) @@ -241,14 +258,19 @@ def _set_sources(self, value): ValueError: If the number of sources is incorrect or if the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} Variable objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} Variable objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of Variable objects.") super()._set_sources(value) - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to kernel format. @@ -261,19 +283,17 @@ def _toPISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") # N, twntt, dst_tw, src_tw, tw_meta, stage, block, res # comment - retval = super()._toPISAFormat(self.tw_meta, - self.stage, - self.block) + retval = super()._to_pisa_format(self.tw_meta, self.stage, self.block) return retval - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM format. @@ -286,13 +306,10 @@ def _toXASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") - return super()._toXASMISAFormat(self.tw_meta, - self.stage, - self.block, - self.N) \ No newline at end of file + return super()._to_xasmisa_format(self.tw_meta, self.stage, self.block, self.N) diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py index a2bbce57..719205e5 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xinstruction.py @@ -1,4 +1,7 @@ -from argparse import Namespace +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from argparse import Namespace from assembler.common import constants from assembler.common.cycle_tracking import CycleType @@ -6,7 +9,8 @@ from assembler.memory_model.variable import Variable from assembler.memory_model.register_file import Register from ..instruction import BaseInstruction -from .. import tokenizeFromLine +from .. import tokenize_from_line + class XInstruction(BaseInstruction): """ @@ -38,16 +42,15 @@ def tokenizeFromPISALine(op_name: str, line: str) -> list: tuple: A tuple containing tokens (tuple of str) and comment (str), or None if the instruction cannot be parsed from the line. """ retval = None - tokens, comment = tokenizeFromLine(line) + tokens, comment = tokenize_from_line(line) if len(tokens) > 1 and tokens[1] == op_name: retval = (tokens, comment) return retval @staticmethod - def parsePISASourceDestsFromTokens(tokens: list, - num_dests: int, - num_sources: int, - offset: int = 0) -> dict: + def parsePISASourceDestsFromTokens( + tokens: list, num_dests: int, num_sources: int, offset: int = 0 + ) -> dict: """ Parses the sources and destinations for an instruction, given sources and destinations in tokens in P-ISA format. @@ -95,13 +98,15 @@ def reset_GlobalCycleReady(cls, value=CycleType(0, 0)): """ pass - def __init__(self, - id: int, - N: int, - throughput: int, - latency: int, - res: int = None, - comment: str = ""): + def __init__( + self, + id: int, + N: int, + throughput: int, + latency: int, + res: int = None, + comment: str = "", + ): """ Constructs a new XInstruction. @@ -121,8 +126,8 @@ def __init__(self, if res is not None and res >= constants.MemoryModel.MAX_RESIDUALS: comment = f"res = {res}" + ("; " + comment if comment else "") super().__init__(id, throughput, latency, comment=comment) - self.__n = N # Read-only ring size for the operation - self.__res = res # Read-only residual + self.__n = N # Read-only ring size for the operation + self.__res = res # Read-only residual @property def N(self) -> int: @@ -178,9 +183,11 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: # Check that variable is in register file if not v.register: # All variables must be in register before scheduling instruction - raise RuntimeError('Instruction( {}, id={} ): Variable {} not in register file.'.format(self.name, - self.id, - v.name)) + raise RuntimeError( + "Instruction( {}, id={} ): Variable {} not in register file.".format( + self.name, self.id, v.name + ) + ) # Update accessed cycle v.last_x_access = cycle_count # Remove this instruction from access list @@ -189,17 +196,22 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: if access_element.instruction_id == self.id: accessed_idx = idx break - assert(accessed_idx >= 0) - v.accessed_by_xinsts = v.accessed_by_xinsts[:accessed_idx] + v.accessed_by_xinsts[accessed_idx + 1:] + assert accessed_idx >= 0 + v.accessed_by_xinsts = ( + v.accessed_by_xinsts[:accessed_idx] + + v.accessed_by_xinsts[accessed_idx + 1 :] + ) # Update ready cycle and dirty state of dests for dst in self.dests: - dst.cycle_ready = CycleType(cycle_count.bundle, cycle_count.cycle + self.latency) + dst.cycle_ready = CycleType( + cycle_count.bundle, cycle_count.cycle + self.latency + ) dst.register_dirty = True return retval - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ Converts the instruction to P-ISA kernel format. @@ -212,15 +224,13 @@ def _toPISAFormat(self, *extra_args) -> str: str: The instruction in P-ISA kernel format. """ preamble = (self.N,) - extra_args = tuple(src.toPISAFormat() for src in self.sources) + extra_args - extra_args = tuple(dst.toPISAFormat() for dst in self.dests) + extra_args + extra_args = tuple(src.to_pisa_format() for src in self.sources) + extra_args + extra_args = tuple(dst.to_pisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res,) - return self.toStringFormat(preamble, - self.OP_NAME_PISA, - *extra_args) + return self.to_string_format(preamble, self.op_name_pisa, *extra_args) - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM-ISA format. @@ -235,21 +245,19 @@ def _toXASMISAFormat(self, *extra_args) -> str: # preamble = (self.id[0], self.N) preamble = (self.id[0],) # Instruction sources - extra_args = tuple(src.toXASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_xasmisa_format() for src in self.sources) + extra_args # Instruction destinations - extra_args = tuple(dst.toXASMISAFormat() for dst in self.dests) + extra_args + extra_args = tuple(dst.to_xasmisa_format() for dst in self.dests) + extra_args if self.res is not None: extra_args += (self.res % constants.MemoryModel.MAX_RESIDUALS,) - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) - + return self.to_string_format(preamble, self.op_name_asm, *extra_args) + @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the operation name in ASM format. Returns: str: ASM format operation. """ - return "default_op" # Provide a default operation name or a meaningful one if applicable \ No newline at end of file + return "default_op" # Provide a default operation name or a meaningful one if applicable diff --git a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py index e01b1d31..96899866 100644 --- a/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py +++ b/assembler_tools/hec-assembler-tools/assembler/instructions/xinst/xstore.py @@ -1,8 +1,12 @@ -from assembler.common.cycle_tracking import CycleType +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.common.cycle_tracking import CycleType from .xinstruction import XInstruction from assembler.memory_model import MemoryModel from assembler.memory_model.variable import Variable + class Instruction(XInstruction): """ Encapsulates an `xstore` MInstruction. @@ -13,7 +17,7 @@ class Instruction(XInstruction): For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md - + Attributes: dest_spad_address (int): The SPAD address where the source variable will be stored. @@ -21,10 +25,12 @@ class Instruction(XInstruction): reset_GlobalCycleReady: Resets the global cycle ready for `xstore` instructions. """ - __xstore_global_cycle_ready = CycleType(0, 0) # private class attribute to track cycle ready among xstores + __xstore_global_cycle_ready = CycleType( + 0, 0 + ) # private class attribute to track cycle ready among xstores @classmethod - def _get_OP_NAME_ASM(cls) -> str: + def _get_op_name_asm(cls) -> str: """ Returns the ASM name of the operation. @@ -33,14 +39,16 @@ def _get_OP_NAME_ASM(cls) -> str: """ return "xstore" - def __init__(self, - id: int, - src: list, - mem_model: MemoryModel, - dest_spad_addr: int = -1, - throughput: int = None, - latency: int = None, - comment: str = ""): + def __init__( + self, + id: int, + src: list, + mem_model: MemoryModel, + dest_spad_addr: int = -1, + throughput: int = None, + latency: int = None, + comment: str = "", + ): """ Constructs a new `xstore` MInstruction. @@ -66,41 +74,55 @@ def __init__(self, ValueError: If `mem_model` is not an instance of `MemoryModel` or if `dest_spad_addr` is invalid. """ if not isinstance(mem_model, MemoryModel): - raise ValueError('`mem_model` must be an instance of `MemoryModel`.') + raise ValueError("`mem_model` must be an instance of `MemoryModel`.") if not throughput: throughput = Instruction._OP_DEFAULT_THROUGHPUT if not latency: latency = Instruction._OP_DEFAULT_LATENCY - N = 0 # Does not require ring-size + N = 0 # Does not require ring-size super().__init__(id, N, throughput, latency, comment=comment) self.__mem_model = mem_model self._set_sources(src) self.__internal_set_dests(src) if dest_spad_addr < 0 and src[0].spad_address < 0: - raise ValueError('`dest_spad_addr` must be a valid SPAD address if source variable is not allocated in SPAD.') - if dest_spad_addr >= 0 and src[0].spad_address >= 0 and dest_spad_addr != src[0].spad_address: - raise ValueError('`dest_spad_addr` must be null SPAD address (negative) if source variable is allocated in SPAD.') - self.dest_spad_address = src[0].spad_address if dest_spad_addr < 0 else dest_spad_addr + raise ValueError( + "`dest_spad_addr` must be a valid SPAD address if source variable is not allocated in SPAD." + ) + if ( + dest_spad_addr >= 0 + and src[0].spad_address >= 0 + and dest_spad_addr != src[0].spad_address + ): + raise ValueError( + "`dest_spad_addr` must be null SPAD address (negative) if source variable is allocated in SPAD." + ) + self.dest_spad_address = ( + src[0].spad_address if dest_spad_addr < 0 else dest_spad_addr + ) def __repr__(self): """ Returns a string representation of the Instruction object. Returns: - str: A string representation of the Instruction object, including + str: A string representation of the Instruction object, including its type, name, memory address, ID, source, memory model, destination SPAD address, throughput, and latency. """ - retval=('<{}({}) object at {}>(id={}[0], ' - 'src={}, mem_model, dest_spad_addr={}, ' - 'throughput={}, latency={})').format(type(self).__name__, - self.name, - hex(id(self)), - self.id, - self.dests, - self.dest_spad_address, - self.throughput, - self.latency) + retval = ( + "<{}({}) object at {}>(id={}[0], " + "src={}, mem_model, dest_spad_addr={}, " + "throughput={}, latency={})" + ).format( + type(self).__name__, + self.name, + hex(id(self)), + self.id, + self.dests, + self.dest_spad_address, + self.throughput, + self.latency, + ) return retval @classmethod @@ -111,7 +133,7 @@ def __set_xstoreGlobalCycleReady(cls, value: CycleType): Parameters: value (CycleType): The cycle type value to set. """ - if (value > cls.__xstore_global_cycle_ready): + if value > cls.__xstore_global_cycle_ready: cls.__xstore_global_cycle_ready = value @classmethod @@ -148,9 +170,14 @@ def __internal_set_dests(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_DESTS: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_dests(value) @@ -167,9 +194,14 @@ def _set_sources(self, value): TypeError: If the list does not contain `Variable` objects. """ if len(value) != Instruction._OP_NUM_SOURCES: - raise ValueError(("`value`: Expected list of {} `Variable` objects, " - "but list with {} elements received.".format(Instruction._OP_NUM_SOURCES, - len(value)))) + raise ValueError( + ( + "`value`: Expected list of {} `Variable` objects, " + "but list with {} elements received.".format( + Instruction._OP_NUM_SOURCES, len(value) + ) + ) + ) if not all(isinstance(x, Variable) for x in value): raise ValueError("`value`: Expected list of `Variable` objects.") super()._set_sources(value) @@ -188,8 +220,7 @@ def _get_cycle_ready(self): # sources and the global cycles-ready for other xstores. # An xstore cannot be within _OP_DEFAULT_LATENCY cycles from another xstore # because they both use the SPAD-CE data channel. - return max(super()._get_cycle_ready(), - Instruction.__xstore_global_cycle_ready) + return max(super()._get_cycle_ready(), Instruction.__xstore_global_cycle_ready) def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: """ @@ -212,35 +243,54 @@ def _schedule(self, cycle_count: CycleType, schedule_id: int) -> int: int: The throughput for this instruction. i.e. the number of cycles by which to advance the current cycle counter. """ - assert(Instruction._OP_NUM_SOURCES > 0 and len(self.sources) == Instruction._OP_NUM_SOURCES) - assert(Instruction._OP_NUM_DESTS > 0 and len(self.dests) == Instruction._OP_NUM_DESTS) - assert(all(src == dst for src, dst in zip(self.sources, self.dests))) + assert ( + Instruction._OP_NUM_SOURCES > 0 + and len(self.sources) == Instruction._OP_NUM_SOURCES + ) + assert ( + Instruction._OP_NUM_DESTS > 0 + and len(self.dests) == Instruction._OP_NUM_DESTS + ) + assert all(src == dst for src, dst in zip(self.sources, self.dests)) if not isinstance(self.sources[0], Variable): - raise RuntimeError('XInstruction ({}, id = {}) already scheduled.'.format(self.name, self.id)) - - store_buffer_item = MemoryModel.StoreBufferValueType(variable=self.sources[0], - dest_spad_address=self.dest_spad_address) + raise RuntimeError( + "XInstruction ({}, id = {}) already scheduled.".format( + self.name, self.id + ) + ) + + store_buffer_item = MemoryModel.StoreBufferValueType( + variable=self.sources[0], dest_spad_address=self.dest_spad_address + ) register = self.sources[0].register retval = super()._schedule(cycle_count, schedule_id) # Perform xstore - register.register_dirty = False # Register has been flushed + register.register_dirty = False # Register has been flushed register.allocateVariable(None) - self.sources[0] = register # Make the register the source for freezing, since variable is no longer in it - self.__mem_model.store_buffer[store_buffer_item.variable.name] = store_buffer_item + self.sources[0] = ( + register # Make the register the source for freezing, since variable is no longer in it + ) + self.__mem_model.store_buffer[store_buffer_item.variable.name] = ( + store_buffer_item + ) # Matching CInst cstore completes the xstore if self.comment: - self.comment += ';' - self.comment += ' variable "{}": SPAD({}) <- {}'.format(store_buffer_item.variable.name, - store_buffer_item.dest_spad_address, - register.name) + self.comment += ";" + self.comment += ' variable "{}": SPAD({}) <- {}'.format( + store_buffer_item.variable.name, + store_buffer_item.dest_spad_address, + register.name, + ) # Set the global cycle ready for next xstore - Instruction.__set_xstoreGlobalCycleReady(CycleType(cycle_count.bundle, cycle_count.cycle + self.latency)) + Instruction.__set_xstoreGlobalCycleReady( + CycleType(cycle_count.bundle, cycle_count.cycle + self.latency) + ) return retval - def _toPISAFormat(self, *extra_args) -> str: + def _to_pisa_format(self, *extra_args) -> str: """ This instruction has no PISA equivalent. @@ -249,7 +299,7 @@ def _toPISAFormat(self, *extra_args) -> str: """ return None - def _toXASMISAFormat(self, *extra_args) -> str: + def _to_xasmisa_format(self, *extra_args) -> str: """ Converts the instruction to ASM-ISA format. @@ -266,18 +316,16 @@ def _toXASMISAFormat(self, *extra_args) -> str: Raises: ValueError: If extra arguments are provided. """ - assert(len(self.dests) == Instruction._OP_NUM_DESTS) - assert(len(self.sources) == Instruction._OP_NUM_SOURCES) + assert len(self.dests) == Instruction._OP_NUM_DESTS + assert len(self.sources) == Instruction._OP_NUM_SOURCES if extra_args: - raise ValueError('`extra_args` not supported.') + raise ValueError("`extra_args` not supported.") preamble = (self.id[0],) # Instruction sources - extra_args = tuple(src.toXASMISAFormat() for src in self.sources) + extra_args + extra_args = tuple(src.to_xasmisa_format() for src in self.sources) + extra_args # Instruction destinations - # extra_args = tuple(dst.toCASMISAFormat() for dst in self.dests) + extra_args + # extra_args = tuple(dst.to_casmisa_format() for dst in self.dests) + extra_args # extra_args += (0,) # res = 0 - return self.toStringFormat(preamble, - self.OP_NAME_ASM, - *extra_args) \ No newline at end of file + return self.to_string_format(preamble, self.op_name_asm, *extra_args) diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py index 7250116b..35c5bb19 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/mem_info.py @@ -1,8 +1,12 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common import constants -from assembler.instructions import tokenizeFromLine +from assembler.instructions import tokenize_from_line from assembler.memory_model.variable import Variable from . import MemoryModel + class MemInfoVariable: """ Represents a memory information variable with a name and an HBM address. @@ -10,9 +14,8 @@ class MemInfoVariable: This class encapsulates the details of a variable, including its name and the address in high-bandwidth memory (HBM) where it is stored. """ - def __init__(self, - var_name: str, - hbm_address: int): + + def __init__(self, var_name: str, hbm_address: int): """ Initializes a new MemInfoVariable object with a specified name and HBM address. @@ -44,8 +47,8 @@ def as_dict(self) -> dict: Returns: dict: A dictionary representation of the variable, including its name and HBM address. """ - return { 'var_name': self.var_name, - 'hbm_address': self.hbm_address } + return {"var_name": self.var_name, "hbm_address": self.hbm_address} + class MemInfoKeygenVariable(MemInfoVariable): """ @@ -54,10 +57,8 @@ class MemInfoKeygenVariable(MemInfoVariable): This class extends MemInfoVariable to include additional attributes for key generation, specifically the seed index and key index associated with the variable. """ - def __init__(self, - var_name: str, - seed_index: int, - key_index: int): + + def __init__(self, var_name: str, seed_index: int, key_index: int): """ Initializes a new MemInfoKeygenVariable object with a specified name, seed index, and key index. @@ -71,11 +72,11 @@ def __init__(self, """ super().__init__(var_name, -1) if seed_index < 0: - raise IndexError('seed_index: must be a zero-based index.') + raise IndexError("seed_index: must be a zero-based index.") if key_index < 0: - raise IndexError('key_index: must be a zero-based index.') + raise IndexError("key_index: must be a zero-based index.") self.seed_index = seed_index - self.key_index = key_index + self.key_index = key_index def as_dict(self) -> dict: """ @@ -84,9 +85,12 @@ def as_dict(self) -> dict: Returns: dict: A dictionary representation of the variable, including its name, seed index, and key index. """ - return { 'var_name': self.var_name, - 'seed_index': self.seed_index, - 'key_index': self.key_index } + return { + "var_name": self.var_name, + "seed_index": self.seed_index, + "key_index": self.key_index, + } + class MemInfo: """ @@ -119,9 +123,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed ones metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_ONES, - var_prefix=MemInfo.Const.Keyword.LOAD_ONES) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_ONES, + var_prefix=MemInfo.Const.Keyword.LOAD_ONES, + ) class NTTAuxTable: @classmethod @@ -135,9 +141,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_AUX_TABLE, + ) class NTTRoutingTable: @classmethod @@ -151,9 +159,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed NTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_NTT_ROUTING_TABLE, + ) class iNTTAuxTable: @classmethod @@ -167,9 +177,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT auxiliary table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_AUX_TABLE, + ) class iNTTRoutingTable: @classmethod @@ -183,9 +195,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed iNTT routing table metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, - var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + var_prefix=MemInfo.Const.Keyword.LOAD_iNTT_ROUTING_TABLE, + ) class Twiddle: @classmethod @@ -199,9 +213,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed twiddle metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_TWIDDLE, - var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_TWIDDLE, + var_prefix=MemInfo.Const.Keyword.LOAD_TWIDDLE, + ) class KeygenSeed: @classmethod @@ -215,16 +231,20 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: Returns: MemInfoVariable: The parsed keygen seed metadata variable. """ - return MemInfo.Metadata.parseMetaFieldFromMemLine(tokens, - MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, - var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED) + return MemInfo.Metadata.parseMetaFieldFromMemLine( + tokens, + MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + var_prefix=MemInfo.Const.Keyword.LOAD_KEYGEN_SEED, + ) @classmethod - def parseMetaFieldFromMemLine(cls, - tokens: list, - meta_field_name: str, - var_prefix: str = "meta", - var_extra: str = None) -> MemInfoVariable: + def parseMetaFieldFromMemLine( + cls, + tokens: list, + meta_field_name: str, + var_prefix: str = "meta", + var_extra: str = None, + ) -> MemInfoVariable: """ Parses a metadata variable name from a tokenized line. @@ -239,20 +259,21 @@ def parseMetaFieldFromMemLine(cls, """ retval = None if len(tokens) >= 3: - if tokens[0] == MemInfo.Const.Keyword.LOAD \ - and tokens[1] == meta_field_name: + if ( + tokens[0] == MemInfo.Const.Keyword.LOAD + and tokens[1] == meta_field_name + ): hbm_addr = int(tokens[2]) if len(tokens) >= 4 and tokens[3]: # name supplied in the tokenized line var_name = tokens[3] else: if var_extra is None: - var_extra = f'_{hbm_addr}' + var_extra = f"_{hbm_addr}" else: var_extra = var_extra.strip() - var_name = f'{var_prefix}{var_extra}' - retval = MemInfoVariable(var_name = var_name, - hbm_address = hbm_addr) + var_name = f"{var_prefix}{var_extra}" + retval = MemInfoVariable(var_name=var_name, hbm_address=hbm_addr) return retval def __init__(self, **kwargs): @@ -264,7 +285,9 @@ def __init__(self, **kwargs): """ self.__meta_dict = {} for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS: - self.__meta_dict[meta_field] = [ MemInfoVariable(**d) for d in kwargs.get(meta_field, []) ] + self.__meta_dict[meta_field] = [ + MemInfoVariable(**d) for d in kwargs.get(meta_field, []) + ] def __getitem__(self, key): """ @@ -278,7 +301,6 @@ def __getitem__(self, key): """ return self.__meta_dict[key] - @property def ones(self) -> list: """ @@ -365,11 +387,11 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: if len(tokens) >= 4: if tokens[0] == MemInfo.Const.Keyword.KEYGEN: seed_idx = int(tokens[1]) - key_idx = int(tokens[2]) + key_idx = int(tokens[2]) var_name = tokens[3] - retval = MemInfoKeygenVariable(var_name = var_name, - seed_index = seed_idx, - key_index = key_idx) + retval = MemInfoKeygenVariable( + var_name=var_name, seed_index=seed_idx, key_index=key_idx + ) return retval class Input: @@ -386,13 +408,16 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: """ retval = None if len(tokens) >= 4: - if tokens[0] == MemInfo.Const.Keyword.LOAD \ - and tokens[1] == MemInfo.Const.Keyword.LOAD_INPUT: + if ( + tokens[0] == MemInfo.Const.Keyword.LOAD + and tokens[1] == MemInfo.Const.Keyword.LOAD_INPUT + ): hbm_addr = int(tokens[2]) var_name = tokens[3] if Variable.validateName(var_name): - retval = MemInfoVariable(var_name = var_name, - hbm_address = hbm_addr) + retval = MemInfoVariable( + var_name=var_name, hbm_address=hbm_addr + ) return retval class Output: @@ -413,8 +438,9 @@ def parseFromMemLine(cls, tokens: list) -> MemInfoVariable: hbm_addr = int(tokens[2]) var_name = tokens[1] if Variable.validateName(var_name): - retval = MemInfoVariable(var_name = var_name, - hbm_address = hbm_addr) + retval = MemInfoVariable( + var_name=var_name, hbm_address=hbm_addr + ) return retval def __init__(self, **kwargs): @@ -428,10 +454,18 @@ def __init__(self, **kwargs): kwargs (dict): A dictionary as generated by the method MemInfo.as_dict(). This is provided as a shortcut to creating a MemInfo object from structured data such as the contents of a YAML file. """ - self.__keygens = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) ] - self.__inputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) ] - self.__outputs = [ MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) ] - self.__metadata = MemInfo.Metadata(**kwargs.get(MemInfo.Const.FIELD_METADATA, {})) + self.__keygens = [ + MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_KEYGENS, []) + ] + self.__inputs = [ + MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_INPUTS, []) + ] + self.__outputs = [ + MemInfoVariable(**d) for d in kwargs.get(MemInfo.Const.FIELD_OUTPUTS, []) + ] + self.__metadata = MemInfo.Metadata( + **kwargs.get(MemInfo.Const.FIELD_METADATA, {}) + ) self.validate() @classmethod @@ -453,20 +487,22 @@ def from_iter(cls, line_iter): retval = cls() - factory_dict = { MemInfo.Keygen: retval.keygens, - MemInfo.Input: retval.inputs, - MemInfo.Output: retval.outputs, - MemInfo.Metadata.KeygenSeed: retval.metadata.keygen_seeds, - MemInfo.Metadata.Ones: retval.metadata.ones, - MemInfo.Metadata.NTTAuxTable: retval.metadata.ntt_auxiliary_table, - MemInfo.Metadata.NTTRoutingTable: retval.metadata.ntt_routing_table, - MemInfo.Metadata.iNTTAuxTable: retval.metadata.intt_auxiliary_table, - MemInfo.Metadata.iNTTRoutingTable: retval.metadata.intt_routing_table, - MemInfo.Metadata.Twiddle: retval.metadata.twiddle } + factory_dict = { + MemInfo.Keygen: retval.keygens, + MemInfo.Input: retval.inputs, + MemInfo.Output: retval.outputs, + MemInfo.Metadata.KeygenSeed: retval.metadata.keygen_seeds, + MemInfo.Metadata.Ones: retval.metadata.ones, + MemInfo.Metadata.NTTAuxTable: retval.metadata.ntt_auxiliary_table, + MemInfo.Metadata.NTTRoutingTable: retval.metadata.ntt_routing_table, + MemInfo.Metadata.iNTTAuxTable: retval.metadata.intt_auxiliary_table, + MemInfo.Metadata.iNTTRoutingTable: retval.metadata.intt_routing_table, + MemInfo.Metadata.Twiddle: retval.metadata.twiddle, + } for line_no, s_line in enumerate(line_iter, 1): s_line = s_line.strip() - if s_line: # skip empty lines - tokens, _ = tokenizeFromLine(s_line) + if s_line: # skip empty lines + tokens, _ = tokenize_from_line(s_line) if tokens and len(tokens) > 0: b_parsed = False for mem_info_type in factory_dict: @@ -474,9 +510,11 @@ def from_iter(cls, line_iter): if miv is not None: factory_dict[mem_info_type].append(miv) b_parsed = True - break # next line + break # next line if not b_parsed: - raise RuntimeError(f'Could not parse line {line_no}: "{s_line}"') + raise RuntimeError( + f'Could not parse line {line_no}: "{s_line}"' + ) retval.validate() return retval @@ -527,11 +565,16 @@ def as_dict(self): Returns: dict: A dictionary representation of the MemInfo object. """ - return { MemInfo.Const.FIELD_KEYGENS: [ x.as_dict() for x in self.keygens ], - MemInfo.Const.FIELD_INPUTS: [ x.as_dict() for x in self.inputs ], - MemInfo.Const.FIELD_OUTPUTS: [ x.as_dict() for x in self.outputs ], - MemInfo.Const.FIELD_METADATA: { meta_field: [ x.as_dict() for x in self.metadata[meta_field] ] \ - for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS if self.metadata[meta_field] } } + return { + MemInfo.Const.FIELD_KEYGENS: [x.as_dict() for x in self.keygens], + MemInfo.Const.FIELD_INPUTS: [x.as_dict() for x in self.inputs], + MemInfo.Const.FIELD_OUTPUTS: [x.as_dict() for x in self.outputs], + MemInfo.Const.FIELD_METADATA: { + meta_field: [x.as_dict() for x in self.metadata[meta_field]] + for meta_field in MemInfo.Const.FIELD_METADATA_SUBFIELDS + if self.metadata[meta_field] + }, + } def validate(self): """ @@ -540,28 +583,48 @@ def validate(self): Raises: RuntimeError: If the validation fails due to inconsistent metadata or duplicate variable names. """ - if len(self.metadata.ones) * MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT != len(self.metadata.twiddle): - raise RuntimeError(('Expected {} times as many twiddles as ones metadata values, ' - 'but received {} twiddles and {} ones.').format(MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT, - len(self.metadata.twiddle), - len(self.metadata.ones))) + if len( + self.metadata.ones + ) * MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT != len(self.metadata.twiddle): + raise RuntimeError( + ( + "Expected {} times as many twiddles as ones metadata values, " + "but received {} twiddles and {} ones." + ).format( + MemoryModel.MAX_TWIDDLE_META_VARS_PER_SEGMENT, + len(self.metadata.twiddle), + len(self.metadata.ones), + ) + ) # Avoid duplicate variable names with different HBM addresses. mem_info_vars = {} - all_var_info = self.inputs + self.outputs \ - + self.metadata.intt_auxiliary_table + self.metadata.intt_routing_table \ - + self.metadata.ntt_auxiliary_table + self.metadata.ntt_routing_table \ - + self.metadata.ones + self.metadata.twiddle + all_var_info = ( + self.inputs + + self.outputs + + self.metadata.intt_auxiliary_table + + self.metadata.intt_routing_table + + self.metadata.ntt_auxiliary_table + + self.metadata.ntt_routing_table + + self.metadata.ones + + self.metadata.twiddle + ) for var_info in all_var_info: if var_info.var_name not in mem_info_vars: mem_info_vars[var_info.var_name] = var_info elif mem_info_vars[var_info.var_name].hbm_address != var_info.hbm_address: - raise RuntimeError(('Variable "{}" already allocated in HBM address {}, ' - 'but new allocation requested into address {}.').format(var_info.var_name, - mem_info_vars[var_info.var_name].hbm_address, - var_info.hbm_address)) - -def __allocateMemInfoVariable(mem_model: MemoryModel, - v_info: MemInfoVariable): + raise RuntimeError( + ( + 'Variable "{}" already allocated in HBM address {}, ' + "but new allocation requested into address {}." + ).format( + var_info.var_name, + mem_info_vars[var_info.var_name].hbm_address, + var_info.hbm_address, + ) + ) + + +def __allocateMemInfoVariable(mem_model: MemoryModel, v_info: MemInfoVariable): """ Allocates a memory information variable in the memory model. @@ -579,17 +642,27 @@ def __allocateMemInfoVariable(mem_model: MemoryModel, """ assert v_info.hbm_address >= 0 if v_info.var_name not in mem_model.variables: - raise RuntimeError(f'Variable {v_info.var_name} not in memory model. All variables used in mem info must be present in P-ISA kernel.') + raise RuntimeError( + f"Variable {v_info.var_name} not in memory model. All variables used in mem info must be present in P-ISA kernel." + ) if mem_model.variables[v_info.var_name].hbm_address < 0: - mem_model.hbm.allocateForce(v_info.hbm_address, mem_model.variables[v_info.var_name]) + mem_model.hbm.allocateForce( + v_info.hbm_address, mem_model.variables[v_info.var_name] + ) elif v_info.hbm_address != mem_model.variables[v_info.var_name].hbm_address: - raise RuntimeError(('Variable {} already allocated in HBM address {}, ' - 'but new allocation requested into address {}.').format(v_info.var_name, - mem_model.variables[v_info.var_name].hbm_address, - v_info.hbm_address)) - -def updateMemoryModelWithMemInfo(mem_model: MemoryModel, - mem_info: MemInfo): + raise RuntimeError( + ( + "Variable {} already allocated in HBM address {}, " + "but new allocation requested into address {}." + ).format( + v_info.var_name, + mem_model.variables[v_info.var_name].hbm_address, + v_info.hbm_address, + ) + ) + + +def updateMemoryModelWithMemInfo(mem_model: MemoryModel, mem_info: MemInfo): """ Updates the memory model with memory information. @@ -624,28 +697,28 @@ def updateMemoryModelWithMemInfo(mem_model: MemoryModel, # Shuffle meta vars if mem_info.metadata.ntt_auxiliary_table: - assert(len(mem_info.metadata.ntt_auxiliary_table) == 1) + assert len(mem_info.metadata.ntt_auxiliary_table) == 1 v_info = mem_info.metadata.ntt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) __allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_aux_table = v_info.var_name if mem_info.metadata.ntt_routing_table: - assert(len(mem_info.metadata.ntt_routing_table) == 1) + assert len(mem_info.metadata.ntt_routing_table) == 1 v_info = mem_info.metadata.ntt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) __allocateMemInfoVariable(mem_model, v_info) mem_model.meta_ntt_routing_table = v_info.var_name if mem_info.metadata.intt_auxiliary_table: - assert(len(mem_info.metadata.intt_auxiliary_table) == 1) + assert len(mem_info.metadata.intt_auxiliary_table) == 1 v_info = mem_info.metadata.intt_auxiliary_table[0] mem_model.retrieveVarAdd(v_info.var_name) __allocateMemInfoVariable(mem_model, v_info) mem_model.meta_intt_aux_table = v_info.var_name if mem_info.metadata.intt_routing_table: - assert(len(mem_info.metadata.intt_routing_table) == 1) + assert len(mem_info.metadata.intt_routing_table) == 1 v_info = mem_info.metadata.intt_routing_table[0] mem_model.retrieveVarAdd(v_info.var_name) __allocateMemInfoVariable(mem_model, v_info) diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py index 51868e8d..88be136c 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/register_file.py @@ -1,9 +1,12 @@ - +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from assembler.common import constants from assembler.common.cycle_tracking import CycleTracker from .variable import Variable from . import mem_utilities as utilities + class RegisterBank: """ Encapsulates a register bank. @@ -28,6 +31,7 @@ class __RBIterator: """ Allows iteration over the registers in a register bank. """ + def __init__(self, obj): assert obj is not None and obj.register_count > 0 self.__obj = obj @@ -43,9 +47,7 @@ def __next__(self): # Constructor # ----------- - def __init__(self, - bank_index: int, - register_range: range = None): + def __init__(self, bank_index: int, register_range: range = None): """ Constructs a new RegisterBank object. @@ -59,19 +61,31 @@ def __init__(self, ValueError: If the bank index is negative or if the register range is invalid. """ if bank_index < 0: - raise ValueError((f'`bank_index`: expected non-negative a index for bank, ' - f'but {bank_index} received.')) + raise ValueError( + ( + f"`bank_index`: expected non-negative a index for bank, " + f"but {bank_index} received." + ) + ) if not register_range: register_range = range(constants.MemoryModel.NUM_REGISTER_PER_BANKS) elif len(register_range) < 1: - raise ValueError((f'`register_range`: expected a range within [0, {constants.MemoryModel.NUM_REGISTER_PER_BANKS}) with, ' - f'at least, 1 element, but {register_range} received.')) + raise ValueError( + ( + f"`register_range`: expected a range within [0, {constants.MemoryModel.NUM_REGISTER_PER_BANKS}) with, " + f"at least, 1 element, but {register_range} received." + ) + ) elif abs(register_range.step) != 1: - raise ValueError((f'`register_range`: expected a range within step of 1 or -1, ' - f'but {register_range} received.')) + raise ValueError( + ( + f"`register_range`: expected a range within step of 1 or -1, " + f"but {register_range} received." + ) + ) self.__bank_index = bank_index # list of registers in this bank - self.__registers = [ Register(self, register_i) for register_i in register_range ] + self.__registers = [Register(self, register_i) for register_i in register_range] # Special methods # --------------- @@ -92,9 +106,9 @@ def __repr__(self): Returns: str: A string representation of the RegisterBank. """ - return '<{} object at {}>(bank_index = {})'.format(type(self).__name__, - hex(id(self)), - self.bank_index) + return "<{} object at {}>(bank_index = {})".format( + type(self).__name__, hex(id(self)), self.bank_index + ) # Methods and properties # ---------------------- @@ -133,18 +147,20 @@ def getRegister(self, idx: int): ValueError: If the index is out of range. """ if idx < -self.register_count or idx >= self.register_count: - raise ValueError((f'`idx`: expected an index for register in the range [-{self.register_count}, {self.register_count}), ' - f'but {idx} received.')) + raise ValueError( + ( + f"`idx`: expected an index for register in the range [-{self.register_count}, {self.register_count}), " + f"but {idx} received." + ) + ) return self.__registers[idx] - def findAvailableRegister(self, - live_var_names, - replacement_policy: str = None): + def findAvailableRegister(self, live_var_names, replacement_policy: str = None): """ Retrieve the next available register or propose a register to use if all are occupied. Args: - live_var_names (set or list): + live_var_names (set or list): A set of variable names containing the variables that are not available for replacement i.e. live variables. This is used to avoid replacing variables that were just allocated as dependencies for an upcoming instruction. @@ -160,9 +176,11 @@ def findAvailableRegister(self, Returns: Register: The first empty register, or the register to replace if all are occupied. Returns None if no suitable register is found. """ - retval_idx = utilities.findAvailableLocation((register.contained_variable for register in self.__registers), - live_var_names, - replacement_policy) + retval_idx = utilities.findAvailableLocation( + (register.contained_variable for register in self.__registers), + live_var_names, + replacement_policy, + ) return self.getRegister(retval_idx) if retval_idx >= 0 else None def dump(self, ostream): @@ -172,27 +190,26 @@ def dump(self, ostream): Args: ostream: The output stream to write the register bank state to. """ - print(f'Register bank, {self.bank_index}', file = ostream) - print(f'Number of registers, {self.register_count}', file = ostream) - print("", file = ostream) - print("register, variable, variable register, dirty", file = ostream) + print(f"Register bank, {self.bank_index}", file=ostream) + print(f"Number of registers, {self.register_count}", file=ostream) + print("", file=ostream) + print("register, variable, variable register, dirty", file=ostream) for idx in range(self.register_count): register = self.getRegister(idx) if not register: - print('ERROR: None Register') + print("ERROR: None Register") else: - var_data = 'None' + var_data = "None" variable = register.contained_variable if variable is not None: if variable.name: - var_data = '{}, {}'.format(variable.name, - variable.register, - variable.register_dirty) + var_data = "{}, {}".format( + variable.name, variable.register, variable.register_dirty + ) else: - var_data = f'Dummy_{variable.tag}' - print('{}, {}'.format(register.name, - var_data), - file = ostream) + var_data = f"Dummy_{variable.tag}" + print("{}, {}".format(register.name, var_data), file=ostream) + class Register(CycleTracker): """ @@ -218,9 +235,7 @@ class Register(CycleTracker): # Constructor # ----------- - def __init__(self, - bank: RegisterBank, - register_index: int): + def __init__(self, bank: RegisterBank, register_index: int): """ Initializes a new Register object. @@ -231,9 +246,16 @@ def __init__(self, Raises: ValueError: If the register index is out of the valid range. """ - if register_index < 0 or register_index >= constants.MemoryModel.NUM_REGISTERS_PER_BANK: - raise ValueError((f'`register_index`: expected an index for register in the range [0, {constants.MemoryModel.NUM_REGISTERS_PER_BANK}), ' - f'but {register_index} received.')) + if ( + register_index < 0 + or register_index >= constants.MemoryModel.NUM_REGISTERS_PER_BANK + ): + raise ValueError( + ( + f"`register_index`: expected an index for register in the range [0, {constants.MemoryModel.NUM_REGISTERS_PER_BANK}), " + f"but {register_index} received." + ) + ) super().__init__((0, 0)) self.register_dirty = False self.__bank = bank @@ -253,8 +275,9 @@ def __eq__(self, other): Returns: bool: True if the other Register is the same as this one, False otherwise. """ - return other is self \ - or (isinstance(other, Register) and other.name == self.name) + return other is self or ( + isinstance(other, Register) and other.name == self.name + ) def __hash__(self): """ @@ -284,10 +307,9 @@ def __repr__(self): var_section = "" if self.contained_variable: var_section = "Variable='{}'".format(self.contained_variable.name) - return '<{}({}) object at {}>({})'.format(type(self).__name__, - self.name, - hex(id(self)), - var_section) + return "<{}({}) object at {}>({})".format( + type(self).__name__, self.name, hex(id(self)), var_section + ) # Methods and properties # ---------------------- @@ -344,7 +366,7 @@ def _set_contained_variable(self, value): """ if value: if not isinstance(value, Variable): - raise ValueError('`value`: expected a `Variable`.') + raise ValueError("`value`: expected a `Variable`.") self.__contained_var = value # register no longer dirty because we are overwriting it with new variable (or None to clear) self.register_dirty = False @@ -362,7 +384,9 @@ def allocateVariable(self, variable: Variable = None): old_var: Variable = self.contained_variable if old_var: # make old variable aware that it is no longer in this register - assert(not old_var.register_dirty) # we should not be deallocating dirty variables + assert ( + not old_var.register_dirty + ) # we should not be deallocating dirty variables old_var.register = None if variable: # make variable aware of new register @@ -374,7 +398,7 @@ def allocateVariable(self, variable: Variable = None): self._set_contained_variable(variable) - def toCASMISAFormat(self) -> str: + def to_casmisa_format(self) -> str: """ Converts the register to CInst ASM-ISA format. @@ -383,7 +407,7 @@ def toCASMISAFormat(self) -> str: """ return self.name - def toXASMISAFormat(self) -> str: + def to_xasmisa_format(self) -> str: """ Converts the register to XInst ASM-ISA format. diff --git a/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py index 1bf179fc..cf51bee4 100644 --- a/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py +++ b/assembler_tools/hec-assembler-tools/assembler/memory_model/variable.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import re from typing import NamedTuple @@ -5,6 +8,7 @@ from assembler.common.config import GlobalConfig from assembler.common.cycle_tracking import CycleTracker, CycleType + class Variable(CycleTracker): """ Class to represent a variable within a memory model. @@ -47,6 +51,7 @@ class AccessElement(NamedTuple): index (int): The index of the instruction in the listing. instruction_id (tuple): The ID of the instruction. """ + index: int instruction_id: tuple @@ -70,7 +75,7 @@ def parseFromPISAFormat(cls, s_pisa: str): """ tokens = list(map(lambda s: s.strip(), s_pisa.split())) if len(tokens) > 2 or len(tokens) < 1: - raise ValueError(f'Invalid format for P-ISA variable: {s_pisa}.') + raise ValueError(f"Invalid format for P-ISA variable: {s_pisa}.") if len(tokens) < 2: # default to suggested bank -1 tokens.append(-1) @@ -94,17 +99,14 @@ def validateName(cls, name: str) -> bool: name = name.strip() if not name: retval = False - if retval and not re.search('^[A-Za-z_][A-Za-z0-9_]*', name): + if retval and not re.search("^[A-Za-z_][A-Za-z0-9_]*", name): retval = False return retval - # Constructor # ----------- - def __init__(self, - var_name: str, - suggested_bank: int = -1): + def __init__(self, var_name: str, suggested_bank: int = -1): """ Constructs a new Variable object with a specified name and suggested bank number. @@ -123,11 +125,16 @@ def __init__(self, self.__var_name = var_name.strip() # validate bank number if suggested_bank >= constants.MemoryModel.NUM_REGISTER_BANKS: - raise ValueError(("`suggested_bank`: Expected negative to indicate no " - "suggestion or a bank index less than {}, but {} received.").format( - constants.MemoryModel.NUM_REGISTER_BANKS, suggested_bank)) + raise ValueError( + ( + "`suggested_bank`: Expected negative to indicate no " + "suggestion or a bank index less than {}, but {} received." + ).format(constants.MemoryModel.NUM_REGISTER_BANKS, suggested_bank) + ) - super().__init__(CycleType(0, 0)) # cycle ready in the form (bundle, clock_cycle) + super().__init__( + CycleType(0, 0) + ) # cycle ready in the form (bundle, clock_cycle) self.__suggested_bank = suggested_bank # HBM data region address (zero-based word index) where this variable is stored. @@ -135,10 +142,12 @@ def __init__(self, self.hbm_address = -1 self.__spad_address = -1 self.__spad_dirty = False - self.__register = None # Register + self.__register = None # Register self.__register_dirty = False - self.accessed_by_xinsts = [] # list of AccessElements containing instruction IDs that access this variable - self.last_x_access = None # last xinstruction that accessed this variable + self.accessed_by_xinsts = ( + [] + ) # list of AccessElements containing instruction IDs that access this variable + self.last_x_access = None # last xinstruction that accessed this variable # Special methods # --------------- @@ -150,10 +159,9 @@ def __repr__(self): Returns: str: A string representation. """ - retval = '<{} object at {}>(var_name="{}", suggested_bank={})'.format(type(self).__name__, - hex(id(self)), - self.name, - self.suggested_bank) + retval = '<{} object at {}>(var_name="{}", suggested_bank={})'.format( + type(self).__name__, hex(id(self)), self.name, self.suggested_bank + ) return retval def __str__(self): @@ -221,9 +229,12 @@ def suggested_bank(self): @suggested_bank.setter def suggested_bank(self, value: int): if value >= constants.MemoryModel.NUM_REGISTER_BANKS: - raise ValueError('`value`: must be in range [0, {}), but {} received.'.format(constants.MemoryModel.NUM_REGISTER_BANKS, - str(value))) - if value >= 0: # ignore negative values + raise ValueError( + "`value`: must be in range [0, {}), but {} received.".format( + constants.MemoryModel.NUM_REGISTER_BANKS, str(value) + ) + ) + if value >= 0: # ignore negative values self.__suggested_bank = value @property @@ -242,14 +253,21 @@ def register(self, value): def _set_register(self, value): from .register_file import Register + if value: if not isinstance(value, Register): - raise ValueError(('`value`: expected a `Register`, but received a `{}`.'.format(type(value).__name__))) + raise ValueError( + ( + "`value`: expected a `Register`, but received a `{}`.".format( + type(value).__name__ + ) + ) + ) self.__register = value else: self.__register = None self.register_dirty = False - self.last_x_access = None # new Register, so, no XInst access yet + self.last_x_access = None # new Register, so, no XInst access yet @property def register_dirty(self) -> bool: @@ -281,7 +299,7 @@ def spad_address(self, value: int): self._set_spad_address(value) def _set_spad_address(self, value: int): - self.spad_dirty = False # SPAD is no longer dirty because we are overwriting it + self.spad_dirty = False # SPAD is no longer dirty because we are overwriting it if value < 0: self.__spad_address = -1 else: @@ -317,19 +335,19 @@ def _get_cycle_ready(self) -> CycleType: return retval - def toPISAFormat(self) -> str: + def to_pisa_format(self) -> str: """ Converts the variable to P-ISA kernel format. Returns: str: The P-ISA format of the variable. """ - retval = f'{self.name}' + retval = f"{self.name}" if self.suggested_bank >= 0: - retval += f' ({self.suggested_bank})' + retval += f" ({self.suggested_bank})" return retval - def toXASMISAFormat(self) -> str: + def to_xasmisa_format(self) -> str: """ Converts the variable to XInst ASM-ISA format. @@ -340,10 +358,12 @@ def toXASMISAFormat(self) -> str: RuntimeError: If the variable is not allocated to a register. """ if not self.register: - raise RuntimeError("`Variable` object not allocated to register. Cannot convert to XInst ASM-ISA format.") - return self.register.toXASMISAFormat() + raise RuntimeError( + "`Variable` object not allocated to register. Cannot convert to XInst ASM-ISA format." + ) + return self.register.to_xasmisa_format() - def toCASMISAFormat(self) -> str: + def to_casmisa_format(self) -> str: """ Converts the variable to CInst ASM-ISA format. @@ -354,10 +374,12 @@ def toCASMISAFormat(self) -> str: RuntimeError: If the variable is not stored in SPAD. """ if self.spad_address < 0: - raise RuntimeError("`Variable` object not allocated in SPAD. Cannot convert to CInst ASM-ISA format.") + raise RuntimeError( + "`Variable` object not allocated in SPAD. Cannot convert to CInst ASM-ISA format." + ) return self.spad_address if GlobalConfig.hasHBM else self.name - def toMASMISAFormat(self) -> str: + def to_masmisa_format(self) -> str: """ Converts the variable to MInst ASM-ISA format. @@ -368,9 +390,12 @@ def toMASMISAFormat(self) -> str: RuntimeError: If the variable is not stored in HBM. """ if self.hbm_address < 0: - raise RuntimeError("`Variable` object not allocated in HBM. Cannot convert to MInst ASM-ISA format.") + raise RuntimeError( + "`Variable` object not allocated in HBM. Cannot convert to MInst ASM-ISA format." + ) return self.name if GlobalConfig.useHBMPlaceHolders else self.hbm_address + def findVarByName(vars_lst, var_name: str) -> Variable: """ Finds the first variable in an iterable of Variable objects that matches the specified name. @@ -384,6 +409,7 @@ def findVarByName(vars_lst, var_name: str) -> Variable: """ return next((var for var in vars_lst if var.name == var_name), None) + class DummyVariable(Variable): """ Represents a dummy variable used as a placeholder. @@ -395,7 +421,7 @@ class DummyVariable(Variable): # Constructor # ----------- - def __init__(self, tag = None): + def __init__(self, tag=None): """ Initializes a new DummyVariable object. diff --git a/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py b/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py index 4476dd86..d3999932 100644 --- a/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py +++ b/assembler_tools/hec-assembler-tools/assembler/spec_config/isa_spec.py @@ -1,48 +1,52 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os import json import assembler.instructions.cinst as cinst import assembler.instructions.minst as minst import assembler.instructions.xinst as xinst + class ISASpecConfig: __target_cops = { - "bload" : cinst.bload.Instruction, - "bones" : cinst.bones.Instruction, - "exit" : cinst.cexit.Instruction, - "cload" : cinst.cload.Instruction, - "nop" : cinst.cnop.Instruction, - "cstore" : cinst.cstore.Instruction, - "csyncm" : cinst.csyncm.Instruction, - "ifetch" : cinst.ifetch.Instruction, - "kgload" : cinst.kgload.Instruction, - "kgseed" : cinst.kgseed.Instruction, - "kgstart" : cinst.kgstart.Instruction, - "nload" : cinst.nload.Instruction, + "bload": cinst.bload.Instruction, + "bones": cinst.bones.Instruction, + "exit": cinst.cexit.Instruction, + "cload": cinst.cload.Instruction, + "nop": cinst.cnop.Instruction, + "cstore": cinst.cstore.Instruction, + "csyncm": cinst.csyncm.Instruction, + "ifetch": cinst.ifetch.Instruction, + "kgload": cinst.kgload.Instruction, + "kgseed": cinst.kgseed.Instruction, + "kgstart": cinst.kgstart.Instruction, + "nload": cinst.nload.Instruction, "xinstfetch": cinst.xinstfetch.Instruction, } __target_xops = { - "add" : xinst.add.Instruction, - "copy" : xinst.copy_mod.Instruction, - "exit" : xinst.exit_mod.Instruction, - "intt" : xinst.intt.Instruction, + "add": xinst.add.Instruction, + "copy": xinst.copy_mod.Instruction, + "exit": xinst.exit_mod.Instruction, + "intt": xinst.intt.Instruction, "irshuffle": xinst.irshuffle.Instruction, - "mac" : xinst.mac.Instruction, - "maci" : xinst.maci.Instruction, - "move" : xinst.move.Instruction, - "mul" : xinst.mul.Instruction, - "muli" : xinst.muli.Instruction, - "nop" : xinst.nop.Instruction, - "ntt" : xinst.ntt.Instruction, - "rshuffle" : xinst.rshuffle.Instruction, - "sub" : xinst.sub.Instruction, - "twintt" : xinst.twintt.Instruction, - "twntt" : xinst.twntt.Instruction, - "xstore" : xinst.xstore.Instruction, + "mac": xinst.mac.Instruction, + "maci": xinst.maci.Instruction, + "move": xinst.move.Instruction, + "mul": xinst.mul.Instruction, + "muli": xinst.muli.Instruction, + "nop": xinst.nop.Instruction, + "ntt": xinst.ntt.Instruction, + "rshuffle": xinst.rshuffle.Instruction, + "sub": xinst.sub.Instruction, + "twintt": xinst.twintt.Instruction, + "twntt": xinst.twntt.Instruction, + "xstore": xinst.xstore.Instruction, } __target_mops = { - "mload" : minst.mload.Instruction, + "mload": minst.mload.Instruction, "mstore": minst.mstore.Instruction, "msyncc": minst.msyncc.Instruction, } @@ -50,16 +54,16 @@ class ISASpecConfig: _target_ops = { "xinst": __target_xops, "cinst": __target_cops, - "minst": __target_mops + "minst": __target_mops, } _target_attributes = { - "num_tokens" : "SetNumTokens", - "num_dests" : "SetNumDests", - "num_sources" : "SetNumSources", - "default_throughput" : "SetDefaultThroughput", - "default_latency" : "SetDefaultLatency", - "special_latency_max" : "SetSpecialLatencyMax", + "num_tokens": "SetNumTokens", + "num_dests": "set_num_dests", + "num_sources": "set_num_sources", + "default_throughput": "set_default_throughput", + "default_latency": "set_default_latency", + "special_latency_max": "SetSpecialLatencyMax", "special_latency_increment": "SetSpecialLatencyIncrement", } @@ -86,19 +90,19 @@ def dump_isa_spec_to_json(cls, filename): output_dict = {"isa_spec": isa_spec_dict} # Write the dictionary to a JSON file - with open(filename, 'w') as json_file: + with open(filename, "w") as json_file: json.dump(output_dict, json_file, indent=4) @classmethod def init_isa_spec_from_json(cls, filename): """ Updates ops' class attributes using methods specified in the target_attributes dictionary based on a JSON file. - This method checks wether values found on json file exists in target dictionaries. + This method checks whether values found on json file exists in target dictionaries. Args: filename (str): The name of the JSON file to read from. """ - with open(filename, 'r') as json_file: + with open(filename, "r") as json_file: data = json.load(json_file) # Check for the "isa_spec" section @@ -109,11 +113,15 @@ def init_isa_spec_from_json(cls, filename): for inst_type, ops in cls._target_ops.items(): if inst_type not in isa_spec: - raise ValueError(f"Instruction type '{inst_type}' is not found in the JSON file.") + raise ValueError( + f"Instruction type '{inst_type}' is not found in the JSON file." + ) for op_name, op in ops.items(): if op_name not in isa_spec[inst_type]: - raise ValueError(f"Operation '{op_name}' is not found in the JSON file for instruction type '{inst_type}'.") + raise ValueError( + f"Operation '{op_name}' is not found in the JSON file for instruction type '{inst_type}'." + ) attributes = isa_spec[inst_type][op_name] @@ -124,7 +132,7 @@ def init_isa_spec_from_json(cls, filename): setter(value) else: raise ValueError(f"Attribute '{attr_name}' is not recognized.") - + @classmethod def initialize_isa_spec(cls, module_dir, isa_spec_file): @@ -137,8 +145,8 @@ def initialize_isa_spec(cls, module_dir, isa_spec_file): f"Required ISA Spec file not found: {isa_spec_file}\n" "Please provide a valid path using the `isa_spec` option, " "or use a valid default file at: `/config/isa_spec.json`." - ) - + ) + cls.init_isa_spec_from_json(isa_spec_file) return isa_spec_file diff --git a/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py index b1cf68b8..79024ed3 100644 --- a/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py +++ b/assembler_tools/hec-assembler-tools/assembler/stages/preprocessor.py @@ -1,13 +1,19 @@ -import networkx as nx +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Preprocessing utilities for HERACLES assembler stages.""" + +from typing import Tuple +import networkx as nx from assembler.common.constants import Constants from assembler.instructions import xinst from assembler.instructions.xinst.xinstruction import XInstruction from assembler.instructions.xinst import parse_xntt from assembler.memory_model import MemoryModel -from assembler.memory_model import variable -def __dependencyGraphForVars(insts_list: list) -> (nx.Graph, set, set): + +def dependency_graph_for_vars(insts_list: list) -> Tuple[nx.Graph, set, set]: """ Given the listing of instructions, this method returns the dependency graph for the variables in the listing and the sets of destination and source variables. @@ -36,7 +42,7 @@ def __dependencyGraphForVars(insts_list: list) -> (nx.Graph, set, set): all_sources_vars = set() for inst in insts_list: - extra_sources = [] + extra_sources: list = [] for idx, v in enumerate(inst.dests): all_dests_vars.add(v.name) if v.name not in retval: @@ -44,7 +50,9 @@ def __dependencyGraphForVars(insts_list: list) -> (nx.Graph, set, set): for v_i in range(idx + 1, len(inst.dests)): v_next = inst.dests[v_i] if v.name == v_next.name: - raise RuntimeError(f"Cannot write to the same variable in the same instruction more than once: {inst.toPISAFormat()}") + raise RuntimeError( + f"Cannot write to the same variable in the same instruction more than once: {inst.to_pisa_format()}" + ) if not retval.has_edge(v.name, v_next.name): retval.add_edge(v.name, v_next.name) # Mac deps already handled in the Mac instructions themselves @@ -59,16 +67,18 @@ def __dependencyGraphForVars(insts_list: list) -> (nx.Graph, set, set): for v_i in range(idx + 1, len(inst_all_sources)): v_next = inst_all_sources[v_i] if v.name == v_next.name: - raise RuntimeError(f"Cannot read from the same variable in the same instruction more than once: {inst.toPISAFormat()}") + raise RuntimeError( + f"Cannot read from the same variable in the same instruction more than once: {inst.to_pisa_format()}" + ) if not retval.has_edge(v.name, v_next.name): retval.add_edge(v.name, v_next.name) return retval, all_dests_vars, all_sources_vars -def injectVariableCopy(mem_model: MemoryModel, - insts_list: list, - instruction_idx: int, - var_name: str) -> int: + +def inject_variable_copy( + mem_model: MemoryModel, insts_list: list, instruction_idx: int, var_name: str +) -> int: """ Injects a copy of a variable into the instruction list at the specified index. @@ -85,20 +95,24 @@ def injectVariableCopy(mem_model: MemoryModel, IndexError: If the instruction index is out of range. """ if instruction_idx < 0 or instruction_idx >= len(insts_list): - raise IndexError(f'instruction_idx: Expected index in range [0, {len(insts_list)}), but received {instruction_idx}.') + raise IndexError( + f"instruction_idx: Expected index in range [0, {len(insts_list)}), but received {instruction_idx}." + ) last_instruction: XInstruction = insts_list[instruction_idx] last_instruction_sources = last_instruction.sources[:] - for idx, variable in enumerate(last_instruction_sources): - if variable.name == var_name: + for idx, src_var in enumerate(last_instruction_sources): + if src_var.name == var_name: # Find next available temp var name temp_name = mem_model.findUniqueVarName() temp_var = mem_model.retrieveVarAdd(temp_name, -1) # Copy source var into temp - copy_xinst = xinst.Copy(id = last_instruction.id[1], - N = 0, - dst = [ temp_var ], - src = [ variable ], - comment='Injected copy for bank reduction.') + copy_xinst = xinst.Copy( + id=last_instruction.id[1], + N=0, + dst=[temp_var], + src=[src_var], + comment="Injected copy for bank reduction.", + ) insts_list.insert(instruction_idx, copy_xinst) # Replace src by temp last_instruction.sources[idx] = temp_var @@ -106,9 +120,8 @@ def injectVariableCopy(mem_model: MemoryModel, return instruction_idx -def reduceVarDepsByVar(mem_model: MemoryModel, - insts_list: list, - var_name: str): + +def reduce_var_deps_by_var(mem_model: MemoryModel, insts_list: list, var_name: str): """ Reduces variable dependencies by injecting copies of the specified variable. @@ -123,34 +136,43 @@ def reduceVarDepsByVar(mem_model: MemoryModel, # * care with mac instructions while last_pos < len(insts_list): if var_name in (v.name for v in insts_list[last_pos].sources): - last_instruction: XInstruction = insts_list[last_pos] + last_instruction = insts_list[last_pos] if isinstance(last_instruction, (xinst.Mac, xinst.Maci)): # Check if the conflicting variable is the accumulator if last_instruction.sources[0].name == var_name: # Turn all other variables into copies - for variable in last_instruction.sources[1:]: - last_pos = injectVariableCopy(mem_model, insts_list, last_pos, variable.name) + for src_var in last_instruction.sources[1:]: + last_pos = inject_variable_copy( + mem_model, insts_list, last_pos, src_var.name + ) assert last_instruction == insts_list[last_pos] - last_instruction = None # avoid further processing of instruction + last_instruction = None # avoid further processing of instruction last_pos += 1 continue # If conflict variable was not the accumulator, proceed to change the other variables # Skip copy, twxntt and xrshuffle - if not isinstance(last_instruction, (xinst.twiNTT, - xinst.twiNTT, - xinst.irShuffle, - xinst.rShuffle, - xinst.Copy)): + if not isinstance( + last_instruction, + ( + xinst.twiNTT, + xinst.twiNTT, + xinst.irShuffle, + xinst.rShuffle, + xinst.Copy, + ), + ): # Break up indicated variable in sources into a temp copy - last_pos = injectVariableCopy(mem_model, insts_list, last_pos, var_name) + last_pos = inject_variable_copy( + mem_model, insts_list, last_pos, var_name + ) assert last_instruction == insts_list[last_pos] last_pos += 1 -def assignRegisterBanksToVars(mem_model: MemoryModel, - insts_list: list, - use_bank0: bool, - verbose = False) -> str: + +def assign_register_banks_to_vars( + mem_model: MemoryModel, insts_list: list, use_bank0: bool, verbose=False +) -> str: """ Assigns register banks to variables using vertex coloring graph algorithm. @@ -187,45 +209,61 @@ def assignRegisterBanksToVars(mem_model: MemoryModel, """ reduced_vars = set() needs_reduction = True - pass_counter = 0 while needs_reduction: - pass_counter += 1 - if verbose: - print(f"Pass {pass_counter}") # Extract the dependency graph for variables - dep_graph_vars, dest_names, source_names = __dependencyGraphForVars(insts_list) - only_sources = source_names - dest_names # Find which variables are ever only used as sources - color_dict = nx.greedy_color(dep_graph_vars) # Do coloring + dep_graph_vars, dest_names, source_names = dependency_graph_for_vars(insts_list) + only_sources = ( + source_names - dest_names + ) # Find which variables are ever only used as sources + color_dict = nx.greedy_color(dep_graph_vars) # Do coloring needs_reduction = False for var_name, bank in color_dict.items(): if bank > 2: if var_name in reduced_vars: - raise RuntimeError(('Found invalid bank {} > 2 for variable {} already reduced.').format(bank, - var_name)) + raise RuntimeError( + f"Found invalid bank {bank} > 2 for variable {var_name} already reduced." + ) # DEBUG print if verbose: - print('Variable {} ({}) requires reduction.'.format(var_name, bank)) - reduceVarDepsByVar(mem_model, insts_list, var_name) - reduced_vars.add(var_name) # Track reduced variable + print(f"Variable {var_name} ({bank}) requires reduction.") + reduce_var_deps_by_var(mem_model, insts_list, var_name) + reduced_vars.add(var_name) # Track reduced variable needs_reduction = True # Assign banks based on coloring algo results for v in mem_model.variables.values(): - if not mem_model.isMetaVar(v.name): # Skip meta variables - assert(v.name in color_dict) + if not mem_model.isMetaVar(v.name): # Skip meta variables + assert v.name in color_dict bank = color_dict[v.name] - assert bank < 3, f'{v.name}, {bank}' + assert bank < 3, f"{v.name}, {bank}" # If requested, keep vars used only as sources in bank 0 - v.suggested_bank = bank + (0 if use_bank0 and (v.name in only_sources) else 1) + v.suggested_bank = bank + ( + 0 if use_bank0 and (v.name in only_sources) else 1 + ) retval: str = mem_model.findUniqueVarName() return retval -def preprocessPISAKernelListing(mem_model: MemoryModel, - line_iter, - progress_verbose: bool = False) -> list: + +def ntt_kernel_grammar(line): + """Parse NTT kernel grammar from a line.""" + return parse_xntt.parseXNTTKernelLine( + line, xinst.NTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR + ) + + +def intt_kernel_grammar(line): + """Parse INTT kernel grammar from a line.""" + return parse_xntt.parseXNTTKernelLine( + line, xinst.iNTT.op_name_pisa, Constants.TW_GRAMMAR_SEPARATOR + ) + + +def preprocess_pisa_kernel_listing( + mem_model: MemoryModel, line_iter, progress_verbose: bool = False +) -> list: """ Parses a P-ISA kernel listing, given as an iterator for strings, where each is a line representing a P-ISA instruction. @@ -250,9 +288,6 @@ def preprocessPISAKernelListing(mem_model: MemoryModel, Variables in `mem_model` collection of variables will be modified to reflect assigned bank in `suggested_bank` attribute. """ - NTT_KERNEL_GRAMMAR = lambda line: parse_xntt.parseXNTTKernelLine(line, xinst.NTT.OP_NAME_PISA, Constants.TW_GRAMMAR_SEPARATOR) - iNTT_KERNEL_GRAMMAR = lambda line: parse_xntt.parseXNTTKernelLine(line, xinst.iNTT.OP_NAME_PISA, Constants.TW_GRAMMAR_SEPARATOR) - retval = [] if progress_verbose: @@ -265,26 +300,28 @@ def preprocessPISAKernelListing(mem_model: MemoryModel, parsed_insts = None if not parsed_insts: - parsed_op = NTT_KERNEL_GRAMMAR(s_line) + parsed_op = ntt_kernel_grammar(s_line) if not parsed_op: - parsed_op = iNTT_KERNEL_GRAMMAR(s_line) + parsed_op = intt_kernel_grammar(s_line) if parsed_op: # Instruction is a P-ISA xntt - parsed_insts = parse_xntt.generateXNTT(mem_model, - parsed_op, - new_id = line_no) + parsed_insts = parse_xntt.generateXNTT( + mem_model, parsed_op, new_id=line_no + ) if not parsed_insts: # Instruction is one that is represented by single XInst inst = xinst.createFromPISALine(mem_model, s_line, line_no) if inst: - parsed_insts = [ inst ] + parsed_insts = [inst] if not parsed_insts: - raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + raise SyntaxError( + f"Line {line_no}: unable to parse kernel instruction:\n{s_line}" + ) retval += parsed_insts if progress_verbose: print(f"{num_input_insts}") - return retval \ No newline at end of file + return retval diff --git a/assembler_tools/hec-assembler-tools/debug_tools/main.py b/assembler_tools/hec-assembler-tools/debug_tools/main.py index 7ac6efed..a46c0f3a 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/main.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/main.py @@ -1,3 +1,6 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import os import sys import time @@ -12,7 +15,9 @@ from assembler.stages.asm_scheduler import scheduleASMISAInstructions from assembler.memory_model import MemoryModel from assembler.memory_model import mem_info -from assembler.isa_spec import SpecConfig +from assembler.spec_config.isa_spec import ISASpecConfig +from assembler.spec_config.mem_spec import MemSpecConfig + def parse_args(): """ @@ -24,19 +29,43 @@ def parse_args(): Returns: argparse.Namespace: Parsed command-line arguments. """ - parser = argparse.ArgumentParser( - description=("Main Test.\n")) + parser = argparse.ArgumentParser(description="Main Test.\n") parser.add_argument("--mem_file", default="", help="Input memory file.") - parser.add_argument("--prefix", default="", dest="base_names", nargs='+', help="One or more input prefix to process.") - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + parser.add_argument( + "--prefix", + default="", + dest="base_names", + nargs="+", + help="One or more input prefix to process.", + ) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument( + "--mem_spec", + default="", + dest="mem_spec_file", + help=("Input memory specification (.json) file."), + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) args = parser.parse_args() return args + def main_readmem(args): """ Reads and processes memory information from a file. @@ -48,10 +77,12 @@ def main_readmem(args): if args.mem_file: mem_filename = args.mem_file else: - raise argparse.ArgumentError(None, "Please provide input memory file using `--mem_file` option.") + raise argparse.ArgumentError( + None, "Please provide input memory file using `--mem_file` option." + ) mem_meta_info = None - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) if mem_meta_info: @@ -69,10 +100,10 @@ def main_readmem(args): else: print("None") -def asmisa_preprocessing(input_filename: str, - output_filename: str, - b_use_bank_0: bool, - b_verbose=True) -> int: + +def asmisa_preprocessing( + input_filename: str, output_filename: str, b_use_bank_0: bool, b_verbose=True +) -> int: """ Preprocess P-ISA kernel and save the intermediate result. @@ -86,46 +117,54 @@ def asmisa_preprocessing(input_filename: str, int: The time taken for preprocessing in seconds. """ if b_verbose: - print('Preprocessing P-ISA kernel...') + print("Preprocessing P-ISA kernel...") - hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, - constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS) + hec_mem_model = MemoryModel( + constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, + constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS, + constants.MemoryModel.NUM_REGISTER_BANKS, + ) start_time = time.time() - with open(input_filename, 'r') as insts: - insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, - insts, - progress_verbose=b_verbose) + with open(input_filename, "r") as insts: + insts_listing = preprocessor.preprocess_pisa_kernel_listing( + hec_mem_model, insts, progress_verbose=b_verbose + ) if b_verbose: print("Assigning register banks to variables...") - preprocessor.assignRegisterBanksToVars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) + preprocessor.assign_register_banks_to_vars( + hec_mem_model, insts_listing, use_bank0=b_use_bank_0 + ) retval_timing = time.time() - start_time if b_verbose: print("Saving intermediate...") - with open(output_filename, 'w') as outnum: + with open(output_filename, "w") as outnum: for inst in insts_listing: - inst_line = inst.toPISAFormat() # + f" # {inst.id}" + inst_line = inst.to_pisa_format() # + f" # {inst.id}" if inst_line: print(inst_line, file=outnum) return retval_timing -def asmisa_assembly(output_xinst_filename: str, - output_cinst_filename: str, - output_minst_filename: str, - output_mem_filename: str, - input_filename: str, - mem_filename: str, - max_bundle_size: int, - hbm_capcity_words: int, - spad_capacity_words: int, - num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS, - register_range: range = None, - b_verbose=True) -> tuple: + +def asmisa_assembly( + output_xinst_filename: str, + output_cinst_filename: str, + output_minst_filename: str, + output_mem_filename: str, + input_filename: str, + mem_filename: str, + max_bundle_size: int, + hbm_capacity_words: int, + spad_capacity_words: int, + num_register_banks: int, + register_range: range = None, + b_verbose=True, +) -> tuple: """ Assembles ASM-ISA instructions from preprocessed P-ISA kernel. @@ -137,7 +176,7 @@ def asmisa_assembly(output_xinst_filename: str, input_filename (str): The input file containing the preprocessed P-ISA kernel. mem_filename (str): The file containing memory information. max_bundle_size (int): Maximum number of instructions in a bundle. - hbm_capcity_words (int): Capacity of HBM in words. + hbm_capacity_words (int): Capacity of HBM in words. spad_capacity_words (int): Capacity of SPAD in words. num_register_banks (int): Number of register banks. register_range (range): Range of registers. @@ -150,10 +189,12 @@ def asmisa_assembly(output_xinst_filename: str, print("Assembling!") print("Reloading kernel from intermediate...") - hec_mem_model = MemoryModel(hbm_capcity_words, spad_capacity_words, num_register_banks, register_range) + hec_mem_model = MemoryModel( + hbm_capacity_words, spad_capacity_words, num_register_banks, register_range + ) insts_listing = [] - with open(input_filename, 'r') as insts: + with open(input_filename, "r") as insts: for line_no, s_line in enumerate(insts, 1): parsed_insts = None if GlobalConfig.debugVerbose: @@ -165,13 +206,17 @@ def asmisa_assembly(output_xinst_filename: str, parsed_insts = [inst] if not parsed_insts: - raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + raise SyntaxError( + "Line {}: unable to parse kernel instruction:\n{}".format( + line_no, s_line + ) + ) insts_listing += parsed_insts if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) @@ -184,11 +229,13 @@ def asmisa_assembly(output_xinst_filename: str, if b_verbose: print("Scheduling ASM-ISA instructions...") start_time = time.time() - minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions(dep_graph, - max_bundle_size, - hec_mem_model, - constants.Constants.REPLACEMENT_POLICY_FTBU, - b_verbose) + minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions( + dep_graph, + max_bundle_size, + hec_mem_model, + constants.Constants.REPLACEMENT_POLICY_FTBU, + b_verbose, + ) sched_end = time.time() - start_time num_nops = 0 num_xinsts = 0 @@ -202,37 +249,38 @@ def asmisa_assembly(output_xinst_filename: str, if b_verbose: print("Saving minst...") - with open(output_minst_filename, 'w') as outnum: + with open(output_minst_filename, "w") as outnum: for idx, inst in enumerate(minsts): - inst_line = inst.toMASMISAFormat() + inst_line = inst.to_masmisa_format() if inst_line: print(f"{idx}, {inst_line}", file=outnum) if b_verbose: print("Saving cinst...") - with open(output_cinst_filename, 'w') as outnum: + with open(output_cinst_filename, "w") as outnum: for idx, inst in enumerate(cinsts): - inst_line = inst.toCASMISAFormat() + inst_line = inst.to_casmisa_format() if inst_line: print(f"{idx}, {inst_line}", file=outnum) if b_verbose: print("Saving xinst...") - with open(output_xinst_filename, 'w') as outnum: + with open(output_xinst_filename, "w") as outnum: for bundle_i, bundle_data in enumerate(xinsts): for inst in bundle_data[0]: - inst_line = inst.toXASMISAFormat() + inst_line = inst.to_xasmisa_format() if inst_line: print(f"F{bundle_i}, {inst_line}", file=outnum) if output_mem_filename: if b_verbose: print("Saving mem...") - with open(output_mem_filename, 'w') as outnum: + with open(output_mem_filename, "w") as outnum: mem_meta_info.exportLegacyMem(outnum) return num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end + def main_asmisa(args): """ Main function to run ASM-ISA assembly process. @@ -246,7 +294,7 @@ def main_asmisa(args): GlobalConfig.useXInstFetch = False max_bundle_size = constants.Constants.MAX_BUNDLE_SIZE - hbm_capcity_words = constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2 + hbm_capacity_words = constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2 spad_capacity_words = constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS num_register_banks = constants.MemoryModel.NUM_REGISTER_BANKS register_range = None @@ -255,21 +303,23 @@ def main_asmisa(args): if len(args.base_names) > 0: all_base_names = args.base_names else: - raise argparse.ArgumentError(f"Please provide one or more input file prefixes using `--prefix` option.") + raise argparse.ArgumentError( + message=f"Please provide one or more input file prefixes using `--prefix` option." + ) for base_name in all_base_names: - in_kernel = f'{base_name}.csv' - mem_kernel = f'{base_name}.tw.mem' - mid_kernel = f'{base_name}.tw.csv' - out_xinst = f'{base_name}.xinst' - out_cinst = f'{base_name}.cinst' - out_minst = f'{base_name}.minst' - out_mem = f'{base_name}.mem' if b_use_old_mem_file else None + in_kernel = f"{base_name}.csv" + mem_kernel = f"{base_name}.tw.mem" + mid_kernel = f"{base_name}.tw.csv" + out_xinst = f"{base_name}.xinst" + out_cinst = f"{base_name}.cinst" + out_minst = f"{base_name}.minst" + out_mem = f"{base_name}.mem" if b_use_old_mem_file else None if b_verbose: print("Verbose mode: ON") - print('Input:', in_kernel) + print("Input:", in_kernel) # Preprocessing insts_end = asmisa_preprocessing(in_kernel, mid_kernel, b_use_bank_0, b_verbose) @@ -277,19 +327,20 @@ def main_asmisa(args): if b_verbose: print() - num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = \ - asmisa_assembly(out_xinst, - out_cinst, - out_minst, - out_mem, - mid_kernel, - mem_kernel, - max_bundle_size, - hbm_capcity_words, - spad_capacity_words, - num_register_banks, - register_range, - b_verbose=b_verbose) + num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = asmisa_assembly( + out_xinst, + out_cinst, + out_minst, + out_mem, + mid_kernel, + mem_kernel, + max_bundle_size, + hbm_capacity_words, + spad_capacity_words, + num_register_banks, + register_range, + b_verbose=b_verbose, + ) if b_verbose: print(f"Input: {in_kernel}") @@ -304,6 +355,7 @@ def main_asmisa(args): print("Complete") + def main_pisa(args): """ Main function to run P-ISA scheduling process. @@ -311,57 +363,57 @@ def main_pisa(args): b_use_bank_0: bool = False b_verbose = True if args.verbose > 0 else False - max_bundle_size = 8 - hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2, - 16, - 4, - range(8)) - + hec_mem_model = MemoryModel( + constants.MemoryModel.HBM.MAX_CAPACITY_WORDS // 2, 16, 4, range(8) + ) + if len(args.base_names) == 1: base_name = args.base_names[0] else: - raise argparse.ArgumentError(None, f"Please provide an input file prefix using `--prefix` option.") - + raise argparse.ArgumentError( + None, f"Please provide an input file prefix using `--prefix` option." + ) + print("HBM") print(hec_mem_model.hbm.CAPACITY / constants.Constants.GIGABYTE, "GB") print(hec_mem_model.hbm.CAPACITY_WORDS, "words") print() - - in_kernel = f'{base_name}.csv' - mid_kernel = f'{base_name}.tw.csv' - out_kernel = f'{base_name}.tw.new.csv' - out_xinst = f'{base_name}.xinst' - out_cinst = f'{base_name}.cinst' - out_minst = f'{base_name}.minst' + in_kernel = f"{base_name}.csv" + mid_kernel = f"{base_name}.tw.csv" + out_kernel = f"{base_name}.tw.new.csv" insts_listing = [] start_time = time.time() # Read input kernel and pre-process P-ISA: # Resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; # Variables used in the kernel will be automatically assigned to banks. - with open(in_kernel, 'r') as insts: - insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, - insts, - progress_verbose=b_verbose) + with open(in_kernel, "r") as insts: + insts_listing = preprocessor.preprocessPISAKernelListing( + hec_mem_model, insts, progress_verbose=b_verbose + ) print("Assigning register banks to variables...") - preprocessor.assignRegisterBanksToVars(hec_mem_model, insts_listing, use_bank0=b_use_bank_0) + preprocessor.assignRegisterBanksToVars( + hec_mem_model, insts_listing, use_bank0=b_use_bank_0 + ) - hec_mem_model.output_variables.update(v_name for v_name in hec_mem_model.variables if 'output' in v_name) + hec_mem_model.output_variables.update( + v_name for v_name in hec_mem_model.variables if "output" in v_name + ) insts_end = time.time() - start_time print("Saving intermediate...") - with open(mid_kernel, 'w') as outnum: + with open(mid_kernel, "w") as outnum: for inst in insts_listing: - inst_line = inst.toPISAFormat() + f" # {inst.id}" + inst_line = inst.to_pisa_format() + f" # {inst.id}" if inst_line: print(inst_line, file=outnum) - #print("Reloading kernel from intermediate...") - #insts_listing = [] - #with open(mid_kernel, 'r') as insts: + # print("Reloading kernel from intermediate...") + # insts_listing = [] + # with open(mid_kernel, 'r') as insts: # for line_no, s_line in enumerate(insts, 1): # parsed_insts = None # if line_no % 100 == 0: @@ -387,13 +439,15 @@ def main_pisa(args): print("Scheduling P-ISA instructions...") start_time = time.time() - pisa_insts_schedule, num_idle_cycles, num_nops = schedulePISAInstructions(dep_graph, progress_verbose=b_verbose) + pisa_insts_schedule, num_idle_cycles, num_nops = schedulePISAInstructions( + dep_graph, progress_verbose=b_verbose + ) sched_end = time.time() - start_time print("Saving...") - with open(out_kernel, 'w') as outnum: + with open(out_kernel, "w") as outnum: for idx, inst in enumerate(pisa_insts_schedule): - inst_line = inst.toPISAFormat() + inst_line = inst.to_pisa_format() if inst_line: print(inst_line, file=outnum) @@ -410,16 +464,18 @@ def main_pisa(args): print("Complete") + if __name__ == "__main__": module_dir = os.path.dirname(__file__) module_name = os.path.basename(__file__) - sys.path.append(os.path.join(module_dir,'xinst_timing_check')) - print(module_dir,'xinst_timing_check') + sys.path.append(os.path.join(module_dir, "xinst_timing_check")) + print(module_dir, "xinst_timing_check") args = parse_args() - repo_dir = os.path.join(module_dir,"..") - args.isa_spec_file = SpecConfig.initialize_isa_spec(repo_dir, args.isa_spec_file) + repo_dir = os.path.join(module_dir, "..") + args.isa_spec_file = ISASpecConfig.initialize_isa_spec(repo_dir, args.isa_spec_file) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec(repo_dir, args.mem_spec_file) if args.verbose > 0: print(f"ISA Spec: {args.isa_spec_file}") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py index 4a2dae7a..519ef37f 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/inject_bundles.py @@ -1,13 +1,17 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import argparse import os from xinst import xinstruction -from spec_config import XTC_SpecConfig +from spec_config import XTCSpecConfig # Injects dummy bundles after bundle 1 NUM_BUNDLE_INSTRUCTIONS = 64 + def makeUniquePath(path: str): """ Normalizes and expand a given file path. @@ -20,6 +24,7 @@ def makeUniquePath(path: str): """ return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + def transferNextBundle(xinst_in_stream, xinst_out_stream, bundle_number): """ Transfers the next bundle of instructions from input to output stream. @@ -35,22 +40,25 @@ def transferNextBundle(xinst_in_stream, xinst_out_stream, bundle_number): assert s_line # Split line into tokens - tokens, comment = xinstruction.tokenizeFromLine(s_line) + tokens, comment = xinstruction.tokenize_from_line(s_line) tokens = list(tokens) tokens[0] = f"F{bundle_number}" - s_line = ', '.join(tokens) + s_line = ", ".join(tokens) if comment: s_line += f" # {comment}" print(s_line, file=xinst_out_stream) -def main(nbundles: int, - input_dir: str, - output_dir: str, - input_prefix: str = None, - output_prefix: str = None, - b_use_exit: bool = True): + +def main( + nbundles: int, + input_dir: str, + output_dir: str, + input_prefix: str = None, + output_prefix: str = None, + b_use_exit: bool = True, +): """ Main function to inject dummy bundles into instruction files. @@ -71,12 +79,12 @@ def main(nbundles: int, if not output_prefix: output_prefix = os.path.basename(output_dir) - print('Input dir:', input_dir) - print('Input prefix:', input_prefix) - print('Output dir:', output_dir) - print('Output prefix:', output_prefix) - print('Dummy bundles to insert:', nbundles) - print('Use bexit:', b_use_exit) + print("Input dir:", input_dir) + print("Input prefix:", input_prefix) + print("Output dir:", output_dir) + print("Output prefix:", output_prefix) + print("Dummy bundles to insert:", nbundles) + print("Use bexit:", b_use_exit) xinst_file_i = os.path.join(input_dir, input_prefix + ".xinst") cinst_file_i = os.path.join(input_dir, input_prefix + ".cinst") @@ -86,12 +94,12 @@ def main(nbundles: int, cinst_file_o = os.path.join(output_dir, output_prefix + ".cinst") minst_file_o = os.path.join(output_dir, output_prefix + ".minst") - with open(xinst_file_i, 'r') as f_xinst_file_i, \ - open(cinst_file_i, 'r') as f_cinst_file_i, \ - open(minst_file_i, 'r') as f_minst_file_i: - with open(xinst_file_o, 'w') as f_xinst_file_o, \ - open(cinst_file_o, 'w') as f_cinst_file_o, \ - open(minst_file_o, 'w') as f_minst_file_o: + with open(xinst_file_i, "r") as f_xinst_file_i, open( + cinst_file_i, "r" + ) as f_cinst_file_i, open(minst_file_i, "r") as f_minst_file_i: + with open(xinst_file_o, "w") as f_xinst_file_o, open( + cinst_file_o, "w" + ) as f_cinst_file_o, open(minst_file_o, "w") as f_minst_file_o: current_bundle = 0 @@ -105,18 +113,22 @@ def main(nbundles: int, print(line, file=f_xinst_file_o) # Split line into tokens - tokens, _ = xinstruction.tokenizeFromLine(line) + tokens, _ = xinstruction.tokenize_from_line(line) # Must be bundle 0 assert int(tokens[0][1:]) == current_bundle - if tokens[2] == 'xstore': + if tokens[2] == "xstore": # Encountered xstore num_xstores += 1 cinst_line_no = 0 - cinst_insertion_line_start = 0 # Track which line we started inserting dummy bundles into CInstQ - cinst_insertion_line_count = 0 # Track how many lines of dummy bundles were inserted into CInstQ + cinst_insertion_line_start = ( + 0 # Track which line we started inserting dummy bundles into CInstQ + ) + cinst_insertion_line_count = ( + 0 # Track how many lines of dummy bundles were inserted into CInstQ + ) # Read cinst until first bundle is over while True: # do-while @@ -128,11 +140,11 @@ def main(nbundles: int, print(line, file=f_cinst_file_o) # Split line into tokens - tokens, _ = xinstruction.tokenizeFromLine(line) + tokens, _ = xinstruction.tokenize_from_line(line) cinst_line_no += 1 - if tokens[1] == 'ifetch': + if tokens[1] == "ifetch": # Encountered first ifetch assert int(tokens[2]) == current_bundle break @@ -147,9 +159,9 @@ def main(nbundles: int, print(line, file=f_cinst_file_o) # Split line into tokens - tokens, _ = xinstruction.tokenizeFromLine(line) + tokens, _ = xinstruction.tokenize_from_line(line) # Must be a matching cstore - assert tokens[1] == 'cstore' + assert tokens[1] == "cstore" cinst_line_no += 1 @@ -166,13 +178,19 @@ def main(nbundles: int, if idx % 5000 == 0: print("{}% - {}/{}".format(idx * 100 // nbundles, idx, nbundles)) # Cinst - print(f"{cinst_line_no}, ifetch, {current_bundle} # dummy bundle {idx + 1}", file=f_cinst_file_o) + print( + f"{cinst_line_no}, ifetch, {current_bundle} # dummy bundle {idx + 1}", + file=f_cinst_file_o, + ) print(f"{cinst_line_no + 1}, cnop, 70", file=f_cinst_file_o) cinst_line_no += 2 # Xinst if b_use_exit: - print(f"F{current_bundle}, 0, bexit # dummy bundle", file=f_xinst_file_o) + print( + f"F{current_bundle}, 0, bexit # dummy bundle", + file=f_xinst_file_o, + ) else: print(f"F{current_bundle}, 0, nop, 0", file=f_xinst_file_o) for _ in range(NUM_BUNDLE_INSTRUCTIONS - 1): @@ -187,7 +205,7 @@ def main(nbundles: int, # Complete CInstQ and XInstQ print() - print('Transferring remaining CInstQ and XInstQ...') + print("Transferring remaining CInstQ and XInstQ...") print(cinst_line_no) while True: # do-while if cinst_line_no % 50000 == 0: @@ -198,12 +216,12 @@ def main(nbundles: int, break # Split line into tokens - tokens, comment = xinstruction.tokenizeFromLine(line) + tokens, comment = xinstruction.tokenize_from_line(line) tokens = list(tokens) tokens[0] = str(cinst_line_no) # Output line with correct line and bundle number - if tokens[1] == 'ifetch': + if tokens[1] == "ifetch": # Ensure fetching correct bundle tokens[2] = str(current_bundle) @@ -211,7 +229,7 @@ def main(nbundles: int, transferNextBundle(f_xinst_file_i, f_xinst_file_o, current_bundle) current_bundle += 1 - line = ', '.join(tokens) + line = ", ".join(tokens) if comment: line += f" # {comment}" @@ -222,24 +240,26 @@ def main(nbundles: int, # Fix sync points in MInstQ print() - print('Fixing MInstQ sync points...') + print("Fixing MInstQ sync points...") for idx, line in enumerate(f_minst_file_i): if idx % 5000 == 0: print(idx) - tokens, comment = xinstruction.tokenizeFromLine(line) - assert int(tokens[0]) == idx, 'Unexpected line number mismatch in MInstQ.' + tokens, comment = xinstruction.tokenize_from_line(line) + assert ( + int(tokens[0]) == idx + ), "Unexpected line number mismatch in MInstQ." tokens = list(tokens) # Process sync instruction - if tokens[1] == 'msyncc': + if tokens[1] == "msyncc": ctarget_line_no = int(tokens[2]) if ctarget_line_no >= cinst_insertion_line_start: ctarget_line_no += cinst_insertion_line_count tokens[2] = str(ctarget_line_no) # Transfer minst line to output file - line = ', '.join(tokens) + line = ", ".join(tokens) if comment: line += f" # {comment}" @@ -247,6 +267,7 @@ def main(nbundles: int, print(idx) + if __name__ == "__main__": module_dir = os.path.dirname(__file__) module_name = os.path.basename(__file__) @@ -257,18 +278,30 @@ def main(nbundles: int, parser.add_argument("output_dir") parser.add_argument("input_prefix", nargs="?") parser.add_argument("output_prefix", nargs="?") - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("-b", "--dummy_bundles", dest='nbundles', type=int, default=0) - parser.add_argument("-ne", "--skip_exit", dest='b_use_exit', action='store_false') + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument("-b", "--dummy_bundles", dest="nbundles", type=int, default=0) + parser.add_argument("-ne", "--skip_exit", dest="b_use_exit", action="store_false") args = parser.parse_args() - args.isa_spec_file = XTC_SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + args.isa_spec_file = XTCSpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) print(f"ISA Spec: {args.isa_spec_file}") print() - main(args.nbundles, args.input_dir, args.output_dir, - args.input_prefix, args.output_prefix, args.b_use_exit) + main( + args.nbundles, + args.input_dir, + args.output_dir, + args.input_prefix, + args.output_prefix, + args.b_use_exit, + ) print() - print(module_name, "- Complete") \ No newline at end of file + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py index 04a7d98d..51b1ce05 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/spec_config.py @@ -1,38 +1,58 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Configuration for the Xinst Timing Check ISA specification.""" + import os -import xinst -from assembler.isa_spec import SpecConfig +from assembler.spec_config.isa_spec import ISASpecConfig + +import xinst # pylint: disable=import-error + -class XTC_SpecConfig (SpecConfig): +class XTCSpecConfig(ISASpecConfig): + """ + Configuration class for the Xinst Timing Check ISA specification. + This class defines the target operations, attributes, and methods for + initializing the ISA specification from a JSON file. + + Methods: + dump_isa_spec_to_json: Unimplemented for this child class. + initialize_isa_spec: Initializes the ISA specification from a JSON file. + dump_isa_spec_to_json: Unimplemented for this child class. + Attributes: + __target_xops: Dictionary mapping operation names to their corresponding + instruction classes in the xinst module. + _target_ops: Dictionary containing the target operations for this ISA spec. + _target_attributes: Dictionary mapping attribute names to their setter methods. + """ __target_xops = { - "add" : xinst.add.Instruction, - "exit" : xinst.exit_mod.Instruction, - "intt" : xinst.intt.Instruction, - "mac" : xinst.mac.Instruction, - "maci" : xinst.maci.Instruction, - "move" : xinst.move.Instruction, - "mul" : xinst.mul.Instruction, - "muli" : xinst.muli.Instruction, - "nop" : xinst.nop.Instruction, - "ntt" : xinst.ntt.Instruction, - "rshuffle" : xinst.rshuffle.Instruction, - "sub" : xinst.sub.Instruction, - "twintt" : xinst.twintt.Instruction, - "twntt" : xinst.twntt.Instruction, - "xstore" : xinst.xstore.Instruction, + "add": xinst.add.Instruction, + "exit": xinst.exit_mod.Instruction, + "intt": xinst.intt.Instruction, + "mac": xinst.mac.Instruction, + "maci": xinst.maci.Instruction, + "move": xinst.move.Instruction, + "mul": xinst.mul.Instruction, + "muli": xinst.muli.Instruction, + "nop": xinst.nop.Instruction, + "ntt": xinst.ntt.Instruction, + "rshuffle": xinst.rshuffle.Instruction, + "sub": xinst.sub.Instruction, + "twintt": xinst.twintt.Instruction, + "twntt": xinst.twntt.Instruction, + "xstore": xinst.xstore.Instruction, } - _target_ops = { - "xinst": __target_xops - } + _target_ops = {"xinst": __target_xops} _target_attributes = { - "num_tokens" : "SetNumTokens", - "num_dests" : "SetNumDests", - "num_sources" : "SetNumSources", - "default_throughput" : "SetDefaultThroughput", - "default_latency" : "SetDefaultLatency", - "special_latency_max" : "SetSpecialLatencyMax", + "num_tokens": "SetNumTokens", + "num_dests": "set_num_dests", + "num_sources": "set_num_sources", + "default_throughput": "set_default_throughput", + "default_latency": "set_default_latency", + "special_latency_max": "SetSpecialLatencyMax", "special_latency_increment": "SetSpecialLatencyIncrement", } @@ -47,16 +67,16 @@ def dump_isa_spec_to_json(cls, filename): def initialize_isa_spec(cls, module_dir, isa_spec_file): if not isa_spec_file: - isa_spec_file = os.path.join(module_dir, "../../config/isa_spec.json") - isa_spec_file = os.path.abspath(isa_spec_file) + isa_spec_file = os.path.join(module_dir, "../../config/isa_spec.json") + isa_spec_file = os.path.abspath(isa_spec_file) if not os.path.exists(isa_spec_file): - raise FileNotFoundError( + raise FileNotFoundError( f"Required ISA Spec file not found: {isa_spec_file}\n" "Please provide a valid path using the `isa_spec` option, " "or use a valid default file at: `/config/isa_spec.json`." - ) - + ) + cls.init_isa_spec_from_json(isa_spec_file) return isa_spec_file diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py index d5a47482..1d0aa938 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xinstruction.py @@ -1,33 +1,37 @@ -import re +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import re from assembler.common.decorators import * -from assembler.instructions import tokenizeFromLine +from assembler.instructions import tokenize_from_line + class XInstruction: # To be initialized from ASM ISA spec - _OP_NUM_DESTS : int - _OP_NUM_SOURCES : int - _OP_DEFAULT_THROUGHPUT : int - _OP_DEFAULT_LATENCY : int + _OP_NUM_DESTS: int + _OP_NUM_SOURCES: int + _OP_DEFAULT_THROUGHPUT: int + _OP_DEFAULT_LATENCY: int @classmethod def SetNumTokens(cls, val): pass @classmethod - def SetNumDests(cls, val): + def set_num_dests(cls, val): cls._OP_NUM_DESTS = val @classmethod - def SetNumSources(cls, val): + def set_num_sources(cls, val): cls._OP_NUM_SOURCES = val @classmethod - def SetDefaultThroughput(cls, val): + def set_default_throughput(cls, val): cls._OP_DEFAULT_THROUGHPUT = val @classmethod - def SetDefaultLatency(cls, val): + def set_default_latency(cls, val): cls._OP_DEFAULT_LATENCY = val # Static methods @@ -47,13 +51,15 @@ def tokenizeFromASMISALine(op_name: str, line: str) -> list: None if instruction cannot be parsed from the line. """ retval = None - tokens, comment = tokenizeFromLine(line) + tokens, comment = tokenize_from_line(line) if len(tokens) > 2 and tokens[2] == op_name: retval = (tokens, comment) return retval @staticmethod - def parseASMISASourceDestsFromTokens(tokens: list, num_dests: int, num_sources: int, offset: int = 0) -> dict: + def parseASMISASourceDestsFromTokens( + tokens: list, num_dests: int, num_sources: int, offset: int = 0 + ) -> dict: """ Parses the sources and destinations for an instruction, given sources and destinations in tokens in P-ISA format. @@ -79,20 +85,20 @@ def parseASMISASourceDestsFromTokens(tokens: list, num_dests: int, num_sources: dst = [] for dst_tokens in tokens[dst_start:dst_end]: if not re.search("r[0-9]+b[0-3]", dst_tokens): - raise ValueError(f'Invalid register name: `{dst_tokens}`.') + raise ValueError(f"Invalid register name: `{dst_tokens}`.") # Parse rXXbXX into a tuple of the form (reg, bank) tmp = dst_tokens[1:] - reg = tuple(map(int, tmp.split('b'))) + reg = tuple(map(int, tmp.split("b"))) dst.append(reg) src_start = dst_end src_end = src_start + num_sources src = [] for src_tokens in tokens[src_start:src_end]: if not re.search("r[0-9]+b[0-3]", src_tokens): - raise ValueError(f'Invalid register name: `{src_tokens}`.') + raise ValueError(f"Invalid register name: `{src_tokens}`.") # Parse rXXbXX into a tuple of the form (reg, bank) tmp = src_tokens[1:] - reg = tuple(map(int, tmp.split('b'))) + reg = tuple(map(int, tmp.split("b"))) src.append(reg) if dst: retval["dst"] = dst @@ -118,20 +124,22 @@ def _get_name(cls) -> str: Raises: NotImplementedError: If the method is not implemented in a derived class. """ - raise NotImplementedError('Abstract base') + raise NotImplementedError("Abstract base") # Constructor # ----------- - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = "", + ): """ Initializes an XInstruction object. @@ -161,18 +169,16 @@ def __str__(self): Returns: str: The string representation of the instruction. """ - retval = "f{}, {}, {}".format(self.bundle, - self.pisa_instr, - self.name) + retval = "f{}, {}, {}".format(self.bundle, self.pisa_instr, self.name) if self.dsts: - dsts = ['r{}b{}'.format(r, b) for r, b in self.dsts] - retval += ', {}'.format(', '.join(dsts)) + dsts = ["r{}b{}".format(r, b) for r, b in self.dsts] + retval += ", {}".format(", ".join(dsts)) if self.srcs: - srcs = ['r{}b{}'.format(r, b) for r, b in self.srcs] - retval += ', {}'.format(', '.join(srcs)) + srcs = ["r{}b{}".format(r, b) for r, b in self.srcs] + retval += ", {}".format(", ".join(srcs)) if self.other: - retval += ', {}'.format(', '.join(self.other)) + retval += ", {}".format(", ".join(self.other)) if self.comment: - retval += f' # {self.comment}' + retval += f" # {self.comment}" - return retval \ No newline at end of file + return retval diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py index 79a8d697..b1317ba9 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xinst/xstore.py @@ -1,20 +1,24 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + from .xinstruction import XInstruction + class Instruction(XInstruction): """ Represents a `xstore` Instruction. This instruction transfers data from a register into the intermediate data buffer for subsequent transfer into SPAD. - + For more information, check the specification: https://github.com/IntelLabs/hec-assembler-tools/blob/master/docsrc/inst_spec/xinst/xinst_xstore.md Methods: fromASMISALine: Parses an ASM ISA line to create an Instruction instance. """ - + @classmethod - def SetNumDests(cls, val): + def set_num_dests(cls, val): cls._OP_NUM_DESTS = 0 @classmethod @@ -36,16 +40,22 @@ def fromASMISALine(cls, line: str) -> list: if tokens: tokens, comment = tokens if len(tokens) < 3 or tokens[2] != cls.name: - raise ValueError('`line`: could not parse f{cls.name} from specified line.') - dst_src_map = XInstruction.parseASMISASourceDestsFromTokens(tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3) - retval = cls(int(tokens[0][1:]), # bundle - int(tokens[1]), # pisa - [], - dst_src_map['src'], - cls._OP_DEFAULT_THROUGHPUT, - cls._OP_DEFAULT_LATENCY, - tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES:], - comment) + raise ValueError( + "`line`: could not parse f{cls.name} from specified line." + ) + dst_src_map = XInstruction.parseASMISASourceDestsFromTokens( + tokens, cls._OP_NUM_DESTS, cls._OP_NUM_SOURCES, 3 + ) + retval = cls( + int(tokens[0][1:]), # bundle + int(tokens[1]), # pisa + [], + dst_src_map["src"], + cls._OP_DEFAULT_THROUGHPUT, + cls._OP_DEFAULT_LATENCY, + tokens[3 + cls._OP_NUM_DESTS + cls._OP_NUM_SOURCES :], + comment, + ) return retval @classmethod @@ -58,15 +68,17 @@ def _get_name(cls) -> str: """ return "xstore" - def __init__(self, - bundle: int, - pisa_instr: int, - dsts: list, - srcs: list, - throughput: int, - latency: int, - other: list = [], - comment: str = ""): + def __init__( + self, + bundle: int, + pisa_instr: int, + dsts: list, + srcs: list, + throughput: int, + latency: int, + other: list = [], + comment: str = "", + ): """ Initializes an Instruction instance. @@ -80,4 +92,6 @@ def __init__(self, other (list, optional): Additional parameters. Defaults to an empty list. comment (str, optional): A comment associated with the instruction. Defaults to an empty string. """ - super().__init__(bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment) \ No newline at end of file + super().__init__( + bundle, pisa_instr, dsts, srcs, throughput, latency, other, comment + ) diff --git a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py index 73ce674c..de3a6df4 100644 --- a/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py +++ b/assembler_tools/hec-assembler-tools/debug_tools/xinst_timing_check/xtiming_check.py @@ -1,8 +1,12 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + import argparse import os +from typing import List, Optional, Tuple, TypeVar import xinst -from spec_config import XTC_SpecConfig +from spec_config import XTCSpecConfig # Checks timing for register access. # - Checks if a register is being read from before its write completes. @@ -11,6 +15,7 @@ NUM_BUNDLE_INSTRUCTIONS = 64 + def makeUniquePath(path: str): """ Normalizes and expand a given file path. @@ -23,6 +28,7 @@ def makeUniquePath(path: str): """ return os.path.normcase(os.path.realpath(os.path.expanduser(path))) + def computeXBundleLatency(xinstr_bundle: list) -> int: """ Computes the latency of a bundle of XInstructions. @@ -37,9 +43,9 @@ def computeXBundleLatency(xinstr_bundle: list) -> int: RuntimeError: If the bundle size is invalid. """ if len(xinstr_bundle) != NUM_BUNDLE_INSTRUCTIONS: - raise RuntimeError('Invalid bundle size for bundle. Expected {} instructions, but {} found.'.format(bundle_id, - NUM_BUNDLE_INSTRUCTIONS, - len(xinstrs[idx:]))) + raise RuntimeError( + f"Invalid bundle size for bundle. Expected {NUM_BUNDLE_INSTRUCTIONS} instructions, but {len(xinstr_bundle)} found." + ) current_bundle_cycle_count = 0 # Tracks number of cycles since last sync point (bundle start is the first sync point) current_bundle_latency = 0 for xinstr in xinstr_bundle: @@ -60,6 +66,7 @@ def computeXBundleLatency(xinstr_bundle: list) -> int: return current_bundle_latency + def computeXBundleLatencies(xinstrs: list) -> list: """ Computes latencies for all bundles of XInstructions. @@ -70,14 +77,18 @@ def computeXBundleLatencies(xinstrs: list) -> list: Returns: list: A list of latencies for each bundle. """ - print('WARNING: Check latency for `exit` XInstruction.') - print('Computing x bundle latencies') + print("WARNING: Check latency for `exit` XInstruction.") + print("Computing x bundle latencies") retval = [] total_xinstr = len(xinstrs) bundle_id = 0 while xinstrs: if bundle_id % 1000 == 0: - print("{}% - {}/{}".format((total_xinstr - len(xinstrs)) * 100 // total_xinstr, (total_xinstr - len(xinstrs)), total_xinstr)) + print( + f"{(total_xinstr - len(xinstrs)) * 100 // total_xinstr}% " + f"- {(total_xinstr - len(xinstrs))}" + f"/{total_xinstr}" + ) bundle = xinstrs[:NUM_BUNDLE_INSTRUCTIONS] xinstrs = xinstrs[NUM_BUNDLE_INSTRUCTIONS:] assert bundle[0].bundle == bundle_id and bundle[-1].bundle == bundle_id @@ -88,6 +99,7 @@ def computeXBundleLatencies(xinstrs: list) -> list: return retval + def computeCBundleLatencies(cinstr_lines) -> list: """ Computes latencies for all bundles of CInstructions. @@ -98,7 +110,7 @@ def computeCBundleLatencies(cinstr_lines) -> list: Returns: list: A list of latencies for each bundle. """ - print('Computing c bundle latencies') + print("Computing c bundle latencies") retval = [] bundle_id = 0 bundle_latency = 0 @@ -108,40 +120,43 @@ def computeCBundleLatencies(cinstr_lines) -> list: if c_line.strip(): # remove comment and tokenize - s_split = [s.strip() for s in c_line.split("#")[0].split(',')] - if bundle_id < 0 and ('cnop' not in s_split[1]): - raise RuntimeError('Invalid CInstruction detected after end of CInstQ') - if 'ifetch' == s_split[1]: + s_split = [s.strip() for s in c_line.split("#")[0].split(",")] + if bundle_id < 0 and ("cnop" not in s_split[1]): + raise RuntimeError("Invalid CInstruction detected after end of CInstQ") + if "ifetch" == s_split[1]: # New bundle - assert int(s_split[2]) == bundle_id, f'ifetch, {s_split[2]} | expected {bundle_id}' + assert ( + int(s_split[2]) == bundle_id + ), f"ifetch, {s_split[2]} | expected {bundle_id}" retval.append(bundle_latency) bundle_id += 1 bundle_latency = 0 - elif 'exit' in s_split[1]: + elif "exit" in s_split[1]: # CInstQ terminate retval.append(bundle_latency) bundle_id = -1 # Will assert if more instructions after exit - elif 'cstore' == s_split[1]: + elif "cstore" == s_split[1]: # Reset latency bundle_latency = 0 else: instruction_throughput = 1 - if 'nop' in s_split[1]: + if "nop" in s_split[1]: instruction_throughput = int(s_split[2]) - elif 'cload' in s_split[1]: + elif "cload" in s_split[1]: instruction_throughput = 4 - elif 'nload' in s_split[1]: + elif "nload" in s_split[1]: instruction_throughput = 4 bundle_latency += instruction_throughput return retval[1:] -def main(input_dir: str, input_prefix: str = None): + +def main(input_dir: str, input_prefix: Optional[str] = None): """ Main function to check timing for register access and synchronization. Parameters: input_dir (str): Directory containing input files. - input_prefix (str): Prefix for input files. + input_prefix (Optional[str]): Prefix for input files. """ print("Starting") @@ -149,20 +164,20 @@ def main(input_dir: str, input_prefix: str = None): if not input_prefix: input_prefix = os.path.basename(input_dir) - print('Input dir:', input_dir) - print('Input prefix:', input_prefix) + print("Input dir:", input_dir) + print("Input prefix:", input_prefix) xinst_file = os.path.join(input_dir, input_prefix + ".xinst") cinst_file = os.path.join(input_dir, input_prefix + ".cinst") - xinstrs = [] - with open(xinst_file, 'r') as f_in: + xinstrs: List[xinst.XInstruction] = [] + with open(xinst_file, "r") as f_in: for idx, line in enumerate(f_in): if idx % 50000 == 0: print(idx) if line.strip(): # Remove comment - s_split = line.split("#")[0].split(',') + s_split = line.split("#")[0].split(",") # Parse the line into an instruction instr_name = s_split[2].strip() b_parsed = False @@ -173,68 +188,77 @@ def main(input_dir: str, input_prefix: str = None): b_parsed = True break if not b_parsed: - raise ValueError(f'Could not parse line f{idx + 1}: {line}') + raise ValueError(f"Could not parse line f{idx + 1}: {line}") # Check synchronization between C and X queues print("--------------") print("Checking synchronization between C and X queues...") xbundle_cycles = computeXBundleLatencies(xinstrs) - with open(cinst_file, 'r') as f_in: + with open(cinst_file, "r") as f_in: cbundle_cycles = computeCBundleLatencies(f_in) if len(xbundle_cycles) != len(cbundle_cycles): - raise RuntimeError('Mismatched bundles: {} xbundles vs. {} cbundles'.format(len(xbundle_cycles), - len(cbundle_cycles))) + raise RuntimeError( + "Mismatched bundles: {} xbundles vs. {} cbundles".format( + len(xbundle_cycles), len(cbundle_cycles) + ) + ) print("Comparing latencies...") bundle_cycles_violation_list = [] for idx in range(len(xbundle_cycles)): if xbundle_cycles[idx] > cbundle_cycles[idx]: - bundle_cycles_violation_list.append('Bundle {} | X {} cycles; C {} cycles'.format(idx, - xbundle_cycles[idx], - cbundle_cycles[idx])) + bundle_cycles_violation_list.append( + "Bundle {} | X {} cycles; C {} cycles".format( + idx, xbundle_cycles[idx], cbundle_cycles[idx] + ) + ) # Check timings for register access print("--------------") print("Checking timings for register access...") - violation_lst = [] # list(tuple(xinstr_idx, violating_idx, register: str, cycle_counter)) + violation_lst: List[Tuple[int, int, str, int]] = ( + [] + ) # list(tuple(xinstr_idx, violating_idx, register: str, cycle_counter)) for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) # Check bank conflict - banks = set() for r, b in xinstr.srcs: if b in banks: - violation_lst.append((idx + 1, f"Bank conflict source {b}", xinstr.name)) + violation_lst.append((idx + 1, idx + 1, f"Bank conflict source {b}", 0)) break banks.add(b) banks = set() for r, b in xinstr.dsts: if b in banks: - violation_lst.append((idx + 1, f"Bank conflict dests {b}", xinstr.name)) + violation_lst.append((idx + 1, idx + 1, f"Bank conflict dests {b}", 0)) break banks.add(b) - if xinstr.name == 'move': + if xinstr.name == "move": # Make sure move is only moving from bank zero src_bank = xinstr.srcs[0][1] dst_bank = xinstr.dsts[0][1] if src_bank != 0: - violation_lst.append((idx + 1, f"Move bank error sources {src_bank}", xinstr.name)) + violation_lst.append( + (idx + 1, idx + 1, f"Move bank error sources {src_bank}", 0) + ) if dst_bank == src_bank: - violation_lst.append((idx + 1, f"Move bank error dests {dst_bank}", xinstr.name)) + violation_lst.append( + (idx + 1, idx + 1, f"Move bank error dests {dst_bank}", 0) + ) # Check timing - cycle_counter = xinstr.throughput for jdx in range(idx + 1, len(xinstrs)): if cycle_counter >= xinstr.latency: break # Instruction outputs are ready next_xinstr = xinstrs[jdx] if next_xinstr.bundle != xinstr.bundle: - assert(next_xinstr.bundle == xinstr.bundle + 1) + assert next_xinstr.bundle == xinstr.bundle + 1 break # Different bundle # Check @@ -242,7 +266,9 @@ def main(input_dir: str, input_prefix: str = None): for reg in xinstr.dsts: if reg in all_next_regs: # Register is not ready and still used by an instruction - violation_lst.append((idx + 1, jdx + 1, f"r{reg[0]}b{reg[1]}", cycle_counter)) + violation_lst.append( + (idx + 1, jdx + 1, f"r{reg[0]}b{reg[1]}", cycle_counter) + ) cycle_counter += next_xinstr.throughput @@ -251,7 +277,9 @@ def main(input_dir: str, input_prefix: str = None): # Check rshuffle separation print("--------------") print("Checking rshuffle separation...") - rshuffle_violation_lst = [] # list(tuple(xinstr_idx, violating_idx, data_types: str, cycle_counter)) + rshuffle_violation_lst: List[Tuple[int, int, str, int]] = ( + [] + ) # list(tuple(xinstr_idx, violating_idx, data_types: str, cycle_counter)) print("WARNING: No distinction between `rshuffle` and `irshuffle`.") for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: @@ -264,18 +292,34 @@ def main(input_dir: str, input_prefix: str = None): break # Instruction outputs are ready next_xinstr = xinstrs[jdx] if next_xinstr.bundle != xinstr.bundle: - assert(next_xinstr.bundle == xinstr.bundle + 1) + assert next_xinstr.bundle == xinstr.bundle + 1 break # Different bundle # Check if isinstance(next_xinstr, xinst.rShuffle): if next_xinstr.data_type != xinstr.data_type: # Mixing ntt and intt rshuffle inside the latency of first rshuffle - rshuffle_violation_lst.append((idx + 1, jdx + 1, f"{xinstr.data_type} != {next_xinstr.data_type}", cycle_counter)) - elif cycle_counter < xinstr.special_latency_max \ - and cycle_counter % xinstr.special_latency_increment != 0: + rshuffle_violation_lst.append( + ( + idx + 1, + jdx + 1, + f"{xinstr.data_type} != {next_xinstr.data_type}", + cycle_counter, + ) + ) + elif ( + cycle_counter < xinstr.special_latency_max + and cycle_counter % xinstr.special_latency_increment != 0 + ): # Same data type - rshuffle_violation_lst.append((idx + 1, jdx + 1, f"{xinstr.data_type} == {next_xinstr.data_type}", cycle_counter)) + rshuffle_violation_lst.append( + ( + idx + 1, + jdx + 1, + f"{xinstr.data_type} == {next_xinstr.data_type}", + cycle_counter, + ) + ) cycle_counter += next_xinstr.throughput @@ -284,7 +328,9 @@ def main(input_dir: str, input_prefix: str = None): # Check bank conflicts with rshuffle print("--------------") print("Checking bank conflicts with rshuffle...") - rshuffle_bank_violation_lst = [] # list(tuple(xinstr_idx, violating_idx, banks: str, cycle_counter)) + rshuffle_bank_violation_lst: List[Tuple[int, int, str, int]] = ( + [] + ) # list(tuple(xinstr_idx, violating_idx, banks: str, cycle_counter)) for idx, xinstr in enumerate(xinstrs): if idx % 50000 == 0: print("{}% - {}/{}".format(idx * 100 // len(xinstrs), idx, len(xinstrs))) @@ -299,7 +345,7 @@ def main(input_dir: str, input_prefix: str = None): break # Instruction outputs are ready next_xinstr = xinstrs[jdx] if next_xinstr.bundle != xinstr.bundle: - assert(next_xinstr.bundle == xinstr.bundle + 1) + assert next_xinstr.bundle == xinstr.bundle + 1 break # Different bundle # Check if cycle_counter + next_xinstr.latency - 1 == rshuffle_write_cycle: @@ -307,7 +353,16 @@ def main(input_dir: str, input_prefix: str = None): # Check for bank conflicts next_xinstr_banks = set(bank for _, bank in next_xinstr.dsts) if rshuffle_banks & next_xinstr_banks: - rshuffle_bank_violation_lst.append((idx + 1, jdx + 1, "{} | banks: {}".format(next_xinstr.name, rshuffle_banks & next_xinstr_banks), cycle_counter)) + rshuffle_bank_violation_lst.append( + ( + idx + 1, + jdx + 1, + "{} | banks: {}".format( + next_xinstr.name, rshuffle_banks & next_xinstr_banks + ), + cycle_counter, + ) + ) cycle_counter += next_xinstr.throughput @@ -318,36 +373,37 @@ def main(input_dir: str, input_prefix: str = None): if bundle_cycles_violation_list: # Log violation list print() - for x in bundle_cycles_violation_list: - print(x) - s_error_msgs.append('Bundle cycle violations detected.') + for violation in bundle_cycles_violation_list: + print(violation) + s_error_msgs.append("Bundle cycle violations detected.") if violation_lst: # Log violation list print() - for x in violation_lst: - print(x) - s_error_msgs.append('Register access violations detected.') + for violation in violation_lst: + print(violation) + s_error_msgs.append("Register access violations detected.") if rshuffle_violation_lst: # Log violation list print() - for x in rshuffle_violation_lst: - print(x) - s_error_msgs.append('rShuffle special latency violations detected.') + for violation in rshuffle_violation_lst: + print(violation) + s_error_msgs.append("rShuffle special latency violations detected.") if rshuffle_bank_violation_lst: # Log violation list print() - for x in rshuffle_bank_violation_lst: - print(x) - s_error_msgs.append('rShuffle bank access violations detected.') + for violation in rshuffle_bank_violation_lst: + print(violation) + s_error_msgs.append("rShuffle bank access violations detected.") if s_error_msgs: - raise RuntimeError('\n'.join(s_error_msgs)) + raise RuntimeError("\n".join(s_error_msgs)) print() - print('No timing errors found.') + print("No timing errors found.") + if __name__ == "__main__": module_dir = os.path.dirname(__file__) @@ -357,15 +413,21 @@ def main(input_dir: str, input_prefix: str = None): parser = argparse.ArgumentParser() parser.add_argument("input_dir") parser.add_argument("input_prefix", nargs="?") - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) args = parser.parse_args() - args.isa_spec_file = XTC_SpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) + args.isa_spec_file = XTCSpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) print(f"ISA Spec: {args.isa_spec_file}") print() main(args.input_dir, args.input_prefix) print() - print(module_name, "- Complete") \ No newline at end of file + print(module_name, "- Complete") diff --git a/assembler_tools/hec-assembler-tools/he_as.py b/assembler_tools/hec-assembler-tools/he_as.py index 1facbf66..8c3e07b4 100644 --- a/assembler_tools/hec-assembler-tools/he_as.py +++ b/assembler_tools/hec-assembler-tools/he_as.py @@ -1,4 +1,7 @@ #! /usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """ This module provides functionality for assembling pre-processed P-ISA kernel programs into valid assembly code for execution queues: MINST, CINST, and XINST. @@ -50,6 +53,7 @@ DEFAULT_MINST_FILE_EXT = "minst" DEFAULT_MEM_FILE_EXT = "mem" + class AssemblerRunConfig(RunConfig): """ Maintains the configuration data for the run. @@ -59,7 +63,7 @@ class AssemblerRunConfig(RunConfig): Returns the configuration as a dictionary. """ - __initialized = False # specifies whether static members have been initialized + __initialized = False # specifies whether static members have been initialized # contains the dictionary of all configuration items supported and their # default value (or None if no default) __default_config = {} @@ -98,13 +102,15 @@ def __init__(self, **kwargs): for config_name, default_value in self.__default_config.items(): value = kwargs.get(config_name) if value is not None: - assert(not hasattr(self, config_name)) + assert not hasattr(self, config_name) setattr(self, config_name, value) else: if not hasattr(self, config_name): setattr(self, config_name, default_value) if getattr(self, config_name) is None: - raise TypeError(f'Expected value for configuration `{config_name}`, but `None` received.') + raise TypeError( + f"Expected value for configuration `{config_name}`, but `None` received." + ) # class members self.input_prefix = "" @@ -120,8 +126,9 @@ def __init__(self, **kwargs): self.input_prefix = os.path.splitext(os.path.basename(self.input_file))[0] if not self.input_mem_file: - self.input_mem_file = "{}.{}".format(os.path.join(input_dir, self.input_prefix), - DEFAULT_MEM_FILE_EXT) + self.input_mem_file = "{}.{}".format( + os.path.join(input_dir, self.input_prefix), DEFAULT_MEM_FILE_EXT + ) self.input_mem_file = makeUniquePath(self.input_mem_file) @classmethod @@ -130,23 +137,23 @@ def init_default_config(cls): Initializes static members of the class. """ if not cls.__initialized: - cls.__default_config["input_file"] = None - cls.__default_config["input_mem_file"] = "" - cls.__default_config["output_dir"] = "" - cls.__default_config["output_prefix"] = "" - cls.__default_config["has_hbm"] = True - cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB - cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB - cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY - cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch + cls.__default_config["input_file"] = None + cls.__default_config["input_mem_file"] = "" + cls.__default_config["output_dir"] = "" + cls.__default_config["output_prefix"] = "" + cls.__default_config["has_hbm"] = True + cls.__default_config["hbm_size"] = cls.DEFAULT_HBM_SIZE_KB + cls.__default_config["spad_size"] = cls.DEFAULT_SPAD_SIZE_KB + cls.__default_config["repl_policy"] = cls.DEFAULT_REPL_POLICY + cls.__default_config["use_xinstfetch"] = GlobalConfig.useXInstFetch cls.__default_config["suppress_comments"] = GlobalConfig.suppressComments - cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose + cls.__default_config["debug_verbose"] = GlobalConfig.debugVerbose cls.__initialized = True def __str__(self): """ Provides a string representation of the configuration. - + Returns: str: The string for the configuration. """ @@ -166,14 +173,22 @@ def as_dict(self) -> dict: """ retval = super().as_dict() tmp_self_dict = vars(self) - retval.update({ config_name: tmp_self_dict[config_name] for config_name in self.__default_config }) + retval.update( + { + config_name: tmp_self_dict[config_name] + for config_name in self.__default_config + } + ) return retval -def asmisaAssemble(run_config, - output_minst_filename: str, - output_cinst_filename: str, - output_xinst_filename: str, - b_verbose=True) -> tuple: + +def asmisaAssemble( + run_config, + output_minst_filename: str, + output_cinst_filename: str, + output_xinst_filename: str, + b_verbose=True, +) -> tuple: """ Assembles the P-ISA kernel into ASM-ISA instructions and saves them to specified output files. @@ -193,21 +208,27 @@ def asmisaAssemble(run_config, max_bundle_size = 64 - input_filename: str = run_config.input_file - mem_filename: str = run_config.input_mem_file - hbm_capcity_words: int = constants.convertBytes2Words(run_config.hbm_size * constants.Constants.KILOBYTE) - spad_capacity_words: int = constants.convertBytes2Words(run_config.spad_size * constants.Constants.KILOBYTE) - num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS - register_range: range = None + input_filename: str = run_config.input_file + mem_filename: str = run_config.input_mem_file + hbm_capacity_words: int = constants.convertBytes2Words( + run_config.hbm_size * constants.Constants.KILOBYTE + ) + spad_capacity_words: int = constants.convertBytes2Words( + run_config.spad_size * constants.Constants.KILOBYTE + ) + num_register_banks: int = constants.MemoryModel.NUM_REGISTER_BANKS + register_range: range = None if b_verbose: print("Assembling!") print("Reloading kernel from intermediate...") - hec_mem_model = MemoryModel(hbm_capcity_words, spad_capacity_words, num_register_banks, register_range) + hec_mem_model = MemoryModel( + hbm_capacity_words, spad_capacity_words, num_register_banks, register_range + ) insts_listing = [] - with open(input_filename, 'r') as insts: + with open(input_filename, "r") as insts: for line_no, s_line in enumerate(insts, 1): parsed_insts = None if GlobalConfig.debugVerbose: @@ -216,35 +237,44 @@ def asmisaAssemble(run_config, # instruction is one that is represented by single XInst inst = xinst.createFromPISALine(hec_mem_model, s_line, line_no) if inst: - parsed_insts = [ inst ] + parsed_insts = [inst] if not parsed_insts: - raise SyntaxError("Line {}: unable to parse kernel instruction:\n{}".format(line_no, s_line)) + raise SyntaxError( + "Line {}: unable to parse kernel instruction:\n{}".format( + line_no, s_line + ) + ) insts_listing += parsed_insts if b_verbose: print("Interpreting variable meta information...") - with open(mem_filename, 'r') as mem_ifnum: + with open(mem_filename, "r") as mem_ifnum: mem_meta_info = mem_info.MemInfo.from_iter(mem_ifnum) mem_info.updateMemoryModelWithMemInfo(hec_mem_model, mem_meta_info) if b_verbose: print("Generating dependency graph...") start_time = time.time() - dep_graph = scheduler.generateInstrDependencyGraph(insts_listing, - sys.stdout if b_verbose else None) - scheduler.enforceKeygenOrdering(dep_graph, hec_mem_model, sys.stdout if b_verbose else None) + dep_graph = scheduler.generateInstrDependencyGraph( + insts_listing, sys.stdout if b_verbose else None + ) + scheduler.enforceKeygenOrdering( + dep_graph, hec_mem_model, sys.stdout if b_verbose else None + ) deps_end = time.time() - start_time if b_verbose: print("Preparing to schedule ASM-ISA instructions...") start_time = time.time() - minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions(dep_graph, - max_bundle_size, # max number of instructions in a bundle - hec_mem_model, - run_config.repl_policy, - b_verbose) + minsts, cinsts, xinsts, num_idle_cycles = scheduleASMISAInstructions( + dep_graph, + max_bundle_size, # max number of instructions in a bundle + hec_mem_model, + run_config.repl_policy, + b_verbose, + ) sched_end = time.time() - start_time num_nops = 0 num_xinsts = 0 @@ -252,37 +282,38 @@ def asmisaAssemble(run_config, for xinstr in bundle_xinsts: num_xinsts += 1 if isinstance(xinstr, xinst.Exit): - break # stop counting instructions after bundle exit + break # stop counting instructions after bundle exit if isinstance(xinstr, xinst.Nop): num_nops += 1 if b_verbose: print("Saving minst...") - with open(output_minst_filename, 'w') as outnum: + with open(output_minst_filename, "w") as outnum: for idx, inst in enumerate(minsts): - inst_line = inst.toMASMISAFormat() + inst_line = inst.to_masmisa_format() if inst_line: print(f"{idx}, {inst_line}", file=outnum) if b_verbose: print("Saving cinst...") - with open(output_cinst_filename, 'w') as outnum: + with open(output_cinst_filename, "w") as outnum: for idx, inst in enumerate(cinsts): - inst_line = inst.toCASMISAFormat() + inst_line = inst.to_casmisa_format() if inst_line: print(f"{idx}, {inst_line}", file=outnum) if b_verbose: print("Saving xinst...") - with open(output_xinst_filename, 'w') as outnum: + with open(output_xinst_filename, "w") as outnum: for bundle_i, bundle_data in enumerate(xinsts): for inst in bundle_data[0]: - inst_line = inst.toXASMISAFormat() + inst_line = inst.to_xasmisa_format() if inst_line: print(f"F{bundle_i}, {inst_line}", file=outnum) return num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end + def main(config: AssemblerRunConfig, verbose: bool = False): """ Executes the assembly process using the provided configuration. @@ -303,27 +334,29 @@ def main(config: AssemblerRunConfig, verbose: bool = False): config = AssemblerRunConfig(**config.as_dict()) # create output directory to store outputs (if it doesn't already exist) - pathlib.Path(config.output_dir).mkdir(exist_ok = True, parents=True) + pathlib.Path(config.output_dir).mkdir(exist_ok=True, parents=True) # initialize output filenames - output_basef = os.path.join(config.output_dir, config.output_prefix) \ - if config.output_prefix \ - else os.path.join(config.output_dir, config.input_prefix) + output_basef = ( + os.path.join(config.output_dir, config.output_prefix) + if config.output_prefix + else os.path.join(config.output_dir, config.input_prefix) + ) - output_xinst_file = f'{output_basef}.{DEFAULT_XINST_FILE_EXT}' - output_cinst_file = f'{output_basef}.{DEFAULT_CINST_FILE_EXT}' - output_minst_file = f'{output_basef}.{DEFAULT_MINST_FILE_EXT}' + output_xinst_file = f"{output_basef}.{DEFAULT_XINST_FILE_EXT}" + output_cinst_file = f"{output_basef}.{DEFAULT_CINST_FILE_EXT}" + output_minst_file = f"{output_basef}.{DEFAULT_MINST_FILE_EXT}" # test output is writable for filename in (output_minst_file, output_cinst_file, output_xinst_file): try: - with open(filename, 'w') as outnum: + with open(filename, "w") as outnum: print("", file=outnum) except Exception as ex: raise Exception(f'Failed to write to output location "{filename}"') from ex - GlobalConfig.useHBMPlaceHolders = True #config.use_hbm_placeholders + GlobalConfig.useHBMPlaceHolders = True # config.use_hbm_placeholders GlobalConfig.useXInstFetch = config.use_xinstfetch GlobalConfig.suppressComments = config.suppress_comments GlobalConfig.hasHBM = config.has_hbm @@ -331,12 +364,13 @@ def main(config: AssemblerRunConfig, verbose: bool = False): Counter.reset() - num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = \ - asmisaAssemble(config, - output_minst_file, - output_cinst_file, - output_xinst_file, - b_verbose=verbose) + num_xinsts, num_nops, num_idle_cycles, deps_end, sched_end = asmisaAssemble( + config, + output_minst_file, + output_cinst_file, + output_xinst_file, + b_verbose=verbose, + ) if verbose: print(f"Output:") @@ -348,6 +382,7 @@ def main(config: AssemblerRunConfig, verbose: bool = False): print(f"--- Minimum idle cycles: {num_idle_cycles} ---") print(f"--- Minimum nops required: {num_nops} ---") + def parse_args(): """ Parses command-line arguments for the assembler script. @@ -359,50 +394,114 @@ def parse_args(): argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description=("HERACLES Assembler.\n" - "The assembler takes a pre-processed P-ISA kernel program and generates " - "valid assembly code for each of the three execution queues: MINST, CINST, and XINST.")) - parser.add_argument("input_file", - help=("Input pre-processed P-ISA kernel file. " - "File must be the result of pre-processing a P-ISA kernel with he_prep.py")) - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("--mem_spec", default="", dest="mem_spec_file", - help=("Input Mem specification (.json) file.")) - parser.add_argument("--input_mem_file", default="", help=("Input memory mapping file associated with the kernel. " - "Defaults to the same name as the input file, but with `.mem` extension.")) - parser.add_argument("--output_dir", default="", help=("Directory where to store all intermediate files and final output. " - "This will be created if it doesn't exists. " - "Defaults to the same directory as the input file.")) - parser.add_argument("--output_prefix", default="", help=("Prefix for the output files. " - "Defaults to the same the input file without extension.")) + description=( + "HERACLES Assembler.\n" + "The assembler takes a pre-processed P-ISA kernel program and generates " + "valid assembly code for each of the three execution queues: MINST, CINST, and XINST." + ) + ) + parser.add_argument( + "input_file", + help=( + "Input pre-processed P-ISA kernel file. " + "File must be the result of pre-processing a P-ISA kernel with he_prep.py" + ), + ) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument( + "--mem_spec", + default="", + dest="mem_spec_file", + help=("Input Mem specification (.json) file."), + ) + parser.add_argument( + "--input_mem_file", + default="", + help=( + "Input memory mapping file associated with the kernel. " + "Defaults to the same name as the input file, but with `.mem` extension." + ), + ) + parser.add_argument( + "--output_dir", + default="", + help=( + "Directory where to store all intermediate files and final output. " + "This will be created if it doesn't exists. " + "Defaults to the same directory as the input file." + ), + ) + parser.add_argument( + "--output_prefix", + default="", + help=( + "Prefix for the output files. " + "Defaults to the same the input file without extension." + ), + ) parser.add_argument("--spad_size", type=int, help="Scratchpad size in KB.") parser.add_argument("--hbm_size", type=int, help="HBM size in KB.") - parser.add_argument("--no_hbm", dest="has_hbm", action="store_false", - help="If set, this flag tells he_prep there is no HBM in the target chip.") - parser.add_argument("--repl_policy", choices=constants.Constants.REPLACEMENT_POLICIES, - help="Replacement policy for cache evictions.") - parser.add_argument("--use_xinstfetch", dest="use_xinstfetch", action="store_true", - help=("When enabled, `xinstfetch` instructions are generated in the CInstQ.")) - parser.add_argument("--suppress_comments", "--no_comments", dest="suppress_comments", action="store_true", - help=("When enabled, no comments will be emited on the output generated by the assembler.")) - parser.add_argument("-v", "--verbose", dest="debug_verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) + parser.add_argument( + "--no_hbm", + dest="has_hbm", + action="store_false", + help="If set, this flag tells he_prep there is no HBM in the target chip.", + ) + parser.add_argument( + "--repl_policy", + choices=constants.Constants.REPLACEMENT_POLICIES, + help="Replacement policy for cache evictions.", + ) + parser.add_argument( + "--use_xinstfetch", + dest="use_xinstfetch", + action="store_true", + help=("When enabled, `xinstfetch` instructions are generated in the CInstQ."), + ) + parser.add_argument( + "--suppress_comments", + "--no_comments", + dest="suppress_comments", + action="store_true", + help=( + "When enabled, no comments will be emitted on the output generated by the assembler." + ), + ) + parser.add_argument( + "-v", + "--verbose", + dest="debug_verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) args = parser.parse_args() return args + if __name__ == "__main__": module_dir = os.path.dirname(__file__) module_name = os.path.basename(__file__) # Initialize Defaults args = parse_args() - args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) - args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec( + module_dir, args.mem_spec_file + ) - config = AssemblerRunConfig(**vars(args)) # convert argsparser into a dictionary + config = AssemblerRunConfig(**vars(args)) # convert argsparser into a dictionary if args.debug_verbose > 0: print(module_name) @@ -413,7 +512,7 @@ def parse_args(): print("=================") print() - main(config, verbose = args.debug_verbose > 1) + main(config, verbose=args.debug_verbose > 1) if args.debug_verbose > 0: print() diff --git a/assembler_tools/hec-assembler-tools/he_prep.py b/assembler_tools/hec-assembler-tools/he_prep.py index 234058cf..1dae2ff1 100644 --- a/assembler_tools/hec-assembler-tools/he_prep.py +++ b/assembler_tools/hec-assembler-tools/he_prep.py @@ -1,10 +1,13 @@ #! /usr/bin/env python3 +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """ This module provides functionality for preprocessing P-ISA abstract kernels before further assembling for HERACLES. Functions: - __savePISAListing(out_stream, instr_listing: list) + save_pisa_listing(out_stream, instr_listing: list) Stores instructions to a stream in P-ISA format. main(output_file_name: str, input_file_name: str, b_verbose: bool) @@ -14,8 +17,8 @@ Parses command-line arguments for the preprocessing script. Usage: - This script is intended to be run as a standalone program. It requires specific command-line arguments - to specify input and output files and verbosity options for the preprocessing process. + This script is intended to be run as a standalone program. It requires specific command-line + arguments to specify input and output files and verbosity options for the preprocessing process. """ import argparse @@ -28,8 +31,8 @@ from assembler.stages import preprocessor from assembler.memory_model import MemoryModel -def __savePISAListing(out_stream, - instr_listing: list): + +def save_pisa_listing(out_stream, instr_listing: list): """ Stores the instructions to a stream in P-ISA format. @@ -44,13 +47,12 @@ def __savePISAListing(out_stream, None """ for inst in instr_listing: - inst_line = inst.toPISAFormat() + inst_line = inst.to_pisa_format() if inst_line: print(inst_line, file=out_stream) -def main(output_file_name: str, - input_file_name: str, - b_verbose: bool): + +def main(output_file_name: str, input_file_name: str, b_verbose: bool): """ Preprocesses the P-ISA kernel and saves the output to a specified file. @@ -66,42 +68,45 @@ def main(output_file_name: str, None """ # used for timings - insts_end: int = 0 + insts_end: float = 0.0 # check for default `output_file_name` # e.g. of default # input_file_name = /path/to/some/file.csv # output_file_name = /path/to/some/file.tw.csv if not output_file_name: - output_file_name = os.path.splitext(input_file_name) - output_file_name = ''.join(output_file_name[:-1] + (".tw",) + output_file_name[-1:]) + root, ext = os.path.splitext(input_file_name) + output_file_name = root + ".tw" + ext - hec_mem_model = MemoryModel(constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, - constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS, - constants.MemoryModel.NUM_REGISTER_BANKS) + hec_mem_model = MemoryModel( + constants.MemoryModel.HBM.MAX_CAPACITY_WORDS, + constants.MemoryModel.SPAD.MAX_CAPACITY_WORDS, + constants.MemoryModel.NUM_REGISTER_BANKS, + ) insts_listing = [] start_time = time.time() # read input kernel and pre-process P-ISA: # resulting instructions will be correctly transformed and ready to be converted into ASM-ISA instructions; # variables used in the kernel will be automatically assigned to banks. - with open(input_file_name, 'r') as insts: - insts_listing = preprocessor.preprocessPISAKernelListing(hec_mem_model, - insts, - progress_verbose=b_verbose) - num_input_instr: int = len(insts_listing) # track number of instructions in input kernel + with open(input_file_name, "r", encoding="utf-8") as insts: + insts_listing = preprocessor.preprocess_pisa_kernel_listing( + hec_mem_model, insts, progress_verbose=b_verbose + ) + num_input_instr: int = len( + insts_listing + ) # track number of instructions in input kernel if b_verbose: print("Assigning register banks to variables...") - preprocessor.assignRegisterBanksToVars(hec_mem_model, - insts_listing, - use_bank0=False, - verbose=b_verbose) + preprocessor.assign_register_banks_to_vars( + hec_mem_model, insts_listing, use_bank0=False, verbose=b_verbose + ) insts_end = time.time() - start_time if b_verbose: print("Saving...") - with open(output_file_name, 'w') as outnum: - __savePISAListing(outnum, insts_listing) + with open(output_file_name, "w", encoding="utf-8") as outnum: + save_pisa_listing(outnum, insts_listing) if b_verbose: print(f"Input: {input_file_name}") @@ -110,6 +115,7 @@ def main(output_file_name: str, print(f"Instructions in output: {len(insts_listing)}") print(f"--- Generation time: {insts_end} seconds ---") + def parse_args(): """ Parses command-line arguments for the preprocessing script. @@ -121,19 +127,44 @@ def parse_args(): argparse.Namespace: Parsed command-line arguments. """ parser = argparse.ArgumentParser( - description="HERACLES Assembling Pre-processor.\nThis program performs the preprocessing of P-ISA abstract kernels before further assembling.") - parser.add_argument("input_file_name", help="Input abstract kernel file to which to add twiddle factors.") - parser.add_argument("output_file_name", nargs="?", help="Output file name. Defaults to .tw.") - parser.add_argument("--isa_spec", default="", dest="isa_spec_file", - help=("Input ISA specification (.json) file.")) - parser.add_argument("--mem_spec", default="", dest="mem_spec_file", - help=("Input Mem specification (.json) file.")) - parser.add_argument("-v", "--verbose", dest="verbose", action="count", default=0, - help=("If enabled, extra information and progress reports are printed to stdout. " - "Increase level of verbosity by specifying flag multiple times, e.g. -vv")) - args = parser.parse_args() - - return args + description="HERACLES Assembling Pre-processor.\nThis program performs the preprocessing of P-ISA abstract kernels before further assembling." + ) + parser.add_argument( + "input_file_name", + help="Input abstract kernel file to which to add twiddle factors.", + ) + parser.add_argument( + "output_file_name", + nargs="?", + help="Output file name. Defaults to .tw.", + ) + parser.add_argument( + "--isa_spec", + default="", + dest="isa_spec_file", + help=("Input ISA specification (.json) file."), + ) + parser.add_argument( + "--mem_spec", + default="", + dest="mem_spec_file", + help=("Input Mem specification (.json) file."), + ) + parser.add_argument( + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help=( + "If enabled, extra information and progress reports are printed to stdout. " + "Increase level of verbosity by specifying flag multiple times, e.g. -vv" + ), + ) + p_args = parser.parse_args() + + return p_args + if __name__ == "__main__": module_dir = os.path.dirname(__file__) @@ -141,20 +172,26 @@ def parse_args(): args = parse_args() - args.isa_spec_file = ISASpecConfig.initialize_isa_spec(module_dir, args.isa_spec_file) - args.mem_spec_file = MemSpecConfig.initialize_mem_spec(module_dir, args.mem_spec_file) + args.isa_spec_file = ISASpecConfig.initialize_isa_spec( + module_dir, args.isa_spec_file + ) + args.mem_spec_file = MemSpecConfig.initialize_mem_spec( + module_dir, args.mem_spec_file + ) if args.verbose > 0: print(module_name) print() - print("Input: {0}".format(args.input_file_name)) - print("Output: {0}".format(args.output_file_name)) - print("ISA Spec: {0}".format(args.isa_spec_file)) - print("Mem Spec: {0}".format(args.mem_spec_file)) - - main(output_file_name=args.output_file_name, - input_file_name=args.input_file_name, - b_verbose=(args.verbose > 1)) + print(f"Input: {args.input_file_name}") + print(f"Output: {args.output_file_name}") + print(f"ISA Spec: {args.isa_spec_file}") + print(f"Mem Spec: {args.mem_spec_file}") + + main( + output_file_name=args.output_file_name, + input_file_name=args.input_file_name, + b_verbose=(args.verbose > 1), + ) if args.verbose > 0: print() diff --git a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py index 66afd290..135608cc 100644 --- a/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py +++ b/assembler_tools/hec-assembler-tools/linker/instructions/__init__.py @@ -1,6 +1,10 @@ -from assembler.instructions import tokenizeFromLine +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from assembler.instructions import tokenize_from_line from linker.instructions.instruction import BaseInstruction + def fromStrLine(line: str, factory) -> BaseInstruction: """ Parses an instruction from a line of text. @@ -13,7 +17,7 @@ def fromStrLine(line: str, factory) -> BaseInstruction: parsed from the specified input line. """ retval = None - tokens, comment = tokenizeFromLine(line) + tokens, comment = tokenize_from_line(line) for instr_type in factory: try: retval = instr_type(tokens, comment) diff --git a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py index e67823e8..c87ad818 100644 --- a/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py +++ b/assembler_tools/hec-assembler-tools/tests/unit_tests/test_he_prep.py @@ -26,15 +26,17 @@ def test_main_assigns_and_saves(monkeypatch, tmp_path): output_file = tmp_path / "output.csv" dummy_model = object() - dummy_insts = [mock.Mock(toPISAFormat=mock.Mock(return_value="inst1"))] + dummy_insts = [mock.Mock(to_pisa_format=mock.Mock(return_value="inst1"))] monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) monkeypatch.setattr( he_prep.preprocessor, - "preprocessPISAKernelListing", + "preprocess_pisa_kernel_listing", mock.Mock(return_value=dummy_insts), ) - monkeypatch.setattr(he_prep.preprocessor, "assignRegisterBanksToVars", mock.Mock()) + monkeypatch.setattr( + he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock() + ) he_prep.main(str(output_file), str(input_file), b_verbose=False) # Output file should contain the instruction @@ -76,9 +78,13 @@ def test_main_no_instructions(monkeypatch): dummy_model = object() monkeypatch.setattr(he_prep, "MemoryModel", mock.Mock(return_value=dummy_model)) monkeypatch.setattr( - he_prep.preprocessor, "preprocessPISAKernelListing", mock.Mock(return_value=[]) + he_prep.preprocessor, + "preprocess_pisa_kernel_listing", + mock.Mock(return_value=[]), + ) + monkeypatch.setattr( + he_prep.preprocessor, "assign_register_banks_to_vars", mock.Mock() ) - monkeypatch.setattr(he_prep.preprocessor, "assignRegisterBanksToVars", mock.Mock()) he_prep.main(output_file, input_file, b_verbose=False) diff --git a/p-isa_tools/kerngen/tests/test_kerngraph_parser.py b/p-isa_tools/kerngen/tests/test_kerngraph_parser.py index 6e4df667..4ab17100 100644 --- a/p-isa_tools/kerngen/tests/test_kerngraph_parser.py +++ b/p-isa_tools/kerngen/tests/test_kerngraph_parser.py @@ -1,6 +1,3 @@ -# Copyright (C) 2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0