- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Description
System Info
2023-05-24 23:09:53.575434: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
WARNING:tensorflow:From /usr/local/lib/python3.10/dist-packages/transformers/commands/env.py:63: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.config.list_physical_devices('GPU') instead.
2023-05-24 23:10:05.261610: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.
- transformersversion: 4.29.2
- Platform: Linux-5.15.107+-x86_64-with-glibc2.31
- Python version: 3.10.11
- Huggingface_hub version: 0.14.1
- Safetensors version: not installed
- PyTorch version (GPU?): 2.0.1+cu118 (True)
- Tensorflow version (GPU?): 2.12.0 (True)
- Flax version (CPU?/GPU?/TPU?): 0.6.9 (gpu)
- Jax version: 0.4.10
- JaxLib version: 0.4.10
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
-  An officially supported task in the examplesfolder (such as GLUE/SQuAD, ...)
- My own task or dataset (give details below)
Reproduction
I was trying to use the new shiny mpt model from the huggingface hub from a revision :
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
import transformers
import accelerate
model_name = 'mosaicml/mpt-7b'
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                            trust_remote_code=True,
                                            revision="refs/pr/23",
                                            device_map="auto"
                                          )But I stumble on this error after the using the above code :
ValueError: MPTForCausalLM does not support device_map='auto' yet.
The "auto" was indeed not supported in the main branch but we add a correction in the PR branch (so the argument revision="refs/pr/23")
I did some investigation and the model was indeed loading the main .py files :
Downloading (…)main/modeling_mpt.py: 100%
17.4k/17.4k [00:00<00:00, 1.12MB/s]
Downloading (…)in/param_init_fns.py: 100%
12.6k/12.6k [00:00<00:00, 971kB/s]
Downloading (…)resolve/main/norm.py: 100%
2.56k/2.56k [00:00<00:00, 131kB/s]
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- norm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b:
- param_init_fns.py
- norm.py
You can see the main/ here. I did manually check the modeling_mpt.py file it didn't have the PR changes.
So I did try to find where the bug where inside the transformers package ... (first time looking at the code).
I am a bit surprised !
Basicly the code rewrite the config values after having read it (it adds the information about the repo ids (in add_model_info_to_auto_map in generic.py in utils/ from the transformers package) something that seems normal.
"auto_map": {
    "AutoConfig": "mosaicml/mpt-7b--configuration_mpt.MPTConfig",
    "AutoModelForCausalLM": "mosaicml/mpt-7b--modeling_mpt.MPTForCausalLM"
  }
It notably add the "--" string.
then in get_class_from_dynamic_module (in dynamic_module_utils.py) it has :
if "--" in class_reference:
        repo_id, class_reference = class_reference.split("--")
        # Invalidate revision since it's not relevant for this repo
        revision = "main"
So the revision become "main" and from here we are done.
I suppose if i do a PR removing the revision overide some people will not be happy ?
Expected behavior
The expected behaviour is to load the file from the PR branch. (not the main/)