Skip to content

Attribute error on _NotYetLoadedTensor loading checkpoint into quantized model #20119

@rasbt

Description

@rasbt

Bug description

When upgrading the lightning version from 2.3.0.dev20240428 to 2.3.3, we encounter an AttributeError: '_NotYetLoadedTensor' object has no attribute 'data'.

What version are you seeing the problem on?

master

How to reproduce the bug

litgpt generate --quantize bnb.nf4 checkpoints/microsoft/phi-2

Error messages and logs

⚡ main ~/litgpt2 litgpt generate --quantize bnb.nf4 checkpoints/microsoft/phi-2 
{'checkpoint_dir': PosixPath('checkpoints/microsoft/phi-2'),
 'compile': False,
 'max_new_tokens': 50,
 'num_samples': 1,
 'precision': None,
 'prompt': 'What food do llamas eat?',
 'quantize': 'bnb.nf4',
 'temperature': 0.8,
 'top_k': 50,
 'top_p': 1.0}
Loading model 'checkpoints/microsoft/phi-2/lit_model.pth' with {'name': 'phi-2', 'hf_config': {'name': 'phi-2', 'org': 'microsoft'}, 'scale_embeddings': False, 'block_size': 2048, 'vocab_size': 50257, 'padding_multiple': 512, 'padded_vocab_size': 51200, 'n_layer': 32, 'n_head': 32, 'head_size': 80, 'n_embd': 2560, 'rotary_percentage': 0.4, 'parallel_residual': True, 'bias': True, 'lm_head_bias': True, 'n_query_groups': 32, 'shared_attention_norm': True, 'norm_class_name': 'LayerNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'GptNeoxMLP', 'gelu_approximate': 'tanh', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 32}
Time to instantiate model: 0.25 seconds.
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt2/litgpt/__main__.py", line 71, in main
    CLI(parser_data)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py", line 255, in main
    load_checkpoint(fabric, model, checkpoint_path)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/utils.py", line 362, in load_checkpoint
    model.load_state_dict(state_dict, strict=strict)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 168, in load_state_dict
    return self._original_module.load_state_dict(state_dict=state_dict, strict=strict, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2139, in load_state_dict
    load(self, state_dict)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2127, in load
    load(child, child_state_dict, child_prefix)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2121, in load
    module._load_from_state_dict(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1991, in _load_from_state_dict
    hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 72, in __call__
    return self.hook(*args, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 166, in _quantize_on_load_hook
    quantize_fn(weight)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/plugins/precision/bitsandbytes.py", line 320, in quantize_
    if weight.data.dtype == torch.uint8:
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/utilities/load.py", line 166, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: '_NotYetLoadedTensor' object has no attribute 'data'

Environment

Current environment
* CUDA:
        - GPU:
                - NVIDIA A10G
        - available:         True
        - version:           12.1
* Lightning:
        - lightning:         2.3.3
        - lightning-cloud:   0.5.70
        - lightning-sdk:     0.1.10
        - lightning-utilities: 0.11.3.post0
        - pytorch-lightning: 2.3.3
        - torch:             2.2.1+cu121
        - torchmetrics:      1.3.1
        - torchvision:       0.17.1+cu121
* Packages:
        - absl-py:           2.1.0
        - accelerate:        0.32.1
        - aiohttp:           3.9.5
        - aiosignal:         1.3.1
        - annotated-types:   0.7.0
        - anyio:             4.4.0
        - argon2-cffi:       23.1.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.3.0
        - asttokens:         2.4.1
        - async-lru:         2.0.4
        - async-timeout:     4.0.3
        - attrs:             23.2.0
        - babel:             2.15.0
        - backoff:           2.2.1
        - beautifulsoup4:    4.12.3
        - bitsandbytes:      0.42.0
        - bleach:            6.1.0
        - boto3:             1.34.142
        - botocore:          1.34.142
        - cachetools:        5.3.3
        - certifi:           2024.7.4
        - cffi:              1.16.0
        - chardet:           5.2.0
        - charset-normalizer: 3.3.2
        - click:             8.1.7
        - colorama:          0.4.6
        - comm:              0.2.2
        - contourpy:         1.2.1
        - cycler:            0.12.1
        - dataproperty:      1.0.1
        - datasets:          2.20.0
        - debugpy:           1.8.2
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - dill:              0.3.8
        - dnspython:         2.6.1
        - docstring-parser:  0.16
        - email-validator:   2.2.0
        - evaluate:          0.4.2
        - exceptiongroup:    1.2.1
        - executing:         2.0.1
        - fastapi:           0.111.0
        - fastapi-cli:       0.0.4
        - fastjsonschema:    2.20.0
        - filelock:          3.15.4
        - fire:              0.6.0
        - fonttools:         4.53.1
        - fqdn:              1.5.1
        - frozenlist:        1.4.1
        - fsspec:            2024.5.0
        - google-auth:       2.32.0
        - google-auth-oauthlib: 1.2.1
        - grpcio:            1.64.1
        - h11:               0.14.0
        - hf-transfer:       0.1.6
        - httpcore:          1.0.5
        - httptools:         0.6.1
        - httpx:             0.27.0
        - huggingface-hub:   0.24.0
        - idna:              3.7
        - importlib-resources: 6.4.0
        - ipykernel:         6.26.0
        - ipython:           8.17.2
        - ipywidgets:        8.1.1
        - isoduration:       20.11.0
        - jedi:              0.19.1
        - jinja2:            3.1.4
        - jmespath:          1.0.1
        - joblib:            1.4.2
        - json5:             0.9.25
        - jsonargparse:      4.31.0
        - jsonlines:         4.0.0
        - jsonpointer:       3.0.0
        - jsonschema:        4.23.0
        - jsonschema-specifications: 2023.12.1
        - jupyter-client:    8.6.2
        - jupyter-core:      5.7.2
        - jupyter-events:    0.10.0
        - jupyter-lsp:       2.2.5
        - jupyter-server:    2.14.1
        - jupyter-server-terminals: 0.5.3
        - jupyterlab:        4.2.0
        - jupyterlab-pygments: 0.3.0
        - jupyterlab-server: 2.27.2
        - jupyterlab-widgets: 3.0.11
        - kiwisolver:        1.4.5
        - lightning:         2.3.3
        - lightning-cloud:   0.5.70
        - lightning-sdk:     0.1.10
        - lightning-utilities: 0.11.3.post0
        - litdata:           0.2.17
        - litgpt:            0.4.5
        - litserve:          0.1.3
        - lm-eval:           0.4.3
        - lxml:              5.2.2
        - markdown:          3.6
        - markdown-it-py:    3.0.0
        - markupsafe:        2.1.5
        - matplotlib:        3.8.2
        - matplotlib-inline: 0.1.7
        - mbstrdecoder:      1.1.3
        - mdurl:             0.1.2
        - mistune:           3.0.2
        - more-itertools:    10.3.0
        - mpmath:            1.3.0
        - multidict:         6.0.5
        - multiprocess:      0.70.16
        - nbclient:          0.10.0
        - nbconvert:         7.16.4
        - nbformat:          5.10.4
        - nest-asyncio:      1.6.0
        - networkx:          3.3
        - nltk:              3.8.1
        - notebook-shim:     0.2.4
        - numexpr:           2.10.1
        - numpy:             1.26.4
        - nvidia-cublas-cu12: 12.1.3.1
        - nvidia-cuda-cupti-cu12: 12.1.105
        - nvidia-cuda-nvrtc-cu12: 12.1.105
        - nvidia-cuda-runtime-cu12: 12.1.105
        - nvidia-cudnn-cu12: 8.9.2.26
        - nvidia-cufft-cu12: 11.0.2.54
        - nvidia-curand-cu12: 10.3.2.106
        - nvidia-cusolver-cu12: 11.4.5.107
        - nvidia-cusparse-cu12: 12.1.0.106
        - nvidia-nccl-cu12:  2.19.3
        - nvidia-nvjitlink-cu12: 12.5.82
        - nvidia-nvtx-cu12:  12.1.105
        - oauthlib:          3.2.2
        - orjson:            3.10.6
        - overrides:         7.7.0
        - packaging:         24.1
        - pandas:            2.1.4
        - pandocfilters:     1.5.1
        - parso:             0.8.4
        - pathvalidate:      3.2.0
        - peft:              0.11.1
        - pexpect:           4.9.0
        - pillow:            10.4.0
        - pip:               24.1.2
        - platformdirs:      4.2.2
        - portalocker:       2.10.1
        - prometheus-client: 0.20.0
        - prompt-toolkit:    3.0.47
        - protobuf:          4.23.4
        - psutil:            6.0.0
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           17.0.0
        - pyarrow-hotfix:    0.6
        - pyasn1:            0.6.0
        - pyasn1-modules:    0.4.0
        - pybind11:          2.13.1
        - pycparser:         2.22
        - pydantic:          2.8.2
        - pydantic-core:     2.20.1
        - pygments:          2.18.0
        - pyjwt:             2.8.0
        - pyparsing:         3.1.2
        - pytablewriter:     1.2.0
        - python-dateutil:   2.9.0.post0
        - python-dotenv:     1.0.1
        - python-json-logger: 2.0.7
        - python-multipart:  0.0.9
        - pytorch-lightning: 2.3.3
        - pytz:              2024.1
        - pyyaml:            6.0.1
        - pyzmq:             26.0.3
        - referencing:       0.35.1
        - regex:             2024.5.15
        - requests:          2.32.3
        - requests-oauthlib: 2.0.0
        - rfc3339-validator: 0.1.4
        - rfc3986-validator: 0.1.1
        - rich:              13.7.1
        - rouge-score:       0.1.2
        - rpds-py:           0.19.0
        - rsa:               4.9
        - s3transfer:        0.10.2
        - sacrebleu:         2.4.2
        - safetensors:       0.4.3
        - scikit-learn:      1.3.2
        - scipy:             1.11.4
        - send2trash:        1.8.3
        - sentencepiece:     0.2.0
        - setuptools:        69.5.1
        - shellingham:       1.5.4
        - simple-term-menu:  1.6.4
        - six:               1.16.0
        - sniffio:           1.3.1
        - soupsieve:         2.5
        - sqlitedict:        2.1.0
        - stack-data:        0.6.3
        - starlette:         0.37.2
        - sympy:             1.13.0
        - tabledata:         1.3.3
        - tabulate:          0.9.0
        - tcolorpy:          0.1.6
        - tensorboard:       2.15.1
        - tensorboard-data-server: 0.7.2
        - termcolor:         2.4.0
        - terminado:         0.18.1
        - threadpoolctl:     3.5.0
        - tinycss2:          1.3.0
        - tokenizers:        0.19.1
        - tomli:             2.0.1
        - torch:             2.2.1+cu121
        - torchmetrics:      1.3.1
        - torchvision:       0.17.1+cu121
        - tornado:           6.4.1
        - tqdm:              4.66.4
        - tqdm-multiprocess: 0.0.11
        - traitlets:         5.14.3
        - transformers:      4.42.4
        - triton:            2.2.0
        - typepy:            1.3.2
        - typer:             0.12.3
        - types-python-dateutil: 2.9.0.20240316
        - typeshed-client:   2.7.0
        - typing-extensions: 4.12.2
        - tzdata:            2024.1
        - ujson:             5.10.0
        - uri-template:      1.3.0
        - urllib3:           2.2.2
        - uvicorn:           0.30.1
        - uvloop:            0.19.0
        - watchfiles:        0.22.0
        - wcwidth:           0.2.13
        - webcolors:         24.6.0
        - webencodings:      0.5.1
        - websocket-client:  1.8.0
        - websockets:        12.0
        - werkzeug:          3.0.3
        - wheel:             0.43.0
        - widgetsnbextension: 4.0.11
        - word2number:       1.1
        - xxhash:            3.4.1
        - yarl:              1.9.4
        - zstandard:         0.23.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - 
        - processor:         x86_64
        - python:            3.10.10
        - release:           5.15.0-1064-aws
        - version:           #70~20.04.1-Ubuntu SMP Fri Jun 14 15:42:13 UTC 2024

More info

Now, one thing to keep in mind is that we have the bitsandbytes version pinned to bitsandbytes==0.42.0 because 0.43 results in the following issue:

/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py:207: UserWarning: LitGPT only supports bitsandbytes v0.42.0. This may result in errors when using quantization.
  warnings.warn(
Loading model 'checkpoints/microsoft/phi-2/lit_model.pth' with {'name': 'phi-2', 'hf_config': {'name': 'phi-2', 'org': 'microsoft'}, 'scale_embeddings': False, 'block_size': 2048, 'vocab_size': 50257, 'padding_multiple': 512, 'padded_vocab_size': 51200, 'n_layer': 32, 'n_head': 32, 'head_size': 80, 'n_embd': 2560, 'rotary_percentage': 0.4, 'parallel_residual': True, 'bias': True, 'lm_head_bias': True, 'n_query_groups': 32, 'shared_attention_norm': True, 'norm_class_name': 'LayerNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'GptNeoxMLP', 'gelu_approximate': 'tanh', 'intermediate_size': 10240, 'rope_condense_ratio': 1, 'rope_base': 10000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 32}
Time to instantiate model: 0.24 seconds.
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/teamspace/studios/this_studio/litgpt2/litgpt/__main__.py", line 71, in main
    CLI(parser_data)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 119, in CLI
    return _run_component(component, init.get(subcommand))
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/jsonargparse/_cli.py", line 204, in _run_component
    return component(**cfg)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/teamspace/studios/this_studio/litgpt2/litgpt/generate/base.py", line 252, in main
    model = fabric.setup_module(model)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 308, in setup_module
    module = self._move_model_to_device(model=module, optimizers=[])
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 976, in _move_model_to_device
    model = self.to_device(model)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 526, in to_device
    self._strategy.module_to_device(obj)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning/fabric/strategies/single_device.py", line 59, in module_to_device
    module.to(self.root_device)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1152, in to
    return self._apply(convert)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1150, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 324, in to
    return self._quantize(device)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 289, in _quantize
    w_4bit, quant_state = bnb.functional.quantize_4bit(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/bitsandbytes/functional.py", line 1234, in quantize_4bit
    raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
ValueError: Blockwise quantization only supports 16/32-bit floats, but got torch.uint8

I wonder if these are issues that can or need to be addressed one at a time. Supporting lightning 2.3.3 with bitsandbytes==0.42.0 first to restore the litgpt quantization as is, and then see how we can upgrade to the most recent bitsandbytes version.

Any thoughts?

cc @awaelchli @carmocca

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions