1+ from textwrap import dedent , indent
2+
13from pytensor .configdefaults import config
24
35
@@ -8,51 +10,49 @@ def make_declare(loop_orders, dtypes, sub):
810 """
911 decl = ""
1012 for i , (loop_order , dtype ) in enumerate (zip (loop_orders , dtypes )):
11- var = sub [f"lv{ int ( i ) } " ] # input name corresponding to ith loop variable
13+ var = sub [f"lv{ i } " ] # input name corresponding to ith loop variable
1214 # we declare an iteration variable
1315 # and an integer for the number of dimensions
14- decl += f"""
15- { dtype } * { var } _iter;
16- """
16+ decl += f"{ dtype } * { var } _iter;\n "
1717 for j , value in enumerate (loop_order ):
1818 if value != "x" :
1919 # If the dimension is not broadcasted, we declare
2020 # the number of elements in that dimension,
2121 # the stride in that dimension,
2222 # and the jump from an iteration to the next
2323 decl += f"""
24- npy_intp { var } _n{ int ( value ) } ;
25- ssize_t { var } _stride{ int ( value ) } ;
26- int { var } _jump{ int ( value ) } _{ int ( j ) } ;
24+ npy_intp { var } _n{ value } ;
25+ ssize_t { var } _stride{ value } ;
26+ int { var } _jump{ value } _{ j } ;
2727 """
2828
2929 else :
3030 # if the dimension is broadcasted, we only need
3131 # the jump (arbitrary length and stride = 0)
32- decl += f"""
33- int { var } _jump{ value } _{ int (j )} ;
34- """
32+ decl += f"int { var } _jump{ value } _{ j } ;\n "
3533
3634 return decl
3735
3836
3937def make_checks (loop_orders , dtypes , sub ):
4038 init = ""
4139 for i , (loop_order , dtype ) in enumerate (zip (loop_orders , dtypes )):
42- var = f"%( lv{ int ( i ) } )s"
40+ var = sub [ f" lv{ i } " ]
4341 # List of dimensions of var that are not broadcasted
4442 nonx = [x for x in loop_order if x != "x" ]
4543 if nonx :
4644 # If there are dimensions that are not broadcasted
4745 # this is a check that the number of dimensions of the
4846 # tensor is as expected.
4947 min_nd = max (nonx ) + 1
50- init += f"""
51- if (PyArray_NDIM({ var } ) < { min_nd } ) {{
52- PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
53- %(fail)s
54- }}
55- """
48+ init += dedent (
49+ f"""
50+ if (PyArray_NDIM({ var } ) < { min_nd } ) {{
51+ PyErr_SetString(PyExc_ValueError, "Not enough dimensions on input.");
52+ { indent (sub ["fail" ], " " * 12 )}
53+ }}
54+ """
55+ )
5656
5757 # In loop j, adjust represents the difference of values of the
5858 # data pointer between the beginning and the end of the
@@ -75,9 +75,7 @@ def make_checks(loop_orders, dtypes, sub):
7575 adjust = f"{ var } _n{ index } *{ var } _stride{ index } "
7676 else :
7777 jump = f"-({ adjust } )"
78- init += f"""
79- { var } _jump{ index } _{ j } = { jump } ;
80- """
78+ init += f"{ var } _jump{ index } _{ j } = { jump } ;\n "
8179 adjust = "0"
8280 check = ""
8381
@@ -101,34 +99,36 @@ def make_checks(loop_orders, dtypes, sub):
10199
102100 j0 , x0 = to_compare [0 ]
103101 for j , x in to_compare [1 :]:
104- check += f"""
105- if (%(lv{ j0 } )s_n{ x0 } != %(lv{ j } )s_n{ x } )
106- {{
107- if (%(lv{ j0 } )s_n{ x0 } == 1 || %(lv{ j } )s_n{ x } == 1)
102+ check += dedent (
103+ f"""
104+ if ({ sub [f"lv{ j0 } " ]} _n{ x0 } != { sub [f"lv{ j } " ]} _n{ x } )
108105 {{
109- PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
110- { j0 } ,
111- { x0 } ,
112- (long long int) %(lv{ j0 } )s_n{ x0 } ,
113- { j } ,
114- { x } ,
115- (long long int) %(lv{ j } )s_n{ x }
116- );
117- }} else {{
118- PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
106+ if ({ sub [f"lv{ j0 } " ]} _n{ x0 } == 1 || { sub [f"lv{ j } " ]} _n{ x } == 1)
107+ {{
108+ PyErr_Format(PyExc_ValueError, "{ runtime_broadcast_error_msg } ",
119109 { j0 } ,
120110 { x0 } ,
121- (long long int) %( lv{ j0 } )s_n { x0 } ,
111+ (long long int) { sub [ f" lv{ j0 } " ] } _n { x0 } ,
122112 { j } ,
123113 { x } ,
124- (long long int) %(lv{ j } )s_n{ x }
125- );
114+ (long long int) { sub [f"lv{ j } " ]} _n{ x }
115+ );
116+ }} else {{
117+ PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
118+ { j0 } ,
119+ { x0 } ,
120+ (long long int) { sub [f"lv{ j0 } " ]} _n{ x0 } ,
121+ { j } ,
122+ { x } ,
123+ (long long int) { sub [f"lv{ j } " ]} _n{ x }
124+ );
125+ }}
126+ { sub ["fail" ]}
126127 }}
127- %(fail)s
128- }}
129- """
128+ """
129+ )
130130
131- return init % sub + check % sub
131+ return init + check
132132
133133
134134def compute_output_dims_lengths (array_name : str , loop_orders , sub ) -> str :
@@ -144,7 +144,7 @@ def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
144144 # Borrow the length of the first non-broadcastable input dimension
145145 for j , candidate in enumerate (candidates ):
146146 if candidate != "x" :
147- var = sub [f"lv{ int ( j ) } " ]
147+ var = sub [f"lv{ j } " ]
148148 dims_c_code += f"{ array_name } [{ i } ] = { var } _n{ candidate } ;\n "
149149 break
150150 # If none is non-broadcastable, the output dimension has a length of 1
@@ -177,35 +177,37 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
177177 # way that its contiguous dimensions match one of the input's
178178 # contiguous dimensions, or the dimension with the smallest
179179 # stride. Right now, it is allocated to be C_CONTIGUOUS.
180- return f"""
181- {{
182- npy_intp dims[{ nd } ];
183- //npy_intp* dims = (npy_intp*)malloc({ nd } * sizeof(npy_intp));
184- { init_dims }
185- if (!{ olv } ) {{
186- { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims,
187- { type } ,
188- { fortran } );
189- }}
190- else {{
191- PyArray_Dims new_dims;
192- new_dims.len = { nd } ;
193- new_dims.ptr = dims;
194- PyObject* success = PyArray_Resize({ olv } , &new_dims, 0, NPY_CORDER);
195- if (!success) {{
196- // If we can't resize the ndarray we have we can allocate a new one.
197- PyErr_Clear();
198- Py_XDECREF({ olv } );
199- { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims, { type } , 0);
200- }} else {{
201- Py_DECREF(success);
180+ return dedent (
181+ f"""
182+ {{
183+ npy_intp dims[{ nd } ];
184+ { init_dims }
185+ if (!{ olv } ) {{
186+ { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } ,
187+ dims,
188+ { type } ,
189+ { fortran } );
190+ }}
191+ else {{
192+ PyArray_Dims new_dims;
193+ new_dims.len = { nd } ;
194+ new_dims.ptr = dims;
195+ PyObject* success = PyArray_Resize({ olv } , &new_dims, 0, NPY_CORDER);
196+ if (!success) {{
197+ // If we can't resize the ndarray we have we can allocate a new one.
198+ PyErr_Clear();
199+ Py_XDECREF({ olv } );
200+ { olv } = (PyArrayObject*)PyArray_EMPTY({ nd } , dims, { type } , 0);
201+ }} else {{
202+ Py_DECREF(success);
203+ }}
204+ }}
205+ if (!{ olv } ) {{
206+ { fail }
202207 }}
203208 }}
204- if (!{ olv } ) {{
205- { fail }
206- }}
207- }}
208- """
209+ """
210+ )
209211
210212
211213def make_loop (loop_orders , dtypes , loop_tasks , sub , openmp = None ):
@@ -235,11 +237,11 @@ def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
235237 """
236238
237239 def loop_over (preloop , code , indices , i ):
238- iterv = f"ITER_{ int ( i ) } "
240+ iterv = f"ITER_{ i } "
239241 update = ""
240242 suitable_n = "1"
241243 for j , index in enumerate (indices ):
242- var = sub [f"lv{ int ( j ) } " ]
244+ var = sub [f"lv{ j } " ]
243245 dtype = dtypes [j ]
244246 update += f"{ dtype } &{ var } _i = * ( { var } _iter + { iterv } * { var } _jump{ index } _{ i } );\n "
245247
@@ -305,21 +307,21 @@ def make_reordered_loop(
305307 nnested = len (init_loop_orders [0 ])
306308
307309 # This is the var from which we'll get the loop order
308- ovar = sub [f"lv{ int ( olv_index ) } " ]
310+ ovar = sub [f"lv{ olv_index } " ]
309311
310312 # The loops are ordered by (decreasing) absolute values of ovar's strides.
311313 # The first element of each pair is the absolute value of the stride
312314 # The second element correspond to the index in the initial loop order
313315 order_loops = f"""
314- std::vector< std::pair<int, int> > { ovar } _loops({ int ( nnested ) } );
316+ std::vector< std::pair<int, int> > { ovar } _loops({ nnested } );
315317 std::vector< std::pair<int, int> >::iterator { ovar } _loops_it = { ovar } _loops.begin();
316318 """
317319
318320 # Fill the loop vector with the appropriate <stride, index> pairs
319321 for i , index in enumerate (init_loop_orders [olv_index ]):
320322 if index != "x" :
321323 order_loops += f"""
322- { ovar } _loops_it->first = abs(PyArray_STRIDES({ ovar } )[{ int ( index ) } ]);
324+ { ovar } _loops_it->first = abs(PyArray_STRIDES({ ovar } )[{ index } ]);
323325 """
324326 else :
325327 # Stride is 0 when dimension is broadcastable
@@ -328,7 +330,7 @@ def make_reordered_loop(
328330 """
329331
330332 order_loops += f"""
331- { ovar } _loops_it->second = { int ( i ) } ;
333+ { ovar } _loops_it->second = { i } ;
332334 ++{ ovar } _loops_it;
333335 """
334336
@@ -352,7 +354,7 @@ def make_reordered_loop(
352354
353355 for i in range (nnested ):
354356 declare_totals += f"""
355- int TOTAL_{ int ( i ) } = init_totals[{ ovar } _loops_it->second];
357+ int TOTAL_{ i } = init_totals[{ ovar } _loops_it->second];
356358 ++{ ovar } _loops_it;
357359 """
358360
@@ -365,7 +367,7 @@ def get_loop_strides(loop_order, i):
365367 specified loop_order.
366368
367369 """
368- var = sub [f"lv{ int ( i ) } " ]
370+ var = sub [f"lv{ i } " ]
369371 r = []
370372 for index in loop_order :
371373 # Note: the stride variable is not declared for broadcasted variables
@@ -383,7 +385,7 @@ def get_loop_strides(loop_order, i):
383385 )
384386
385387 declare_strides = f"""
386- int init_strides[{ int ( nvars ) } ][{ int ( nnested ) } ] = {{
388+ int init_strides[{ nvars } ][{ nnested } ] = {{
387389 { strides }
388390 }};"""
389391
@@ -394,33 +396,33 @@ def get_loop_strides(loop_order, i):
394396 """
395397
396398 for i in range (nvars ):
397- var = sub [f"lv{ int ( i ) } " ]
399+ var = sub [f"lv{ i } " ]
398400 declare_strides += f"""
399401 { ovar } _loops_rit = { ovar } _loops.rbegin();"""
400402 for j in reversed (range (nnested )):
401403 declare_strides += f"""
402- int { var } _stride_l{ int ( j ) } = init_strides[{ int ( i ) } ][{ ovar } _loops_rit->second];
404+ int { var } _stride_l{ j } = init_strides[{ i } ][{ ovar } _loops_rit->second];
403405 ++{ ovar } _loops_rit;
404406 """
405407
406408 declare_iter = ""
407409 for i , dtype in enumerate (dtypes ):
408- var = sub [f"lv{ int ( i ) } " ]
410+ var = sub [f"lv{ i } " ]
409411 declare_iter += f"{ var } _iter = ({ dtype } *)(PyArray_DATA({ var } ));\n "
410412
411413 pointer_update = ""
412414 for j , dtype in enumerate (dtypes ):
413- var = sub [f"lv{ int ( j ) } " ]
415+ var = sub [f"lv{ j } " ]
414416 pointer_update += f"{ dtype } &{ var } _i = * ( { var } _iter"
415417 for i in reversed (range (nnested )):
416- iterv = f"ITER_{ int ( i ) } "
417- pointer_update += f"+{ var } _stride_l{ int ( i ) } *{ iterv } "
418+ iterv = f"ITER_{ i } "
419+ pointer_update += f"+{ var } _stride_l{ i } *{ iterv } "
418420 pointer_update += ");\n "
419421
420422 loop = inner_task
421423 for i in reversed (range (nnested )):
422- iterv = f"ITER_{ int ( i ) } "
423- total = f"TOTAL_{ int ( i ) } "
424+ iterv = f"ITER_{ i } "
425+ total = f"TOTAL_{ i } "
424426 update = ""
425427 forloop = ""
426428 # The pointers are defined only in the most inner loop
@@ -434,36 +436,14 @@ def get_loop_strides(loop_order, i):
434436
435437 loop = f"""
436438 { forloop }
437- {{ // begin loop { int ( i ) }
439+ {{ // begin loop { i }
438440 { update }
439441 { loop }
440- }} // end loop { int ( i ) }
442+ }} // end loop { i }
441443 """
442444
443- return f"{{\n { order_loops } \n { declare_totals } \n { declare_strides } \n { declare_iter } \n { loop } \n }}\n "
444-
445-
446- # print make_declare(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
447- # ('double', 'int', 'float'),
448- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
449-
450- # print make_checks(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
451- # ('double', 'int', 'float'),
452- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
453-
454- # print make_alloc(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
455- # 'double',
456- # dict(olv='out', lv0='x', lv1='y', lv2='z', fail="FAIL;"))
457-
458- # print make_loop(((0, 1, 2, 3), ('x', 1, 0, 3), ('x', 'x', 'x', 0)),
459- # ('double', 'int', 'float'),
460- # (("C00;", "C%01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
461- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
462-
463- # print make_loop(((0, 1, 2, 3), (3, 'x', 0, 'x'), (0, 'x', 'x', 'x')),
464- # ('double', 'int', 'float'),
465- # (("C00;", "C01;"), ("C10;", "C11;"), ("C20;", "C21;"), ("C30;", "C31;"),"C4;"),
466- # dict(lv0='x', lv1='y', lv2='z', fail="FAIL;"))
445+ code = "\n " .join ((order_loops , declare_totals , declare_strides , declare_iter , loop ))
446+ return f"{{\n { code } \n }}\n "
467447
468448
469449##################
0 commit comments