@@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
48
48
return probs
49
49
50
50
def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
51
- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
51
+ probs = logits_to_probs (logits [: , - 1 ], temperature , top_k )
52
52
idx_next = multinomial_sample_one_no_sync (probs )
53
53
return idx_next , probs
54
54
@@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
75
75
new_tokens .append (next_token )
76
76
callback (new_tokens [- 1 ])
77
77
new_probs .append (next_prob )
78
- cur_token = next_token . view ( 1 , - 1 )
78
+ cur_token = next_token
79
79
80
80
return new_tokens , new_probs
81
81
@@ -88,6 +88,7 @@ def generate(
88
88
model : Transformer ,
89
89
prompt : torch .Tensor ,
90
90
max_new_tokens : int ,
91
+ batch_size : int ,
91
92
* ,
92
93
interactive : bool ,
93
94
callback = lambda x : x ,
@@ -102,34 +103,34 @@ def generate(
102
103
103
104
# create an empty tensor of the expected final shape and fill in the current tokens
104
105
device = prompt .device
105
- T = prompt .numel ( )
106
+ T = prompt .size ( - 1 )
106
107
107
108
# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
108
109
max_seq_length = min (T + max_new_tokens , model .config .block_size ) if not interactive else 350
109
110
new_tokens = max_seq_length - T
110
111
112
+ # format model input
113
+ prompt , input_pos = prepare_inputs_for_model (prompt )
114
+ prompt = prompt .repeat (batch_size , 1 ) # expand prompt based on batchsize
115
+
111
116
# full prompt+output will be stored in seq
112
- seq = torch .empty (max_seq_length , dtype = prompt .dtype , device = device )
113
- seq [:T ] = prompt . view ( - 1 )
117
+ seq = torch .empty (batch_size , max_seq_length , dtype = prompt .dtype , device = device )
118
+ seq [:, : T ] = prompt
114
119
115
120
# setup model caches
116
121
with torch .device (device ):
117
122
if cache_size is None :
118
123
cache_size = max_seq_length
119
124
assert cache_size >= max_seq_length , "need cache_size to be greater than max_new_tokens + size-of-prompt"
120
- model .setup_caches (max_batch_size = 1 , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
121
-
122
- # format model input
123
- x , input_pos = prepare_inputs_for_model (prompt , max_new_tokens )
125
+ model .setup_caches (max_batch_size = batch_size , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
124
126
125
127
# execute prefill
126
- next_token = prefill (model , x , input_pos , ** sampling_kwargs ).clone ()
127
- seq [T ] = next_token
128
+ next_token = prefill (model , prompt . view ( batch_size , - 1 ) , input_pos , ** sampling_kwargs ).clone ()
129
+ seq [:, T ] = next_token . squeeze ()
128
130
# execute token generation
129
131
input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
130
- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
131
-
132
- seq = torch .cat ((seq [:T + 1 ], * generated_tokens ))
132
+ generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
133
+ seq = torch .cat ((seq [:, :T + 1 ], * generated_tokens ), dim = - 1 )
133
134
134
135
return seq
135
136
@@ -157,6 +158,7 @@ def main(
157
158
interactive : bool = False ,
158
159
num_samples : int = 5 ,
159
160
max_new_tokens : int = 100 ,
161
+ batch_size : int = 1 ,
160
162
top_k : int = 200 ,
161
163
temperature : float = 0.8 ,
162
164
checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
@@ -229,9 +231,9 @@ def main(
229
231
use_hqq = True
230
232
else :
231
233
use_hqq = False
232
- groupsize = int (quantization .split ("-" )[1 ])
233
- assert groupsize in [32 ,64 ,128 ,256 ], f"int4wo groupsize needs to be one of [32,64,128,256] but got { groupsize } "
234
- quantize_ (model , int4_weight_only (group_size = groupsize ))
234
+ group_size = int (quantization .split ("-" )[1 ])
235
+ assert group_size in [32 ,64 ,128 ,256 ], f"int4wo group_size needs to be one of [32,64,128,256] but got { group_size } "
236
+ quantize_ (model , int4_weight_only (group_size = group_size ))
235
237
if "marlin" in quantization :
236
238
from torchao .dtypes import MarlinSparseLayout
237
239
quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()))
@@ -267,9 +269,9 @@ def main(
267
269
use_hqq = "hqq" in quantization
268
270
quantize_ (model , awq_uintx (quant_dtype = quant_dtype , group_size = group_size , use_hqq = use_hqq ), is_observed_linear )
269
271
if "uintx" in quantization :
270
- # uintx-nbits-groupsize , e.g. "uintx-2-64"
272
+ # uintx-nbits-group_size , e.g. "uintx-2-64"
271
273
if "hqq" in quantization :
272
- # uintx-nbits-groupsize -hqq
274
+ # uintx-nbits-group_size -hqq
273
275
use_hqq = True
274
276
else :
275
277
use_hqq = False
@@ -303,6 +305,7 @@ def main(
303
305
model ,
304
306
encode_tokens (tokenizer , prompt , bos = True , device = device ),
305
307
max_new_tokens ,
308
+ batch_size ,
306
309
interactive = False ,
307
310
temperature = temperature ,
308
311
top_k = top_k ,
@@ -375,6 +378,7 @@ def callback(x):
375
378
model ,
376
379
encoded ,
377
380
max_new_tokens ,
381
+ batch_size ,
378
382
interactive = interactive ,
379
383
callback = callback ,
380
384
temperature = temperature ,
@@ -392,13 +396,13 @@ def callback(x):
392
396
t = time .perf_counter () - t0
393
397
394
398
if not interactive :
395
- tok_list = y .tolist ()
399
+ tok_list = y [ 0 ] .tolist ()
396
400
# truncate text after end of string token
397
- tokens = tok_list if not tokenizer .eos_id () in y else tok_list [:tok_list .index (tokenizer .eos_id ())]
401
+ tokens = tok_list if not tokenizer .eos_id () in tok_list else tok_list [:tok_list .index (tokenizer .eos_id ())]
398
402
print (tokenizer .decode (tokens ))
399
403
else :
400
404
print ()
401
- tokens_generated = y .size (0 ) - prompt_length
405
+ tokens_generated = ( y .size (- 1 ) - prompt_length )
402
406
tokens_sec = tokens_generated / t
403
407
aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
404
408
print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
@@ -421,6 +425,8 @@ def callback(x):
421
425
bandwidth = model_size * tokpersec
422
426
mem = torch .cuda .max_memory_reserved () / 1e9
423
427
print (f"Average tokens/sec: { tokpersec :.2f} " )
428
+ if batch_size > 1 :
429
+ print (f"Average tokens/sec including batches { batch_size * tokpersec :.2f} " )
424
430
print (f"Average Bandwidth: { bandwidth :.02f} GB/s" )
425
431
print (f"Peak Memory Usage: { mem :.02f} GB" )
426
432
print (f"Model Size: { model_size :.02f} GB" )
@@ -439,6 +445,7 @@ def callback(x):
439
445
result_txt += f"--interactive " if interactive else ""
440
446
result_txt += f"--num_samples { num_samples } "
441
447
result_txt += f"--max_new_tokens { max_new_tokens } "
448
+ result_txt += f"--batch_size { batch_size } "
442
449
result_txt += f"--top_k { top_k } "
443
450
result_txt += f"--temperature { temperature } "
444
451
result_txt += f"--cache_size { cache_size } " if cache_size else ""
@@ -459,13 +466,15 @@ def callback(x):
459
466
parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
460
467
parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
461
468
parser .add_argument ('--max_new_tokens' , type = int , default = 200 , help = 'Maximum number of new tokens.' )
469
+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to benchmark with' )
462
470
parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
463
471
parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
464
472
parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
465
473
parser .add_argument ('-q' , '--quantization' , type = str ,
466
474
help = (
467
475
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
468
- + 'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
476
+ + 'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
477
+ + 'embed-int8wo'
469
478
)
470
479
)
471
480
parser .add_argument ("--calibration_limit" , type = int , default = 10 , help = "Number of calibration examples" )
@@ -484,6 +493,6 @@ def callback(x):
484
493
485
494
args = parser .parse_args ()
486
495
main (
487
- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
496
+ args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args . top_k ,
488
497
args .temperature , args .checkpoint_path , args .quantization , args .calibration_limit , args .calibration_seq_length , args .kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
489
498
)
0 commit comments