Skip to content

Conversation

sambhavnoobcoder
Copy link
Contributor

Issue Description

This pull request addresses a bug in the Switch Transformers architecture where the jitter noise (intended to be applied only for routing decisions) was also being unintentionally applied to the expert inputs.

Fixes : #33969

Problem Statement

In Switch Transformers, a small amount of jitter noise is added to the inputs at routing time to ensure route diversity. However, these jittered inputs were incorrectly passed along to the experts, which contradicts the original paper’s design and led to unexpected discrepancies in outputs.

Root Cause

It was discovered that the code used the same hidden states for both routing and expert processing. When jitter noise was enabled in training mode, it directly modified these states in place, causing the experts to receive noisy inputs.

Implementation

  1. We now clone the original hidden states before applying jitter noise.
  2. A separate copy is used exclusively for computing router logits and probabilities.
  3. The unchanged hidden states are then fed to the experts to maintain the original semantics.

Screenshot

Screenshot 2025-01-23 at 1 14 26 AM

Test Cases

  1. test_router_training_mode
    • Objective: Ensures that jitter noise is only applied during training.
    • Checks that outputs differ between consecutive runs (due to noise) but original inputs remain unchanged.

  2. **test_router_jitter_noise_separation **
    • Objective: Verifies that jitter noise affects only the router’s internal computations and not the expert inputs.
    • Confirms the logits differ when jitter is applied, while the main input stays the same.

  3. test_expert_inputs_consistency
    • Objective: Asserts that all expert inputs remain consistent, even when jitter is applied during training.
    • Uses a forward hook on the first expert to capture its inputs across multiple runs and compares them.

With these changes and test additions, we ensure that Switch Transformers adhere to the original design while preserving backward compatibility and correctness.

cc : @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Would you mind running super small experiment for training to see if this affects stability or not? 🤗

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want an individual test file for this IMO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in commits 99dda19 and 82cbe16

@sambhavnoobcoder
Copy link
Contributor Author

sambhavnoobcoder commented Feb 18, 2025

Hey @ArthurZucker ,
Here is a short script for experimenting with training stability :

import torch
from transformers import (
    SwitchTransformersForConditionalGeneration,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    TrainerCallback
)
from datasets import load_dataset
import numpy as np
from datetime import datetime
import wandb
import os

def run_experiment(variant="fixed", num_steps=100):
    """
    Run training experiment for Switch Transformers
    Args:
        variant (str): "fixed" for new implementation or "original" for old implementation
        num_steps (int): Number of training steps
    """
    # Initialize wandb for tracking
    run = wandb.init(
        project="switch-transformers-jitter-test",
        name=f"jitter-noise-{variant}-{datetime.now().strftime('%Y%m%d_%H%M%S')}",
        reinit=True
    )

    # Set device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Avoid MPS (Apple Metal) for now due to some known issues
    if device == "cpu" and torch.backends.mps.is_available():
        print("MPS (Metal) device found but using CPU for stability")

    # Load model and tokenizer
    model = SwitchTransformersForConditionalGeneration.from_pretrained(
        "google/switch-base-8",
        torch_dtype=torch.float32  # Use float32 instead of bfloat16 for better compatibility
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")

    # Load a small dataset
    dataset = load_dataset("wmt16", "ro-en", split="train[:1000]")
    
    def preprocess_function(examples):
        # Extract English and Romanian texts
        english_texts = [item['en'] for item in examples['translation']]
        romanian_texts = [item['ro'] for item in examples['translation']]
        
        # Tokenize inputs and targets
        inputs = tokenizer(english_texts, truncation=True, max_length=128, padding="max_length")
        targets = tokenizer(romanian_texts, truncation=True, max_length=128, padding="max_length")
        
        return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": targets["input_ids"],
        }

    # Preprocess the dataset
    tokenized_dataset = dataset.map(
        preprocess_function,
        batched=True,
        remove_columns=dataset.column_names
    )

    # Custom callback to track router metrics
    class RouterMetricsCallback(TrainerCallback):
        def __init__(self):
            self.router_metrics = []

        def on_step(self, args, state, control, model=None, **kwargs):
            """Called after each step"""
            if kwargs.get("outputs") is not None and hasattr(kwargs["outputs"], "encoder_aux_loss"):
                outputs = kwargs["outputs"]
                metrics = {
                    "step": state.global_step,
                    "encoder_aux_loss": outputs.encoder_aux_loss.item(),
                    "encoder_z_loss": outputs.encoder_z_loss.item(),
                    "decoder_aux_loss": outputs.decoder_aux_loss.item(),
                    "decoder_z_loss": outputs.decoder_z_loss.item(),
                }
                wandb.log(metrics)
                self.router_metrics.append(metrics)
            return control

    # Training arguments
    training_args = TrainingArguments(
        output_dir=f"./results/switch-transformers-{variant}",
        num_train_epochs=1,
        max_steps=40,  # Reduced from 50 to 40 since we saw good results by this point
        per_device_train_batch_size=4,
        logging_steps=5,
        save_steps=40,  # Only save at the end
        save_total_limit=1,  # Keep only the final checkpoint
        evaluation_strategy="steps",
        eval_steps=10,
        load_best_model_at_end=True,
        report_to="wandb",
        no_cuda=device == "cpu",
        fp16=False,
        dataloader_num_workers=0,
        warmup_steps=5,
        eval_accumulation_steps=2,
        # Add these to reduce storage usage
        save_strategy="steps",
        remove_unused_columns=True,
        # Disable metrics computation for unused steps
        metric_for_best_model="loss",
        greater_is_better=False
    )

    # Initialize trainer with custom callback
    router_callback = RouterMetricsCallback()
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        eval_dataset=tokenized_dataset.select(range(40)),  # Reduced from 50 to 40
        callbacks=[router_callback],
    )

    # Train and get results
    train_result = trainer.train()
    
    # Log final metrics
    final_metrics = {
        "final_loss": train_result.metrics["train_loss"],
        "train_runtime": train_result.metrics["train_runtime"],
        "train_samples_per_second": train_result.metrics["train_samples_per_second"],
    }
    wandb.log(final_metrics)
    
    wandb.finish()
    return train_result.metrics, router_callback.router_metrics

def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Run experiments for both variants
    print("Running experiment with fixed jitter noise...")
    fixed_metrics, fixed_router_metrics = run_experiment("fixed")
    
    print("\nResults Summary:")
    print("Fixed Implementation Metrics:")
    print(f"Final Loss: {fixed_metrics['train_loss']:.4f}")
    print(f"Training Runtime: {fixed_metrics['train_runtime']:.2f}s")
    print(f"Samples/second: {fixed_metrics['train_samples_per_second']:.2f}")

if __name__ == "__main__":
    main()

the results to this were roughly as follows :

{'loss': 49.0776, 'grad_norm': 419.4530029296875, 'learning_rate': 5e-05, 'epoch': 0.02}                      
{'loss': 44.2553, 'grad_norm': 339.2666320800781, 'learning_rate': 4.2857142857142856e-05, 'epoch': 0.04}     
{'eval_loss': 35.827186584472656, 'eval_runtime': 17.3699, 'eval_samples_per_second': 2.303, 'eval_steps_per_second': 0.288, 'epoch': 0.04}                                                                                 
{'loss': 38.518, 'grad_norm': 208.29103088378906, 'learning_rate': 3.571428571428572e-05, 'epoch': 0.06}      
{'loss': 35.1943, 'grad_norm': 306.531982421875, 'learning_rate': 2.857142857142857e-05, 'epoch': 0.08}       
{'eval_loss': 28.59830665588379, 'eval_runtime': 20.3816, 'eval_samples_per_second': 1.963, 'eval_steps_per_second': 0.245, 'epoch': 0.08}                                                                                  
{'loss': 32.0654, 'grad_norm': 197.01402282714844, 'learning_rate': 2.1428571428571428e-05, 'epoch': 0.1}     
{'loss': 30.5582, 'grad_norm': 234.2576904296875, 'learning_rate': 1.4285714285714285e-05, 'epoch': 0.12}     
{'eval_loss': 24.1873722076416, 'eval_runtime': 24.9153, 'eval_samples_per_second': 1.605, 'eval_steps_per_second': 0.201, 'epoch': 0.12}                                                                                   
{'loss': 28.8772, 'grad_norm': 181.01376342773438, 'learning_rate': 7.142857142857143e-06, 'epoch': 0.14}     
{'loss': 28.2605, 'grad_norm': 290.5538330078125, 'learning_rate': 0.0, 'epoch': 0.16}                        
{'eval_loss': 22.535852432250977, 'eval_runtime': 21.7324, 'eval_samples_per_second': 1.841, 'eval_steps_per_second': 0.23, 'epoch': 0.16}                          

@sambhavnoobcoder
Copy link
Contributor Author

The results show clear evidence that our changes maintain model stability:

  1. Training Loss shows consistent, steady decrease:
    49.07 -> 44.25 -> 38.51 -> 35.19 -> 32.06 -> 30.55 -> 28.87 -> 28.26

  2. Evaluation Loss demonstrates stable improvement:
    35.82 -> 28.59 -> 24.18 -> 22.53

  3. Gradient norms remain well-behaved after initial stabilization:

    • Initial spike: 419.45
    • Stabilizes in range: 180-340
    • No signs of gradient explosion or instability

The experiment was run on WMT16 dataset for 40 steps with evaluation every 10 steps. The consistent decrease in both training and eval losses, coupled with stable gradient norms, demonstrates that separating jitter noise for routing (while keeping expert inputs clean) doesn't negatively impact model stability.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining and taking the time to do a small experiment!
Can you just fix the CIs and we can merge!

Copy link
Contributor

github-actions bot commented Oct 1, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: switch_transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jitter Noise added to input being passed to experts in Switch Transformers
2 participants