-
Notifications
You must be signed in to change notification settings - Fork 421
Description
🐛 Describe the bug
huggingface/transformers#37070 brought a clean solution to correctly init weights for most composite models.
However, it took a isintance(module, XXXRMSNorm) to fill initial values into RMSNorm modules, which can lead to a TypeError if we patch RMSNorm with partial(LigerRMSNorm, ...).
Take gemma for example:
If we make a partial object to match the default arguments for GemmaRMSNorm
Liger-Kernel/src/liger_kernel/transformers/monkey_patch.py
Lines 630 to 636 in cd6ec32
| LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma") | |
| _patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0) | |
| if rope: | |
| modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb | |
| if rms_norm: | |
| modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma |
, it will break isinstance(module, GemmaRMSNorm)
TypeError: isinstance() arg 2 must be a type, a tuple of types, or a unionSolution
Instead of creating a partial object, we can make a module with desired default arguments like gemma3 did.
Liger-Kernel/src/liger_kernel/transformers/gema3_rms.py
Lines 4 to 8 in cd6ec32
| class LigerRMSNormForGemma3(LigerRMSNorm): | |
| """Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm.""" | |
| def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False): | |
| super().__init__(dim, eps, offset, casting_mode, init_fn, in_place) |
By the way, we probably want to move these modules under transformers/rms_norm.py instead of a new module.
Reproduce
No response
Versions
liger_kernel==cd6ec32
transformers==4.52.4