Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from bitsandbytes.autograd._functions import get_tile_inds, undo_layout
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import OutlierTracer
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)

T = TypeVar("T", bound="torch.nn.Module")

Expand Down Expand Up @@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
return
weight_format = state_dict.pop(f"{prefix}weight_format", "row")

if isinstance(weight_format, torch.Tensor):
weight_format = weight_format.item()

# For new weights format storage type, we explicitly check
# if weights_format is on the mapping
if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Expected supported weight format - got {weight_format}")
elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format]
Comment on lines +629 to +634
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to make much sense? Why is the isinstance() check repeated?


Comment on lines +626 to +635
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As said before, this should probably be a free-standing helper function.

if weight_format != "row":
tile_indices = get_tile_inds(weight_format, weight.device)
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
Expand Down Expand Up @@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = "row"
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = self.state.formatB
weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")

weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]

destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)

def _load_from_state_dict(
self,
Expand Down
4 changes: 4 additions & 0 deletions bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,7 @@ def unpack_tensor_to_dict(tensor_data):
unpacked_dict = json.loads(json_str)

return unpacked_dict


LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}