-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Fix Jitter Noise Passing to Experts in Switch Transformers #33969 #35847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Would you mind running super small experiment for training to see if this affects stability or not? 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't want an individual test file for this IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @ArthurZucker , import torch
from transformers import (
SwitchTransformersForConditionalGeneration,
AutoTokenizer,
Trainer,
TrainingArguments,
TrainerCallback
)
from datasets import load_dataset
import numpy as np
from datetime import datetime
import wandb
import os
def run_experiment(variant="fixed", num_steps=100):
"""
Run training experiment for Switch Transformers
Args:
variant (str): "fixed" for new implementation or "original" for old implementation
num_steps (int): Number of training steps
"""
# Initialize wandb for tracking
run = wandb.init(
project="switch-transformers-jitter-test",
name=f"jitter-noise-{variant}-{datetime.now().strftime('%Y%m%d_%H%M%S')}",
reinit=True
)
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Avoid MPS (Apple Metal) for now due to some known issues
if device == "cpu" and torch.backends.mps.is_available():
print("MPS (Metal) device found but using CPU for stability")
# Load model and tokenizer
model = SwitchTransformersForConditionalGeneration.from_pretrained(
"google/switch-base-8",
torch_dtype=torch.float32 # Use float32 instead of bfloat16 for better compatibility
).to(device)
tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
# Load a small dataset
dataset = load_dataset("wmt16", "ro-en", split="train[:1000]")
def preprocess_function(examples):
# Extract English and Romanian texts
english_texts = [item['en'] for item in examples['translation']]
romanian_texts = [item['ro'] for item in examples['translation']]
# Tokenize inputs and targets
inputs = tokenizer(english_texts, truncation=True, max_length=128, padding="max_length")
targets = tokenizer(romanian_texts, truncation=True, max_length=128, padding="max_length")
return {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"labels": targets["input_ids"],
}
# Preprocess the dataset
tokenized_dataset = dataset.map(
preprocess_function,
batched=True,
remove_columns=dataset.column_names
)
# Custom callback to track router metrics
class RouterMetricsCallback(TrainerCallback):
def __init__(self):
self.router_metrics = []
def on_step(self, args, state, control, model=None, **kwargs):
"""Called after each step"""
if kwargs.get("outputs") is not None and hasattr(kwargs["outputs"], "encoder_aux_loss"):
outputs = kwargs["outputs"]
metrics = {
"step": state.global_step,
"encoder_aux_loss": outputs.encoder_aux_loss.item(),
"encoder_z_loss": outputs.encoder_z_loss.item(),
"decoder_aux_loss": outputs.decoder_aux_loss.item(),
"decoder_z_loss": outputs.decoder_z_loss.item(),
}
wandb.log(metrics)
self.router_metrics.append(metrics)
return control
# Training arguments
training_args = TrainingArguments(
output_dir=f"./results/switch-transformers-{variant}",
num_train_epochs=1,
max_steps=40, # Reduced from 50 to 40 since we saw good results by this point
per_device_train_batch_size=4,
logging_steps=5,
save_steps=40, # Only save at the end
save_total_limit=1, # Keep only the final checkpoint
evaluation_strategy="steps",
eval_steps=10,
load_best_model_at_end=True,
report_to="wandb",
no_cuda=device == "cpu",
fp16=False,
dataloader_num_workers=0,
warmup_steps=5,
eval_accumulation_steps=2,
# Add these to reduce storage usage
save_strategy="steps",
remove_unused_columns=True,
# Disable metrics computation for unused steps
metric_for_best_model="loss",
greater_is_better=False
)
# Initialize trainer with custom callback
router_callback = RouterMetricsCallback()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
eval_dataset=tokenized_dataset.select(range(40)), # Reduced from 50 to 40
callbacks=[router_callback],
)
# Train and get results
train_result = trainer.train()
# Log final metrics
final_metrics = {
"final_loss": train_result.metrics["train_loss"],
"train_runtime": train_result.metrics["train_runtime"],
"train_samples_per_second": train_result.metrics["train_samples_per_second"],
}
wandb.log(final_metrics)
wandb.finish()
return train_result.metrics, router_callback.router_metrics
def main():
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
# Run experiments for both variants
print("Running experiment with fixed jitter noise...")
fixed_metrics, fixed_router_metrics = run_experiment("fixed")
print("\nResults Summary:")
print("Fixed Implementation Metrics:")
print(f"Final Loss: {fixed_metrics['train_loss']:.4f}")
print(f"Training Runtime: {fixed_metrics['train_runtime']:.2f}s")
print(f"Samples/second: {fixed_metrics['train_samples_per_second']:.2f}")
if __name__ == "__main__":
main() the results to this were roughly as follows :
|
The results show clear evidence that our changes maintain model stability:
The experiment was run on WMT16 dataset for 40 steps with evaluation every 10 steps. The consistent decrease in both training and eval losses, coupled with stable gradient norms, demonstrates that separating jitter noise for routing (while keeping expert inputs clean) doesn't negatively impact model stability. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining and taking the time to do a small experiment!
Can you just fix the CIs and we can merge!
[For maintainers] Suggested jobs to run (before merge) run-slow: switch_transformers |
Issue Description
This pull request addresses a bug in the Switch Transformers architecture where the jitter noise (intended to be applied only for routing decisions) was also being unintentionally applied to the expert inputs.
Fixes : #33969
Problem Statement
In Switch Transformers, a small amount of jitter noise is added to the inputs at routing time to ensure route diversity. However, these jittered inputs were incorrectly passed along to the experts, which contradicts the original paper’s design and led to unexpected discrepancies in outputs.
Root Cause
It was discovered that the code used the same hidden states for both routing and expert processing. When jitter noise was enabled in training mode, it directly modified these states in place, causing the experts to receive noisy inputs.
Implementation
Screenshot
Test Cases
test_router_training_mode
• Objective: Ensures that jitter noise is only applied during training.
• Checks that outputs differ between consecutive runs (due to noise) but original inputs remain unchanged.
**test_router_jitter_noise_separation **
• Objective: Verifies that jitter noise affects only the router’s internal computations and not the expert inputs.
• Confirms the logits differ when jitter is applied, while the main input stays the same.
test_expert_inputs_consistency
• Objective: Asserts that all expert inputs remain consistent, even when jitter is applied during training.
• Uses a forward hook on the first expert to capture its inputs across multiple runs and compares them.
With these changes and test additions, we ensure that Switch Transformers adhere to the original design while preserving backward compatibility and correctness.
cc : @ArthurZucker