Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/README.md
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this file as well :)

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Bumblebee Examples

This directory contains example scripts demonstrating how to use Bumblebee models.

## Qwen3 Examples

See the `qwen3/` subdirectory for comprehensive Qwen3 model examples:

### Text Generation
```bash
elixir examples/qwen3/qwen3.exs
```

### Text Embeddings
```bash
elixir examples/qwen3/qwen3_embedding.exs
elixir examples/qwen3/qwen3_embedding_prompts.exs
```

### Document Reranking
```bash
elixir examples/qwen3/qwen3_reranker.exs
```

### Features Demonstrated

**Text Generation** (`qwen3.exs`):
- Text completion
- Question answering
- Chat format
- Code generation

**Embeddings** (`qwen3_embedding.exs`, `qwen3_embedding_prompts.exs`):
- 1024-dimensional text embeddings
- Semantic similarity computation
- Instruction-aware prompts (recommended by Qwen team)
- Multilingual support
- Code search

**Reranking** (`qwen3_reranker.exs`):
- Query-document relevance scoring
- Custom task instructions
- Top-k result selection

### Requirements

- **Text Generation**: ~8GB disk space, ~10GB RAM
- **Embeddings**: ~1.5GB disk space, ~4GB RAM (0.6B model)
- **Reranking**: ~1.5GB disk space, ~4GB RAM (0.6B model)
- **Backend**: EXLA (CPU or GPU)

### Documentation

See `examples/qwen3/QWEN3_IEX_GUIDE.md` for interactive IEx usage examples.

## Phoenix Examples

See the `phoenix/` subdirectory for LiveView-based examples.
5 changes: 5 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ defmodule Bumblebee do
"Phi3ForCausalLM" => {Bumblebee.Text.Phi3, :for_causal_language_modeling},
"Phi3ForSequenceClassification" => {Bumblebee.Text.Phi3, :for_sequence_classification},
"Phi3ForTokenClassification" => {Bumblebee.Text.Phi3, :for_token_classification},
"Qwen3Model" => {Bumblebee.Text.Qwen3, :base},
"Qwen3ForCausalLM" => {Bumblebee.Text.Qwen3, :for_causal_language_modeling},
"Qwen3ForSequenceClassification" => {Bumblebee.Text.Qwen3, :for_sequence_classification},
"Qwen3ForEmbedding" => {Bumblebee.Text.Qwen3, :for_embedding},
"ResNetForImageClassification" => {Bumblebee.Vision.ResNet, :for_image_classification},
"ResNetModel" => {Bumblebee.Vision.ResNet, :base},
"RobertaForMaskedLM" => {Bumblebee.Text.Roberta, :for_masked_language_modeling},
Expand Down Expand Up @@ -253,6 +257,7 @@ defmodule Bumblebee do
"mbart" => :mbart,
"phi" => :code_gen,
"phi3" => :llama,
"qwen3" => :qwen2,
"roberta" => :roberta,
"t5" => :t5,
"whisper" => :whisper,
Expand Down
55 changes: 52 additions & 3 deletions lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ defmodule Bumblebee.Layers.Transformer do
:block_type,
:attention_window_size,
:scale_attention_weights,
:rotary_embedding
:rotary_embedding,
:query_norm,
:key_norm
]

opts =
Expand Down Expand Up @@ -317,7 +319,9 @@ defmodule Bumblebee.Layers.Transformer do
layer_norm: [],
attention_window_size: nil,
scale_attention_weights: true,
rotary_embedding: nil
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
])

name = opts[:name]
Expand Down Expand Up @@ -347,6 +351,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size = opts[:attention_window_size]
scale_attention_weights = opts[:scale_attention_weights]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]

ffn_fun =
case ffn do
Expand Down Expand Up @@ -405,6 +411,8 @@ defmodule Bumblebee.Layers.Transformer do
attention_window_size: attention_window_size,
scale_attention_weights: scale_attention_weights,
rotary_embedding: rotary_embedding,
query_norm: query_norm,
key_norm: key_norm,
name: join(name, "self_attention")
)

Expand Down Expand Up @@ -690,6 +698,14 @@ defmodule Bumblebee.Layers.Transformer do

* `:max_positions` - the maximum number of distinct positions

* `:query_norm` - configuration for query normalization. If set, normalizes
the query projection before rotary embedding. Configured with the same
options as `:layer_norm` in the block function. Defaults to `nil`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it always a function. For :layer_norm we allow keyword list because it used to always be layer norm, but it needed to be a different norm for some specific model (though perhaps it would make sense to change it to always be a function :D).


* `:key_norm` - configuration for key normalization. If set, normalizes
the key projection before rotary embedding. Configured with the same
options as `:layer_norm` in the block function. Defaults to `nil`

* `:name` - the prefix for layer names

## References
Expand Down Expand Up @@ -721,7 +737,9 @@ defmodule Bumblebee.Layers.Transformer do
key_use_bias: true,
value_use_bias: true,
output_use_bias: true,
rotary_embedding: nil
rotary_embedding: nil,
query_norm: nil,
key_norm: nil
])

attention_mask = opts[:attention_mask]
Expand All @@ -739,6 +757,8 @@ defmodule Bumblebee.Layers.Transformer do
scale_attention_weights = opts[:scale_attention_weights]
dropout_rate = opts[:dropout_rate]
rotary_embedding = opts[:rotary_embedding]
query_norm = opts[:query_norm]
key_norm = opts[:key_norm]

query_use_bias = opts[:query_use_bias]
key_use_bias = opts[:key_use_bias]
Expand Down Expand Up @@ -778,6 +798,35 @@ defmodule Bumblebee.Layers.Transformer do
)
|> Layers.split_heads(num_key_value_heads)

# Apply query and key normalization if configured (before rotary embedding)
query =
case query_norm do
opts when is_list(opts) ->
opts = Keyword.validate!(opts, epsilon: 1.0e-5)
# Normalize over the head dimension (channel_index: -1)
Layers.rms_norm(query, [epsilon: opts[:epsilon], channel_index: -1, name: join(name, "query_norm")])

fun when is_function(fun) ->
fun.(query, join(name, "query_norm"))

nil ->
query
end

key =
case key_norm do
opts when is_list(opts) ->
opts = Keyword.validate!(opts, epsilon: 1.0e-5)
# Normalize over the head dimension (channel_index: -1)
Layers.rms_norm(key, [epsilon: opts[:epsilon], channel_index: -1, name: join(name, "key_norm")])

fun when is_function(fun) ->
fun.(key, join(name, "key_norm"))

nil ->
key
end

{query, key} =
case rotary_embedding do
opts when is_list(opts) ->
Expand Down
42 changes: 42 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,48 @@ defmodule Bumblebee.Text do
defdelegate text_embedding(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextEmbedding

@type text_reranking_input :: {String.t(), String.t()} | [{String.t(), String.t()}]
@type text_reranking_output :: %{scores: text_reranking_score() | list(text_reranking_score())}
@type text_reranking_score :: %{score: number(), query: String.t(), document: String.t()}

@doc """
Builds a serving for text reranking.

The serving expects input in one of the following formats:

* `{query, document}` - a tuple with query and document text
* `[{query1, doc1}, {query2, doc2}, ...]` - a list of query-document pairs

## Options

See `Bumblebee.Text.TextReranking.text_reranking/3` for available options.

## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "Qwen/Qwen3-Reranker-0.6B"},
architecture: :for_reranker)
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "Qwen/Qwen3-Reranker-0.6B"})

serving = Bumblebee.Text.text_reranking(model_info, tokenizer)

query = "What is the capital of France?"
documents = [
"Paris is the capital of France.",
"Berlin is the capital of Germany."
]

pairs = Enum.map(documents, &{query, &1})
Nx.Serving.run(serving, pairs)

"""
@spec text_reranking(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate text_reranking(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextReranking

@type fill_mask_input :: String.t()
@type fill_mask_output :: %{predictions: list(fill_mask_prediction())}
@type fill_mask_prediction :: %{score: number(), token: String.t()}
Expand Down
6 changes: 6 additions & 0 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
},
default_template_options: [language_token: "eng_Latn"]
},
qwen2: %{
special_tokens: %{
eos: "<|im_end|>",
pad: "<|endoftext|>"
Comment on lines +205 to +206
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want the same defaults as hf/transformers. If a particular uploaded model uses different ones, it is in the configuration files and we load those.

Suggested change
eos: "<|im_end|>",
pad: "<|endoftext|>"
unk: "<|endoftext|>",
eos: "<|endoftext|>",
pad: "<|endoftext|>"

}
},
roberta: %{
special_tokens: %{
bos: "<s>",
Expand Down
Loading
Loading