@@ -154,18 +154,17 @@ def test_optimization_levels__debug__(self):
154154 self .assertEqual (res .body [0 ].value .id , expected )
155155
156156 def test_optimization_levels_const_folding (self ):
157- folded = ('Expr' , (1 , 0 , 1 , 5 ), ('Constant' , (1 , 0 , 1 , 5 ), 3 , None ))
158- not_folded = ('Expr' , (1 , 0 , 1 , 5 ),
159- ('BinOp' , (1 , 0 , 1 , 5 ),
160- ('Constant' , (1 , 0 , 1 , 1 ), 1 , None ),
161- ('Add' ,),
162- ('Constant' , (1 , 4 , 1 , 5 ), 2 , None )))
157+ folded = ('Expr' , (1 , 0 , 1 , 6 ), ('Constant' , (1 , 0 , 1 , 6 ), (1 , 2 ), None ))
158+ not_folded = ('Expr' , (1 , 0 , 1 , 6 ),
159+ ('Tuple' , (1 , 0 , 1 , 6 ),
160+ [('Constant' , (1 , 1 , 1 , 2 ), 1 , None ),
161+ ('Constant' , (1 , 4 , 1 , 5 ), 2 , None )], ('Load' ,)))
163162
164163 cases = [(- 1 , not_folded ), (0 , not_folded ), (1 , folded ), (2 , folded )]
165164 for (optval , expected ) in cases :
166165 with self .subTest (optval = optval ):
167- tree1 = ast .parse ("1 + 2 " , optimize = optval )
168- tree2 = ast .parse (ast .parse ("1 + 2 " ), optimize = optval )
166+ tree1 = ast .parse ("(1, 2) " , optimize = optval )
167+ tree2 = ast .parse (ast .parse ("(1, 2) " ), optimize = optval )
169168 for tree in [tree1 , tree2 ]:
170169 res = to_tuple (tree .body [0 ])
171170 self .assertEqual (res , expected )
@@ -3089,27 +3088,6 @@ def test_cli_file_input(self):
30893088
30903089
30913090class ASTOptimiziationTests (unittest .TestCase ):
3092- binop = {
3093- "+" : ast .Add (),
3094- "-" : ast .Sub (),
3095- "*" : ast .Mult (),
3096- "/" : ast .Div (),
3097- "%" : ast .Mod (),
3098- "<<" : ast .LShift (),
3099- ">>" : ast .RShift (),
3100- "|" : ast .BitOr (),
3101- "^" : ast .BitXor (),
3102- "&" : ast .BitAnd (),
3103- "//" : ast .FloorDiv (),
3104- "**" : ast .Pow (),
3105- }
3106-
3107- unaryop = {
3108- "~" : ast .Invert (),
3109- "+" : ast .UAdd (),
3110- "-" : ast .USub (),
3111- }
3112-
31133091 def wrap_expr (self , expr ):
31143092 return ast .Module (body = [ast .Expr (value = expr )])
31153093
@@ -3141,83 +3119,6 @@ def assert_ast(self, code, non_optimized_target, optimized_target):
31413119 f"{ ast .dump (optimized_tree )} " ,
31423120 )
31433121
3144- def create_binop (self , operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3145- return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3146-
3147- def test_folding_binop (self ):
3148- code = "1 %s 1"
3149- operators = self .binop .keys ()
3150-
3151- for op in operators :
3152- result_code = code % op
3153- non_optimized_target = self .wrap_expr (self .create_binop (op ))
3154- optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3155-
3156- with self .subTest (
3157- result_code = result_code ,
3158- non_optimized_target = non_optimized_target ,
3159- optimized_target = optimized_target
3160- ):
3161- self .assert_ast (result_code , non_optimized_target , optimized_target )
3162-
3163- # Multiplication of constant tuples must be folded
3164- code = "(1,) * 3"
3165- non_optimized_target = self .wrap_expr (self .create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3166- optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3167-
3168- self .assert_ast (code , non_optimized_target , optimized_target )
3169-
3170- def test_folding_unaryop (self ):
3171- code = "%s1"
3172- operators = self .unaryop .keys ()
3173-
3174- def create_unaryop (operand ):
3175- return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3176-
3177- for op in operators :
3178- result_code = code % op
3179- non_optimized_target = self .wrap_expr (create_unaryop (op ))
3180- optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3181-
3182- with self .subTest (
3183- result_code = result_code ,
3184- non_optimized_target = non_optimized_target ,
3185- optimized_target = optimized_target
3186- ):
3187- self .assert_ast (result_code , non_optimized_target , optimized_target )
3188-
3189- def test_folding_not (self ):
3190- code = "not (1 %s (1,))"
3191- operators = {
3192- "in" : ast .In (),
3193- "is" : ast .Is (),
3194- }
3195- opt_operators = {
3196- "is" : ast .IsNot (),
3197- "in" : ast .NotIn (),
3198- }
3199-
3200- def create_notop (operand ):
3201- return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3202- left = ast .Constant (value = 1 ),
3203- ops = [operators [operand ]],
3204- comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3205- ))
3206-
3207- for op in operators .keys ():
3208- result_code = code % op
3209- non_optimized_target = self .wrap_expr (create_notop (op ))
3210- optimized_target = self .wrap_expr (
3211- ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3212- )
3213-
3214- with self .subTest (
3215- result_code = result_code ,
3216- non_optimized_target = non_optimized_target ,
3217- optimized_target = optimized_target
3218- ):
3219- self .assert_ast (result_code , non_optimized_target , optimized_target )
3220-
32213122 def test_folding_format (self ):
32223123 code = "'%s' % (a,)"
32233124
@@ -3247,9 +3148,9 @@ def test_folding_tuple(self):
32473148 self .assert_ast (code , non_optimized_target , optimized_target )
32483149
32493150 def test_folding_type_param_in_function_def (self ):
3250- code = "def foo[%s = 1 + 1 ](): pass"
3151+ code = "def foo[%s = (1, 2) ](): pass"
32513152
3252- unoptimized_binop = self . create_binop ( "+" )
3153+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
32533154 unoptimized_type_params = [
32543155 ("T" , "T" , ast .TypeVar ),
32553156 ("**P" , "P" , ast .ParamSpec ),
@@ -3263,23 +3164,23 @@ def test_folding_type_param_in_function_def(self):
32633164 name = 'foo' ,
32643165 args = ast .arguments (),
32653166 body = [ast .Pass ()],
3266- type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3167+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))]
32673168 )
32683169 )
32693170 non_optimized_target = self .wrap_statement (
32703171 ast .FunctionDef (
32713172 name = 'foo' ,
32723173 args = ast .arguments (),
32733174 body = [ast .Pass ()],
3274- type_params = [type_param (name = name , default_value = unoptimized_binop )]
3175+ type_params = [type_param (name = name , default_value = unoptimized_tuple )]
32753176 )
32763177 )
32773178 self .assert_ast (result_code , non_optimized_target , optimized_target )
32783179
32793180 def test_folding_type_param_in_class_def (self ):
3280- code = "class foo[%s = 1 + 1 ]: pass"
3181+ code = "class foo[%s = (1, 2) ]: pass"
32813182
3282- unoptimized_binop = self . create_binop ( "+" )
3183+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
32833184 unoptimized_type_params = [
32843185 ("T" , "T" , ast .TypeVar ),
32853186 ("**P" , "P" , ast .ParamSpec ),
@@ -3292,22 +3193,22 @@ def test_folding_type_param_in_class_def(self):
32923193 ast .ClassDef (
32933194 name = 'foo' ,
32943195 body = [ast .Pass ()],
3295- type_params = [type_param (name = name , default_value = ast .Constant (2 ))]
3196+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))]
32963197 )
32973198 )
32983199 non_optimized_target = self .wrap_statement (
32993200 ast .ClassDef (
33003201 name = 'foo' ,
33013202 body = [ast .Pass ()],
3302- type_params = [type_param (name = name , default_value = unoptimized_binop )]
3203+ type_params = [type_param (name = name , default_value = unoptimized_tuple )]
33033204 )
33043205 )
33053206 self .assert_ast (result_code , non_optimized_target , optimized_target )
33063207
33073208 def test_folding_type_param_in_type_alias (self ):
3308- code = "type foo[%s = 1 + 1 ] = 1"
3209+ code = "type foo[%s = (1, 2) ] = 1"
33093210
3310- unoptimized_binop = self . create_binop ( "+" )
3211+ unoptimized_tuple = ast . Tuple ( elts = [ ast . Constant ( 1 ), ast . Constant ( 2 )] )
33113212 unoptimized_type_params = [
33123213 ("T" , "T" , ast .TypeVar ),
33133214 ("**P" , "P" , ast .ParamSpec ),
@@ -3319,19 +3220,80 @@ def test_folding_type_param_in_type_alias(self):
33193220 optimized_target = self .wrap_statement (
33203221 ast .TypeAlias (
33213222 name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3322- type_params = [type_param (name = name , default_value = ast .Constant (2 ))],
3223+ type_params = [type_param (name = name , default_value = ast .Constant (( 1 , 2 ) ))],
33233224 value = ast .Constant (value = 1 ),
33243225 )
33253226 )
33263227 non_optimized_target = self .wrap_statement (
33273228 ast .TypeAlias (
33283229 name = ast .Name (id = 'foo' , ctx = ast .Store ()),
3329- type_params = [type_param (name = name , default_value = unoptimized_binop )],
3230+ type_params = [type_param (name = name , default_value = unoptimized_tuple )],
33303231 value = ast .Constant (value = 1 ),
33313232 )
33323233 )
33333234 self .assert_ast (result_code , non_optimized_target , optimized_target )
33343235
3236+ def test_folding_match_case_allowed_expressions (self ):
3237+ def get_match_case_values (node ):
3238+ result = []
3239+ if isinstance (node , ast .Constant ):
3240+ result .append (node .value )
3241+ elif isinstance (node , ast .MatchValue ):
3242+ result .extend (get_match_case_values (node .value ))
3243+ elif isinstance (node , ast .MatchMapping ):
3244+ for key in node .keys :
3245+ result .extend (get_match_case_values (key ))
3246+ elif isinstance (node , ast .MatchSequence ):
3247+ for pat in node .patterns :
3248+ result .extend (get_match_case_values (pat ))
3249+ else :
3250+ self .fail (f"Unexpected node { node } " )
3251+ return result
3252+
3253+ tests = [
3254+ ("-0" , [0 ]),
3255+ ("-0.1" , [- 0.1 ]),
3256+ ("-0j" , [complex (0 , 0 )]),
3257+ ("-0.1j" , [complex (0 , - 0.1 )]),
3258+ ("1 + 2j" , [complex (1 , 2 )]),
3259+ ("1 - 2j" , [complex (1 , - 2 )]),
3260+ ("1.1 + 2.1j" , [complex (1.1 , 2.1 )]),
3261+ ("1.1 - 2.1j" , [complex (1.1 , - 2.1 )]),
3262+ ("-0 + 1j" , [complex (0 , 1 )]),
3263+ ("-0 - 1j" , [complex (0 , - 1 )]),
3264+ ("-0.1 + 1.1j" , [complex (- 0.1 , 1.1 )]),
3265+ ("-0.1 - 1.1j" , [complex (- 0.1 , - 1.1 )]),
3266+ ("{-0: 0}" , [0 ]),
3267+ ("{-0.1: 0}" , [- 0.1 ]),
3268+ ("{-0j: 0}" , [complex (0 , 0 )]),
3269+ ("{-0.1j: 0}" , [complex (0 , - 0.1 )]),
3270+ ("{1 + 2j: 0}" , [complex (1 , 2 )]),
3271+ ("{1 - 2j: 0}" , [complex (1 , - 2 )]),
3272+ ("{1.1 + 2.1j: 0}" , [complex (1.1 , 2.1 )]),
3273+ ("{1.1 - 2.1j: 0}" , [complex (1.1 , - 2.1 )]),
3274+ ("{-0 + 1j: 0}" , [complex (0 , 1 )]),
3275+ ("{-0 - 1j: 0}" , [complex (0 , - 1 )]),
3276+ ("{-0.1 + 1.1j: 0}" , [complex (- 0.1 , 1.1 )]),
3277+ ("{-0.1 - 1.1j: 0}" , [complex (- 0.1 , - 1.1 )]),
3278+ ("{-0: 0, 0 + 1j: 0, 0.1 + 1j: 0}" , [0 , complex (0 , 1 ), complex (0.1 , 1 )]),
3279+ ("[-0, -0.1, -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3280+ ("[[[[-0, -0.1, -0j, -0.1j]]]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3281+ ("[[-0, -0.1], -0j, -0.1j]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3282+ ("[[-0, -0.1], [-0j, -0.1j]]" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3283+ ("(-0, -0.1, -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3284+ ("((((-0, -0.1, -0j, -0.1j))))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3285+ ("((-0, -0.1), -0j, -0.1j)" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3286+ ("((-0, -0.1), (-0j, -0.1j))" , [0 , - 0.1 , complex (0 , 0 ), complex (0 , - 0.1 )]),
3287+ ]
3288+ for match_expr , constants in tests :
3289+ with self .subTest (match_expr ):
3290+ src = f"match 0:\n \t case { match_expr } : pass"
3291+ tree = ast .parse (src , optimize = 1 )
3292+ match_stmt = tree .body [0 ]
3293+ case = match_stmt .cases [0 ]
3294+ values = get_match_case_values (case .pattern )
3295+ self .assertListEqual (constants , values )
3296+
33353297
33363298if __name__ == '__main__' :
33373299 if len (sys .argv ) > 1 and sys .argv [1 ] == '--snapshot-update' :
0 commit comments