Skip to content

Commit 107e378

Browse files
authored
Enable AWQ example on CPU (#1043)
1 parent fa6d156 commit 107e378

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

torchao/prototype/awq/example.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512):
3232
return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)]
3333

3434
# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py
35-
def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
35+
def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda"):
3636
model.eval()
3737
tokenizer.pad_token = tokenizer.eos_token
3838
tokenizer.padding_side = "right"
@@ -41,7 +41,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
4141
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
4242
encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt')
4343

44-
encodings['input_ids'] = encodings['input_ids'].to('cuda')
44+
encodings['input_ids'] = encodings['input_ids'].to(device)
4545

4646
lls, t = [], []
4747
for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose):
@@ -55,7 +55,8 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
5555
t1 = time.time()
5656
with torch.no_grad():
5757
log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
58-
torch.cuda.synchronize()
58+
if device.startswith("cuda"):
59+
torch.cuda.synchronize()
5960
t2 = time.time()
6061
t.append((t2-t1))
6162
lls.append(log_likelihood)
@@ -71,7 +72,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
7172
return {'perplexity':ppl, 'prediction_time':pred_time}
7273

7374
# adapted from Hicham Badri (@mobicham)
74-
def benchmark(model, tokenizer, max_length, tasks=None):
75+
def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"):
7576
import numpy as np
7677
import copy
7778
import lm_eval
@@ -87,7 +88,7 @@ def benchmark(model, tokenizer, max_length, tasks=None):
8788
tasks = ["PPL","truthfulqa_mc2", "winogrande", "arc_challenge", "hellaswag", "gsm8k", "mmlu"]
8889
results = {}
8990
if "PPL" in tasks:
90-
results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True)
91+
results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True, device=device)
9192
############################################
9293
if "truthfulqa_mc2" in tasks:
9394
for task in [("truthfulqa_mc2", 0)]:
@@ -192,7 +193,7 @@ def wikitext2_ppl(
192193
if compile:
193194
model = torch.compile(model)
194195

195-
results = benchmark(model, tokenizer, sequence_length, tasks=tasks)
196+
return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device)
196197

197198
if __name__ == "__main__":
198199
parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.")

0 commit comments

Comments
 (0)