@@ -443,7 +443,7 @@ The following is an example that distributes dot products across additions.
443443.. code ::
444444
445445 import pytensor
446- import pytensor.tensor as at
446+ import pytensor.tensor as pt
447447 from pytensor.graph.rewriting.kanren import KanrenRelationSub
448448 from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
449449 from pytensor.graph.rewriting.utils import rewrite_graph
@@ -462,7 +462,7 @@ The following is an example that distributes dot products across additions.
462462 )
463463
464464 # Tell `kanren` that `add` is associative
465- fact(associative, at .add)
465+ fact(associative, pt .add)
466466
467467
468468 def dot_distributeo(in_lv, out_lv):
@@ -473,13 +473,13 @@ The following is an example that distributes dot products across additions.
473473 # Make sure the input is a `_dot`
474474 eq(in_lv, etuple(_dot, A_lv, add_term_lv)),
475475 # Make sure the term being `_dot`ed is an `add`
476- heado(at .add, add_term_lv),
476+ heado(pt .add, add_term_lv),
477477 # Flatten the associative pairings of `add` operations
478478 assoc_flatten(add_term_lv, add_flat_lv),
479479 # Get the flattened `add` arguments
480480 tailo(add_cdr_lv, add_flat_lv),
481481 # Add all the `_dot`ed arguments and set the output
482- conso(at .add, dot_cdr_lv, out_lv),
482+ conso(pt .add, dot_cdr_lv, out_lv),
483483 # Apply the `_dot` to all the flattened `add` arguments
484484 mapo(lambda x, y: conso(_dot, etuple(A_lv, x), y), add_cdr_lv, dot_cdr_lv),
485485 )
@@ -490,10 +490,10 @@ The following is an example that distributes dot products across additions.
490490
491491 Below, we apply `dot_distribute_rewrite ` to a few example graphs. First we create simple test graph:
492492
493- >>> x_at = at .vector(" x" )
494- >>> y_at = at .vector(" y" )
495- >>> A_at = at .matrix(" A" )
496- >>> test_at = A_at .dot(x_at + y_at)
493+ >>> x_at = pt .vector(" x" )
494+ >>> y_at = pt .vector(" y" )
495+ >>> A_at = pt .matrix(" A" )
496+ >>> test_at = A_pt .dot(x_at + y_at)
497497>>> print (pytensor.pprint(test_at))
498498(A @ (x + y))
499499
@@ -506,18 +506,18 @@ Next we apply the rewrite to the graph:
506506We see that the dot product has been distributed, as desired. Now, let's try a
507507few more test cases:
508508
509- >>> z_at = at .vector(" z" )
510- >>> w_at = at .vector(" w" )
511- >>> test_at = A_at .dot((x_at + y_at) + (z_at + w_at))
509+ >>> z_at = pt .vector(" z" )
510+ >>> w_at = pt .vector(" w" )
511+ >>> test_at = A_pt .dot((x_at + y_at) + (z_at + w_at))
512512>>> print (pytensor.pprint(test_at))
513513(A @ ((x + y) + (z + w)))
514514>>> res = rewrite_graph(test_at, include = [], custom_rewrite = dot_distribute_rewrite, clone = False )
515515>>> print (pytensor.pprint(res))
516516(((A @ x) + (A @ y)) + ((A @ z) + (A @ w)))
517517
518- >>> B_at = at .matrix(" B" )
519- >>> w_at = at .vector(" w" )
520- >>> test_at = A_at .dot(x_at + (y_at + B_at .dot(z_at + w_at)))
518+ >>> B_at = pt .matrix(" B" )
519+ >>> w_at = pt .vector(" w" )
520+ >>> test_at = A_pt .dot(x_at + (y_at + B_pt .dot(z_at + w_at)))
521521>>> print (pytensor.pprint(test_at))
522522(A @ (x + (y + ((B @ z) + (B @ w)))))
523523>>> res = rewrite_graph(test_at, include = [], custom_rewrite = dot_distribute_rewrite, clone = False )
0 commit comments