77
88import pytensor .scalar as ps
99from pytensor .compile .function import function
10- from pytensor .gradient import grad , hessian , jacobian
10+ from pytensor .gradient import grad , jacobian
1111from pytensor .graph .basic import Apply , Constant
1212from pytensor .graph .fg import FunctionGraph
1313from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
@@ -484,6 +484,7 @@ def __init__(
484484 jac : bool = True ,
485485 hess : bool = False ,
486486 hessp : bool = False ,
487+ use_vectorized_jac : bool = False ,
487488 optimizer_kwargs : dict | None = None ,
488489 ):
489490 if not cast (TensorVariable , objective ).ndim == 0 :
@@ -496,6 +497,7 @@ def __init__(
496497 )
497498
498499 self .fgraph = FunctionGraph ([x , * args ], [objective ])
500+ self .use_vectorized_jac = use_vectorized_jac
499501
500502 if jac :
501503 grad_wrt_x = cast (
@@ -505,7 +507,12 @@ def __init__(
505507
506508 if hess :
507509 hess_wrt_x = cast (
508- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
510+ Variable ,
511+ jacobian (
512+ self .fgraph .outputs [- 1 ],
513+ self .fgraph .inputs [0 ],
514+ vectorize = use_vectorized_jac ,
515+ ),
509516 )
510517 self .fgraph .add_output (hess_wrt_x )
511518
@@ -561,7 +568,10 @@ def L_op(self, inputs, outputs, output_grads):
561568 implicit_f = grad (inner_fx , inner_x )
562569
563570 df_dx , * df_dtheta_columns = jacobian (
564- implicit_f , [inner_x , * inner_args ], disconnected_inputs = "ignore"
571+ implicit_f ,
572+ [inner_x , * inner_args ],
573+ disconnected_inputs = "ignore" ,
574+ vectorize = self .use_vectorized_jac ,
565575 )
566576 grad_wrt_args = implict_optimization_grads (
567577 df_dx = df_dx ,
@@ -581,6 +591,7 @@ def minimize(
581591 method : str = "BFGS" ,
582592 jac : bool = True ,
583593 hess : bool = False ,
594+ use_vectorized_jac : bool = False ,
584595 optimizer_kwargs : dict | None = None ,
585596) -> tuple [TensorVariable , TensorVariable ]:
586597 """
@@ -590,18 +601,21 @@ def minimize(
590601 ----------
591602 objective : TensorVariable
592603 The objective function to minimize. This should be a pytensor variable representing a scalar value.
593-
594- x : TensorVariable
604+ x: TensorVariable
595605 The variable with respect to which the objective function is minimized. It must be an input to the
596606 computational graph of `objective`.
597-
598- method : str, optional
607+ method: str, optional
599608 The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
600-
601- jac : bool, optional
602- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609+ jac: bool, optional
610+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
603611 Default is True.
604-
612+ hess: bool, optional
613+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614+ Default is False. Note that some methods require this, while others do not support it.
615+ use_vectorized_jac: bool, optional
616+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618+ but use more memory. Default is False.
605619 optimizer_kwargs
606620 Additional keyword arguments to pass to scipy.optimize.minimize
607621
@@ -624,6 +638,7 @@ def minimize(
624638 method = method ,
625639 jac = jac ,
626640 hess = hess ,
641+ use_vectorized_jac = use_vectorized_jac ,
627642 optimizer_kwargs = optimizer_kwargs ,
628643 )
629644
@@ -804,6 +819,7 @@ def __init__(
804819 method : str = "hybr" ,
805820 jac : bool = True ,
806821 optimizer_kwargs : dict | None = None ,
822+ use_vectorized_jac : bool = False ,
807823 ):
808824 if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
809825 raise ValueError (
@@ -817,7 +833,11 @@ def __init__(
817833 self .fgraph = FunctionGraph ([variables , * args ], [equations ])
818834
819835 if jac :
820- jac_wrt_x = jacobian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
836+ jac_wrt_x = jacobian (
837+ self .fgraph .outputs [0 ],
838+ self .fgraph .inputs [0 ],
839+ vectorize = use_vectorized_jac ,
840+ )
821841 self .fgraph .add_output (atleast_2d (jac_wrt_x ))
822842
823843 self .jac = jac
@@ -897,8 +917,14 @@ def L_op(
897917 inner_x , * inner_args = self .fgraph .inputs
898918 inner_fx = self .fgraph .outputs [0 ]
899919
900- df_dx = jacobian (inner_fx , inner_x ) if not self .jac else self .fgraph .outputs [1 ]
901- df_dtheta_columns = jacobian (inner_fx , inner_args , disconnected_inputs = "ignore" )
920+ df_dx = (
921+ jacobian (inner_fx , inner_x , vectorize = True )
922+ if not self .jac
923+ else self .fgraph .outputs [1 ]
924+ )
925+ df_dtheta_columns = jacobian (
926+ inner_fx , inner_args , disconnected_inputs = "ignore" , vectorize = True
927+ )
902928
903929 grad_wrt_args = implict_optimization_grads (
904930 df_dx = df_dx ,
@@ -917,6 +943,7 @@ def root(
917943 variables : TensorVariable ,
918944 method : str = "hybr" ,
919945 jac : bool = True ,
946+ use_vectorized_jac : bool = False ,
920947 optimizer_kwargs : dict | None = None ,
921948) -> tuple [TensorVariable , TensorVariable ]:
922949 """
@@ -935,6 +962,10 @@ def root(
935962 jac : bool, optional
936963 Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
937964 Default is True. Most methods require this.
965+ use_vectorized_jac: bool, optional
966+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968+ Default is False.
938969 optimizer_kwargs : dict, optional
939970 Additional keyword arguments to pass to `scipy.optimize.root`.
940971
@@ -958,6 +989,7 @@ def root(
958989 method = method ,
959990 jac = jac ,
960991 optimizer_kwargs = optimizer_kwargs ,
992+ use_vectorized_jac = use_vectorized_jac ,
961993 )
962994
963995 solution , success = cast (
0 commit comments