-
-
Notifications
You must be signed in to change notification settings - Fork 793
[Core] Change 8-bit serialization weight format format
#1164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
e93bb3f
change 8-bit serialization weight format format
younesbelkada c4d8af2
precimmit
younesbelkada 4bf0af5
pre-commit
younesbelkada 8a5668d
fix
younesbelkada ff8b9a7
Update bitsandbytes/nn/modules.py
younesbelkada 9956f5b
Update bitsandbytes/nn/modules.py
younesbelkada 3272a28
Update bitsandbytes/utils.py
younesbelkada e92be59
address feedback
younesbelkada 8f2f57b
Merge branch 'fix-8bit-serialization' of https://github.com/TimDettme…
younesbelkada 6006274
lint
younesbelkada File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,11 @@ | |
| from bitsandbytes.autograd._functions import get_tile_inds, undo_layout | ||
| from bitsandbytes.functional import QuantState | ||
| from bitsandbytes.optim import GlobalOptimManager | ||
| from bitsandbytes.utils import OutlierTracer | ||
| from bitsandbytes.utils import ( | ||
| INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, | ||
| LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, | ||
| OutlierTracer, | ||
| ) | ||
|
|
||
| T = TypeVar("T", bound="torch.nn.Module") | ||
|
|
||
|
|
@@ -619,6 +623,16 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k | |
| return | ||
| weight_format = state_dict.pop(f"{prefix}weight_format", "row") | ||
|
|
||
| if isinstance(weight_format, torch.Tensor): | ||
| weight_format = weight_format.item() | ||
|
|
||
| # For new weights format storage type, we explicitly check | ||
| # if weights_format is on the mapping | ||
| if isinstance(weight_format, int) and weight_format not in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: | ||
| raise ValueError(f"Expected supported weight format - got {weight_format}") | ||
| elif isinstance(weight_format, int) and weight_format in INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: | ||
| weight_format = INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weight_format] | ||
|
|
||
|
Comment on lines
+626
to
+635
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As said before, this should probably be a free-standing helper function. |
||
| if weight_format != "row": | ||
| tile_indices = get_tile_inds(weight_format, weight.device) | ||
| state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) | ||
|
|
@@ -711,13 +725,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): | |
| if not self.state.has_fp16_weights: | ||
| if param_from_weight is not None: | ||
| destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() | ||
| destination[format_name] = "row" | ||
| destination[format_name] = torch.tensor(0, dtype=torch.uint8) | ||
| elif param_from_state is not None and not layout_reordered: | ||
| destination[key_name] = param_from_state if keep_vars else param_from_state.detach() | ||
| destination[format_name] = "row" | ||
| destination[format_name] = torch.tensor(0, dtype=torch.uint8) | ||
| elif param_from_state is not None: | ||
| destination[key_name] = param_from_state if keep_vars else param_from_state.detach() | ||
| destination[format_name] = self.state.formatB | ||
| weights_format = self.state.formatB | ||
| # At this point `weights_format` is an str | ||
| if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING: | ||
| raise ValueError(f"Unrecognized weights format {weights_format}") | ||
|
|
||
| weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format] | ||
|
|
||
| destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8) | ||
|
|
||
| def _load_from_state_dict( | ||
| self, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to make much sense? Why is the
isinstance()check repeated?