Skip to content

Commit ce92584

Browse files
committed
introduced types for init_context and process_context
1 parent 70d7f65 commit ce92584

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

lib/bumblebee.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ defmodule Bumblebee do
10921092
@doc type: :logits_processor
10931093
@spec logits_processor_init(
10941094
Bumblebee.LogitsProcessor.t(),
1095-
context :: term()
1095+
context :: Bumblebee.LogitsProcessor.init_context()
10961096
) :: Bumblebee.LogitsProcessor.state()
10971097
def logits_processor_init(%module{} = logits_processor, context) do
10981098
module.init(logits_processor, context)
@@ -1107,7 +1107,7 @@ defmodule Bumblebee do
11071107
Bumblebee.LogitsProcessor.t(),
11081108
Bumblebee.LogitsProcessor.state(),
11091109
logits :: Nx.Tensor.t(),
1110-
context :: term()
1110+
context :: Bumblebee.LogitsProcessor.process_context()
11111111
) :: {Bumblebee.LogitsProcessor.state(), logits :: Nx.Tensor.t()}
11121112
def logits_processor_process(%module{} = logits_processor, state, logits, context) do
11131113
module.process(logits_processor, state, logits, context)

lib/bumblebee/logits_processor.ex

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ defmodule Bumblebee.LogitsProcessor do
1515

1616
@type state :: Nx.Container.t()
1717

18+
@type process_context :: %{
19+
sequence: Nx.Tensor.t(),
20+
length: Nx.Tensor.t(),
21+
input_length: Nx.Tensor.t()
22+
}
23+
24+
@type init_context :: %{}
25+
1826
@doc """
1927
Initializes state for a new logits processor.
2028
@@ -24,7 +32,7 @@ defmodule Bumblebee.LogitsProcessor do
2432
Oftentimes logits processors are stateless, in which case this
2533
function can return an empty container, such as `{}`.
2634
"""
27-
@callback init(t(), any()) :: state()
35+
@callback init(t(), init_context()) :: state()
2836

2937
@doc """
3038
Processes logits, applying specific rules.
@@ -33,6 +41,6 @@ defmodule Bumblebee.LogitsProcessor do
3341
t(),
3442
state(),
3543
logits :: Nx.Tensor.t(),
36-
context :: term()
44+
context :: process_context()
3745
) :: {state :: map(), logits :: Nx.Tensor.t()}
3846
end

lib/bumblebee/text/generation/stateless_logits_processor.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ defmodule Bumblebee.Text.Generation.StatelessLogitsProcessor do
1919
end
2020

2121
@impl Bumblebee.LogitsProcessor
22-
def init(_logits_processor, _context) do
22+
def init(_logits_processor, _init_context) do
2323
%{}
2424
end
2525

2626
@impl Bumblebee.LogitsProcessor
27-
def process(logits_processor, state, logits, context) do
28-
{state, logits_processor.fun.(logits, context)}
27+
def process(logits_processor, state, logits, process_context) do
28+
{state, logits_processor.fun.(logits, process_context)}
2929
end
3030
end

test/bumblebee/text/generation_test.exs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,19 +220,19 @@ defmodule Bumblebee.Text.GenerationTest do
220220
end
221221

222222
@impl Bumblebee.LogitsProcessor
223-
def init(logits_processor, context) do
223+
def init(logits_processor, init_context) do
224224
initial_enforced_token_id = Nx.tensor([logits_processor.initial_enforced_token_id])
225225

226226
[initial_enforced_batch_token_id, _sequence] =
227-
Nx.broadcast_vectors([initial_enforced_token_id, context.sequence])
227+
Nx.broadcast_vectors([initial_enforced_token_id, init_context.sequence])
228228

229229
%{
230230
next_enforced_token_id: initial_enforced_batch_token_id
231231
}
232232
end
233233

234234
@impl Bumblebee.LogitsProcessor
235-
def process(_logits_processor, state, logits, _context) do
235+
def process(_logits_processor, state, logits, _process_context) do
236236
next_enforced_token_id = state.next_enforced_token_id
237237

238238
logits = enforce_token(logits, next_enforced_token_id)

0 commit comments

Comments
 (0)