1313import sys
1414import tokenize
1515import types
16+ from collections import defaultdict
1617from pathlib import Path
1718from pathlib import PurePath
1819from typing import Callable
5657 astNum = ast .Num
5758
5859
60+ class Sentinel :
61+ pass
62+
63+
5964assertstate_key = StashKey ["AssertionState" ]()
6065
6166# pytest caches rewritten pycs in pycache dirs
6267PYTEST_TAG = f"{ sys .implementation .cache_tag } -pytest-{ version } "
6368PYC_EXT = ".py" + (__debug__ and "c" or "o" )
6469PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
6570
71+ # Special marker that denotes we have just left a scope definition
72+ _SCOPE_END_MARKER = Sentinel ()
73+
6674
6775class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
6876 """PEP302/PEP451 import hook which rewrites asserts."""
@@ -645,6 +653,8 @@ class AssertionRewriter(ast.NodeVisitor):
645653 .push_format_context() and .pop_format_context() which allows
646654 to build another %-formatted string while already building one.
647655
656+ :scope: A tuple containing the current scope used for variables_overwrite.
657+
648658 :variables_overwrite: A dict filled with references to variables
649659 that change value within an assert. This happens when a variable is
650660 reassigned with the walrus operator
@@ -666,7 +676,10 @@ def __init__(
666676 else :
667677 self .enable_assertion_pass_hook = False
668678 self .source = source
669- self .variables_overwrite : Dict [str , str ] = {}
679+ self .scope : tuple [ast .AST , ...] = ()
680+ self .variables_overwrite : defaultdict [
681+ tuple [ast .AST , ...], Dict [str , str ]
682+ ] = defaultdict (dict )
670683
671684 def run (self , mod : ast .Module ) -> None :
672685 """Find all assert statements in *mod* and rewrite them."""
@@ -732,9 +745,17 @@ def run(self, mod: ast.Module) -> None:
732745 mod .body [pos :pos ] = imports
733746
734747 # Collect asserts.
735- nodes : List [ast .AST ] = [mod ]
748+ self .scope = (mod ,)
749+ nodes : List [Union [ast .AST , Sentinel ]] = [mod ]
736750 while nodes :
737751 node = nodes .pop ()
752+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
753+ self .scope = tuple ((* self .scope , node ))
754+ nodes .append (_SCOPE_END_MARKER )
755+ if node == _SCOPE_END_MARKER :
756+ self .scope = self .scope [:- 1 ]
757+ continue
758+ assert isinstance (node , ast .AST )
738759 for name , field in ast .iter_fields (node ):
739760 if isinstance (field , list ):
740761 new : List [ast .AST ] = []
@@ -1005,7 +1026,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
10051026 ]
10061027 ):
10071028 pytest_temp = self .variable ()
1008- self .variables_overwrite [
1029+ self .variables_overwrite [self . scope ][
10091030 v .left .target .id
10101031 ] = v .left # type:ignore[assignment]
10111032 v .left .target .id = pytest_temp
@@ -1048,17 +1069,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10481069 new_args = []
10491070 new_kwargs = []
10501071 for arg in call .args :
1051- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1052- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1072+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1073+ self .scope , {}
1074+ ):
1075+ arg = self .variables_overwrite [self .scope ][
1076+ arg .id
1077+ ] # type:ignore[assignment]
10531078 res , expl = self .visit (arg )
10541079 arg_expls .append (expl )
10551080 new_args .append (res )
10561081 for keyword in call .keywords :
1057- if (
1058- isinstance (keyword .value , ast .Name )
1059- and keyword .value .id in self .variables_overwrite
1060- ):
1061- keyword .value = self .variables_overwrite [
1082+ if isinstance (
1083+ keyword .value , ast .Name
1084+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1085+ keyword .value = self .variables_overwrite [self .scope ][
10621086 keyword .value .id
10631087 ] # type:ignore[assignment]
10641088 res , expl = self .visit (keyword .value )
@@ -1094,12 +1118,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10941118 def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
10951119 self .push_format_context ()
10961120 # We first check if we have overwritten a variable in the previous assert
1097- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1098- comp .left = self .variables_overwrite [
1121+ if isinstance (
1122+ comp .left , ast .Name
1123+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1124+ comp .left = self .variables_overwrite [self .scope ][
10991125 comp .left .id
11001126 ] # type:ignore[assignment]
11011127 if isinstance (comp .left , namedExpr ):
1102- self .variables_overwrite [
1128+ self .variables_overwrite [self . scope ][
11031129 comp .left .target .id
11041130 ] = comp .left # type:ignore[assignment]
11051131 left_res , left_expl = self .visit (comp .left )
@@ -1119,7 +1145,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11191145 and next_operand .target .id == left_res .id
11201146 ):
11211147 next_operand .target .id = self .variable ()
1122- self .variables_overwrite [
1148+ self .variables_overwrite [self . scope ][
11231149 left_res .id
11241150 ] = next_operand # type:ignore[assignment]
11251151 next_res , next_expl = self .visit (next_operand )
0 commit comments