2828# map from dataset name to a local directory, or 
2929# a dataset repository on the HF hub 
3030_supported_datasets  =  {
31-     "c4_mini " : "torchtitan/datasets/c4_mini " ,
31+     "c4_test " : "test/assets/c4_test " ,
3232    "c4" : "allenai/c4" ,
3333}
3434
@@ -48,8 +48,8 @@ class HuggingFaceDataset(IterableDataset, Stateful):
4848        rank (int): rank of the current data parallel process 
4949        infinite (bool): whether to loop infinitely over the dataset 
5050
51-     We currently support the c4 dataset and a subset of it: 
52-     c4_mini (45K  training entries) 
51+     We currently support the c4 dataset,  and a subset of it for testing purposes : 
52+     c4_test (2K  training entries) 
5353    c4 (177M training entries - this dataset is streamed due to the size) 
5454
5555    >> c4 (EN) <<: 
@@ -83,12 +83,12 @@ def __init__(
8383            if  dataset_path :
8484                logger .warning (
8585                    f"Dataset { dataset_name }  
86-                     f"Recommended datasets are: { list (_supported_datasets .keys ())} . " 
86+                     f"Recommended datasets are: { list (_supported_datasets .keys ())}  
8787                )
8888            else :
8989                raise  ValueError (
9090                    f"Dataset { dataset_name }  
91-                     f"Supported datasets are: { list (_supported_datasets .keys ())} . " 
91+                     f"Supported datasets are: { list (_supported_datasets .keys ())}  
9292                )
9393
9494        if  not  dataset_path :
@@ -132,15 +132,12 @@ def __iter__(self):
132132                    yield  input , label 
133133
134134            if  not  self .infinite :
135-                 logger .warning (f"Dataset { self .dataset_name } . " )
135+                 logger .warning (f"Dataset { self .dataset_name }  )
136136                break 
137137            else :
138138                # Reset offset for the next iteration 
139139                self ._sample_idx  =  0 
140-                 logger .warning (
141-                     f"Dataset { self .dataset_name }  
142-                     "Loss related metrics might be misleading." 
143-                 )
140+                 logger .warning (f"Dataset { self .dataset_name }  )
144141
145142    def  _get_data_iter (self ):
146143        if  self ._sample_idx  ==  0 :
@@ -188,7 +185,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
188185
189186        if  self ._rank_id  not  in state_dict :
190187            logger .warning (
191-                 f"DataLoader state is empty for dp rank { self ._dp_rank } { self ._rank_id } . " 
188+                 f"DataLoader state is empty for dp rank { self ._dp_rank } { self ._rank_id }  
192189            )
193190            return 
194191        super ().load_state_dict (pickle .loads (state_dict [self ._rank_id ]))
0 commit comments