-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Labels
Description
System Info
System Info
- transformers version: 4.44.2
- Platform: Linux-6.5.0-35-generic-x86_64-with-glibc2.35
- Python version: 3.10.14
- Huggingface_hub version: 0.24.6
- Safetensors version: 0.4.5
- Accelerate version: 0.34.2
- Accelerate config: not found
- PyTorch version (GPU?): 2.3.0 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: No
- Using GPU in script?: No
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
import torch
import torch.nn as nn
from transformers import (
SwitchTransformersConfig,
SwitchTransformersTop1Router,
)
from transformers.models.switch_transformers.modeling_switch_transformers import SwitchTransformersDenseActDense
class MySwitchTransformersSparseMLP(nn.Module):
r"""
Implementation of the Switch Transformers Sparse MLP module.
"""
def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
super().__init__()
# Step 1: Get the correct router according to its class
self.router = SwitchTransformersTop1Router(config)
# Step 2: Get the experts
self.experts = nn.ModuleDict()
for idx in range(config.num_experts):
self.experts[f"expert_{idx}"] = expert_class(config)
def forward(self, hidden_states):
r"""
Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:
1- Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
2- Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
expert the corresponding hidden states.
"""
prev_save = hidden_states.clone()
# Step 1: Get the router_mask from the router as wel as the probabilities
router_mask, router_probs, router_logits = self.router(hidden_states)
expert_index = torch.argmax(router_mask, dim=-1)
print(torch.allclose(prev_save, hidden_states))
print(torch.mean(prev_save - hidden_states))
# The routers introduced might not always map all the tokens, to a router, which means that some hidden states
# can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.
next_states = hidden_states.clone()
router_mask = router_mask.bool()
batch_size, seq_len, num_experts = router_mask.shape
idx_mask = router_mask.transpose(1, 2).reshape(batch_size * seq_len, num_experts).sum(dim=0)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
for idx in idx_mask:
next_states[router_mask[:, :, idx]] = getattr(self.experts, "expert_{}".format(idx))(
hidden_states[router_mask[:, :, idx]]
)
hidden_states = router_probs * next_states
return hidden_states, (router_logits, expert_index)
config = SwitchTransformersConfig()
model = MySwitchTransformersSparseMLP(config)
model.train()
in_data = torch.ones(1, 1, 768)
out = model(in_data)
The output is
False
tensor(-0.0001)
which ideally should give True and the mean difference should be zero.
This is because in SwitchTransformersTop1Router
, the hidden_states
are multiplied with jitter noise which persists even when you pass it to the experts.
transformers/src/transformers/models/switch_transformers/modeling_switch_transformers.py
Lines 159 to 161 in e71a01a
if self.training and self.jitter_noise > 0: | |
# Multiply the token inputs by the uniform distribution - adding some noise | |
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) |
Expected behavior
Ideally, no jitter noise should be present when passing the input to the experts, returning True and the mean difference as 0.