Skip to content

Commit ab478e8

Browse files
Add caching for parametrization
1 parent 179db08 commit ab478e8

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

bitsandbytes/nn/parametrize.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ class Bnb4bitParametrization(nn.Module):
2121
The quantization state containing the necessary information for dequantization.
2222
"""
2323

24-
def __init__(self, quant_state: F.QuantState, p_name="unknown"):
24+
def __init__(self, quant_state: F.QuantState):
2525
super().__init__()
2626
self.quant_state = quant_state
27-
self.p_name = p_name
2827

28+
@torch.no_grad()
2929
def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
3030
"""
3131
Forward pass to dequantize the parameter.
@@ -55,13 +55,8 @@ def replace_parameter_4bit_prequantized(
5555
# Apply a parametrization to the module to handle dequantization.
5656
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
5757

58-
# Next, register state dict hook for saving.
59-
module.register_state_dict_post_hook(
60-
partial(
61-
_parametrized_state_dict_post_hook,
62-
param_name=param_name,
63-
)
64-
)
58+
# Next, register hooks.
59+
_register_parametrization_hooks(module, param_name)
6560

6661

6762
def replace_parameter_4bit(
@@ -127,14 +122,35 @@ def replace_parameter_4bit(
127122
# Apply a parametrization to the module to handle dequantization.
128123
P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
129124

130-
# Next, register state dict hook for saving.
125+
# Next, register hooks.
126+
_register_parametrization_hooks(module, param_name)
127+
128+
129+
def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any):
130+
P._cache_enabled -= 1
131+
if not P._cache_enabled:
132+
P._cache = {}
133+
134+
135+
def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]):
136+
P._cache_enabled += 1
137+
138+
139+
def _register_parametrization_hooks(module: nn.Module, param_name: str):
140+
# Register a state dict hook for saving.
131141
module.register_state_dict_post_hook(
132142
partial(
133143
_parametrized_state_dict_post_hook,
134144
param_name=param_name,
135145
)
136146
)
137147

148+
# Register hooks to enable caching for the dequantization parametrization.
149+
# This helps preserve time and memory when the same quantized parameter
150+
# is accessed multiple times in the forward computation.
151+
module.register_forward_pre_hook(_enable_parametrization_cache)
152+
module.register_forward_hook(_disable_parametrization_cache)
153+
138154

139155
def _parametrized_state_dict_post_hook(
140156
module: nn.Module,

0 commit comments

Comments
 (0)