@@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
5252 return draft_token_ids
5353
5454
55+ def get_acceptance_sampler (
56+ posterior_threshold : float = 0.03 ,
57+ posterior_alpha : float = 0.9 ,
58+ disable_bonus_tokens : bool = False ,
59+ strict_mode : bool = False ,
60+ ) -> TypicalAcceptanceSampler :
61+ """
62+ Initializes and returns a TypicalAcceptanceSampler.
63+ """
64+ return TypicalAcceptanceSampler (posterior_threshold , posterior_alpha ,
65+ disable_bonus_tokens , strict_mode )
66+
67+
5568@pytest .mark .parametrize ("k" , list (range (1 , 6 )))
5669@pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
5770@pytest .mark .parametrize ("batch_size" , list (range (1 , 32 )))
@@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
6477 different combinations of k, vocab_size, batch_size and num devices.
6578 """
6679 torch .set_default_device (device )
67- typical_acceptance_sampler = TypicalAcceptanceSampler ()
80+ typical_acceptance_sampler = get_acceptance_sampler ()
6881 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
6982 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
7083 bonus_token_ids = torch .randint (low = 0 ,
@@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
7689 size = (batch_size , k ),
7790 dtype = torch .int64 )
7891 # Verify that sampling succeeds for all cases.
79- typical_acceptance_sampler (target_probs , bonus_token_ids , draft_token_ids )
92+ typical_acceptance_sampler (target_probs ,
93+ bonus_token_ids ,
94+ draft_probs = None ,
95+ draft_token_ids = draft_token_ids )
8096
8197
8298@pytest .mark .parametrize ("above_or_below_vocab_range" , ["above" , "below" ])
@@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
94110 batch_size = 5
95111 vocab_size = 30_000
96112 torch .set_default_device (device )
97- typical_acceptance_sampler = TypicalAcceptanceSampler (strict_mode = True )
113+ typical_acceptance_sampler = get_acceptance_sampler (strict_mode = True )
98114 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
99115 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
100116 bonus_token_ids = torch .randint (low = 0 ,
@@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
125141 oob_token_ids [0 ][0 ] = rogue_token_id
126142
127143 with pytest .raises (AssertionError ):
128- typical_acceptance_sampler (target_probs , bonus_token_ids ,
129- draft_token_ids )
144+ typical_acceptance_sampler (target_probs ,
145+ bonus_token_ids ,
146+ draft_probs = None ,
147+ draft_token_ids = draft_token_ids )
130148
131149
132150@pytest .mark .parametrize ("seed" , list (range (10 )))
@@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
151169 batch_size = 5
152170 vocab_size = 30_000
153171 torch .set_default_device (device )
154- typical_acceptance_sampler = TypicalAcceptanceSampler (
172+ typical_acceptance_sampler = get_acceptance_sampler (
155173 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
156174 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
157175 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
@@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
163181 high = vocab_size ,
164182 size = (batch_size , 1 ),
165183 dtype = torch .int64 )
166- output_token_ids = typical_acceptance_sampler (target_probs ,
167- bonus_token_ids ,
168- draft_token_ids )
184+ output_token_ids = typical_acceptance_sampler (
185+ target_probs ,
186+ bonus_token_ids ,
187+ draft_probs = None ,
188+ draft_token_ids = draft_token_ids )
169189 # We are using a uniform target probability distribution.
170190 # For a uniform distribution the entropy is very high and it
171191 # should lead to all draft tokens being accepted. Verify that.
@@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
203223 vocab_size = 30_000
204224 torch .set_default_device (device )
205225
206- typical_acceptance_sampler = TypicalAcceptanceSampler (
226+ typical_acceptance_sampler = get_acceptance_sampler (
207227 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
208228 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
209229 # Simulate temperature 0 probability distribution for target probabilities
@@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
224244 # 1.0 tokens in the target distribution we will reject all of them and
225245 # fallback to the greedy sampling for selecting 1 token for each sequence.
226246 # Verify the same.
227- output_token_ids = typical_acceptance_sampler (target_probs ,
228- bonus_token_ids ,
229- draft_token_ids )
247+ output_token_ids = typical_acceptance_sampler (
248+ target_probs ,
249+ bonus_token_ids ,
250+ draft_probs = None ,
251+ draft_token_ids = draft_token_ids )
230252 assert output_token_ids .shape [0 ] == batch_size
231253 assert output_token_ids .shape [1 ] == (k + 1 )
232254 assert torch .all (output_token_ids [:, - 1 ] == - 1 )
@@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
261283 batch_size = 4
262284 vocab_size = 30_000
263285 torch .set_default_device (device )
264- typical_acceptance_sampler = TypicalAcceptanceSampler (
286+ typical_acceptance_sampler = get_acceptance_sampler (
265287 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
266288 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
267289 # For sequences 0 and 2 set the distribution to a temperature
@@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
277299 high = vocab_size ,
278300 size = (batch_size , 1 ),
279301 dtype = torch .int64 )
280- output_token_ids = typical_acceptance_sampler (target_probs ,
281- bonus_token_ids ,
282- draft_token_ids )
302+ output_token_ids = typical_acceptance_sampler (
303+ target_probs ,
304+ bonus_token_ids ,
305+ draft_probs = None ,
306+ draft_token_ids = draft_token_ids )
283307 # verify the shape of output_token_ids
284308 assert output_token_ids .shape [0 ] == batch_size
285309 assert output_token_ids .shape [1 ] == (k + 1 )
@@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
326350 batch_size = 1
327351 vocab_size = 30_000
328352 torch .set_default_device (device )
329- typical_acceptance_sampler = TypicalAcceptanceSampler (
353+ typical_acceptance_sampler = get_acceptance_sampler (
330354 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
331355 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
332356 # Create a temperature zero target probability distribution and ensure
@@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
339363 high = vocab_size ,
340364 size = (batch_size , 1 ),
341365 dtype = torch .int64 )
342- output_token_ids = typical_acceptance_sampler (target_probs ,
343- bonus_token_ids ,
344- draft_token_ids )
366+ output_token_ids = typical_acceptance_sampler (
367+ target_probs ,
368+ bonus_token_ids ,
369+ draft_probs = None ,
370+ draft_token_ids = draft_token_ids )
345371 assert output_token_ids .shape [0 ] == batch_size
346372 assert output_token_ids .shape [1 ] == (k + 1 )
347373 assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
357383 batch_size , k , vocab_size , zero_temperature_token_ids )
358384 draft_token_ids = torch .cat (
359385 (draft_token_ids [:, :2 ], draft_token_ids_to_replace [:, - 3 :]), dim = 1 )
360- output_token_ids = typical_acceptance_sampler (target_probs ,
361- bonus_token_ids ,
362- draft_token_ids )
386+ output_token_ids = typical_acceptance_sampler (
387+ target_probs ,
388+ bonus_token_ids ,
389+ draft_probs = None ,
390+ draft_token_ids = draft_token_ids )
363391 assert output_token_ids .shape [0 ] == batch_size
364392 assert output_token_ids .shape [1 ] == (k + 1 )
365393 assert torch .all (output_token_ids [:, :2 ] == draft_token_ids [:, :2 ])
@@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
384412 batch_size = 1
385413 vocab_size = 30_000
386414 torch .set_default_device (device )
387- typical_acceptance_sampler = TypicalAcceptanceSampler (
415+ typical_acceptance_sampler = get_acceptance_sampler (
388416 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
389417 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
390418 # Simulate temperature 0 probability distribution for target
@@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
402430 high = vocab_size ,
403431 size = (batch_size , 1 ),
404432 dtype = torch .int64 )
405- output_token_ids = typical_acceptance_sampler (target_probs ,
406- bonus_token_ids ,
407- draft_token_ids )
433+ output_token_ids = typical_acceptance_sampler (
434+ target_probs ,
435+ bonus_token_ids ,
436+ draft_probs = None ,
437+ draft_token_ids = draft_token_ids )
408438 assert output_token_ids .shape [0 ] == batch_size
409439 assert output_token_ids .shape [1 ] == (k + 1 )
410440 assert torch .all (output_token_ids [:, 1 :- 1 ] == - 1 )
@@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
418448 posterior_threshold = 0.0 ,
419449 posterior_alpha = 0.0 )
420450 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
421- output_token_ids = typical_acceptance_sampler (target_probs ,
422- bonus_token_ids ,
423- draft_token_ids )
451+ output_token_ids = typical_acceptance_sampler (
452+ target_probs ,
453+ bonus_token_ids ,
454+ draft_probs = None ,
455+ draft_token_ids = draft_token_ids )
424456 assert output_token_ids .shape [0 ] == batch_size
425457 assert output_token_ids .shape [1 ] == (k + 1 )
426458 assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
451483 batch_size = 5
452484 vocab_size = 30_000
453485 torch .set_default_device (device )
454- typical_acceptance_sampler = TypicalAcceptanceSampler (
486+ typical_acceptance_sampler = get_acceptance_sampler (
455487 strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
456488 typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
457489 target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
0 commit comments