@@ -21,11 +21,11 @@ class Bnb4bitParametrization(nn.Module):
2121 The quantization state containing the necessary information for dequantization.
2222 """
2323
24- def __init__ (self , quant_state : F .QuantState , p_name = "unknown" ):
24+ def __init__ (self , quant_state : F .QuantState ):
2525 super ().__init__ ()
2626 self .quant_state = quant_state
27- self .p_name = p_name
2827
28+ @torch .no_grad ()
2929 def forward (self , quantized_param : torch .Tensor ) -> torch .Tensor :
3030 """
3131 Forward pass to dequantize the parameter.
@@ -55,13 +55,8 @@ def replace_parameter_4bit_prequantized(
5555 # Apply a parametrization to the module to handle dequantization.
5656 P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state ), unsafe = True )
5757
58- # Next, register state dict hook for saving.
59- module .register_state_dict_post_hook (
60- partial (
61- _parametrized_state_dict_post_hook ,
62- param_name = param_name ,
63- )
64- )
58+ # Next, register hooks.
59+ _register_parametrization_hooks (module , param_name )
6560
6661
6762def replace_parameter_4bit (
@@ -127,14 +122,35 @@ def replace_parameter_4bit(
127122 # Apply a parametrization to the module to handle dequantization.
128123 P .register_parametrization (module , param_name , Bnb4bitParametrization (quant_state ), unsafe = True )
129124
130- # Next, register state dict hook for saving.
125+ # Next, register hooks.
126+ _register_parametrization_hooks (module , param_name )
127+
128+
129+ def _disable_parametrization_cache (module : nn .Module , inputs : tuple [Any , ...], output : Any ):
130+ P ._cache_enabled -= 1
131+ if not P ._cache_enabled :
132+ P ._cache = {}
133+
134+
135+ def _enable_parametrization_cache (module : nn .Module , inputs : tuple [Any , ...]):
136+ P ._cache_enabled += 1
137+
138+
139+ def _register_parametrization_hooks (module : nn .Module , param_name : str ):
140+ # Register a state dict hook for saving.
131141 module .register_state_dict_post_hook (
132142 partial (
133143 _parametrized_state_dict_post_hook ,
134144 param_name = param_name ,
135145 )
136146 )
137147
148+ # Register hooks to enable caching for the dequantization parametrization.
149+ # This helps preserve time and memory when the same quantized parameter
150+ # is accessed multiple times in the forward computation.
151+ module .register_forward_pre_hook (_enable_parametrization_cache )
152+ module .register_forward_hook (_disable_parametrization_cache )
153+
138154
139155def _parametrized_state_dict_post_hook (
140156 module : nn .Module ,
0 commit comments