diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 51f2330f..62010053 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -1083,6 +1083,36 @@ defmodule Bumblebee do end end + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/4`. + """ + @doc type: :logits_processor + @spec logits_processor_init( + Bumblebee.LogitsProcessor.t(), + context :: Bumblebee.LogitsProcessor.init_context() + ) :: Bumblebee.LogitsProcessor.state() + def logits_processor_init(%module{} = logits_processor, context) do + module.init(logits_processor, context) + end + + @doc """ + Processes logits, applying specific rules. Receives context, state and + logits, and returns updated logits and state. + """ + @doc type: :logits_processor + @spec logits_processor_process( + Bumblebee.LogitsProcessor.t(), + Bumblebee.LogitsProcessor.state(), + logits :: Nx.Tensor.t(), + context :: Bumblebee.LogitsProcessor.process_context() + ) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()} + def logits_processor_process(%module{} = logits_processor, state, logits, context) do + module.process(logits_processor, state, logits, context) + end + @doc """ Initializes state for a new scheduler loop. diff --git a/lib/bumblebee/logits_processor.ex b/lib/bumblebee/logits_processor.ex new file mode 100644 index 00000000..21df999d --- /dev/null +++ b/lib/bumblebee/logits_processor.ex @@ -0,0 +1,46 @@ +defmodule Bumblebee.LogitsProcessor do + @moduledoc """ + An interface for configuring and using logits processors. + + Logits processors are used during autoregressive generation to modify + predicted scores at each generation step. This allows for applying + certain rules to the model output to control which tokens are picked + at each generation step, and which are not. + + Every module implementing this behaviour is expected to also define + a configuration struct. + """ + + @type t :: Bumblebee.Configurable.t() + + @type state :: Nx.Container.t() + + @type process_context :: %{ + sequence: Nx.Tensor.t(), + length: Nx.Tensor.t(), + input_length: Nx.Tensor.t() + } + + @type init_context :: %{} + + @doc """ + Initializes state for a new logits processor. + + Returns `state`, which is an opaque `Nx.Container`, and it is then + passed to and returned from `process/2`. + + Oftentimes logits processors are stateless, in which case this + function can return an empty container, such as `{}`. + """ + @callback init(t(), init_context()) :: state() + + @doc """ + Processes logits, applying specific rules. + """ + @callback process( + t(), + state(), + logits :: Nx.Tensor.t(), + context :: process_context() + ) :: {state :: map(), logits :: Nx.Tensor.t()} +end diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 935c4921..669c1b7e 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -164,13 +164,15 @@ defmodule Bumblebee.Text.Generation do {_init_fun, predict_fun} = Axon.build(model, global_layer_options: global_layer_options) - logits_processor_fun = get_logits_processor(min_length_fun, config, opts[:logits_processors]) + {logits_processor_init_fun, logits_processor_process_fun} = + get_logits_processor(min_length_fun, config, opts[:logits_processors]) &generate_impl( &2, predict_fun, &1, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -386,18 +388,45 @@ defmodule Bumblebee.Text.Generation do [] end ++ logits_processors - fn logits, context -> - for processor <- processors, processor, reduce: logits do - logits -> processor.(logits, context) - end + processors = + processors + |> Enum.filter(fn processor -> processor != nil end) + |> Enum.map(fn processor -> + if is_function(processor, 2) do + %Bumblebee.Text.Generation.StatelessLogitsProcessor{fun: processor} + else + processor + end + end) + + init_fun = fn context -> + processors + |> Enum.map(fn processor -> + Bumblebee.logits_processor_init(processor, context) + end) + |> List.to_tuple() end + + process_fun = fn logits, context, processor_states -> + {processor_states, logits} = + processors + |> Enum.zip(Tuple.to_list(processor_states)) + |> Enum.map_reduce(logits, fn {processor, processor_state}, logits -> + Bumblebee.logits_processor_process(processor, processor_state, logits, context) + end) + + {List.to_tuple(processor_states), logits} + end + + {init_fun, process_fun} end defnp generate_impl( inputs, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, prepare_inputs_fun, update_inputs_fun, traverse_cache_fun, @@ -427,7 +456,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -439,7 +469,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, merge_options( @@ -456,7 +487,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, merge_options([max_length: max_length], opts) ) @@ -485,7 +517,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -493,7 +526,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # The loop works with inputs of length 1, so if the initial input # is longer, we make the initial pass outside @@ -504,7 +544,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -521,7 +561,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -533,7 +573,13 @@ defmodule Bumblebee.Text.Generation do state end - defnp init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) do + defnp init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) do {batch_size, length} = Nx.shape(decoder_input_ids) sequences = Nx.broadcast(pad_token_id, {batch_size, max_length}) @@ -545,13 +591,20 @@ defmodule Bumblebee.Text.Generation do # they could produce arbitrary tokens until we reach max length. finished_length = Nx.select(padded_batch_item?, 1, 0) + context = %{ + sequence: Nx.vectorize(sequences, :batch), + input_length: length, + length: length + } + %{ sequences: sequences, input_length: length, length: length, finished_length: finished_length, # The ignored return value that we attach all hooks to - ignored: Nx.broadcast(0, {batch_size}) + ignored: Nx.broadcast(0, {batch_size}), + logits_processor_states: logits_processor_init_fun.(context) } end @@ -564,7 +617,7 @@ defmodule Bumblebee.Text.Generation do inputs, predict_fun, params, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts ) do @@ -574,7 +627,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) token_id = Nx.argmax(logits, axis: -1) state = update_sequences(state, token_id, pad_token_id, eos_token_id) @@ -631,15 +684,25 @@ defmodule Bumblebee.Text.Generation do end end - defnp batch_process_logits(logits_processor_fun, logits, state) do - logits - |> Nx.vectorize(:batch) - |> logits_processor_fun.(%{ + defnp batch_process_logits(logits_processor_process_fun, logits, state) do + logits = Nx.vectorize(logits, :batch) + + context = %{ sequence: Nx.vectorize(state.sequences, :batch), length: state.length, input_length: state.input_length - }) - |> Nx.devectorize(keep_names: false) + } + + {logits_processor_states, logits} = + logits_processor_process_fun.( + logits, + context, + state.logits_processor_states + ) + + logits = Nx.devectorize(logits, keep_names: false) + + {logits, %{state | logits_processor_states: logits_processor_states}} end # Contrastive search @@ -650,7 +713,8 @@ defmodule Bumblebee.Text.Generation do padded_batch_item?, predict_fun, params, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, traverse_cache_fun, opts \\ [] @@ -661,7 +725,14 @@ defmodule Bumblebee.Text.Generation do top_k = opts[:top_k] penalty_alpha = opts[:penalty_alpha] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) # Step (1) # Initial pass to obtain hidden state and expand inputs to top-k @@ -684,7 +755,7 @@ defmodule Bumblebee.Text.Generation do joint_hidden_state = Nx.put_slice(joint_hidden_state, [0, 0, 0], initial_hidden_state) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -727,7 +798,7 @@ defmodule Bumblebee.Text.Generation do logits = outputs.logits[[.., -1]] logits = Utils.Nx.chunked_take(logits, top_k, selected_idx) - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits, axis: -1) {top_k_scores, top_k_token_ids} = Nx.top_k(scores, k: top_k) @@ -817,7 +888,8 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, seed, - logits_processor_fun, + logits_processor_init_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -825,7 +897,14 @@ defmodule Bumblebee.Text.Generation do pad_token_id = opts[:pad_token_id] eos_token_id = opts[:eos_token_id] - state = init_sequences(decoder_input_ids, padded_batch_item?, max_length, pad_token_id) + state = + init_sequences( + decoder_input_ids, + padded_batch_item?, + max_length, + pad_token_id, + logits_processor_init_fun + ) prng_key = seed |> Nx.vectorize(:batch) |> Nx.Random.key() @@ -839,7 +918,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -857,7 +936,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, pad_token_id: pad_token_id, eos_token_id: eos_token_id @@ -875,7 +954,7 @@ defmodule Bumblebee.Text.Generation do predict_fun, params, prng_key, - logits_processor_fun, + logits_processor_process_fun, update_inputs_fun, opts \\ [] ) do @@ -888,7 +967,7 @@ defmodule Bumblebee.Text.Generation do outputs = predict_fun.(params, inputs) logits = outputs.logits[[.., -1]] - logits = batch_process_logits(logits_processor_fun, logits, state) + {logits, state} = batch_process_logits(logits_processor_process_fun, logits, state) scores = Axon.Activations.softmax(logits) token_id = batched_choice(key, scores) diff --git a/lib/bumblebee/text/generation/stateless_logits_processor.ex b/lib/bumblebee/text/generation/stateless_logits_processor.ex new file mode 100644 index 00000000..8e84d6fd --- /dev/null +++ b/lib/bumblebee/text/generation/stateless_logits_processor.ex @@ -0,0 +1,30 @@ +defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do + @moduledoc false + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + fun: [ + default: nil, + doc: "a state-less function that is applied to the logits" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(_logits_processor, _init_context) do + %{} + end + + @impl Bumblebee.LogitsProcessor + def process(logits_processor, state, logits, process_context) do + {state, logits_processor.fun.(logits, process_context)} + end +end diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index ff9854a1..10014b4b 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -106,4 +106,145 @@ defmodule Bumblebee.Text.GenerationTest do assert_equal(token_ids, Nx.tensor([[80, 1023, 1023]])) end + + test "with stateful logits processor with different batch sizes" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + {:ok, generation_config} = + Bumblebee.load_generation_config({:hf, "hf-internal-testing/tiny-random-GPT2LMHeadModel"}) + + assert %Bumblebee.Text.Gpt2{architecture: :for_causal_language_modeling} = spec + + input_ids = Nx.tensor([[0, 0, 10, 20, 30, 40, 50, 60, 70, 80]]) + attention_mask = Nx.tensor([[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]) + seed = Nx.tensor([0]) + + ######################################################### + # batch size of 1 + + inputs = %{ + "input_ids" => input_ids, + "attention_mask" => attention_mask, + "seed" => seed + } + + # We demonstrate the use of the state with the following example of a + # stateful processor (see below). On the first iteration, it enforces the + # given initial ID, then increments the token ID to be enforced on the + # following iterations. The ID of the token to be enforced is passed on + # between iterations using the logits_processor_state. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 2) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + # ToDo Bumblee.configure() + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_id: 79 + ) + ] + ) + + # The result without the logits processor would be, as with the first + # decoder test above, [80, 80, 80]. + # + # Now, with the processor below, we expect the sequence of [79, 80, 81 ..], + # demonstrating the use of the state in a logits processor. + + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) + + assert_equal(token_ids[[0, 0]], 79) + assert_equal(token_ids[[0, 1]], 80) + + ######################################################### + # batch size of 2 + + inputs = %{ + "input_ids" => Nx.Batch.concatenate([input_ids, input_ids]), + "attention_mask" => Nx.Batch.concatenate([attention_mask, attention_mask]), + "seed" => Nx.Batch.concatenate([seed, seed]) + } + + # this is the same example as above, but with a batch size of 2. + + generation_config = Bumblebee.configure(generation_config, max_new_tokens: 3) + + generate = + Bumblebee.Text.Generation.build_generate(model, spec, generation_config, + logits_processors: [ + Bumblebee.configure(Bumblebee.Text.GenerationTest.StatefulLogitsProcessing, + initial_enforced_token_id: 78 + ) + ] + ) + + %{token_ids: token_ids} = + Nx.Defn.jit_apply(generate, [params, inputs], compiler: EXLA) + + # result without logit processor: 80, 80, 80 + + # first entry in batch + assert_equal(token_ids[[0, 0]], 78) + assert_equal(token_ids[[0, 1]], 79) + assert_equal(token_ids[[0, 2]], 80) + + # second entry in batch + assert_equal(token_ids[[1, 0]], 78) + assert_equal(token_ids[[1, 1]], 79) + assert_equal(token_ids[[1, 2]], 80) + end + + defmodule StatefulLogitsProcessing do + @moduledoc false + + import Nx.Defn + + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.LogitsProcessor + + options = [ + initial_enforced_token_id: [ + default: [], + doc: "A token id to enforce on the first iteration" + ] + ] + + defstruct Bumblebee.Shared.option_defaults(options) + + @impl Bumblebee.Configurable + def config(logits_processor, opts) do + Bumblebee.Shared.put_config_attrs(logits_processor, opts) + end + + @impl Bumblebee.LogitsProcessor + def init(logits_processor, _init_context) do + initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id]) + + %{ + next_enforced_token_id: initial_enforced_token_id + } + end + + @impl Bumblebee.LogitsProcessor + def process(_logits_processor, state, logits, _process_context) do + next_enforced_token_id = state.next_enforced_token_id + + logits = enforce_token(logits, next_enforced_token_id) + + next_enforced_token_id = Nx.add(next_enforced_token_id, 1) + + state = put_in(state.next_enforced_token_id, next_enforced_token_id) + + {state, logits} + end + + defnp enforce_token(logits, token_id) do + logits + |> Nx.fill(Nx.Constants.neg_infinity(), type: Nx.type(logits)) + |> Nx.indexed_put(token_id, Nx.tensor(0, type: Nx.type(logits))) + end + end end