diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 968aa2080..804a7d597 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -8,7 +8,7 @@ import traceback import types from itertools import chain -from typing import Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from ...utils import ( BreakGraphError, @@ -72,16 +72,22 @@ VariableFactory, ) +if TYPE_CHECKING: + from typing import TypeVar + + VariableT = TypeVar("VariableT", bound=VariableBase) + + GuardedFunction = Tuple[types.CodeType, Guard] + GuardedFunctions = List[GuardedFunction] + CacheGetter = Callable[ + [types.FrameType, GuardedFunctions], Optional["CustomCode"] + ] + CustomCode = collections.namedtuple( "CustomCode", ["code", "disable_eval_frame"] ) -GuardedFunction = Tuple[types.CodeType, Guard] -GuardedFunctions = List[GuardedFunction] -CacheGetter = Callable[ - [types.FrameType, GuardedFunctions], Optional[CustomCode] -] dummy_guard: Guard = lambda frame: True SUPPORT_COMPARE_OP = { @@ -634,7 +640,7 @@ def indexof(self, instr: Instruction): """ return self._instructions.index(instr) - def pop(self) -> VariableBase: + def pop(self, *, var_type: type[VariableT] = VariableBase) -> VariableT: """ Pops the top value from the stack. @@ -642,7 +648,9 @@ def pop(self) -> VariableBase: The popped value. """ - return self._stack.pop() + var = self._stack.pop() + assert isinstance(var, var_type) + return var def peek(self) -> VariableBase: """ @@ -995,7 +1003,7 @@ def BUILD_CONST_KEY_MAP(self, instr: Instruction): assert map_size + 1 <= len( self._stack ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." - keys = self.pop().get_items() + keys = self.pop(var_type=ContainerVariable).get_items() assert len(keys) == map_size values = self.pop_n(map_size) self.push(self.build_map(keys, values)) @@ -1241,10 +1249,14 @@ def GET_ITER(self, instr: Instruction): ) def JUMP_FORWARD(self, instr): - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) def JUMP_ABSOLUTE(self, instr: Instruction): - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) def CONTAINS_OP(self, instr: Instruction): # It will only be 0 or 1 @@ -1264,7 +1276,9 @@ def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = not bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) else: self.pop() return @@ -1279,7 +1293,9 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) else: self.pop() return @@ -1294,7 +1310,9 @@ def POP_JUMP_IF_FALSE(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = not bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) return raise NotImplementException( "Currently don't support predicate a non-const / non-tensor obj." @@ -1307,7 +1325,9 @@ def POP_JUMP_IF_TRUE(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) return raise NotImplementException( "Currently don't support predicate a non-const / non-tensor obj." @@ -1389,7 +1409,8 @@ def DICT_UPDATE(self, instr: Instruction): ) def DICT_MERGE(self, instr: Instruction): - dict_value = self.pop() + # TODO: self._stack[index] should be replaced? + dict_value = self.pop(var_type=DictVariable) assert instr.arg > 0 for key in dict_value.get_wrapped_items().keys(): result = self._stack[-instr.arg].get_wrapped_items().get(key, None) @@ -1416,7 +1437,10 @@ def LIST_EXTEND(self, instr: Instruction): ) def LIST_TO_TUPLE(self, instr: Instruction): - list_value = self.pop() + # TODO(zrr1999): I think list_value should a ListVariable instance, + # but return_value of get_wrapped_items method in ListVariable is a list instead of tuple. + # list_value = self.pop(var_type=ListVariable) + list_value = self.pop(var_type=ContainerVariable) self.push( TupleVariable( list_value.get_wrapped_items(), @@ -1517,7 +1541,8 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): self.indexof(instr) + 1, stack_size ) else_fn, else_inputs = self._create_resume_fn( - self.indexof(instr.jump_to), stack_size + self.indexof(instr.safe_getattr("jump_to", var_type=Instruction)), + stack_size, ) # gen call static fn opcode @@ -1718,11 +1743,18 @@ def _break_graph_in_for_loop( pycode_gen = PyCodeGen(self._frame) loop_body, loop_inputs = pycode_gen.gen_loop_body_between( - for_iter, loop_body_start_idx, self.indexof(for_iter.jump_to) + for_iter, + loop_body_start_idx, + self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ), ) after_loop_fn, fn_inputs = self._create_resume_fn( - self.indexof(for_iter.jump_to), len(self._stack) + self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ), + len(self._stack), ) total_inputs = OrderedSet(list(fn_inputs) + list(loop_inputs)) @@ -1822,7 +1854,9 @@ def _inline_call_for_loop( origin_instrs = get_instructions(pycode_gen._origin_code) start_idx = self.indexof(for_iter) - end_idx = self.indexof(for_iter.jump_to) + end_idx = self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ) inputs = list( analysis_inputs_outputs(origin_instrs, start_idx, end_idx) @@ -1850,7 +1884,10 @@ def _inline_call_for_loop( if ( instr.jump_to in origin_instrs - and origin_instrs.index(instr.jump_to) >= end_idx + and origin_instrs.index( + instr.safe_getattr("jump_to", var_type=Instruction) + ) + >= end_idx ): instr.jump_to = nop_for_break @@ -1899,11 +1936,11 @@ def STORE_ATTR(self, instr): ) def FOR_ITER(self, instr): - iterator = self.pop() + iterator = self.pop(var_type=IterVariable) backup_iter_idx = None start = self.indexof(instr) - end = self.indexof(instr.jump_to) + end = self.indexof(instr.safe_getattr("jump_to", var_type=Instruction)) for i in range(start, end): if self._instructions[i].opname == "RETURN_VALUE": raise NotImplementException( @@ -1911,7 +1948,7 @@ def FOR_ITER(self, instr): ) self._graph.add_global_guarded_variable(iterator) - # TODO need support TensorIterVariable.next + # TODO: need support TensorIterVariable.next try: if not isinstance( @@ -1923,8 +1960,12 @@ def FOR_ITER(self, instr): backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) except BreakGraphError as e: + # TODO: backup_iter_idx is not None? + # TODO: idx is not a member of IterVariable if backup_iter_idx: iterator.idx = backup_iter_idx self._graph.remove_global_guarded_variable(iterator) diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index f1d0baabb..af769aa2e 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -3,12 +3,13 @@ import dataclasses import dis import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from .opcode_info import ABS_JUMP, ALL_JUMP, REL_JUMP if TYPE_CHECKING: import types + from typing import Any @dataclasses.dataclass @@ -290,12 +291,15 @@ def calc_offset_from_bytecode_offset(bytecode_offset: int) -> int: return bytecode_offset // 2 -def replace_instr(instructions, instr, new_instr): +def replace_instr( + instructions: list[Instruction], instr: Instruction, new_instr +): idx = instructions.index(instr) + # TODO: maybe new_instr is a lsit? instructions[idx : idx + 1] = new_instr -def instrs_info(instrs, mark=None, range=None): +def instrs_info(instrs: dict[str, Instruction], mark=None, range=None): ret = [] start = -1 end = 1000000 diff --git a/sot/utils/utils.py b/sot/utils/utils.py index 9dcaf8c00..64fdf64a5 100644 --- a/sot/utils/utils.py +++ b/sot/utils/utils.py @@ -239,8 +239,8 @@ def __init__(self): def clear(self): self.graph_num = 0 self.op_num = 0 - self.graphs: list = [] - self.ops: list = [] + self.graphs = [] + self.ops = [] def get_graph_num(self): return self.graph_num @@ -257,6 +257,7 @@ def add_subgraph(self, program: Program): for op in block.ops: self.op_num += 1 sub_op.append(op) + # TODO: self.ops is a list, and sub_op is a list? self.ops.append(sub_op) def add_subgprah_info(self, strs):