Skip to content

Commit 7449d71

Browse files
younesbelkadaakx
andauthored
[Core] Change 8-bit serialization weight format format (#1164)
* change 8-bit serialization weight format format * precimmit * pre-commit * fix * Update bitsandbytes/nn/modules.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/nn/modules.py Co-authored-by: Aarni Koskela <[email protected]> * Update bitsandbytes/utils.py Co-authored-by: Aarni Koskela <[email protected]> * address feedback * lint --------- Co-authored-by: Aarni Koskela <[email protected]>
1 parent c54053d commit 7449d71

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

bitsandbytes/nn/modules.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
1515
from bitsandbytes.functional import QuantState
1616
from bitsandbytes.optim import GlobalOptimManager
17-
from bitsandbytes.utils import OutlierTracer
17+
from bitsandbytes.utils import (
18+
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
19+
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
20+
OutlierTracer,
21+
)
1822

1923
T = TypeVar("T", bound="torch.nn.Module")
2024

@@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
619623
return
620624
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
621625

626+
if isinstance(weight_format, torch.Tensor):
627+
weight_format = weight_format.item()
628+
629+
# For new weights format storage type, we explicitly check
630+
# if weights_format is on the mapping
631+
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
632+
raise ValueError(f"Expected supported weight format - got {weight_format}")
633+
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
634+
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
635+
622636
if weight_format != "row":
623637
tile_indices = get_tile_inds(weight_format, weight.device)
624638
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
@@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
711725
if not self.state.has_fp16_weights:
712726
if param_from_weight is not None:
713727
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
714-
destination[format_name] = "row"
728+
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
715729
elif param_from_state is not None and not layout_reordered:
716730
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
717-
destination[format_name] = "row"
731+
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
718732
elif param_from_state is not None:
719733
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
720-
destination[format_name] = self.state.formatB
734+
weights_format = self.state.formatB
735+
# At this point `weights_format` is an str
736+
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
737+
raise ValueError(f"Unrecognized weights format {weights_format}")
738+
739+
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
740+
741+
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
721742

722743
def _load_from_state_dict(
723744
self,

bitsandbytes/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data):
198198
unpacked_dict = json.loads(json_str)
199199

200200
return unpacked_dict
201+
202+
203+
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
204+
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}

0 commit comments

Comments
 (0)