@@ -393,6 +393,8 @@ def __init__(
393393 assert len (self .input_storage ) == len (self .maker .fgraph .inputs )
394394 assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
395395
396+ self .has_defaults = any (refeed for _ , refeed , _ in self .defaults )
397+
396398 # Group indexes of inputs that are potentially aliased to each other
397399 # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
398400 # even though there could be two distinct types that use the same kinds of underlying objects.
@@ -540,14 +542,40 @@ def __contains__(self, item):
540542 self ._value = ValueAttribute ()
541543 self ._container = ContainerAttribute ()
542544
543- # TODO: Get rid of all this `expanded_inputs` nonsense
544- assert len (self .maker .expanded_inputs ) == len (self .input_storage )
545+ update_storage = [
546+ container
547+ for inp , container in zip (
548+ self .maker .expanded_inputs , input_storage , strict = True
549+ )
550+ if inp .update is not None
551+ ]
552+ # Updates are the last inner outputs that are not returned by Function.__call__
553+ self .n_returned_outputs = len (self .output_storage ) - len (update_storage )
554+
555+ # Function.__call__ is responsible for updating the inputs, unless the vm promises to do it itself
556+ self .update_input_storage : tuple [int , Container ] = ()
557+ if getattr (vm , "need_update_inputs" , True ):
558+ self .update_input_storage = tuple (
559+ zip (
560+ range (self .n_returned_outputs , len (output_storage )),
561+ update_storage ,
562+ strict = True ,
563+ )
564+ )
545565
546- # This is used only when `vm.need_update_inputs` is `False`, because
547- # we're using one of the VM objects and it is putting updates back into
548- # the input containers all by itself.
549- self .n_returned_outputs = len (self .output_storage ) - sum (
550- inp .update is not None for inp in self .maker .expanded_inputs
566+ # In every function call we place inputs in the input_storage, and the vm places outputs in the output_storage
567+ # After the call, we want to erase (some of) these references, to allow Python to GC them if unused
568+ # Required input containers are the non-default inputs, must always be provided again, so we GC them
569+ self .clear_input_storage_data = tuple (
570+ container .storage for container in input_storage if container .required
571+ )
572+ # This is only done when `vm.allow_gc` is True, which can change at runtime.
573+ self .clear_output_storage_data = tuple (
574+ container .storage
575+ for container , variable in zip (
576+ self .output_storage , self .maker .fgraph .outputs , strict = True
577+ )
578+ if variable .owner is not None # Not a constant output
551579 )
552580
553581 for node in self .maker .fgraph .apply_nodes :
@@ -747,7 +775,7 @@ def checkSV(sv_ori, sv_rpl):
747775 elif isinstance (profile , str ):
748776 profile = pytensor .compile .profiling .ProfileStats (message = profile )
749777
750- f_cpy = maker . __class__ (
778+ f_cpy = type ( maker ) (
751779 inputs = ins ,
752780 outputs = outs ,
753781 fgraph = fg_cpy ,
@@ -765,6 +793,8 @@ def checkSV(sv_ori, sv_rpl):
765793 # check that.
766794 accept_inplace = True ,
767795 no_fgraph_prep = True ,
796+ output_keys = maker .output_keys ,
797+ name = name ,
768798 ).create (input_storage , storage_map = new_storage_map )
769799
770800 for in_ori , in_cpy , ori , cpy in zip (
@@ -797,8 +827,6 @@ def checkSV(sv_ori, sv_rpl):
797827
798828 f_cpy .trust_input = self .trust_input
799829 f_cpy .unpack_single = self .unpack_single
800- f_cpy .name = name
801- f_cpy .maker .fgraph .name = name
802830 return f_cpy
803831
804832 def _restore_defaults (self ):
@@ -808,7 +836,7 @@ def _restore_defaults(self):
808836 value = value .storage [0 ]
809837 self [i ] = value
810838
811- def __call__ (self , * args , ** kwargs ):
839+ def __call__ (self , * args , output_subset = None , ** kwargs ):
812840 """
813841 Evaluates value of a function on given arguments.
814842
@@ -836,20 +864,21 @@ def __call__(self, *args, **kwargs):
836864 List of outputs on indices/keys from ``output_subset`` or all of them,
837865 if ``output_subset`` is not passed.
838866 """
867+ trust_input = self .trust_input
839868 input_storage = self .input_storage
869+ vm = self .vm
840870 profile = self .profile
841871
842872 if profile :
843873 t0 = time .perf_counter ()
844874
845- output_subset = kwargs .pop ("output_subset" , None )
846875 if output_subset is not None :
847876 warnings .warn ("output_subset is deprecated." , FutureWarning )
848877 if self .output_keys is not None :
849878 output_subset = [self .output_keys .index (key ) for key in output_subset ]
850879
851880 # Reinitialize each container's 'provided' counter
852- if self . trust_input :
881+ if trust_input :
853882 for arg_container , arg in zip (input_storage , args , strict = False ):
854883 arg_container .storage [0 ] = arg
855884 else :
@@ -908,7 +937,7 @@ def __call__(self, *args, **kwargs):
908937 for k , arg in kwargs .items ():
909938 self [k ] = arg
910939
911- if not self . trust_input :
940+ if not trust_input :
912941 # Collect aliased inputs among the storage space
913942 for potential_group in self ._potential_aliased_input_groups :
914943 args_share_memory : list [list [int ]] = []
@@ -960,11 +989,7 @@ def __call__(self, *args, **kwargs):
960989 if profile :
961990 t0_fn = time .perf_counter ()
962991 try :
963- outputs = (
964- self .vm ()
965- if output_subset is None
966- else self .vm (output_subset = output_subset )
967- )
992+ outputs = vm () if output_subset is None else vm (output_subset = output_subset )
968993 except Exception :
969994 self ._restore_defaults ()
970995 if hasattr (self .vm , "position_of_error" ):
@@ -991,73 +1016,53 @@ def __call__(self, *args, **kwargs):
9911016
9921017 # Retrieve the values that were computed
9931018 if outputs is None :
994- outputs = [x .data for x in self .output_storage ]
995-
996- # Remove internal references to required inputs.
997- # These cannot be re-used anyway.
998- for arg_container in input_storage :
999- if arg_container .required :
1000- arg_container .storage [0 ] = None
1001-
1002- # if we are allowing garbage collection, remove the
1003- # output reference from the internal storage cells
1004- if getattr (self .vm , "allow_gc" , False ):
1005- # strict=False because we are in a hot loop
1006- for o_container , o_variable in zip (
1007- self .output_storage , self .maker .fgraph .outputs , strict = False
1008- ):
1009- if o_variable .owner is not None :
1010- # this node is the variable of computation
1011- # WARNING: This circumvents the 'readonly' attribute in x
1012- o_container .storage [0 ] = None
1013-
1014- if getattr (self .vm , "need_update_inputs" , True ):
1015- # Update the inputs that have an update function
1016- # strict=False because we are in a hot loop
1017- for input , storage in reversed (
1018- list (zip (self .maker .expanded_inputs , input_storage , strict = False ))
1019- ):
1020- if input .update is not None :
1021- storage .data = outputs .pop ()
1022- else :
1023- outputs = outputs [: self .n_returned_outputs ]
1019+ outputs = [x .storage [0 ] for x in self .output_storage ]
1020+
1021+ # Set updates and filter them out from the returned outputs
1022+ for i , input_storage in self .update_input_storage :
1023+ input_storage .storage [0 ] = outputs [i ]
1024+ outputs = outputs [: self .n_returned_outputs ]
1025+
1026+ # Remove input and output values from storage data
1027+ for storage_data in self .clear_input_storage_data :
1028+ storage_data [0 ] = None
1029+ if getattr (vm , "allow_gc" , False ):
1030+ for storage_data in self .clear_output_storage_data :
1031+ storage_data [0 ] = None
10241032
10251033 # Put default values back in the storage
1026- self ._restore_defaults ()
1034+ if self .has_defaults :
1035+ self ._restore_defaults ()
10271036
10281037 if profile :
10291038 dt_call = time .perf_counter () - t0
10301039 pytensor .compile .profiling .total_fct_exec_time += dt_call
10311040 self .maker .mode .call_time += dt_call
10321041 profile .fct_callcount += 1
10331042 profile .fct_call_time += dt_call
1034- if hasattr (self . vm , "update_profile" ):
1035- self . vm .update_profile (profile )
1043+ if hasattr (vm , "update_profile" ):
1044+ vm .update_profile (profile )
10361045 if profile .ignore_first_call :
10371046 profile .reset ()
10381047 profile .ignore_first_call = False
10391048
10401049 if self .return_none :
10411050 return None
1042- elif self .unpack_single and len (outputs ) == 1 and output_subset is None :
1043- return outputs [0 ]
1044- else :
1045- if self .output_keys is not None :
1046- assert len (self .output_keys ) == len (outputs )
10471051
1048- if output_subset is None :
1049- # strict=False because we are in a hot loop
1050- return dict (zip (self .output_keys , outputs , strict = False ))
1051- else :
1052- return {
1053- self .output_keys [index ]: outputs [index ]
1054- for index in output_subset
1055- }
1052+ if output_subset is not None :
1053+ outputs = [outputs [i ] for i in output_subset ]
10561054
1057- if output_subset is None :
1058- return outputs
1055+ if self .output_keys is None :
1056+ if self .unpack_single :
1057+ [out ] = outputs
1058+ return out
10591059 else :
1060- return [outputs [i ] for i in output_subset ]
1060+ return outputs
1061+ else :
1062+ output_keys = self .output_keys
1063+ if output_subset is not None :
1064+ output_keys = [output_keys [i ] for i in output_subset ]
1065+ return dict (zip (output_keys , outputs , strict = True ))
10611066
10621067 value = property (
10631068 lambda self : self ._value ,
@@ -1077,9 +1082,10 @@ def free(self):
10771082 # 1.no allow_gc return False
10781083 # 2.has allow_gc, if allow_gc is False, return True
10791084 if not getattr (self .vm , "allow_gc" , True ):
1080- for key in self .vm .storage_map :
1081- if not isinstance (key , Constant ):
1082- self .vm .storage_map [key ][0 ] = None
1085+ storage_map = self .vm .storage_map
1086+ for key , value in storage_map .items ():
1087+ if key .owner is not None : # Not a constant
1088+ value [0 ] = None
10831089
10841090 for node in self .nodes_with_inner_function :
10851091 if hasattr (node .fn , "free" ):
@@ -1091,10 +1097,6 @@ def get_shared(self):
10911097 """
10921098 return [i .variable for i in self .maker .inputs if i .implicit ]
10931099
1094- def sync_shared (self ):
1095- # NOTE: sync was needed on old gpu backend
1096- pass
1097-
10981100 def dprint (self , ** kwargs ):
10991101 """Debug print itself
11001102
0 commit comments