diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 935152b4ff49..689e15535eb2 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -102,11 +102,19 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens # https://huggingface.co/papers/2101.03961. # We also store the previous dtype to cast back the output to the previous dtype self.input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(self.dtype) + + # Create a copy for applying jitter noise + routing_states = hidden_states.clone() + routing_states = routing_states.to(self.dtype) + if self.training and self.jitter_noise > 0: - # Multiply the token inputs by the uniform distribution - adding some noise - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) - router_logits = self.classifier(hidden_states) + # Apply jitter noise only to the routing copy + routing_states *= torch.empty_like(routing_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + + # Use jittered states for routing decisions + router_logits = self.classifier(routing_states) # Apply Softmax and cast back to the original `dtype` router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) @@ -623,7 +631,7 @@ def _init_weights(self, module): module.weight.data.fill_(factor * 1.0) elif isinstance( module, - (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), + SwitchTransformersModel | SwitchTransformersForConditionalGeneration | SwitchTransformersEncoderModel, ): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: diff --git a/src/transformers/models/switch_transformers/modular_switch_transformers.py b/src/transformers/models/switch_transformers/modular_switch_transformers.py index cf4eaf0cedff..ec18790f0940 100644 --- a/src/transformers/models/switch_transformers/modular_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modular_switch_transformers.py @@ -159,11 +159,19 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens # https://huggingface.co/papers/2101.03961. # We also store the previous dtype to cast back the output to the previous dtype self.input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(self.dtype) + + # Create a copy for applying jitter noise + routing_states = hidden_states.clone() + routing_states = routing_states.to(self.dtype) + if self.training and self.jitter_noise > 0: - # Multiply the token inputs by the uniform distribution - adding some noise - hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) - router_logits = self.classifier(hidden_states) + # Apply jitter noise only to the routing copy + routing_states *= torch.empty_like(routing_states).uniform_( + 1.0 - self.jitter_noise, 1.0 + self.jitter_noise + ) + + # Use jittered states for routing decisions + router_logits = self.classifier(routing_states) # Apply Softmax and cast back to the original `dtype` router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype) @@ -352,7 +360,7 @@ def _init_weights(self, module): module.weight.data.fill_(factor * 1.0) elif isinstance( module, - (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel), + SwitchTransformersModel | SwitchTransformersForConditionalGeneration | SwitchTransformersEncoderModel, ): module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 86238c053a35..2a3da6931911 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -1024,6 +1024,50 @@ def test_max_routing_capacity(self): assert torch.sum(expert_index) <= batch_size * self.config.num_experts * self.config.expert_capacity + def test_jitter_noise_preserves_hidden_states(self): + r""" + Test that jitter noise is applied only to routing decisions and does not modify the original hidden states. + This tests the fix for the jitter noise issue where noise was corrupting the input hidden states. + """ + # Create a config with jitter noise enabled + config = SwitchTransformersConfig( + num_experts=2, + hidden_size=4, + d_ff=8, + router_jitter_noise=0.1, # Enable jitter noise + expert_capacity=4, + ) + + # Create router + router = SwitchTransformersTop1Router(config) + router.eval() # Set to eval mode first to test training mode separately + + # Create input hidden states + hidden_states = torch.tensor([[[0.5, 0.2, 0.1, 0.3], [0.4, 0.6, 0.2, 0.8]]], dtype=torch.float32) + + # Test in eval mode - no jitter noise should be applied + original_hidden_states = hidden_states.clone() + with torch.no_grad(): + router_probs, expert_index, router_logits = router(hidden_states) + + # Hidden states should remain unchanged in eval mode + self.assertTrue(torch.equal(hidden_states, original_hidden_states)) + + # Test in training mode - jitter noise should be applied only internally + router.train() + torch.manual_seed(42) # Set seed for reproducible results + + original_hidden_states = hidden_states.clone() + with torch.no_grad(): + router_probs_train, expert_index_train, router_logits_train = router(hidden_states) + + # Hidden states should still remain unchanged after router call + self.assertTrue(torch.equal(hidden_states, original_hidden_states)) + + # Results should be different between eval and train mode due to jitter noise + # (though this might occasionally fail due to randomness, it's very unlikely with seed) + self.assertFalse(torch.allclose(router_logits, router_logits_train, atol=1e-5)) + @slow @require_torch