Skip to content

Commit aa86005

Browse files
authored
adding batchsize support for torchao llama benchmarks (#1182)
Summary: added batchsize argument to torchao llama benchmarks Test Plan: see benchmarks.sh Reviewers: Subscribers: Tasks: Tags:
1 parent cbd90e3 commit aa86005

File tree

2 files changed

+48
-25
lines changed

2 files changed

+48
-25
lines changed

torchao/_models/llama/benchmarks.sh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,18 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
6464
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
6565
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
6666
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
67-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
67+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
68+
69+
# Different Batch Size Benchmarks
70+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
71+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 1
72+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32
73+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128
74+
75+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 1
76+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32
77+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128
78+
79+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
80+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
81+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128

torchao/_models/llama/generate.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
4848
return probs
4949

5050
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
51-
probs = logits_to_probs(logits[0, -1], temperature, top_k)
51+
probs = logits_to_probs(logits[:, -1], temperature, top_k)
5252
idx_next = multinomial_sample_one_no_sync(probs)
5353
return idx_next, probs
5454

@@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7575
new_tokens.append(next_token)
7676
callback(new_tokens[-1])
7777
new_probs.append(next_prob)
78-
cur_token = next_token.view(1, -1)
78+
cur_token = next_token
7979

8080
return new_tokens, new_probs
8181

@@ -88,6 +88,7 @@ def generate(
8888
model: Transformer,
8989
prompt: torch.Tensor,
9090
max_new_tokens: int,
91+
batch_size: int,
9192
*,
9293
interactive: bool,
9394
callback = lambda x: x,
@@ -102,34 +103,34 @@ def generate(
102103

103104
# create an empty tensor of the expected final shape and fill in the current tokens
104105
device = prompt.device
105-
T = prompt.numel()
106+
T = prompt.size(-1)
106107

107108
# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
108109
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
109110
new_tokens = max_seq_length - T
110111

112+
# format model input
113+
prompt, input_pos = prepare_inputs_for_model(prompt)
114+
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize
115+
111116
# full prompt+output will be stored in seq
112-
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
113-
seq[:T] = prompt.view(-1)
117+
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
118+
seq[:, :T] = prompt
114119

115120
# setup model caches
116121
with torch.device(device):
117122
if cache_size is None:
118123
cache_size = max_seq_length
119124
assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt"
120-
model.setup_caches(max_batch_size=1, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
121-
122-
# format model input
123-
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
125+
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
124126

125127
# execute prefill
126-
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
127-
seq[T] = next_token
128+
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
129+
seq[:, T] = next_token.squeeze()
128130
# execute token generation
129131
input_pos = torch.tensor([T], device=device, dtype=torch.int)
130-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
131-
132-
seq = torch.cat((seq[:T+1], *generated_tokens))
132+
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
133+
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
133134

134135
return seq
135136

@@ -157,6 +158,7 @@ def main(
157158
interactive: bool = False,
158159
num_samples: int = 5,
159160
max_new_tokens: int = 100,
161+
batch_size: int = 1,
160162
top_k: int = 200,
161163
temperature: float = 0.8,
162164
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
@@ -229,9 +231,9 @@ def main(
229231
use_hqq=True
230232
else:
231233
use_hqq=False
232-
groupsize=int(quantization.split("-")[1])
233-
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
234-
quantize_(model, int4_weight_only(group_size=groupsize))
234+
group_size=int(quantization.split("-")[1])
235+
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
236+
quantize_(model, int4_weight_only(group_size=group_size))
235237
if "marlin" in quantization:
236238
from torchao.dtypes import MarlinSparseLayout
237239
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
@@ -267,9 +269,9 @@ def main(
267269
use_hqq = "hqq" in quantization
268270
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
269271
if "uintx" in quantization:
270-
# uintx-nbits-groupsize, e.g. "uintx-2-64"
272+
# uintx-nbits-group_size, e.g. "uintx-2-64"
271273
if "hqq" in quantization:
272-
# uintx-nbits-groupsize-hqq
274+
# uintx-nbits-group_size-hqq
273275
use_hqq = True
274276
else:
275277
use_hqq = False
@@ -303,6 +305,7 @@ def main(
303305
model,
304306
encode_tokens(tokenizer, prompt, bos=True, device=device),
305307
max_new_tokens,
308+
batch_size,
306309
interactive=False,
307310
temperature=temperature,
308311
top_k=top_k,
@@ -375,6 +378,7 @@ def callback(x):
375378
model,
376379
encoded,
377380
max_new_tokens,
381+
batch_size,
378382
interactive=interactive,
379383
callback=callback,
380384
temperature=temperature,
@@ -392,13 +396,13 @@ def callback(x):
392396
t = time.perf_counter() - t0
393397

394398
if not interactive:
395-
tok_list = y.tolist()
399+
tok_list = y[0].tolist()
396400
# truncate text after end of string token
397-
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
401+
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
398402
print(tokenizer.decode(tokens))
399403
else:
400404
print()
401-
tokens_generated = y.size(0) - prompt_length
405+
tokens_generated = (y.size(-1) - prompt_length)
402406
tokens_sec = tokens_generated / t
403407
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
404408
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
@@ -421,6 +425,8 @@ def callback(x):
421425
bandwidth = model_size * tokpersec
422426
mem = torch.cuda.max_memory_reserved() /1e9
423427
print(f"Average tokens/sec: {tokpersec:.2f}")
428+
if batch_size > 1:
429+
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
424430
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
425431
print(f"Peak Memory Usage: {mem:.02f} GB")
426432
print(f"Model Size: {model_size:.02f} GB")
@@ -439,6 +445,7 @@ def callback(x):
439445
result_txt += f"--interactive " if interactive else ""
440446
result_txt += f"--num_samples {num_samples} "
441447
result_txt += f"--max_new_tokens {max_new_tokens} "
448+
result_txt += f"--batch_size {batch_size} "
442449
result_txt += f"--top_k {top_k} "
443450
result_txt += f"--temperature {temperature} "
444451
result_txt += f"--cache_size {cache_size}" if cache_size else ""
@@ -459,13 +466,15 @@ def callback(x):
459466
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
460467
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
461468
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
469+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
462470
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
463471
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
464472
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
465473
parser.add_argument('-q', '--quantization', type=str,
466474
help=(
467475
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
468-
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
476+
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
477+
+'embed-int8wo'
469478
)
470479
)
471480
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
@@ -484,6 +493,6 @@ def callback(x):
484493

485494
args = parser.parse_args()
486495
main(
487-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
496+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
488497
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
489498
)

0 commit comments

Comments
 (0)