Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,10 @@ def LOAD_DEREF(self, instr: Instruction):
return
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
if isinstance(self._cells[name], NullVariable):
raise InnerError(
f"Deref variable '{name}' absent or have been deleted"
)
self.stack.push(self._cells[name].cell_content())

def COPY_FREE_VARS(self, instr: Instruction):
Expand Down Expand Up @@ -1005,6 +1009,15 @@ def STORE_DEREF(self, instr: Instruction):
name = namemap[instr.arg]
self._cells[name].set_value(self.stack.pop())

def DELETE_DEREF(self, instr: Instruction):
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
if isinstance(self._cells[name], NullVariable):
raise InnerError(
f"Deref variable '{name}' absent or have been deleted"
)
self._cells[name] = NullVariable()

def STORE_FAST(self, instr: Instruction):
"""
TODO: side effect may happen
Expand Down
113 changes: 70 additions & 43 deletions tests/test_19_closure.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import inspect
import types
import unittest

from test_case_base import TestCaseBase, strict_mode_guard

import paddle
from sot.psdb import check_no_breakgraph


def foo(x: int, y: paddle.Tensor):
Expand Down Expand Up @@ -100,6 +102,17 @@ def load_1(a, b=5):
return load_1(1)


@check_no_breakgraph
def closure_del():
x = 0

def load():
nonlocal x
del x

return load


import numpy as np


Expand Down Expand Up @@ -142,7 +155,7 @@ def func7(a, b):
return a + b


def foo7():
def test_builtin_decorator():
return func7(3, 5)


Expand All @@ -155,26 +168,24 @@ def closure():
return closure


class TestExecutor(TestCaseBase):
def test_closure(self):
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(foo2, paddle.to_tensor(2))
self.assert_results(foo3, paddle.to_tensor(2))
self.assert_results_with_global_check(
test_global, ["global_z"], paddle.to_tensor(2)
)
self.assert_results(foo5, paddle.to_tensor(2))
self.assert_results(foo6, paddle.to_tensor(2))
self.assert_results(numpy_sum, paddle.to_tensor(1))
with strict_mode_guard(0):
self.assert_results(
lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1)
)
def non_local_test(t: paddle.Tensor):
a = 1

def func1():
nonlocal a
t = a
a = 2
return t

class TestExecutor2(TestCaseBase):
def test_closure(self):
self.assert_results(foo7)
def func2():
nonlocal a
a = 1
return a

t += func1() # add 2
t += func2() # add 1
t += a # add 1
return t


# Side Effect.
Expand All @@ -195,40 +206,56 @@ def test_slice_in_for_loop(x, iter_num=3):
return out


class TestExecutor3(TestCaseBase):
class TestClosure(TestCaseBase):
def test_closure(self):
tx = paddle.to_tensor([1.0, 2.0, 3.0])
# need side effect of list.
# self.assert_results(test_slice_in_for_loop, tx)


def non_local_test(t: paddle.Tensor):
a = 1
self.assert_results(foo, 1, paddle.to_tensor(2))
self.assert_results(foo2, paddle.to_tensor(2))
self.assert_results(foo3, paddle.to_tensor(2))
self.assert_results(foo5, paddle.to_tensor(2))

def func1():
nonlocal a
t = a
a = 2
return t
def test_global(self):
self.assert_results_with_global_check(
test_global, ["global_z"], paddle.to_tensor(2)
)

def func2():
nonlocal a
a = 1
return a
def test_lambda(self):
with strict_mode_guard(0):
self.assert_results(
lambda_closure, paddle.to_tensor(2), paddle.to_tensor(1)
)

t += func1() # add 2
t += func2() # add 1
t += a # add 1
return t
def test_numpy(self):
self.assert_results(numpy_sum, paddle.to_tensor(1))

def test_del_deref(self):
def is_empty_cell(cell: types.CellType):
try:
cell.cell_contents # noqa: B018
return False
except ValueError as e:
if "Cell is empty" in str(e):
return True
return False

closure_del_func = closure_del()
# Why is it 0: Only one value x is stored in the closure_del method __Closure__ in
self.assertFalse(is_empty_cell(closure_del_func.__closure__[0]))
closure_del_func()
self.assertTrue(is_empty_cell(closure_del_func.__closure__[0]))

def test_decorator(self):
self.assert_results(test_builtin_decorator)
self.assert_results(foo6, paddle.to_tensor(2))

class TestExecutor4(TestCaseBase):
def test_closure(self):
def test_nolocal(self):
tx = paddle.to_tensor([1.0])
self.assert_results(non_local_test, tx)

def test_side_effect(self):
tx = paddle.to_tensor([1.0, 2.0, 3.0])
# need side effect of list.
self.assert_results_with_side_effects(test_slice_in_for_loop, tx)

class TestCreateClosure(TestCaseBase):
def test_create_closure(self):
closure = create_closure()
self.assert_results(closure)
Expand Down