11import  argparse 
2- import  logging 
32import  os 
43import  time 
54
65import  torch 
7- from  auto_gptq  import  AutoGPTQForCausalLM ,  BaseQuantizeConfig 
8- from  auto_gptq . nn_modules . qlinear  import  GeneralQuantLinear 
9- from  transformers  import  AutoTokenizer ,  BloomForCausalLM ,  BloomTokenizerFast ,  LlamaForCausalLM ,  LlamaTokenizer 
6+ from  _utils  import  print_perf_stats 
7+ from  auto_gptq  import  AutoGPTQForCausalLM 
8+ from  transformers  import  BloomTokenizerFast 
109
1110import  colossalai 
1211from  colossalai .inference .tensor_parallel .engine  import  TPInferEngine 
1312from  colossalai .logging  import  disable_existing_loggers 
1413from  colossalai .shardformer  import  ShardConfig 
1514from  colossalai .testing  import  clear_cache_before_run , rerun_if_address_is_in_use , spawn 
1615
17- os .environ ['TRANSFORMERS_NO_ADVISORY_WARNINGS' ] =  'true' 
18- 
19- 
20- def  print_perf_stats (latency_set , config , bs , warmup = 3 ):
21-     # trim warmup queries 
22-     latency_set  =  list (latency_set )
23-     latency_set  =  latency_set [warmup :]
24-     count  =  len (latency_set )
25- 
26-     if  count  >  0 :
27-         latency_set .sort ()
28-         avg  =  sum (latency_set ) /  count 
29-         num_layers  =  getattr (config , "num_layers" , config .num_hidden_layers )
30-         num_parameters  =  num_layers  *  config .hidden_size  *  config .hidden_size  *  12 
31-         num_bytes  =  2     # float16 
32- 
33-         print ("Avg Per Token Latency: {0:8.2f} ms" .format (avg  *  1000 ))
34-         print ("Avg BW: {0:8.2f} GB/s" .format (1  /  avg  *  num_parameters  *  num_bytes  /  1e9 ))
35-         print ("Avg flops: {0:8.2f} TFlops/s" .format (1  /  avg  *  num_parameters  *  num_bytes  *  bs  /  1e12 ))
36-         print ("Avg Throughput: tokens/s: {}" .format ((1000  /  (avg  *  1000 )) *  bs ))
16+ os .environ ["TRANSFORMERS_NO_ADVISORY_WARNINGS" ] =  "true" 
3717
3818
3919def  bench_bloom (args ):
40- 
4120    pretrained_model_dir  =  args .path 
4221    quantized_model_dir  =  args .quantized_path 
4322    max_batch_size  =  args .batch_size 
@@ -48,9 +27,9 @@ def bench_bloom(args):
4827    tokenizer .pad_token  =  tokenizer .eos_token 
4928
5029    # load quantized model to the first GPU 
51-     model  =  AutoGPTQForCausalLM .from_quantized (quantized_model_dir , 
52-                                                 device = torch .cuda .current_device (),
53-                                                 inject_fused_attention = False )
30+     model  =  AutoGPTQForCausalLM .from_quantized (
31+         quantized_model_dir ,  device = torch .cuda .current_device (),  inject_fused_attention = False 
32+     )
5433
5534    model  =  model .half ()
5635
@@ -60,22 +39,22 @@ def bench_bloom(args):
6039    generate_kwargs  =  dict (max_new_tokens = max_output_len , do_sample = False )
6140
6241    input_tokens  =  {
63-         "input_ids" : torch .randint (1 , 1000 , (max_batch_size , max_input_len ), device = ' cuda' 
64-         "attention_mask" : torch .ones ((max_batch_size , max_input_len ), device = ' cuda' ) 
42+         "input_ids" : torch .randint (1 , 1000 , (max_batch_size , max_input_len ), device = " cuda" 
43+         "attention_mask" : torch .ones ((max_batch_size , max_input_len ), device = " cuda" ), 
6544    }
6645
6746    # init TPInferEngine and shard the original model 
6847    # To benchmark torch original, comment out the line of optimizing model 
69-     shard_config  =  ShardConfig (enable_tensor_parallelism = True   if   args . tp_size   >   1   else   False , 
70-                                 inference_only = True ,
71-                                 inference_gptq = True )
48+     shard_config  =  ShardConfig (
49+         enable_tensor_parallelism = True   if   args . tp_size   >   1   else   False ,  inference_only = True ,  inference_gptq = True 
50+     )
7251    infer_engine  =  TPInferEngine (model , shard_config , max_batch_size , max_input_len , max_output_len )
7352
7453    # prepare data for generation 
7554    generate_kwargs  =  dict (max_new_tokens = max_output_len , do_sample = False )
7655    input_tokens  =  {
7756        "input_ids" : torch .randint (10 , 1000 , (max_batch_size , max_input_len )),
78-         "attention_mask" : torch .ones ((max_batch_size , max_input_len ))
57+         "attention_mask" : torch .ones ((max_batch_size , max_input_len )), 
7958    }
8059    for  t  in  input_tokens :
8160        if  torch .is_tensor (input_tokens [t ]):
@@ -99,7 +78,7 @@ def bench_bloom(args):
9978
10079def  check_bloom (rank , world_size , port , args ):
10180    disable_existing_loggers ()
102-     colossalai .launch (config = {}, rank = rank , world_size = world_size , host = ' localhost' port = port , backend = ' nccl' 
81+     colossalai .launch (config = {}, rank = rank , world_size = world_size , host = " localhost" port = port , backend = " nccl" 
10382    bench_bloom (args )
10483
10584
@@ -111,12 +90,12 @@ def test_bloom(args):
11190
11291if  __name__  ==  "__main__" :
11392    parser  =  argparse .ArgumentParser ()
114-     parser .add_argument ('-p' ,  ' --path' type = str , help = ' Model path' required = True )
115-     parser .add_argument ('-q' ,  ' --quantized_path' type = str , help = ' Model path' required = True )
116-     parser .add_argument (' -tp' ,  ' --tp_size' type = int , default = 1 , help = ' Tensor parallel size' 
117-     parser .add_argument ('-b' ,  ' --batch_size' type = int , default = 16 , help = ' Maximum batch size' 
118-     parser .add_argument (' --input_len' type = int , default = 1024 , help = ' Maximum input length' 
119-     parser .add_argument (' --output_len' type = int , default = 128 , help = ' Maximum output length' 
93+     parser .add_argument ("-p" ,  " --path" type = str , help = " Model path" required = True )
94+     parser .add_argument ("-q" ,  " --quantized_path" type = str , help = " Model path" required = True )
95+     parser .add_argument (" -tp" ,  " --tp_size" type = int , default = 1 , help = " Tensor parallel size" 
96+     parser .add_argument ("-b" ,  " --batch_size" type = int , default = 16 , help = " Maximum batch size" 
97+     parser .add_argument (" --input_len" type = int , default = 1024 , help = " Maximum input length" 
98+     parser .add_argument (" --output_len" type = int , default = 128 , help = " Maximum output length" 
12099
121100    args  =  parser .parse_args ()
122101
0 commit comments