diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index 8b2eb06758..8d8cc0ce79 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -32,7 +32,7 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512): return [cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_samples)] # from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py -def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): +def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True, device="cuda"): model.eval() tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" @@ -41,7 +41,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') encodings = tokenizer('\n\n'.join(dataset['text']), return_tensors='pt') - encodings['input_ids'] = encodings['input_ids'].to('cuda') + encodings['input_ids'] = encodings['input_ids'].to(device) lls, t = [], [] for i in tqdm(range(0, encodings['input_ids'].size(1), stride), disable=not verbose): @@ -55,7 +55,8 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): t1 = time.time() with torch.no_grad(): log_likelihood = model(input_ids, labels=target_ids).loss * trg_len - torch.cuda.synchronize() + if device.startswith("cuda"): + torch.cuda.synchronize() t2 = time.time() t.append((t2-t1)) lls.append(log_likelihood) @@ -71,7 +72,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True): return {'perplexity':ppl, 'prediction_time':pred_time} # adapted from Hicham Badri (@mobicham) -def benchmark(model, tokenizer, max_length, tasks=None): +def benchmark(model, tokenizer, max_length, tasks=None, device="cuda"): import numpy as np import copy import lm_eval @@ -87,7 +88,7 @@ def benchmark(model, tokenizer, max_length, tasks=None): tasks = ["PPL","truthfulqa_mc2", "winogrande", "arc_challenge", "hellaswag", "gsm8k", "mmlu"] results = {} if "PPL" in tasks: - results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True) + results["perplexity"] = wiki2_eval(model, tokenizer, 512, verbose=True, device=device) ############################################ if "truthfulqa_mc2" in tasks: for task in [("truthfulqa_mc2", 0)]: @@ -192,7 +193,7 @@ def wikitext2_ppl( if compile: model = torch.compile(model) - results = benchmark(model, tokenizer, sequence_length, tasks=tasks) + return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Evaluate a model with the specified parameters.")