@@ -32,7 +32,7 @@ def get_calib_dataset(tokenizer=None, n_samples=100, block_size=512):
3232 return [cat_samples [:, i * block_size : (i + 1 ) * block_size ] for i in range (n_samples )]
3333
3434# from https://github.com/mobiusml/hqq/blob/master/examples/llama2_benchmark/eval_model.py
35- def wiki2_eval (model , tokenizer , sequence_length , stride = 512 , verbose = True ):
35+ def wiki2_eval (model , tokenizer , sequence_length , stride = 512 , verbose = True , device = "cuda" ):
3636 model .eval ()
3737 tokenizer .pad_token = tokenizer .eos_token
3838 tokenizer .padding_side = "right"
@@ -41,7 +41,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
4141 dataset = load_dataset ('wikitext' , 'wikitext-2-raw-v1' , split = 'test' )
4242 encodings = tokenizer ('\n \n ' .join (dataset ['text' ]), return_tensors = 'pt' )
4343
44- encodings ['input_ids' ] = encodings ['input_ids' ].to ('cuda' )
44+ encodings ['input_ids' ] = encodings ['input_ids' ].to (device )
4545
4646 lls , t = [], []
4747 for i in tqdm (range (0 , encodings ['input_ids' ].size (1 ), stride ), disable = not verbose ):
@@ -55,7 +55,8 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
5555 t1 = time .time ()
5656 with torch .no_grad ():
5757 log_likelihood = model (input_ids , labels = target_ids ).loss * trg_len
58- torch .cuda .synchronize ()
58+ if device .startswith ("cuda" ):
59+ torch .cuda .synchronize ()
5960 t2 = time .time ()
6061 t .append ((t2 - t1 ))
6162 lls .append (log_likelihood )
@@ -71,7 +72,7 @@ def wiki2_eval(model, tokenizer, sequence_length, stride=512, verbose=True):
7172 return {'perplexity' :ppl , 'prediction_time' :pred_time }
7273
7374# adapted from Hicham Badri (@mobicham)
74- def benchmark (model , tokenizer , max_length , tasks = None ):
75+ def benchmark (model , tokenizer , max_length , tasks = None , device = "cuda" ):
7576 import numpy as np
7677 import copy
7778 import lm_eval
@@ -87,7 +88,7 @@ def benchmark(model, tokenizer, max_length, tasks=None):
8788 tasks = ["PPL" ,"truthfulqa_mc2" , "winogrande" , "arc_challenge" , "hellaswag" , "gsm8k" , "mmlu" ]
8889 results = {}
8990 if "PPL" in tasks :
90- results ["perplexity" ] = wiki2_eval (model , tokenizer , 512 , verbose = True )
91+ results ["perplexity" ] = wiki2_eval (model , tokenizer , 512 , verbose = True , device = device )
9192 ############################################
9293 if "truthfulqa_mc2" in tasks :
9394 for task in [("truthfulqa_mc2" , 0 )]:
@@ -192,7 +193,7 @@ def wikitext2_ppl(
192193 if compile :
193194 model = torch .compile (model )
194195
195- results = benchmark (model , tokenizer , sequence_length , tasks = tasks )
196+ return benchmark (model , tokenizer , sequence_length , tasks = tasks , device = device )
196197
197198if __name__ == "__main__" :
198199 parser = argparse .ArgumentParser (description = "Evaluate a model with the specified parameters." )
0 commit comments