Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,19 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
# https://huggingface.co/papers/2101.03961.
# We also store the previous dtype to cast back the output to the previous dtype
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)

# Create a copy for applying jitter noise
routing_states = hidden_states.clone()
routing_states = routing_states.to(self.dtype)

if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
router_logits = self.classifier(hidden_states)
# Apply jitter noise only to the routing copy
routing_states *= torch.empty_like(routing_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)

# Use jittered states for routing decisions
router_logits = self.classifier(routing_states)

# Apply Softmax and cast back to the original `dtype`
router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
Expand Down Expand Up @@ -623,7 +631,7 @@ def _init_weights(self, module):
module.weight.data.fill_(factor * 1.0)
elif isinstance(
module,
(SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel),
SwitchTransformersModel | SwitchTransformersForConditionalGeneration | SwitchTransformersEncoderModel,
):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,19 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
# https://huggingface.co/papers/2101.03961.
# We also store the previous dtype to cast back the output to the previous dtype
self.input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(self.dtype)

# Create a copy for applying jitter noise
routing_states = hidden_states.clone()
routing_states = routing_states.to(self.dtype)

if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
router_logits = self.classifier(hidden_states)
# Apply jitter noise only to the routing copy
routing_states *= torch.empty_like(routing_states).uniform_(
1.0 - self.jitter_noise, 1.0 + self.jitter_noise
)

# Use jittered states for routing decisions
router_logits = self.classifier(routing_states)

# Apply Softmax and cast back to the original `dtype`
router_probs = nn.functional.softmax(router_logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
Expand Down Expand Up @@ -352,7 +360,7 @@ def _init_weights(self, module):
module.weight.data.fill_(factor * 1.0)
elif isinstance(
module,
(SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel),
SwitchTransformersModel | SwitchTransformersForConditionalGeneration | SwitchTransformersEncoderModel,
):
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,50 @@ def test_max_routing_capacity(self):

assert torch.sum(expert_index) <= batch_size * self.config.num_experts * self.config.expert_capacity

def test_jitter_noise_preserves_hidden_states(self):
r"""
Test that jitter noise is applied only to routing decisions and does not modify the original hidden states.
This tests the fix for the jitter noise issue where noise was corrupting the input hidden states.
"""
# Create a config with jitter noise enabled
config = SwitchTransformersConfig(
num_experts=2,
hidden_size=4,
d_ff=8,
router_jitter_noise=0.1, # Enable jitter noise
expert_capacity=4,
)

# Create router
router = SwitchTransformersTop1Router(config)
router.eval() # Set to eval mode first to test training mode separately

# Create input hidden states
hidden_states = torch.tensor([[[0.5, 0.2, 0.1, 0.3], [0.4, 0.6, 0.2, 0.8]]], dtype=torch.float32)

# Test in eval mode - no jitter noise should be applied
original_hidden_states = hidden_states.clone()
with torch.no_grad():
router_probs, expert_index, router_logits = router(hidden_states)

# Hidden states should remain unchanged in eval mode
self.assertTrue(torch.equal(hidden_states, original_hidden_states))

# Test in training mode - jitter noise should be applied only internally
router.train()
torch.manual_seed(42) # Set seed for reproducible results

original_hidden_states = hidden_states.clone()
with torch.no_grad():
router_probs_train, expert_index_train, router_logits_train = router(hidden_states)

# Hidden states should still remain unchanged after router call
self.assertTrue(torch.equal(hidden_states, original_hidden_states))

# Results should be different between eval and train mode due to jitter noise
# (though this might occasionally fail due to randomness, it's very unlikely with seed)
self.assertFalse(torch.allclose(router_logits, router_logits_train, atol=1e-5))


@slow
@require_torch
Expand Down