11from copy import copy
2+ from textwrap import dedent
23
34import numpy as np
45from numpy .core .numeric import normalize_axis_tuple
@@ -1448,116 +1449,114 @@ def infer_shape(self, fgraph, node, shapes):
14481449 return ((),)
14491450 return ([ishape [i ] for i in range (node .inputs [0 ].type .ndim ) if i not in axis ],)
14501451
1451- def _c_all (self , node , name , inames , onames , sub ):
1452- input = node .inputs [0 ]
1453- output = node .outputs [0 ]
1452+ def _c_all (self , node , name , input_names , output_names , sub ):
1453+ [inp ] = node .inputs
1454+ [out ] = node .outputs
1455+ ndim = inp .type .ndim
14541456
1455- iname = inames [ 0 ]
1456- oname = onames [ 0 ]
1457+ [ inp_name ] = input_names
1458+ [ out_name ] = output_names
14571459
1458- idtype = input .type .dtype_specs ()[1 ]
1459- odtype = output .type .dtype_specs ()[1 ]
1460+ inp_dtype = inp .type .dtype_specs ()[1 ]
1461+ out_dtype = out .type .dtype_specs ()[1 ]
14601462
14611463 acc_dtype = getattr (self , "acc_dtype" , None )
14621464
14631465 if acc_dtype is not None :
14641466 if acc_dtype == "float16" :
14651467 raise MethodNotDefined ("no c_code for float16" )
14661468 acc_type = TensorType (shape = node .outputs [0 ].type .shape , dtype = acc_dtype )
1467- adtype = acc_type .dtype_specs ()[1 ]
1469+ acc_dtype = acc_type .dtype_specs ()[1 ]
14681470 else :
1469- adtype = odtype
1471+ acc_dtype = out_dtype
14701472
14711473 axis = self .axis
14721474 if axis is None :
1473- axis = list (range (input .type .ndim ))
1475+ axis = list (range (inp .type .ndim ))
14741476
14751477 if len (axis ) == 0 :
1478+ # This is just an Elemwise cast operation
14761479 # The acc_dtype is never a downcast compared to the input dtype
14771480 # So we just need a cast to the output dtype.
1478- var = pytensor .tensor .basic .cast (input , node .outputs [0 ].dtype )
1479- if var is input :
1480- var = Elemwise (scalar_identity )(input )
1481+ var = pytensor .tensor .basic .cast (inp , node .outputs [0 ].dtype )
1482+ if var is inp :
1483+ var = Elemwise (scalar_identity )(inp )
14811484 assert var .dtype == node .outputs [0 ].dtype
1482- return var .owner .op ._c_all (var .owner , name , inames , onames , sub )
1483-
1484- order1 = [i for i in range (input .type .ndim ) if i not in axis ]
1485- order = order1 + list (axis )
1485+ return var .owner .op ._c_all (var .owner , name , input_names , output_names , sub )
14861486
1487- nnested = len (order1 )
1487+ inp_dims = list (range (ndim ))
1488+ non_reduced_dims = [i for i in inp_dims if i not in axis ]
1489+ counter = iter (range (ndim ))
1490+ acc_dims = ["x" if i in axis else next (counter ) for i in range (ndim )]
14881491
1489- sub = dict (sub )
1490- for i , (input , iname ) in enumerate (zip (node .inputs , inames )):
1491- sub [f"lv{ i } " ] = iname
1492+ sub = sub .copy ()
1493+ sub ["lv0" ] = inp_name
1494+ sub ["lv1" ] = out_name
1495+ sub ["olv" ] = out_name
14921496
1493- decl = ""
1494- if adtype != odtype :
1497+ if acc_dtype != out_dtype :
14951498 # Create an accumulator variable different from the output
1496- aname = "acc"
1497- decl = acc_type .c_declare (aname , sub )
1498- decl += acc_type .c_init (aname , sub )
1499+ acc_name = "acc"
1500+ setup = acc_type .c_declare (acc_name , sub ) + acc_type .c_init (acc_name , sub )
14991501 else :
15001502 # the output is the accumulator variable
1501- aname = oname
1502-
1503- decl += cgen .make_declare ([order ], [idtype ], sub )
1504- checks = cgen .make_checks ([order ], [idtype ], sub )
1505-
1506- alloc = ""
1507- i += 1
1508- sub [f"lv{ i } " ] = oname
1509- sub ["olv" ] = oname
1510-
1511- # Allocate output buffer
1512- alloc += cgen .make_declare (
1513- [list (range (nnested )) + ["x" ] * len (axis )], [odtype ], dict (sub , lv0 = oname )
1514- )
1515- alloc += cgen .make_alloc ([order1 ], odtype , sub )
1516- alloc += cgen .make_checks (
1517- [list (range (nnested )) + ["x" ] * len (axis )], [odtype ], dict (sub , lv0 = oname )
1503+ acc_name = out_name
1504+ setup = ""
1505+
1506+ # Define strides of input array
1507+ setup += cgen .make_declare (
1508+ [inp_dims ], [inp_dtype ], sub , compute_stride_jump = False
1509+ ) + cgen .make_checks ([inp_dims ], [inp_dtype ], sub , compute_stride_jump = False )
1510+
1511+ # Define strides of output array and allocate it
1512+ out_sub = sub | {"lv0" : out_name }
1513+ alloc = (
1514+ cgen .make_declare (
1515+ [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1516+ )
1517+ + cgen .make_alloc ([non_reduced_dims ], out_dtype , sub )
1518+ + cgen .make_checks (
1519+ [acc_dims ], [out_dtype ], out_sub , compute_stride_jump = False
1520+ )
15181521 )
15191522
1520- if adtype != odtype :
1521- # Allocate accumulation buffer
1522- sub [f"lv { i } " ] = aname
1523- sub ["olv" ] = aname
1523+ if acc_dtype != out_dtype :
1524+ # Define strides of accumulation buffer and allocate it
1525+ sub ["lv1 " ] = acc_name
1526+ sub ["olv" ] = acc_name
15241527
1525- alloc += cgen .make_declare (
1526- [list (range (nnested )) + ["x" ] * len (axis )],
1527- [adtype ],
1528- dict (sub , lv0 = aname ),
1529- )
1530- alloc += cgen .make_alloc ([order1 ], adtype , sub )
1531- alloc += cgen .make_checks (
1532- [list (range (nnested )) + ["x" ] * len (axis )],
1533- [adtype ],
1534- dict (sub , lv0 = aname ),
1528+ acc_sub = sub | {"lv0" : acc_name }
1529+ alloc += (
1530+ cgen .make_declare (
1531+ [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1532+ )
1533+ + cgen .make_alloc ([non_reduced_dims ], acc_dtype , sub )
1534+ + cgen .make_checks (
1535+ [acc_dims ], [acc_dtype ], acc_sub , compute_stride_jump = False
1536+ )
15351537 )
15361538
15371539 identity = self .scalar_op .identity
1538-
15391540 if np .isposinf (identity ):
1540- if input .type .dtype in ("float32" , "float64" ):
1541+ if inp .type .dtype in ("float32" , "float64" ):
15411542 identity = "__builtin_inf()"
1542- elif input .type .dtype .startswith ("uint" ) or input .type .dtype == "bool" :
1543+ elif inp .type .dtype .startswith ("uint" ) or inp .type .dtype == "bool" :
15431544 identity = "1"
15441545 else :
1545- identity = "NPY_MAX_" + str (input .type .dtype ).upper ()
1546+ identity = "NPY_MAX_" + str (inp .type .dtype ).upper ()
15461547 elif np .isneginf (identity ):
1547- if input .type .dtype in ("float32" , "float64" ):
1548+ if inp .type .dtype in ("float32" , "float64" ):
15481549 identity = "-__builtin_inf()"
1549- elif input .type .dtype .startswith ("uint" ) or input .type .dtype == "bool" :
1550+ elif inp .type .dtype .startswith ("uint" ) or inp .type .dtype == "bool" :
15501551 identity = "0"
15511552 else :
1552- identity = "NPY_MIN_" + str (input .type .dtype ).upper ()
1553+ identity = "NPY_MIN_" + str (inp .type .dtype ).upper ()
15531554 elif identity is None :
15541555 raise TypeError (f"The { self .scalar_op } does not define an identity." )
15551556
1556- task0_decl = f"{ adtype } & { aname } _i = *{ aname } _iter;\n { aname } _i = { identity } ;"
1557-
1558- task1_decl = f"{ idtype } & { inames [0 ]} _i = *{ inames [0 ]} _iter;\n "
1557+ initial_value = f"{ acc_name } _i = { identity } ;"
15591558
1560- task1_code = self .scalar_op .c_code (
1559+ inner_task = self .scalar_op .c_code (
15611560 Apply (
15621561 self .scalar_op ,
15631562 [
@@ -1570,44 +1569,45 @@ def _c_all(self, node, name, inames, onames, sub):
15701569 ],
15711570 ),
15721571 None ,
1573- [f"{ aname } _i" , f"{ inames [ 0 ] } _i" ],
1574- [f"{ aname } _i" ],
1572+ [f"{ acc_name } _i" , f"{ inp_name } _i" ],
1573+ [f"{ acc_name } _i" ],
15751574 sub ,
15761575 )
1577- code1 = f"""
1578- {{
1579- { task1_decl }
1580- { task1_code }
1581- }}
1582- """
15831576
1584- if node .inputs [0 ].type .ndim :
1585- if len (axis ) == 1 :
1586- all_code = [("" , "" )] * nnested + [(task0_decl , code1 ), "" ]
1587- else :
1588- all_code = (
1589- [("" , "" )] * nnested
1590- + [(task0_decl , "" )]
1591- + [("" , "" )] * (len (axis ) - 2 )
1592- + [("" , code1 ), "" ]
1593- )
1577+ if out .type .ndim == 0 :
1578+ # Simple case where everything is reduced, no need for loop ordering
1579+ loop = cgen .make_complete_loop_careduce (
1580+ inp_var = inp_name ,
1581+ acc_var = acc_name ,
1582+ inp_dtype = inp_dtype ,
1583+ acc_dtype = acc_dtype ,
1584+ initial_value = initial_value ,
1585+ inner_task = inner_task ,
1586+ fail_code = sub ["fail" ],
1587+ )
15941588 else :
1595- all_code = [task0_decl + code1 ]
1596- loop = cgen .make_loop_careduce (
1597- [order , list (range (nnested )) + ["x" ] * len (axis )],
1598- [idtype , adtype ],
1599- all_code ,
1600- sub ,
1601- )
1589+ loop = cgen .make_reordered_loop_careduce (
1590+ inp_var = inp_name ,
1591+ acc_var = acc_name ,
1592+ inp_dtype = inp_dtype ,
1593+ acc_dtype = acc_dtype ,
1594+ inp_ndim = ndim ,
1595+ reduction_axes = axis ,
1596+ initial_value = initial_value ,
1597+ inner_task = inner_task ,
1598+ )
16021599
1603- end = ""
1604- if adtype != odtype :
1605- end = f"""
1606- PyArray_CopyInto({ oname } , { aname } );
1607- """
1608- end += acc_type .c_cleanup (aname , sub )
1600+ if acc_dtype != out_dtype :
1601+ cast = dedent (
1602+ f"""
1603+ PyArray_CopyInto({ out_name } , { acc_name } );
1604+ { acc_type .c_cleanup (acc_name , sub )}
1605+ """
1606+ )
1607+ else :
1608+ cast = ""
16091609
1610- return decl , checks , alloc , loop , end
1610+ return setup , alloc , loop , cast
16111611
16121612 def c_code (self , node , name , inames , onames , sub ):
16131613 code = "\n " .join (self ._c_all (node , name , inames , onames , sub ))
@@ -1619,7 +1619,7 @@ def c_headers(self, **kwargs):
16191619
16201620 def c_code_cache_version_apply (self , node ):
16211621 # the version corresponding to the c code in this Op
1622- version = [9 ]
1622+ version = [10 ]
16231623
16241624 # now we insert versions for the ops on which we depend...
16251625 scalar_node = Apply (
0 commit comments