Skip to content

Conversation

@AlexKoff88
Copy link
Contributor

cache_block_outputs enables the collection of the block output to speed up GPTQ process. However, it does not work for some models such as ChatGLM where the LayerNorm is the first layer in the block.
Just compare:

OPT structure:
model.decoder.layers.0.self_attn
model.decoder.layers.0.self_attn.k_proj
model.decoder.layers.0.self_attn.v_proj
model.decoder.layers.0.self_attn.q_proj
model.decoder.layers.0.self_attn.out_proj
model.decoder.layers.0.activation_fn
model.decoder.layers.0.self_attn_layer_norm
model.decoder.layers.0.fc1
model.decoder.layers.0.fc2
model.decoder.layers.0.final_layer_norm

ChatGLM structure:
transformer.encoder.layers.0
transformer.encoder.layers.0.input_layernorm
transformer.encoder.layers.0.self_attention
transformer.encoder.layers.0.self_attention.query_key_value
transformer.encoder.layers.0.self_attention.core_attention
transformer.encoder.layers.0.self_attention.core_attention.attention_dropout
transformer.encoder.layers.0.self_attention.dense
transformer.encoder.layers.0.post_attention_layernorm
transformer.encoder.layers.0.mlp
transformer.encoder.layers.0.mlp.dense_h_to_4h
transformer.encoder.layers.0.mlp.dense_4h_to_h

The solution is to disable SA block output caching and collect the quantizing block inputs starting from the beginning of the model. It slows down the optimization a bit but works more stable.

Related PR to Optimum: huggingface/optimum#1479

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @AlexKoff88, let's wait for the optimum PR to be merged. We might not need this argument.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! A few nits to fix.

@SunMarc
Copy link
Member

SunMarc commented Oct 31, 2023

@AlexKoff88 please run make style to fix the tests.

@SunMarc SunMarc requested a review from amyeroberts October 31, 2023 17:56
@AlexKoff88
Copy link
Contributor Author

@AlexKoff88 please run make style to fix the tests.

Fixed

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

Just a small nit on the docstring.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@amyeroberts amyeroberts merged commit f9b4bea into huggingface:main Nov 1, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
huggingface#27032)

* Added cache_block_outputs option to enable GPTQ for non-regular models

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <[email protected]>

* Fixed style

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants