Skip to content

Conversation

nyo16
Copy link
Contributor

@nyo16 nyo16 commented Oct 5, 2025

Add Qwen3 Model Family Support

Summary

This PR adds comprehensive support for the Qwen3 model family from Alibaba Cloud, including text generation,
embeddings, and reranking models. Qwen3 is a state-of-the-art multilingual language model with advanced features like
QK normalization and support for up to 262K context length.

What's New

  1. Qwen3 Text Generation Models

Architectures:

  • :base - Base Qwen3 model
  • :for_causal_language_modeling - Text generation
  • :for_sequence_classification - Classification tasks
  • :for_embedding - Text embeddings (new)

Key Features:

  • QK Normalization: RMS normalization on query and key projections for improved training stability (Qwen3-specific
    innovation)
  • Grouped Query Attention (GQA): 32 query heads with 8 key-value heads for efficient inference
  • Extended Context: Supports up to 262,144 tokens
  • High RoPE Theta: 5,000,000 base frequency (vs typical 10,000) for better long-context performance
  • Large Vocabulary: 151,936 tokens for multilingual support
  • Gated FFN: SwiGLU activation
  1. Qwen3-Embedding Support
  • Last Token Pooling: Added :last_token_pooling option to Bumblebee.Text.text_embedding/3
  • Instruction-Aware: Supports custom task instructions (improves performance by 1-5% per Qwen team)
  • Multilingual: Over 100 languages supported
  • Flexible Dimensions: 1024-dim (0.6B), 2560-dim (4B), 4096-dim (8B)
  1. Qwen3-Reranker Support
  • Document Reranking: Score query-document pairs for relevance (0-1 range)
  • Custom Instructions: Task-specific prompts for better performance
  • High Accuracy: Relevant docs score 0.99+, irrelevant docs score near 0.0

Files Changed

Core Implementation:

  • lib/bumblebee/text/qwen3.ex (730 lines) - Full Qwen3 model implementation
  • lib/bumblebee.ex - Model and tokenizer registrations
  • lib/bumblebee/text/text_embedding.ex - Added last token pooling

Examples:

  • examples/README.md - Example documentation
  • examples/qwen3.exs - Text generation example
  • examples/qwen3_embedding.exs - Embedding generation
  • examples/qwen3_embedding_prompts.exs - Instruction-aware embeddings
  • examples/qwen3_reranker.exs - Document reranking

Documentation:

  • QWEN3_IEX_GUIDE.md - Interactive IEx usage guide
  • .gitignore - Added .lexical/

Testing

Text Generation (Qwen3-4B-Instruct)

{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-4B-Instruct-2507"})
{:ok, config} = Bumblebee.load_generation_config({:hf, "Qwen/Qwen3-4B-Instruct-2507"})

serving = Bumblebee.Text.generation(model, tokenizer, config)
Nx.Serving.run(serving, "The future of AI")

Results: Generates coherent English text, answers questions correctly, creates stories and code.

Text Embeddings (Qwen3-Embedding-0.6B)

{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Embedding-0.6B"},
architecture: :for_embedding
)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Embedding-0.6B"})

serving = Bumblebee.Text.text_embedding(model, tokenizer,
output_attribute: :embedding,
embedding_processor: :l2_norm
)

e1 = Nx.Serving.run(serving, "The cat sat on the mat")
e2 = Nx.Serving.run(serving, "A feline rested on the rug")
Nx.dot(e1.embedding, e2.embedding) |> Nx.to_number() # 0.73 (similar)

Results:

  • Generates 1024-dim normalized vectors
  • Semantic similarity: Similar texts = 0.72, different texts = 0.34
  • Instruction prompts improve relevance by ~5%

Reranking (Qwen3-Reranker-0.6B)

{:ok, model} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"})

Score query-document relevance

Relevant: 0.99+, Irrelevant: ~0.0

Results: Correctly ranks documents by relevance to queries.

Compatible Models

Text Generation:

  • Qwen/Qwen3-0.6B → Qwen/Qwen3-32B
  • Qwen/Qwen3-4B-Instruct-2507 (recommended)

Embeddings:

  • Qwen/Qwen3-Embedding-0.6B (1024-dim)
  • Qwen/Qwen3-Embedding-4B (2560-dim)
  • Qwen/Qwen3-Embedding-8B (4096-dim)

Reranking:

  • Qwen/Qwen3-Reranker-0.6B
  • Qwen/Qwen3-Reranker-4B
  • Qwen/Qwen3-Reranker-8B

Technical Implementation

QK Normalization

Unlike standard transformers, Qwen3 applies RMS normalization to query and key states:
hidden -> dense -> split_heads -> rms_norm -> rotary -> attention

Architecture Support

Custom decoder blocks implement QK normalization while maintaining compatibility with Bumblebee's transformer patterns.

Embedding Architecture

New :for_embedding architecture automatically pools the last non-padding token for text embedding tasks.

Reranking

Uses the causal LM architecture with yes/no token logit extraction and softmax scoring.

Breaking Changes

None. This is purely additive.

References

Implements support for the Qwen3 model family, including Qwen3-4B-Instruct.
Key features:

- QK normalization for improved training stability
- Grouped Query Attention (32 query heads, 8 KV heads)
- High RoPE theta (5M) for extended context (262K tokens)
- Support for causal language modeling and sequence classification
- Complete parameter mapping for HuggingFace model loading
- Example scripts demonstrating text generation and chat usage

Tested with Qwen3-4B-Instruct-2507 and generates coherent English output.
@nyo16
Copy link
Contributor Author

nyo16 commented Oct 5, 2025

I will test it tomorrow with my h200 to be sure that everything is working. With my mbr the answers seems ok, but the generation is slow.
My end goal is to add support for the embeddings and rerankers from qwen.
Also comments are really welcome, i generated most of it with sonnet 4.5.

nyo16 added 6 commits October 5, 2025 09:41
Implements last token pooling strategy in text_embedding to support
Qwen3-Embedding models which use the last token's hidden state for
generating text embeddings.

- Add :last_token_pooling option to text_embedding
- Extract last non-padding token using attention_mask
- Add Qwen3-Embedding-0.6B example demonstrating:
  - Text embedding generation (1024-dim vectors)
  - Semantic similarity computation
  - Instruction-aware embeddings
  - Batch processing

Tested with Qwen3-Embedding-0.6B and produces correct similarity scores.
Implements :for_embedding architecture for Qwen3 models with last token
pooling, enabling direct use with Bumblebee.Text.text_embedding/3.

Changes:
- Add :for_embedding architecture to Qwen3 model
- Register Qwen3ForEmbedding in model mappings
- Add instruction prompts example showing Qwen team recommendations
- Update examples to use cleaner serving-based API
- Add .lexical/ to gitignore
- Clean up mix.exs dependencies (remove emlx, nx override)

Examples demonstrate:
- Basic embedding generation (1024-dim vectors)
- Semantic similarity computation
- Instruction-aware prompts (1-5% performance improvement)
- Custom task instructions for code search
- Multilingual embedding support

Tested with Qwen3-Embedding-0.6B, generates correct similarity scores.
Implements document reranking using Qwen3-Reranker models.
Rerankers score query-document pairs for relevance, improving
retrieval quality in RAG and search applications.

Features:
- Automatic yes/no token detection from tokenizer
- Proper input format with instruction, query, and document
- Softmax-based relevance scoring (0-1 range)
- Support for custom task instructions

Example demonstrates:
- Basic query-document scoring
- Custom instructions for code search
- Reranking search results (top-k selection)

Results show correct ranking:
- Relevant docs score 0.99+
- Irrelevant docs score near 0.0
- Custom instructions work for domain-specific tasks

Works with Qwen3-Reranker-0.6B/4B/8B models.
Move all Qwen3-related examples and documentation into examples/qwen3/
for better organization and discoverability.

Changes:
- Create examples/qwen3/ directory
- Move qwen3.exs, qwen3_embedding.exs, qwen3_embedding_prompts.exs, qwen3_reranker.exs
- Move QWEN3_IEX_GUIDE.md to examples/qwen3/
- Update examples/README.md to reference qwen3/ subdirectory

All examples now accessible under examples/qwen3/ with consistent structure.
@fire
Copy link

fire commented Oct 5, 2025

I was interested in getting a qwen3 vision model working like https://huggingface.co/huihui-ai/Huihui-MiniCPM-V-4_5-abliterated

Copy link
Member

@jonatanklosko jonatanklosko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nyo16, thanks for the PR! I dropped a few comments. We also need tests with reference values generated via Python. You can see #422 for a recent complete example.

lib/bumblebee.ex Outdated
"mbart" => :mbart,
"phi" => :code_gen,
"phi3" => :llama,
"qwen3" => :gpt2,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/tokenizer_config.json, it says "tokenizer_class": "Qwen2Tokenizer", so we should add :qwen2 tokenizer type. In practice it we just need to add it here

.

For the default special tokens see https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/tokenization_qwen2_fast.py#L68-L76.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok cool, fun fact claude initially added and then changed it to gpt2..

Comment on lines 546 to 567
# QK Normalization (Qwen3-specific) - normalize over head_dim
query =
if spec.use_qk_norm do
Layers.rms_norm(query,
name: join(name, "query_norm"),
epsilon: spec.layer_norm_epsilon,
channel_index: -1
)
else
query
end

key =
if spec.use_qk_norm do
Layers.rms_norm(key,
name: join(name, "key_norm"),
epsilon: spec.layer_norm_epsilon,
channel_index: -1
)
else
key
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only divergence from the usual logic, right? Instead of rewriting all of the implementation here, you can add a new option to Layers.Transformer.blocks. I would add :query_norm and :key_norm, both being a 2-arity function. There is already a :layer_norm option kinda similar to that (and we already have kqv specific options: :query_use_bias, :key_use_bias, :value_use_bias).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would skip these examples, since they are not as easy to find. We could instead add a section to https://github.com/elixir-nx/bumblebee/blob/main/notebooks/llms.livemd#mistral, or if it's more elaborate, perhaps a separate Qwen3 notebook.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome feedback! i will work on it later today EST.

Niko Maroulis added 6 commits October 6, 2025 17:00
- Remove .lexical/ from project gitignore (should be in global gitignore)
- Add :qwen2 tokenizer type with correct Qwen3 special tokens
- Refactor QK normalization to use generalized approach:
  - Add :query_norm and :key_norm options to Layers.Transformer
  - Apply normalization after head splitting, before rotary embedding
  - Update Qwen3 to use Layers.Transformer.blocks instead of custom implementation
  - Remove ~200 lines of custom decoder/attention code
- Remove standalone examples directory per review feedback

The generalized QK normalization approach makes the transformer layer more
flexible and maintainable, allowing other models to use similar patterns.
Use 'decoder.blocks' as the name prefix when calling Layers.Transformer.blocks
to match the expected params mapping pattern decoder.blocks.{n}.*.
This aligns with how other models like BERT use the transformer blocks.
Fix model_type_to_tokenizer_type mapping to use :qwen2 instead of :gpt2
for qwen3 models. This ensures Qwen3 models load with the correct
tokenizer configuration including proper special tokens.
Create notebooks/qwen3.livemd demonstrating:
- Text generation using Qwen3-4B-Instruct-2507
- Embeddings using Qwen3-Embedding-0.6B with similarity examples
- Reranking using Qwen3-Reranker-0.6B with query-document scoring

This replaces the deleted standalone examples with a consolidated,
easy-to-follow notebook format as suggested in PR review.
Update the embeddings section to use the proper instruction format:
'Instruct: Given a query, retrieve relevant documents\nQuery: {query}\n{text}'

This ensures consistency with the reranker example and follows Qwen3
embedding best practices for better semantic search results.
Add comprehensive test suite for Qwen3 using tiny-random/qwen3:
- Test :base architecture with QK normalization enabled
- Test :for_causal_language_modeling with logits verification
- Test :for_sequence_classification (shape only, random params)
- Test :for_embedding architecture

Reference values generated from tiny-random/qwen3 model predictions.
All tests pass successfully (4 tests, 0 failures).
@nyo16
Copy link
Contributor Author

nyo16 commented Oct 6, 2025

Generation looking good!

iex(16)>   prompt = """
...(16)>   <|im_start|>system
...(16)>   You are a helpful assistant.<|im_end|>
...(16)>   <|im_start|>user
...(16)>   What is the capital of France?<|im_end|>
...(16)>   <|im_start|>assistant
...(16)>   """
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
iex(17)>
nil
iex(18)>   result = Nx.Serving.run(serving, prompt)
%{
  results: [
    %{
      text: "The capital of France is Paris.",
      token_summary: %{input: 26, output: 8, padding: 0}
    }
  ]
}

Still more tests to do and write!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants