Skip to content

Conversation

@joelpaulkoch
Copy link
Contributor

Hey, this is the SmolLM3 model from huggingface. It's smol, fully open and supports reasoning, so I figured it would be a nice addition to bumblebee.

I didn't implement YaRN extrapolation.

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @joelpaulkoch, this looks great! I dropped a few small comments and it's good to go :)

@joelpaulkoch
Copy link
Contributor Author

The implementation is basically llama + NoPE support (in the transformer block) + architectures that are supported but missing in llama (i.e. :for_question_answering and :for_token_classification). So, would you prefer to add the optional NoPE support and architectures in the llama implementation and map smollm3 to llama?

@jonatanklosko
Copy link
Member

So, would you prefer to add the optional NoPE support and architectures in the llama implementation and map smollm3 to llama?

It's separate in hf/transformers, so I would keep it separate here to for consistency. Also, I wouldn't necessarily add features to llama that are not in the hf/transformers implementation, otherwise it's harder to analyse for parity :)

Comment on lines 34 to 35
atol: 1.0e-3,
rtol: 1.0e-3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unusual, pretty much all other LLMs work with the default atol 10e-4. If the numbers are slightly off like that, it often indicates a small difference, like a missing layer norm, layer norm in a different order, or something like that.

Do you know if that deviation is only for the test models, or is it similar for any real checkpoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and it's the same for HuggingFaceTB/SmolLM3-3B, not sure what is missing.
A missing layer norm would show up in the debug logs, right?
I'll investigate further when I find the time.

Copy link
Member

@jonatanklosko jonatanklosko Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A missing layer norm would show up in the debug logs, right?

A norm layer that bumblebee has and cannot be found in the checkpoint would be logged. But if there is a norm layer in the checkpoint that bumblebee doesn't use, then it is not logged, unless we pass log_params_diff: true to Bumblebee.load_model (sometimes unused layers are expected, for example if you load a checkpoint as a base model, ignoring the head layers). But I just tried with log_params_diff: true and it doesn't log anything, so in terms of layers it looks like they match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joelpaulkoch I had a look and found the issue, see bf4d8d8. I first checked the config loaded from HF and noticed double nesting in rotary_embedding_enabled: %{rotary_embedding_enabled: ...}. And the main issue was reading :rotary_embedding_enabled from opts, rather than spec, which means it was never actually applied:

-      case opts[:rotary_embedding_enabled] do
+      case spec.rotary_embedding_enabled do

Btw. Looking at this, I realised that that rotary embedding layer doesn't have any params, so we don't need the layer mapping:

                {"decoder.blocks.#{index}.self_attention.rotary_embedding",
                 "model.layers.#{index}.self_attn.rotary_emb"}

I believe it was added as part of llama unnecessarily, and then by copying we kept that across many models, but it doesn't actually do anything. I will remove this mapping everywhere in a separate commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh great, thank you @jonatanklosko !!

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jonatanklosko jonatanklosko merged commit 5d151ff into elixir-nx:main Nov 5, 2025
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants