Skip to content

patching LigerRMSNorm with partial arguments given would cause error in _init_weights() #739

@Tcc0403

Description

@Tcc0403

🐛 Describe the bug

huggingface/transformers#37070 brought a clean solution to correctly init weights for most composite models.

However, it took a isintance(module, XXXRMSNorm) to fill initial values into RMSNorm modules, which can lead to a TypeError if we patch RMSNorm with partial(LigerRMSNorm, ...).

Take gemma for example:
If we make a partial object to match the default arguments for GemmaRMSNorm

LigerRMSNormForGemma = partial(LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma")
_patch_rms_norm_module_for_gemma = partial(_patch_rms_norm_module, casting_mode="gemma", offset=1.0)
if rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_gemma.GemmaRMSNorm = LigerRMSNormForGemma

, it will break isinstance(module, GemmaRMSNorm)

TypeError: isinstance() arg 2 must be a type, a tuple of types, or a union

Solution

Instead of creating a partial object, we can make a module with desired default arguments like gemma3 did.

class LigerRMSNormForGemma3(LigerRMSNorm):
"""Gemma3RMSNorm has a dim argument not hidden_size used in q_norm and k_norm."""
def __init__(self, dim, eps=0.000001, offset=1.0, casting_mode="gemma", init_fn="zeros", in_place=False):
super().__init__(dim, eps, offset, casting_mode, init_fn, in_place)

By the way, we probably want to move these modules under transformers/rms_norm.py instead of a new module.

Reproduce

No response

Versions

liger_kernel==cd6ec32
transformers==4.52.4

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions