diff --git a/python/mlc_llm/model/gemma/gemma_model.py b/python/mlc_llm/model/gemma/gemma_model.py index 9d62d85129..d04d4f54a0 100644 --- a/python/mlc_llm/model/gemma/gemma_model.py +++ b/python/mlc_llm/model/gemma/gemma_model.py @@ -22,7 +22,7 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """Configuration of the Gemma model.""" hidden_size: int - hidden_act: str + hidden_activation: Optional[str] intermediate_size: int attention_bias: bool num_attention_heads: int @@ -39,7 +39,9 @@ class GemmaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): - if self.hidden_act not in ("gelu", "gelu_pytorch_tanh"): + if self.hidden_activation is None: + self.hidden_activation = self.kwargs.get("hidden_act", None) + if self.hidden_activation not in ("gelu", "gelu_pytorch_tanh"): raise ValueError("Only GeLU is supported as the activation for gemma.") if self.attention_bias: raise ValueError('Only "False" attention_bias is supported for gemma')