Skip to content

Commit 2ba5e0a

Browse files
committed
don't vectorize all the logits processor state
1 parent 544d80f commit 2ba5e0a

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

lib/bumblebee/text/generation.ex

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -687,15 +687,15 @@ defmodule Bumblebee.Text.Generation do
687687
input_length: state.input_length
688688
}
689689

690-
{logits, new_logits_processor_state} =
691-
logits_processor_process_fun.(logits, context, Nx.vectorize(state.logits_processor_state, :batch))
690+
{logits, logits_processor_state} =
691+
logits_processor_process_fun.(
692+
logits,
693+
context,
694+
state.logits_processor_state
695+
)
692696

693697
logits = Nx.devectorize(logits, keep_names: false)
694698

695-
logits_processor_state =
696-
Nx.devectorize(new_logits_processor_state, keep_names: false)
697-
698-
699699
{logits, %{state | logits_processor_state: logits_processor_state}}
700700
end
701701

test/bumblebee/text/generation_test.exs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,15 @@ defmodule Bumblebee.Text.GenerationTest do
248248

249249
@impl Bumblebee.LogitsProcessor
250250
def process(_logits_processor, state, logits, _context) do
251-
sfp_state = state.sfp_state
252-
logits = enforce_token(logits, sfp_state.next_enforced_token_id)
251+
next_enforced_token_id = Nx.vectorize(state.sfp_state.next_enforced_token_id, :batch)
253252

254-
sfp_state = %{
255-
sfp_state
256-
| next_enforced_token_id: Nx.add(sfp_state.next_enforced_token_id, 1)
257-
}
253+
logits = enforce_token(logits, next_enforced_token_id)
254+
255+
next_enforced_token_id =
256+
Nx.add(next_enforced_token_id, 1)
257+
|> Nx.devectorize(keep_names: false)
258258

259-
state = %{state | sfp_state: sfp_state}
259+
state = put_in(state.sfp_state.next_enforced_token_id, next_enforced_token_id)
260260

261261
{logits, state}
262262
end

0 commit comments

Comments
 (0)