Skip to content

Commit 6acb4e4

Browse files
gallilmaimonCyrilvallezydshiehyonigozlanIlyasMoutawwakil
authored
Support BatchNorm in Hubert pos_conv_emb as in fairseq (#34389)
* Support BatchNorm in Hubert pos_conv_emb as in fairseq * Correct the new defaults (#34377) * Correct the new defaults * CIs * add check * Update utils.py * Update utils.py * Add the max_length in generate test checking shape without passing length * style * CIs * fix fx CI issue * [auto. ping] Avoid sending empty info + add more team members (#34383) * update * update --------- Co-authored-by: ydshieh <[email protected]> * Fix glm (#34388) * Fix duplicated * fix import * Use non nested images and batched text Idefics2/3 (#34222) * add support for non nested images and add tests * add tests error scenario * fix style * added single and no image to error tests * Fix onnx non-expotable inplace aten op (#34376) * fix onnx non-expotable inplace op * mistral, qwen2, qwen2_vl, starcoder2 * fixup copies * Fix right padding in LLaVA models (#34305) * fix right pad llavas * device mismatch * no filter (#34391) * no filter * no filter * no filter --------- Co-authored-by: ydshieh <[email protected]> * SynthID: better example (#34372) * better example * Update src/transformers/generation/configuration_utils.py * Update src/transformers/generation/logits_process.py * nits * Tests: upgrade `test_eager_matches_sdpa_generate` (#34386) * Fix bnb training test failure (#34414) * Fix bnb training test: compatibility with OPTSdpaAttention * Avoid check expected exception when it is on CUDA (#34408) * update * update --------- Co-authored-by: ydshieh <[email protected]> * Fix typos in agents_advanced.md (#34405) * [docs] Cache implementations (#34325) cache * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq Add conversion integration test, and make batchnorm explicit variable * Support BatchNorm in Hubert pos_conv_emb as in fairseq fix make fixup styling changes * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq * [run-slow] hubert * Support BatchNorm in Hubert pos_conv_emb as in fairseq Add conversion integration test, and make batchnorm explicit variable * Support BatchNorm in Hubert pos_conv_emb as in fairseq fix make fixup styling changes * [run-slow] hubert * [run-slow] hubert --------- Co-authored-by: Cyril Vallez <[email protected]> Co-authored-by: Yih-Dar <[email protected]> Co-authored-by: ydshieh <[email protected]> Co-authored-by: Yoni Gozlan <[email protected]> Co-authored-by: Ilyas Moutawwakil <[email protected]> Co-authored-by: Raushan Turganbay <[email protected]> Co-authored-by: Joao Gante <[email protected]> Co-authored-by: Matthew Douglas <[email protected]> Co-authored-by: Rudy Delouya <[email protected]> Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Yoach Lacombe <[email protected]>
1 parent 80f2b16 commit 6acb4e4

File tree

4 files changed

+77
-19
lines changed

4 files changed

+77
-19
lines changed

src/transformers/models/hubert/configuration_hubert.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class HubertConfig(PretrainedConfig):
9494
embeddings layer.
9595
num_conv_pos_embedding_groups (`int`, *optional*, defaults to 16):
9696
Number of groups of 1D convolutional positional embeddings layer.
97+
conv_pos_batch_norm (`bool`, *optional*, defaults to `False`):
98+
Whether to use batch norm instead of weight norm in conv_pos
9799
do_stable_layer_norm (`bool`, *optional*, defaults to `False`):
98100
Whether do apply *stable* layer norm architecture of the Transformer encoder. `do_stable_layer_norm is
99101
True` corresponds to applying layer norm before the attention layer, whereas `do_stable_layer_norm is
@@ -182,6 +184,7 @@ def __init__(
182184
conv_bias=False,
183185
num_conv_pos_embeddings=128,
184186
num_conv_pos_embedding_groups=16,
187+
conv_pos_batch_norm=False,
185188
do_stable_layer_norm=False,
186189
apply_spec_augment=True,
187190
mask_time_prob=0.05,
@@ -209,6 +212,7 @@ def __init__(
209212
self.conv_bias = conv_bias
210213
self.num_conv_pos_embeddings = num_conv_pos_embeddings
211214
self.num_conv_pos_embedding_groups = num_conv_pos_embedding_groups
215+
self.conv_pos_batch_norm = conv_pos_batch_norm
212216
self.num_feat_extract_layers = len(self.conv_dim)
213217
self.num_hidden_layers = num_hidden_layers
214218
self.intermediate_size = intermediate_size

src/transformers/models/hubert/convert_hubert_original_pytorch_checkpoint_to_pytorch.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838

3939
MAPPING = {
4040
"post_extract_proj": "feature_projection.projection",
41-
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
41+
"encoder.pos_conv.0": "encoder.pos_conv_embed.batch_norm",
42+
"encoder.pos_conv.1": "encoder.pos_conv_embed.conv",
4243
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
4344
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
4445
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
@@ -76,6 +77,12 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
7677
hf_pointer.weight_v.data = value
7778
elif weight_type == "bias":
7879
hf_pointer.bias.data = value
80+
elif weight_type == "running_mean":
81+
hf_pointer.running_mean.data = value
82+
elif weight_type == "running_var":
83+
hf_pointer.running_var.data = value
84+
elif weight_type == "num_batches_tracked":
85+
hf_pointer.num_batches_tracked.data = value
7986
else:
8087
hf_pointer.data = value
8188

@@ -116,6 +123,12 @@ def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
116123
weight_type = "weight"
117124
elif "bias" in name:
118125
weight_type = "bias"
126+
elif "running_mean" in name:
127+
weight_type = "running_mean"
128+
elif "running_var" in name:
129+
weight_type = "running_var"
130+
elif "num_batches_tracked" in name:
131+
weight_type = "num_batches_tracked"
119132
else:
120133
weight_type = None
121134
set_recursively(hf_model, mapped_key, value, name, weight_type)

src/transformers/models/hubert/modeling_hubert.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,6 @@ def forward(self, hidden_states):
260260
return hidden_states
261261

262262

263-
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
264263
class HubertPositionalConvEmbedding(nn.Module):
265264
def __init__(self, config):
266265
super().__init__()
@@ -272,32 +271,37 @@ def __init__(self, config):
272271
groups=config.num_conv_pos_embedding_groups,
273272
)
274273

275-
weight_norm = nn.utils.weight_norm
276-
if hasattr(nn.utils.parametrizations, "weight_norm"):
277-
weight_norm = nn.utils.parametrizations.weight_norm
274+
self.batch_norm = None
275+
if config.conv_pos_batch_norm:
276+
self.batch_norm = nn.BatchNorm1d(config.hidden_size)
277+
else:
278+
weight_norm = nn.utils.weight_norm
279+
if hasattr(nn.utils.parametrizations, "weight_norm"):
280+
weight_norm = nn.utils.parametrizations.weight_norm
278281

279-
if is_deepspeed_zero3_enabled():
280-
import deepspeed
282+
if is_deepspeed_zero3_enabled():
283+
import deepspeed
281284

282-
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
283-
self.conv = weight_norm(self.conv, name="weight", dim=2)
284-
if hasattr(self.conv, "parametrizations"):
285-
weight_g = self.conv.parametrizations.weight.original0
286-
weight_v = self.conv.parametrizations.weight.original1
285+
with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
286+
self.conv = weight_norm(self.conv, name="weight", dim=2)
287+
if hasattr(self.conv, "parametrizations"):
288+
weight_g = self.conv.parametrizations.weight.original0
289+
weight_v = self.conv.parametrizations.weight.original1
290+
else:
291+
weight_g = self.conv.weight_g
292+
weight_v = self.conv.weight_v
293+
deepspeed.zero.register_external_parameter(self, weight_v)
294+
deepspeed.zero.register_external_parameter(self, weight_g)
287295
else:
288-
weight_g = self.conv.weight_g
289-
weight_v = self.conv.weight_v
290-
deepspeed.zero.register_external_parameter(self, weight_v)
291-
deepspeed.zero.register_external_parameter(self, weight_g)
292-
else:
293-
self.conv = weight_norm(self.conv, name="weight", dim=2)
296+
self.conv = weight_norm(self.conv, name="weight", dim=2)
294297

295298
self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
296299
self.activation = ACT2FN[config.feat_extract_activation]
297300

298301
def forward(self, hidden_states):
299302
hidden_states = hidden_states.transpose(1, 2)
300-
303+
if self.batch_norm is not None:
304+
hidden_states = self.batch_norm(hidden_states)
301305
hidden_states = self.conv(hidden_states)
302306
hidden_states = self.padding(hidden_states)
303307
hidden_states = self.activation(hidden_states)

tests/models/hubert/test_modeling_hubert.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,3 +943,40 @@ def test_inference_distilhubert(self):
943943
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
944944
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
945945
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)
946+
947+
def test_inference_hubert_25hz(self):
948+
model = HubertModel.from_pretrained("slprl/mhubert-base-25hz").to(torch_device)
949+
950+
sample = self._load_datasamples(1)
951+
input_speech = torch.tensor(sample[0], dtype=torch.float, device=torch_device).unsqueeze(0)
952+
953+
with torch.no_grad():
954+
outputs = model(input_speech, output_hidden_states=True).hidden_states[11]
955+
956+
# expected outputs taken from the original textlesslib implementation by:
957+
# model = SpeechEncoder.by_name(dense_model_name='mhubert-base-25hz', quantizer_model_name='kmeans',
958+
# vocab_size=500, deduplicate=False, need_f0=False)
959+
# model(wav)['dense']
960+
expected_outputs_first = torch.tensor(
961+
[
962+
[0.0267, 0.1776, -0.1706, -0.4559],
963+
[-0.2430, -0.2943, -0.1864, -0.1187],
964+
[-0.1812, -0.4239, -0.1916, -0.0858],
965+
[-0.1495, -0.4758, -0.4036, 0.0302],
966+
],
967+
device=torch_device,
968+
)
969+
expected_outputs_last = torch.tensor(
970+
[
971+
[0.3366, -0.2734, -0.1415, -0.3055],
972+
[0.2329, -0.3580, -0.1421, -0.3197],
973+
[0.1631, -0.4301, -0.1965, -0.2956],
974+
[0.3342, -0.2185, -0.2253, -0.2363],
975+
],
976+
device=torch_device,
977+
)
978+
expected_output_sum = 1681.7603
979+
980+
self.assertTrue(torch.allclose(outputs[:, :4, :4], expected_outputs_first, atol=5e-3))
981+
self.assertTrue(torch.allclose(outputs[:, -4:, -4:], expected_outputs_last, atol=5e-3))
982+
self.assertTrue(abs(outputs.sum() - expected_output_sum) < 0.1)

0 commit comments

Comments
 (0)