@@ -561,13 +561,16 @@ def check_format(preset):
561561 return "keras"
562562
563563
564- def load_serialized_object (
565- preset ,
566- config_file = CONFIG_FILE ,
567- config_overrides = {},
568- ):
564+ def load_serialized_object (preset , config_file = CONFIG_FILE , ** kwargs ):
565+ kwargs = kwargs or {}
569566 config = load_config (preset , config_file )
570- config ["config" ] = {** config ["config" ], ** config_overrides }
567+
568+ # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
569+ # Ensure that `dtype` is properly configured.
570+ dtype = kwargs .pop ("dtype" , None )
571+ config = set_dtype_in_config (config , dtype )
572+
573+ config ["config" ] = {** config ["config" ], ** kwargs }
571574 return keras .saving .deserialize_keras_object (config )
572575
573576
@@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
590593 for weight in layer .weights :
591594 if getattr (weight , "_value" , None ) is not None :
592595 weight ._value .delete ()
596+
597+
598+ def set_dtype_in_config (config , dtype = None ):
599+ if dtype is None :
600+ return config
601+
602+ config = config .copy ()
603+ if "dtype" not in config ["config" ]:
604+ # Forward `dtype` to the config.
605+ config ["config" ]["dtype" ] = dtype
606+ elif (
607+ "dtype" in config ["config" ]
608+ and isinstance (config ["config" ]["dtype" ], dict )
609+ and "DTypePolicyMap" in config ["config" ]["dtype" ]["class_name" ]
610+ ):
611+ # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
612+ # policy.
613+ policy_map_config = config ["config" ]["dtype" ]["config" ]
614+ policy_map_config ["default_policy" ] = dtype
615+ for k in policy_map_config ["policy_map" ].keys ():
616+ policy_map_config ["policy_map" ][k ]["config" ]["source_name" ] = dtype
617+ return config
0 commit comments