You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).
27
27
28
28
## What is it?
@@ -38,7 +38,7 @@ Your model.
38
38
2. Run the validation loop.
39
39
3. Run the testing loop.
40
40
4. Early stopping.
41
-
5. Learning rate annealing.
41
+
5. Learning rate annealing.
42
42
6. Can train complex models like GANs or anything with multiple optimizers.
43
43
7. Weight checkpointing.
44
44
8. Model saving.
@@ -49,7 +49,7 @@ Your model.
49
49
13. Distribute memory-bound models on multiple GPUs.
50
50
14. Give your model hyperparameters parsed from the command line OR a JSON file.
51
51
15. Run your model in a dev environment where nothing logs.
52
-
52
+
53
53
## Usage
54
54
To use lightning do 2 things:
55
55
1.[Define a trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/trainer_main.py) (which will run ALL your models).
@@ -130,32 +130,32 @@ class My_Model(RootModule):
130
130
def__init__(self):
131
131
# define model
132
132
self.l1 = nn.Linear(200, 10)
133
-
133
+
134
134
# ---------------
135
135
# TRAINING
136
136
deftraining_step(self, data_batch):
137
137
x, y = data_batch
138
138
y_hat =self.l1(x)
139
139
loss = some_loss(y_hat)
140
-
140
+
141
141
return loss_val, {'train_loss': loss}
142
-
142
+
143
143
defvalidation_step(self, data_batch):
144
144
x, y = data_batch
145
145
y_hat =self.l1(x)
146
146
loss = some_loss(y_hat)
147
-
147
+
148
148
return loss_val, {'val_loss': loss}
149
-
149
+
150
150
defvalidation_end(self, outputs):
151
151
total_accs = []
152
-
152
+
153
153
for output in outputs:
154
154
total_accs.append(output['val_acc'].item())
155
-
155
+
156
156
# return a dict
157
157
return {'total_acc': np.mean(total_accs)}
158
-
158
+
159
159
# ---------------
160
160
# SAVING
161
161
defget_save_dict(self):
@@ -167,15 +167,15 @@ class My_Model(RootModule):
167
167
defload_model_specific(self, checkpoint):
168
168
# lightning loads for you. Here's your chance to say what you want to load
169
169
self.load_state_dict(checkpoint['state_dict'])
170
-
170
+
171
171
# ---------------
172
172
# TRAINING CONFIG
173
173
defconfigure_optimizers(self):
174
174
# give lightning the list of optimizers you want to use.
0 commit comments