Skip to content

Jitter Noise added to input being passed to experts in Switch Transformers #33969

@karan-uppal3

Description

@karan-uppal3

System Info

System Info

  • transformers version: 4.44.2
  • Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.5
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
import torch.nn as nn
from transformers import (
    SwitchTransformersConfig,
    SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import SwitchTransformersDenseActDense


class MySwitchTransformersSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.
    """

    def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = SwitchTransformersTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).

        2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """

        prev_save = hidden_states.clone()

        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = torch.argmax(router_mask, dim=-1)

        print(torch.allclose(prev_save, hidden_states))
        print(torch.mean(prev_save - hidden_states))

        # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
        # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

        next_states = hidden_states.clone()

        router_mask = router_mask.bool()
        batch_size, seq_len, num_experts = router_mask.shape
        idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
        idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
            0
        ].tolist()  # length: number of "activated" expert / value: index
        for idx in idx_mask:
            next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
                hidden_states[router_mask[:, :, idx]]
            )

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)

config = SwitchTransformersConfig()
model = MySwitchTransformersSparseMLP(config)

model.train()
in_data = torch.ones(1, 1, 768)
out = model(in_data)

The output is

False
tensor(-0.0001)

which ideally should give True and the mean difference should be zero.

This is because in SwitchTransformersTop1Router, the hidden_states are multiplied with jitter noise which persists even when you pass it to the experts.

if self.training and self.jitter_noise > 0:
# Multiply the token inputs by the uniform distribution - adding some noise
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)

Expected behavior

Ideally, no jitter noise should be present when passing the input to the experts, returning True and the mean difference as 0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions