@@ -61,8 +61,8 @@ Here's a simple PyTorch example:
6161.. code-block :: python
6262
6363 # regular PyTorch
64- test_data = MNIST(PATH , train = False , download = True )
65- train_data = MNIST(PATH , train = True , download = True )
64+ test_data = MNIST(my_path , train = False , download = True )
65+ train_data = MNIST(my_path , train = True , download = True )
6666 train_data, val_data = random_split(train_data, [55000 , 5000 ])
6767
6868 train_loader = DataLoader(train_data, batch_size = 32 )
@@ -75,8 +75,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa
7575
7676 class MNISTDataModule (pl .LightningDataModule ):
7777
78- def __init__ (self , data_dir : str = PATH , batch_size ):
78+ def __init__ (self , data_dir : str = " path/to/dir " , batch_size : int = 32 ):
7979 super ().__init__ ()
80+ self .data_dir = data_dir
8081 self .batch_size = batch_size
8182
8283 def setup (self , stage = None ):
@@ -99,7 +100,7 @@ colleagues or use in different projects.
99100
100101.. code-block :: python
101102
102- mnist = MNISTDataModule(PATH )
103+ mnist = MNISTDataModule(my_path )
103104 model = LitClassifier()
104105
105106 trainer = Trainer()
0 commit comments