|
14 | 14 | from bitsandbytes.autograd._functions import get_tile_inds, undo_layout |
15 | 15 | from bitsandbytes.functional import QuantState |
16 | 16 | from bitsandbytes.optim import GlobalOptimManager |
17 | | -from bitsandbytes.utils import OutlierTracer |
| 17 | +from bitsandbytes.utils import ( |
| 18 | + INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, |
| 19 | + LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, |
| 20 | + OutlierTracer, |
| 21 | +) |
18 | 22 |
|
19 | 23 | T = TypeVar("T", bound="torch.nn.Module") |
20 | 24 |
|
@@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k |
619 | 623 | return |
620 | 624 | weight_format = state_dict.pop(f"{prefix}weight_format", "row") |
621 | 625 |
|
| 626 | + if isinstance(weight_format, torch.Tensor): |
| 627 | + weight_format = weight_format.item() |
| 628 | + |
| 629 | + # For new weights format storage type, we explicitly check |
| 630 | + # if weights_format is on the mapping |
| 631 | + if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: |
| 632 | + raise ValueError(f"Expected supported weight format - got {weight_format}") |
| 633 | + elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: |
| 634 | + weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] |
| 635 | + |
622 | 636 | if weight_format != "row": |
623 | 637 | tile_indices = get_tile_inds(weight_format, weight.device) |
624 | 638 | state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) |
@@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): |
711 | 725 | if not self.state.has_fp16_weights: |
712 | 726 | if param_from_weight is not None: |
713 | 727 | destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() |
714 | | - destination[format_name] = "row" |
| 728 | + destination[format_name] = torch.tensor(0, dtype=torch.uint8) |
715 | 729 | elif param_from_state is not None and not layout_reordered: |
716 | 730 | destination[key_name] = param_from_state if keep_vars else param_from_state.detach() |
717 | | - destination[format_name] = "row" |
| 731 | + destination[format_name] = torch.tensor(0, dtype=torch.uint8) |
718 | 732 | elif param_from_state is not None: |
719 | 733 | destination[key_name] = param_from_state if keep_vars else param_from_state.detach() |
720 | | - destination[format_name] = self.state.formatB |
| 734 | + weights_format = self.state.formatB |
| 735 | + # At this point `weights_format` is an str |
| 736 | + if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: |
| 737 | + raise ValueError(f"Unrecognized weights format {weights_format}") |
| 738 | + |
| 739 | + weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] |
| 740 | + |
| 741 | + destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8) |
721 | 742 |
|
722 | 743 | def _load_from_state_dict( |
723 | 744 | self, |
|
0 commit comments