Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library. #625
+327
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The Transform enables KV-head replication for CausalLMs and VLMs as well.
The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model.
n_kv_head_repeatparam acts as the multiplier for the number of repeats to be done to original count of KV heads. This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict.Added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model. Two new optional paramters
n_kv_head_repeatandtest_kv_replicateare added for testing purpose. Settingtest_kv_replicateto True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads.