1- # SPDX-License-Identifier: Apache-2.0 
2- # Standard 
31import  argparse 
42import  json 
53import  os 
4+ import  re 
5+ from  pathlib  import  Path 
66
7- # Third Party 
8- from  transformers  import  AutoTokenizer 
97import  numpy  as  np 
10- 
118from  datasets  import  load_dataset 
12- import  re 
9+ from   transformers   import  AutoTokenizer 
1310
14- def  extract_and_save_with_filtering ():
11+ MIN_CHAR  =  10 
12+ MAX_CHAR  =  1000 
13+ 
14+ 
15+ def  extract_and_save_with_filtering (file ):
1516    """substract human prompts and apply filtering conditions""" 
16-     
17-     dataset  =  load_dataset ('json' , data_files = './ShareGPT.json' , split = 'train' )
18-     
17+     dataset  =  load_dataset ("json" , data_files = file , split = "train" )
1918    filtered_prompts  =  []
20-      
19+ 
2120    for  example  in  dataset :
22-         conversations  =  example .get ('conversations' , [])
23-         
21+         conversations  =  example .get ("conversations" , [])
2422        if  isinstance (conversations , list ):
2523            for  turn  in  conversations :
26-                 if  turn .get ('from' ) in  ['human' , 'user' ]:
27-                     prompt_text  =  turn ['value' ].strip ()
28-                     
29-                     # 应用过滤条件 
30-                     if  (len (prompt_text ) >=  10  and   # 至少10个字符 
31-                         len (prompt_text ) <=  1000  and   # 最多1000个字符 
32-                         not  prompt_text .startswith (('http://' , 'https://' )) and   # 排除URL 
33-                         not  re .search (r'[<>{}[\]\\]' , prompt_text ) and   # 排除特殊字符 
34-                         not  prompt_text .isdigit ()):  # 排除纯数字 
35-                         
36-                         filtered_prompts .append ({
37-                             'from' : turn .get ('from' ),
38-                             'text' : prompt_text ,
39-                             'char_count' : len (prompt_text ),
40-                             'word_count' : len (prompt_text .split ())
41-                         })
42-     
24+                 if  turn .get ("from" ) in  ["human" , "user" ]:
25+                     prompt_text  =  turn ["value" ].strip ()
26+                     # apply filter conditions: more than 10 characters 
27+                     if  (
28+                         len (prompt_text ) >=  MIN_CHAR 
29+                         and 
30+                         # less thant 1000 characters 
31+                         len (prompt_text ) <=  MAX_CHAR 
32+                         and 
33+                         # except URLs 
34+                         not  prompt_text .startswith (("http://" , "https://" ))
35+                         and 
36+                         # except special characters 
37+                         not  re .search (r"[<>{}[\]\\]" , prompt_text )
38+                         and  not  prompt_text .isdigit ()
39+                     ):  # except pure numbers 
40+                         filtered_prompts .append (
41+                             {
42+                                 "from" : turn .get ("from" ),
43+                                 "text" : prompt_text ,
44+                                 "char_count" : len (prompt_text ),
45+                                 "word_count" : len (prompt_text .split ()),
46+                             }
47+                         )
48+ 
4349    return  filtered_prompts 
44-     
50+ 
51+ 
4552if  __name__  ==  "__main__" :
4653    parser  =  argparse .ArgumentParser (description = "Process data percentage." )
4754    parser .add_argument (
@@ -50,13 +57,12 @@ def extract_and_save_with_filtering():
5057        default = 1 ,
5158        help = "The percentage of data to process (0 to 1). Default is 1 (100%)." ,
5259    )
53- 
5460    args  =  parser .parse_args ()
5561
56-     with  open ("ShareGPT_V3_unfiltered_cleaned_split.json" , "r" , encoding = "utf-8" ) as  file :
62+     sharegpt_file  =  "ShareGPT_V3_unfiltered_cleaned_split.json" 
63+     with  Path (sharegpt_file ).open ("r" , encoding = "utf-8" ) as  file :
5764        data  =  json .load (file )
5865
59- 
6066    def  estimate_num_tokens (text : str ) ->  int :
6167        if  not  hasattr (estimate_num_tokens , "tokenizer" ):
6268            os .environ ["TOKENIZERS_PARALLELISM" ] =  "false" 
@@ -65,15 +71,10 @@ def estimate_num_tokens(text: str) -> int:
6571            )
6672        return  len (estimate_num_tokens .tokenizer .tokenize (text ))
6773
68- 
6974    num_of_ids  =  len (data )
70-     print (f"Number of IDs: { num_of_ids }  " )
7175    data  =  data [: int (num_of_ids  *  args .parse )]
72- 
73-     count  =  0 
74- 
7576    for  d  in  data :
76-         d ["num_round" ] =  len (d ["conversations" ])   # human is one round, gpt is another round 
77+         d ["num_round" ] =  len (d ["conversations" ])
7778        human_tokens  =  []
7879        gpt_tokens  =  []
7980        for  conv  in  d ["conversations" ]:
@@ -96,15 +97,10 @@ def estimate_num_tokens(text: str) -> int:
9697            d ["average_gpt_token" ] =  float (np .mean (gpt_tokens ))
9798            d ["max_gpt_token" ] =  float (np .max (gpt_tokens ))
9899
99-         count  +=  1 
100-         print (f"Finished { count }  " )
101- 
102100    # save unfiletered datasets to ShareGPT.json 
103-     with  open ("ShareGPT.json" ,  "w" , encoding = "utf-8" ) as  file :
101+     with  Path ("ShareGPT.json" ). open ( "w" , encoding = "utf-8" ) as  file :
104102        json .dump (data , file , ensure_ascii = False , indent = 2 )
105103    # filter from: human prompts and save again 
106-     filtered_result  =  extract_and_save_with_filtering ()
107-     with  open ("ShareGPT.json" ,  "w" , encoding = "utf-8" ) as  file :
104+     filtered_result  =  extract_and_save_with_filtering ("ShareGPT.json" )
105+     with  Path ("ShareGPT.json" ). open ( "w" , encoding = "utf-8" ) as  file :
108106        json .dump (filtered_result , file , ensure_ascii = False , indent = 2 )
109- 
110- 
0 commit comments