33import  os 
44import  re 
55from  pathlib  import  Path 
6- from  typing  import  Callable ,  Optional 
6+ from  typing  import  Optional 
77
88import  numpy  as  np 
99from  datasets  import  load_dataset 
1313MAX_CHAR  =  1000 
1414
1515
16- def   create_token_estimator ( 
17-     model_name : str  =  "mistralai/Mistral-7B-Instruct-v0.2" , 
18- )  ->   Callable [[ str ],  int ]: 
19-     _tokenizer : Optional [AutoTokenizer ] =  None 
16+ class   TokenCounter : 
17+     def   __init__ ( self ,  model_name : str  =  "mistralai/Mistral-7B-Instruct-v0.2" ): 
18+          self . model_name   =   model_name 
19+          self . _tokenizer : Optional [AutoTokenizer ] =  None 
2020
21-     def  initialize () ->  None :
22-         nonlocal  _tokenizer 
23-         if  _tokenizer  is  None :
21+     def  _initialize_tokenizer (self ) ->  None :
22+         if  self ._tokenizer  is  None :
2423            os .environ ["TOKENIZERS_PARALLELISM" ] =  "false" 
2524            try :
26-                 _tokenizer  =  AutoTokenizer .from_pretrained (model_name )
25+                 self . _tokenizer  =  AutoTokenizer .from_pretrained (self . model_name )
2726            except  (OSError , ImportError , ValueError ) as  e :
2827                raise  RuntimeError (f"Failed to initialize tokenizer: { e }  " ) from  e 
2928
30-     def  estimate_num_tokens (text : str ) ->  int :
31-         initialize ()
29+     def  estimate_num_tokens (self ,  text : str ) ->  int :
30+         self . _initialize_tokenizer ()
3231
33-         if  _tokenizer  is  None :
32+         if  self . _tokenizer  is  None :
3433            return  0 
3534
3635        try :
37-             encoding  =  _tokenizer (text , return_tensors = None )
36+             encoding  =  self . _tokenizer (text , return_tensors = None )
3837            return  len (encoding ["input_ids" ])
3938        except  (AttributeError , TypeError , RuntimeError ) as  e :
4039            raise  ValueError (f"Error processing text: { e }  " ) from  e 
4140
42-     return  estimate_num_tokens 
43- 
4441
4542def  extract_and_save_with_filtering (file ):
4643    """substract human prompts and apply filtering conditions""" 
@@ -93,7 +90,7 @@ def extract_and_save_with_filtering(file):
9390    with  Path (sharegpt_file ).open ("r" , encoding = "utf-8" ) as  file :
9491        data  =  json .load (file )
9592
96-     estimate_tokens  =  create_token_estimator ()
93+     counter  =  TokenCounter ()
9794    num_of_ids  =  len (data )
9895    data  =  data [: int (num_of_ids  *  args .parse )]
9996    for  d  in  data :
@@ -102,9 +99,9 @@ def extract_and_save_with_filtering(file):
10299        gpt_tokens  =  []
103100        for  conv  in  d ["conversations" ]:
104101            if  conv ["from" ] ==  "human" :
105-                 human_tokens .append (estimate_tokens (conv ["value" ]))
102+                 human_tokens .append (counter . estimate_num_tokens (conv ["value" ]))
106103            if  conv ["from" ] ==  "gpt" :
107-                 token_number  =  estimate_tokens (conv ["value" ])
104+                 token_number  =  counter . estimate_num_tokens (conv ["value" ])
108105                conv ["num_tokens" ] =  token_number 
109106                gpt_tokens .append (token_number )
110107        if  len (human_tokens ) ==  0 :
0 commit comments