From 0f5f2c2a3aae3e085a721abfa210f7e6ecadae16 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 23 Jul 2023 09:56:04 +0000 Subject: [PATCH 1/6] GraphLogger todo --- sot/utils/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sot/utils/utils.py b/sot/utils/utils.py index f57f3ce53..758a644d0 100644 --- a/sot/utils/utils.py +++ b/sot/utils/utils.py @@ -237,8 +237,8 @@ def __init__(self): def clear(self): self.graph_num = 0 self.op_num = 0 - self.graphs: list = [] - self.ops: list = [] + self.graphs = [] + self.ops = [] def get_graph_num(self): return self.graph_num @@ -255,6 +255,7 @@ def add_subgraph(self, program: Program): for op in block.ops: self.op_num += 1 sub_op.append(op) + # TODO: self.ops is a list, and sub_op is a list? self.ops.append(sub_op) def add_subgprah_info(self, strs): From 7e71ca4b1358372f69d8f4a582e1c7779b49be31 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 23 Jul 2023 10:24:40 +0000 Subject: [PATCH 2/6] add safe_getattr --- .../executor/opcode_executor.py | 163 +++++++++++------- .../instruction_utils/instruction_utils.py | 23 ++- 2 files changed, 119 insertions(+), 67 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index c54a2ea8e..401185335 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -8,7 +8,7 @@ import traceback import types from itertools import chain -from typing import Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from ...utils import ( BreakGraphError, @@ -71,16 +71,22 @@ VariableFactory, ) +if TYPE_CHECKING: + from typing import TypeVar + + T = TypeVar("T") + + GuardedFunction = Tuple[types.CodeType, Guard] + GuardedFunctions = List[GuardedFunction] + CacheGetter = Callable[ + [types.FrameType, GuardedFunctions], Optional["CustomCode"] + ] + CustomCode = collections.namedtuple( "CustomCode", ["code", "disable_eval_frame"] ) -GuardedFunction = Tuple[types.CodeType, Guard] -GuardedFunctions = List[GuardedFunction] -CacheGetter = Callable[ - [types.FrameType, GuardedFunctions], Optional[CustomCode] -] dummy_guard: Guard = lambda frame: True SUPPORT_COMPARE_OP = { @@ -795,7 +801,7 @@ def NOP(self, instr: Instruction): pass def LOAD_ATTR(self, instr: Instruction): - attr_name = self._code.co_names[instr.arg] + attr_name = self._code.co_names[instr.get_arg()] obj = self.pop() self.push( BuiltinVariable( @@ -804,30 +810,30 @@ def LOAD_ATTR(self, instr: Instruction): ) def LOAD_CONST(self, instr: Instruction): - var = self._co_consts[instr.arg] + var = self._co_consts[instr.get_arg()] self.push(var) def LOAD_CLOSURE(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.arg] + name = namemap[instr.get_arg()] self.push(self._cells[name]) def LOAD_DEREF(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.arg] + name = namemap[instr.get_arg()] self.push(self._cells[name].cell_content()) def LOAD_FAST(self, instr: Instruction): - varname = self._code.co_varnames[instr.arg] + varname = self._code.co_varnames[instr.get_arg()] var = self._locals[varname] self.push(var) def DELETE_FAST(self, instr: Instruction): - varname = self._code.co_varnames[instr.arg] + varname = self._code.co_varnames[instr.get_arg()] del self._locals[varname] def LOAD_GLOBAL(self, instr: Instruction): - name = self._code.co_names[instr.arg] + name = self._code.co_names[instr.get_arg()] if name in self._globals.keys(): value = self._globals[name] else: @@ -835,7 +841,7 @@ def LOAD_GLOBAL(self, instr: Instruction): self.push(value) def LOAD_METHOD(self, instr: Instruction): - method_name = self._code.co_names[instr.arg] + method_name = self._code.co_names[instr.get_arg()] obj = self.pop() method = BuiltinVariable( @@ -853,7 +859,7 @@ def LOAD_METHOD(self, instr: Instruction): def STORE_DEREF(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.arg] + name = namemap[instr.get_arg()] self._cells[name].set_value(self.pop()) def STORE_FAST(self, instr: Instruction): @@ -861,13 +867,13 @@ def STORE_FAST(self, instr: Instruction): TODO: side effect may happen """ var = self.pop() - name = self._code.co_varnames[instr.arg] + name = self._code.co_varnames[instr.get_arg()] var.debug_name = name self._locals[name] = var def STORE_GLOBAL(self, instr: Instruction): var = self.pop() - name = self._code.co_names[instr.arg] + name = self._code.co_names[instr.get_arg()] var.debug_name = name self._locals[name] = var @@ -895,7 +901,7 @@ def DELETE_SUBSCR(self, instr: Instruction): ) def BUILD_LIST(self, instr: Instruction): - list_size = instr.arg + list_size = instr.get_arg() assert list_size <= len( self._stack ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." @@ -907,7 +913,7 @@ def BUILD_LIST(self, instr: Instruction): ) def BUILD_TUPLE(self, instr: Instruction): - tuple_size = instr.arg + tuple_size = instr.get_arg() assert tuple_size <= len( self._stack ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." @@ -921,7 +927,7 @@ def BUILD_TUPLE(self, instr: Instruction): ) def BUILD_STRING(self, instr: Instruction): - count = instr.arg + count = instr.get_arg() assert count <= len( self._stack ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." @@ -938,7 +944,7 @@ def BUILD_STRING(self, instr: Instruction): @call_break_graph_decorator(push_n=1) def BUILD_SLICE(self, instr: Instruction): - if instr.arg == 3: + if instr.get_arg() == 3: step = self.pop() else: step = None @@ -972,7 +978,7 @@ def build_map( ) def BUILD_MAP(self, instr: Instruction): - map_size = instr.arg + map_size = instr.get_arg() assert map_size * 2 <= len( self._stack ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." @@ -982,7 +988,7 @@ def BUILD_MAP(self, instr: Instruction): self.push(self.build_map(keys, values)) def BUILD_CONST_KEY_MAP(self, instr: Instruction): - map_size = instr.arg + map_size = instr.get_arg() assert map_size + 1 <= len( self._stack ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." @@ -992,7 +998,7 @@ def BUILD_CONST_KEY_MAP(self, instr: Instruction): self.push(self.build_map(keys, values)) def build_seq_unpack(self, instr: Instruction): - oparg = instr.arg + oparg = instr.get_arg() assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1023,7 +1029,7 @@ def BUILD_LIST_UNPACK(self, instr: Instruction): self.build_seq_unpack(instr) def BUILD_MAP_UNPACK(self, instr: Instruction): - oparg = instr.arg + oparg = instr.get_arg() assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1039,7 +1045,7 @@ def BUILD_MAP_UNPACK(self, instr: Instruction): ) def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): - oparg = instr.arg + oparg = instr.get_arg() assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1060,7 +1066,7 @@ def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): ) def CALL_FUNCTION(self, instr: Instruction): - n_args = instr.arg + n_args = instr.get_arg() assert n_args <= len(self._stack) args = self.pop_n(n_args) kwargs = {} @@ -1069,7 +1075,7 @@ def CALL_FUNCTION(self, instr: Instruction): self.push(ret) def CALL_FUNCTION_KW(self, instr: Instruction): - n_args = instr.arg + n_args = instr.get_arg() assert n_args + 2 <= len(self._stack) kwargs_keys = self.pop() @@ -1091,7 +1097,7 @@ def CALL_FUNCTION_KW(self, instr: Instruction): self.push(ret) def CALL_FUNCTION_EX(self, instr: Instruction): - flag = instr.arg + flag = instr.get_arg() if flag & 0x01: # has kwargs kwargs_variable = self.pop() assert isinstance(kwargs_variable, DictVariable) @@ -1108,7 +1114,7 @@ def CALL_FUNCTION_EX(self, instr: Instruction): self.push(ret) def CALL_METHOD(self, instr: Instruction): - n_args = instr.arg + n_args = instr.get_arg() assert n_args <= len(self._stack) args = self.pop_n(n_args) self_var = self.pop() @@ -1123,7 +1129,7 @@ def CALL_METHOD(self, instr: Instruction): push_n=1 ) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError def COMPARE_OP(self, instr: Instruction): - op = dis.cmp_op[instr.arg] + op = dis.cmp_op[instr.get_arg()] right, left = self.pop(), self.pop() self.push( BuiltinVariable( @@ -1134,9 +1140,9 @@ def COMPARE_OP(self, instr: Instruction): def IS_OP(self, instr: Instruction): # It will only be 0 or 1 - assert instr.arg == 0 or instr.arg == 1 + assert instr.get_arg() == 0 or instr.get_arg() == 1 right, left = self.pop(), self.pop() - op = "is" if instr.arg == 0 else "is not" + op = "is" if instr.get_arg() == 0 else "is not" self.push( BuiltinVariable( SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() @@ -1150,7 +1156,7 @@ def MAKE_FUNCTION(self, instr: Instruction): related_list = [fn_name, codeobj] - flag = instr.arg + flag = instr.get_arg() if flag & MF.MF_HAS_CLOSURE: # closure should be a tuple of Variables closure_variable = self.pop() @@ -1232,16 +1238,20 @@ def GET_ITER(self, instr: Instruction): ) def JUMP_FORWARD(self, instr): - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) def JUMP_ABSOLUTE(self, instr: Instruction): - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) def CONTAINS_OP(self, instr: Instruction): # It will only be 0 or 1 - assert instr.arg == 0 or instr.arg == 1 + assert instr.get_arg() == 0 or instr.get_arg() == 1 right, left = self.pop(), self.pop() - op = "in" if instr.arg == 0 else "not in" + op = "in" if instr.get_arg() == 0 else "not in" self.push( BuiltinVariable( SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() @@ -1255,7 +1265,9 @@ def JUMP_IF_FALSE_OR_POP(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = not bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) else: self.pop() return @@ -1270,7 +1282,9 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) else: self.pop() return @@ -1285,7 +1299,9 @@ def POP_JUMP_IF_FALSE(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = not bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) return raise NotImplementException( "Currently don't support predicate a non-const / non-tensor obj." @@ -1298,7 +1314,9 @@ def POP_JUMP_IF_TRUE(self, instr: Instruction): self._graph.add_global_guarded_variable(pred_obj) is_jump = bool(pred_obj) if is_jump: - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) return raise NotImplementException( "Currently don't support predicate a non-const / non-tensor obj." @@ -1325,14 +1343,14 @@ def UNPACK_SEQUENCE(self, instr: Instruction): ) assert ( - len(sequence) == instr.arg - ), f"Want unpack {sequence} to {instr.arg}, but the len is {len(sequence)}." + len(sequence) == instr.get_arg() + ), f"Want unpack {sequence} to {instr.get_arg()}, but the len is {len(sequence)}." - for i in range(instr.arg - 1, -1, -1): + for i in range(instr.get_arg() - 1, -1, -1): self.push(sequence[i]) def FORMAT_VALUE(self, instr: Instruction): - flag = instr.arg + flag = instr.get_arg() which_conversion = flag & FV.FVC_MASK have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) @@ -1374,36 +1392,38 @@ def FORMAT_VALUE(self, instr: Instruction): # NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet def DICT_UPDATE(self, instr: Instruction): dict_value = self.pop() - assert instr.arg > 0 + assert instr.get_arg() > 0 BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( - self._stack[-instr.arg], dict_value + self._stack[-instr.get_arg()], dict_value ) def DICT_MERGE(self, instr: Instruction): dict_value = self.pop() - assert instr.arg > 0 + assert instr.get_arg() > 0 for key in dict_value.get_wrapped_items().keys(): - result = self._stack[-instr.arg].get_wrapped_items().get(key, None) + result = ( + self._stack[-instr.get_arg()].get_wrapped_items().get(key, None) + ) if result is not None: raise InnerError( f"got multiple values for keyword argument '{key}'" ) BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( - self._stack[-instr.arg], dict_value + self._stack[-instr.get_arg()], dict_value ) def LIST_APPEND(self, instr: Instruction): list_value = self.pop() - assert instr.arg > 0 + assert instr.get_arg() > 0 BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( - self._stack[-instr.arg], list_value + self._stack[-instr.get_arg()], list_value ) def LIST_EXTEND(self, instr: Instruction): list_value = self.pop() - assert instr.arg > 0 + assert instr.get_arg() > 0 BuiltinVariable(list.extend, self._graph, tracker=DanglingTracker())( - self._stack[-instr.arg], list_value + self._stack[-instr.get_arg()], list_value ) def LIST_TO_TUPLE(self, instr: Instruction): @@ -1510,7 +1530,8 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): self.indexof(instr) + 1, stack_size ) else_fn, else_inputs = self._create_resume_fn( - self.indexof(instr.jump_to), stack_size + self.indexof(instr.safe_getattr("jump_to", var_type=Instruction)), + stack_size, ) # gen call static fn opcode @@ -1625,7 +1646,7 @@ def _break_graph_in_call( self._graph.pycode_gen.gen_pop_top() # gen graph break call fn opcode - stack_effect = dis.stack_effect(instr.opcode, instr.arg) + stack_effect = dis.stack_effect(instr.opcode, instr.get_arg()) pop_n = push_n - stack_effect for i, stack_arg in enumerate(self._stack): # Avoid passing NULL as a parameter to the resume function @@ -1711,11 +1732,18 @@ def _break_graph_in_for_loop( pycode_gen = PyCodeGen(self._frame) loop_body, loop_inputs = pycode_gen.gen_loop_body_between( - for_iter, loop_body_start_idx, self.indexof(for_iter.jump_to) + for_iter, + loop_body_start_idx, + self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ), ) after_loop_fn, fn_inputs = self._create_resume_fn( - self.indexof(for_iter.jump_to), len(self._stack) + self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ), + len(self._stack), ) total_inputs = OrderedSet(list(fn_inputs) + list(loop_inputs)) @@ -1815,7 +1843,9 @@ def _inline_call_for_loop( origin_instrs = get_instructions(pycode_gen._origin_code) start_idx = self.indexof(for_iter) - end_idx = self.indexof(for_iter.jump_to) + end_idx = self.indexof( + for_iter.safe_getattr("jump_to", var_type=Instruction) + ) inputs = list( analysis_inputs_outputs(origin_instrs, start_idx, end_idx) @@ -1843,7 +1873,10 @@ def _inline_call_for_loop( if ( instr.jump_to in origin_instrs - and origin_instrs.index(instr.jump_to) >= end_idx + and origin_instrs.index( + instr.safe_getattr("jump_to", var_type=Instruction) + ) + >= end_idx ): instr.jump_to = nop_for_break @@ -1874,7 +1907,7 @@ def _inline_call_for_loop( def STORE_ATTR(self, instr): obj = self.pop() val = self.pop() - key = self._code.co_names[instr.arg] + key = self._code.co_names[instr.get_arg()] if isinstance(obj, TensorVariable): # support tensor variable store attr, like: # t.stop_gradient = True @@ -1896,7 +1929,7 @@ def FOR_ITER(self, instr): backup_iter_idx = None start = self.indexof(instr) - end = self.indexof(instr.jump_to) + end = self.indexof(instr.safe_getattr("jump_to", var_type=Instruction)) for i in range(start, end): if self._instructions[i].opname == "RETURN_VALUE": raise NotImplementException( @@ -1915,7 +1948,9 @@ def FOR_ITER(self, instr): backup_iter_idx = iterator.idx self._inline_call_for_loop(iterator, instr) - self._lasti = self.indexof(instr.jump_to) + self._lasti = self.indexof( + instr.safe_getattr("jump_to", var_type=Instruction) + ) except BreakGraphError as e: if backup_iter_idx: iterator.idx = backup_iter_idx diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index f1d0baabb..82a63235b 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -3,12 +3,15 @@ import dataclasses import dis import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from .opcode_info import ABS_JUMP, ALL_JUMP, REL_JUMP if TYPE_CHECKING: import types + from typing import Any, TypeVar + + T = TypeVar("T") @dataclasses.dataclass @@ -31,6 +34,17 @@ class Instruction: def __hash__(self): return id(self) + def safe_getattr(self, attr: str, *, var_type: type[T] | None = None) -> T: + retval = getattr(self, attr) + assert var_type is None or isinstance(retval, var_type) + return retval + + def get_arg(self) -> int: + return self.safe_getattr("arg", var_type=int) + + def get_argval(self, *, var_type: type[T] | None = None) -> T: + return self.safe_getattr("arg", var_type=var_type) + def gen_instr(name, arg=None, argval=None, gened=True, jump_to=None): return Instruction( @@ -290,12 +304,15 @@ def calc_offset_from_bytecode_offset(bytecode_offset: int) -> int: return bytecode_offset // 2 -def replace_instr(instructions, instr, new_instr): +def replace_instr( + instructions: list[Instruction], instr: Instruction, new_instr +): idx = instructions.index(instr) + # TODO: maybe new_instr is a lsit? instructions[idx : idx + 1] = new_instr -def instrs_info(instrs, mark=None, range=None): +def instrs_info(instrs: dict[str, Instruction], mark=None, range=None): ret = [] start = -1 end = 1000000 From 5855d0600df09fe37a655b681322c0b1cdba954c Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 23 Jul 2023 10:48:15 +0000 Subject: [PATCH 3/6] make pop safe --- .../executor/opcode_executor.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 401185335..0023caf1a 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -74,7 +74,7 @@ if TYPE_CHECKING: from typing import TypeVar - T = TypeVar("T") + VariableT = TypeVar("VariableT", bound=VariableBase) GuardedFunction = Tuple[types.CodeType, Guard] GuardedFunctions = List[GuardedFunction] @@ -635,7 +635,7 @@ def indexof(self, instr: Instruction): """ return self._instructions.index(instr) - def pop(self) -> VariableBase: + def pop(self, *, var_type: type[VariableT] = VariableBase) -> VariableT: """ Pops the top value from the stack. @@ -643,7 +643,9 @@ def pop(self) -> VariableBase: The popped value. """ - return self._stack.pop() + var = self._stack.pop() + assert isinstance(var, var_type) + return var def peek(self) -> VariableBase: """ @@ -992,7 +994,7 @@ def BUILD_CONST_KEY_MAP(self, instr: Instruction): assert map_size + 1 <= len( self._stack ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." - keys = self.pop().get_items() + keys = self.pop(var_type=ContainerVariable).get_items() assert len(keys) == map_size values = self.pop_n(map_size) self.push(self.build_map(keys, values)) @@ -1398,7 +1400,8 @@ def DICT_UPDATE(self, instr: Instruction): ) def DICT_MERGE(self, instr: Instruction): - dict_value = self.pop() + # TODO: self._stack[index] should be replaced? + dict_value = self.pop(var_type=DictVariable) assert instr.get_arg() > 0 for key in dict_value.get_wrapped_items().keys(): result = ( @@ -1427,7 +1430,10 @@ def LIST_EXTEND(self, instr: Instruction): ) def LIST_TO_TUPLE(self, instr: Instruction): - list_value = self.pop() + # TODO(zrr1999): I think list_value should a ListVariable instance, + # but return_value of get_wrapped_items method in ListVariable is a list instead of tuple. + # list_value = self.pop(var_type=ListVariable) + list_value = self.pop(var_type=ContainerVariable) self.push( TupleVariable( list_value.get_wrapped_items(), @@ -1925,7 +1931,7 @@ def STORE_ATTR(self, instr): ) def FOR_ITER(self, instr): - iterator = self.pop() + iterator = self.pop(var_type=IterVariable) backup_iter_idx = None start = self.indexof(instr) @@ -1937,7 +1943,7 @@ def FOR_ITER(self, instr): ) self._graph.add_global_guarded_variable(iterator) - # TODO need support TensorIterVariable.next + # TODO: need support TensorIterVariable.next try: if not isinstance( @@ -1952,6 +1958,8 @@ def FOR_ITER(self, instr): instr.safe_getattr("jump_to", var_type=Instruction) ) except BreakGraphError as e: + # TODO: backup_iter_idx is not None? + # TODO: idx is not a member of IterVariable if backup_iter_idx: iterator.idx = backup_iter_idx self._graph.remove_global_guarded_variable(iterator) From 2a878e6076ced2bd08718a84cfd0da423521e8aa Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Sun, 23 Jul 2023 11:03:27 +0000 Subject: [PATCH 4/6] fix bug --- sot/opcode_translator/executor/opcode_executor.py | 2 +- sot/opcode_translator/instruction_utils/instruction_utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 0023caf1a..46fde60a0 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1652,7 +1652,7 @@ def _break_graph_in_call( self._graph.pycode_gen.gen_pop_top() # gen graph break call fn opcode - stack_effect = dis.stack_effect(instr.opcode, instr.get_arg()) + stack_effect = dis.stack_effect(instr.opcode, instr.arg) pop_n = push_n - stack_effect for i, stack_arg in enumerate(self._stack): # Avoid passing NULL as a parameter to the resume function diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 82a63235b..95c9037df 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -36,7 +36,9 @@ def __hash__(self): def safe_getattr(self, attr: str, *, var_type: type[T] | None = None) -> T: retval = getattr(self, attr) - assert var_type is None or isinstance(retval, var_type) + assert var_type is None or isinstance( + retval, var_type + ), f"{attr} is not {var_type}, but {type(retval)}" return retval def get_arg(self) -> int: From f7c35daf7b6cff14ce1b0e2763b3ce63ad2eb474 Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 27 Jul 2023 07:15:56 +0000 Subject: [PATCH 5/6] remove get_arg --- .../instruction_utils/instruction_utils.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 95c9037df..af769aa2e 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -9,9 +9,7 @@ if TYPE_CHECKING: import types - from typing import Any, TypeVar - - T = TypeVar("T") + from typing import Any @dataclasses.dataclass @@ -34,19 +32,6 @@ class Instruction: def __hash__(self): return id(self) - def safe_getattr(self, attr: str, *, var_type: type[T] | None = None) -> T: - retval = getattr(self, attr) - assert var_type is None or isinstance( - retval, var_type - ), f"{attr} is not {var_type}, but {type(retval)}" - return retval - - def get_arg(self) -> int: - return self.safe_getattr("arg", var_type=int) - - def get_argval(self, *, var_type: type[T] | None = None) -> T: - return self.safe_getattr("arg", var_type=var_type) - def gen_instr(name, arg=None, argval=None, gened=True, jump_to=None): return Instruction( From 143f69cf888f9afb442015ab5c07725810d4933f Mon Sep 17 00:00:00 2001 From: Zhan Rongrui <2742392377@qq.com> Date: Thu, 27 Jul 2023 07:16:50 +0000 Subject: [PATCH 6/6] fix --- .../executor/opcode_executor.py | 90 +++++++++---------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index ab7c74265..804a7d597 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -808,7 +808,7 @@ def NOP(self, instr: Instruction): pass def LOAD_ATTR(self, instr: Instruction): - attr_name = self._code.co_names[instr.get_arg()] + attr_name = self._code.co_names[instr.arg] attr_name_var = ConstantVariable.wrap_literal(attr_name, self._graph) obj = self.pop() self.push( @@ -818,30 +818,30 @@ def LOAD_ATTR(self, instr: Instruction): ) def LOAD_CONST(self, instr: Instruction): - var = self._co_consts[instr.get_arg()] + var = self._co_consts[instr.arg] self.push(var) def LOAD_CLOSURE(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.get_arg()] + name = namemap[instr.arg] self.push(self._cells[name]) def LOAD_DEREF(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.get_arg()] + name = namemap[instr.arg] self.push(self._cells[name].cell_content()) def LOAD_FAST(self, instr: Instruction): - varname = self._code.co_varnames[instr.get_arg()] + varname = self._code.co_varnames[instr.arg] var = self._locals[varname] self.push(var) def DELETE_FAST(self, instr: Instruction): - varname = self._code.co_varnames[instr.get_arg()] + varname = self._code.co_varnames[instr.arg] del self._locals[varname] def LOAD_GLOBAL(self, instr: Instruction): - name = self._code.co_names[instr.get_arg()] + name = self._code.co_names[instr.arg] if name in self._globals.keys(): value = self._globals[name] else: @@ -849,7 +849,7 @@ def LOAD_GLOBAL(self, instr: Instruction): self.push(value) def LOAD_METHOD(self, instr: Instruction): - method_name = self._code.co_names[instr.get_arg()] + method_name = self._code.co_names[instr.arg] method_name_var = ConstantVariable.wrap_literal( method_name, self._graph ) @@ -870,7 +870,7 @@ def LOAD_METHOD(self, instr: Instruction): def STORE_DEREF(self, instr): namemap = self._code.co_cellvars + self._code.co_freevars - name = namemap[instr.get_arg()] + name = namemap[instr.arg] self._cells[name].set_value(self.pop()) def STORE_FAST(self, instr: Instruction): @@ -878,13 +878,13 @@ def STORE_FAST(self, instr: Instruction): TODO: side effect may happen """ var = self.pop() - name = self._code.co_varnames[instr.get_arg()] + name = self._code.co_varnames[instr.arg] var.debug_name = name self._locals[name] = var def STORE_GLOBAL(self, instr: Instruction): var = self.pop() - name = self._code.co_names[instr.get_arg()] + name = self._code.co_names[instr.arg] var.debug_name = name self._locals[name] = var @@ -912,7 +912,7 @@ def DELETE_SUBSCR(self, instr: Instruction): ) def BUILD_LIST(self, instr: Instruction): - list_size = instr.get_arg() + list_size = instr.arg assert list_size <= len( self._stack ), f"OpExecutor want BUILD_LIST with size {list_size}, but current stack do not have enough elems." @@ -924,7 +924,7 @@ def BUILD_LIST(self, instr: Instruction): ) def BUILD_TUPLE(self, instr: Instruction): - tuple_size = instr.get_arg() + tuple_size = instr.arg assert tuple_size <= len( self._stack ), f"OpExecutor want BUILD_TUPLE with size {tuple_size}, but current stack do not have enough elems." @@ -938,7 +938,7 @@ def BUILD_TUPLE(self, instr: Instruction): ) def BUILD_STRING(self, instr: Instruction): - count = instr.get_arg() + count = instr.arg assert count <= len( self._stack ), f"OpExecutor want BUILD_STRING with size {count}, but current stack do not have enough elems." @@ -955,7 +955,7 @@ def BUILD_STRING(self, instr: Instruction): @call_break_graph_decorator(push_n=1) def BUILD_SLICE(self, instr: Instruction): - if instr.get_arg() == 3: + if instr.arg == 3: step = self.pop() else: step = None @@ -989,7 +989,7 @@ def build_map( ) def BUILD_MAP(self, instr: Instruction): - map_size = instr.get_arg() + map_size = instr.arg assert map_size * 2 <= len( self._stack ), f"OpExecutor want BUILD_MAP with size {map_size} * 2, but current stack do not have enough elems." @@ -999,7 +999,7 @@ def BUILD_MAP(self, instr: Instruction): self.push(self.build_map(keys, values)) def BUILD_CONST_KEY_MAP(self, instr: Instruction): - map_size = instr.get_arg() + map_size = instr.arg assert map_size + 1 <= len( self._stack ), f"OpExecutor want BUILD_CONST_KEY_MAP with size {map_size} + 1, but current stack do not have enough elems." @@ -1009,7 +1009,7 @@ def BUILD_CONST_KEY_MAP(self, instr: Instruction): self.push(self.build_map(keys, values)) def build_seq_unpack(self, instr: Instruction): - oparg = instr.get_arg() + oparg = instr.arg assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1040,7 +1040,7 @@ def BUILD_LIST_UNPACK(self, instr: Instruction): self.build_seq_unpack(instr) def BUILD_MAP_UNPACK(self, instr: Instruction): - oparg = instr.get_arg() + oparg = instr.arg assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1056,7 +1056,7 @@ def BUILD_MAP_UNPACK(self, instr: Instruction): ) def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): - oparg = instr.get_arg() + oparg = instr.arg assert oparg <= len(self._stack) unpack_values = self.pop_n(oparg) @@ -1077,7 +1077,7 @@ def BUILD_MAP_UNPACK_WITH_CALL(self, instr: Instruction): ) def CALL_FUNCTION(self, instr: Instruction): - n_args = instr.get_arg() + n_args = instr.arg assert n_args <= len(self._stack) args = self.pop_n(n_args) kwargs = {} @@ -1086,7 +1086,7 @@ def CALL_FUNCTION(self, instr: Instruction): self.push(ret) def CALL_FUNCTION_KW(self, instr: Instruction): - n_args = instr.get_arg() + n_args = instr.arg assert n_args + 2 <= len(self._stack) kwargs_keys = self.pop() @@ -1108,7 +1108,7 @@ def CALL_FUNCTION_KW(self, instr: Instruction): self.push(ret) def CALL_FUNCTION_EX(self, instr: Instruction): - flag = instr.get_arg() + flag = instr.arg if flag & 0x01: # has kwargs kwargs_variable = self.pop() assert isinstance(kwargs_variable, DictVariable) @@ -1125,7 +1125,7 @@ def CALL_FUNCTION_EX(self, instr: Instruction): self.push(ret) def CALL_METHOD(self, instr: Instruction): - n_args = instr.get_arg() + n_args = instr.arg assert n_args <= len(self._stack) args = self.pop_n(n_args) self_var = self.pop() @@ -1140,7 +1140,7 @@ def CALL_METHOD(self, instr: Instruction): push_n=1 ) # call instance, in, not in may call TensorVariable.get_py_value, which raise BreakGraphError def COMPARE_OP(self, instr: Instruction): - op = dis.cmp_op[instr.get_arg()] + op = dis.cmp_op[instr.arg] right, left = self.pop(), self.pop() self.push( BuiltinVariable( @@ -1151,9 +1151,9 @@ def COMPARE_OP(self, instr: Instruction): def IS_OP(self, instr: Instruction): # It will only be 0 or 1 - assert instr.get_arg() == 0 or instr.get_arg() == 1 + assert instr.arg == 0 or instr.arg == 1 right, left = self.pop(), self.pop() - op = "is" if instr.get_arg() == 0 else "is not" + op = "is" if instr.arg == 0 else "is not" self.push( BuiltinVariable( SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() @@ -1167,7 +1167,7 @@ def MAKE_FUNCTION(self, instr: Instruction): related_list = [fn_name, codeobj] - flag = instr.get_arg() + flag = instr.arg if flag & MF.MF_HAS_CLOSURE: # closure should be a tuple of Variables closure_variable = self.pop() @@ -1260,9 +1260,9 @@ def JUMP_ABSOLUTE(self, instr: Instruction): def CONTAINS_OP(self, instr: Instruction): # It will only be 0 or 1 - assert instr.get_arg() == 0 or instr.get_arg() == 1 + assert instr.arg == 0 or instr.arg == 1 right, left = self.pop(), self.pop() - op = "in" if instr.get_arg() == 0 else "not in" + op = "in" if instr.arg == 0 else "not in" self.push( BuiltinVariable( SUPPORT_COMPARE_OP[op], self._graph, DanglingTracker() @@ -1354,14 +1354,14 @@ def UNPACK_SEQUENCE(self, instr: Instruction): ) assert ( - len(sequence) == instr.get_arg() - ), f"Want unpack {sequence} to {instr.get_arg()}, but the len is {len(sequence)}." + len(sequence) == instr.arg + ), f"Want unpack {sequence} to {instr.arg}, but the len is {len(sequence)}." - for i in range(instr.get_arg() - 1, -1, -1): + for i in range(instr.arg - 1, -1, -1): self.push(sequence[i]) def FORMAT_VALUE(self, instr: Instruction): - flag = instr.get_arg() + flag = instr.arg which_conversion = flag & FV.FVC_MASK have_fmt_spec = bool((flag & FV.FVS_MASK) == FV.FVS_HAVE_SPEC) @@ -1403,39 +1403,37 @@ def FORMAT_VALUE(self, instr: Instruction): # NOTE: This operation will generate SideEffects, and the mechanism has not been completed yet def DICT_UPDATE(self, instr: Instruction): dict_value = self.pop() - assert instr.get_arg() > 0 + assert instr.arg > 0 BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( - self._stack[-instr.get_arg()], dict_value + self._stack[-instr.arg], dict_value ) def DICT_MERGE(self, instr: Instruction): # TODO: self._stack[index] should be replaced? dict_value = self.pop(var_type=DictVariable) - assert instr.get_arg() > 0 + assert instr.arg > 0 for key in dict_value.get_wrapped_items().keys(): - result = ( - self._stack[-instr.get_arg()].get_wrapped_items().get(key, None) - ) + result = self._stack[-instr.arg].get_wrapped_items().get(key, None) if result is not None: raise InnerError( f"got multiple values for keyword argument '{key}'" ) BuiltinVariable(dict.update, self._graph, tracker=DanglingTracker())( - self._stack[-instr.get_arg()], dict_value + self._stack[-instr.arg], dict_value ) def LIST_APPEND(self, instr: Instruction): list_value = self.pop() - assert instr.get_arg() > 0 + assert instr.arg > 0 BuiltinVariable(list.append, self._graph, tracker=DanglingTracker())( - self._stack[-instr.get_arg()], list_value + self._stack[-instr.arg], list_value ) def LIST_EXTEND(self, instr: Instruction): list_value = self.pop() - assert instr.get_arg() > 0 + assert instr.arg > 0 BuiltinVariable(list.extend, self._graph, tracker=DanglingTracker())( - self._stack[-instr.get_arg()], list_value + self._stack[-instr.arg], list_value ) def LIST_TO_TUPLE(self, instr: Instruction): @@ -1920,7 +1918,7 @@ def _inline_call_for_loop( def STORE_ATTR(self, instr): obj = self.pop() val = self.pop() - key = self._code.co_names[instr.get_arg()] + key = self._code.co_names[instr.arg] if isinstance(obj, TensorVariable): # support tensor variable store attr, like: # t.stop_gradient = True