Skip to content

Commit 18b0c5a

Browse files
committed
Add src, docs and other important folders
1 parent f26488b commit 18b0c5a

File tree

36 files changed

+260
-105
lines changed

36 files changed

+260
-105
lines changed

LICENSE renamed to COPYING

File renamed without changes.

README.md

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
<p align="center">
22
<a href="https://williamfalcon.github.io/pytorch-lightning/">
3-
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/imgs/lightning_logo.png" width="50">
3+
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/lightning_logo.png" width="50">
44
</a>
55
</p>
66
<h3 align="center">
@@ -12,7 +12,7 @@
1212
<p align="center">
1313
<a href="https://badge.fury.io/py/pytorch-lightning"><img src="https://badge.fury.io/py/pytorch-lightning.svg" alt="PyPI version" height="18"></a>
1414
<!-- <a href="https://travis-ci.org/williamFalcon/test-tube"><img src="https://travis-ci.org/williamFalcon/pytorch-lightning.svg?branch=master"></a> -->
15-
<a href="https://github.com/williamFalcon/pytorch-lightning/blob/master/LICENSE"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
15+
<a href="https://github.com/williamFalcon/pytorch-lightning/blob/master/COPYING"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
1616
</p>
1717

1818
```bash
@@ -22,7 +22,7 @@ pip install pytorch-lightning
2222
## Docs
2323
In progress. Documenting now!
2424

25-
## Disclaimer
25+
## Disclaimer
2626
This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).
2727

2828
## What is it?
@@ -38,7 +38,7 @@ Your model.
3838
2. Run the validation loop.
3939
3. Run the testing loop.
4040
4. Early stopping.
41-
5. Learning rate annealing.
41+
5. Learning rate annealing.
4242
6. Can train complex models like GANs or anything with multiple optimizers.
4343
7. Weight checkpointing.
4444
8. Model saving.
@@ -49,7 +49,7 @@ Your model.
4949
13. Distribute memory-bound models on multiple GPUs.
5050
14. Give your model hyperparameters parsed from the command line OR a JSON file.
5151
15. Run your model in a dev environment where nothing logs.
52-
52+
5353
## Usage
5454
To use lightning do 2 things:
5555
1. [Define a trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/trainer_main.py) (which will run ALL your models).
@@ -130,32 +130,32 @@ class My_Model(RootModule):
130130
def __init__(self):
131131
# define model
132132
self.l1 = nn.Linear(200, 10)
133-
133+
134134
# ---------------
135135
# TRAINING
136136
def training_step(self, data_batch):
137137
x, y = data_batch
138138
y_hat = self.l1(x)
139139
loss = some_loss(y_hat)
140-
140+
141141
return loss_val, {'train_loss': loss}
142-
142+
143143
def validation_step(self, data_batch):
144144
x, y = data_batch
145145
y_hat = self.l1(x)
146146
loss = some_loss(y_hat)
147-
147+
148148
return loss_val, {'val_loss': loss}
149-
149+
150150
def validation_end(self, outputs):
151151
total_accs = []
152-
152+
153153
for output in outputs:
154154
total_accs.append(output['val_acc'].item())
155-
155+
156156
# return a dict
157157
return {'total_acc': np.mean(total_accs)}
158-
158+
159159
# ---------------
160160
# SAVING
161161
def get_save_dict(self):
@@ -167,15 +167,15 @@ class My_Model(RootModule):
167167
def load_model_specific(self, checkpoint):
168168
# lightning loads for you. Here's your chance to say what you want to load
169169
self.load_state_dict(checkpoint['state_dict'])
170-
170+
171171
# ---------------
172172
# TRAINING CONFIG
173173
def configure_optimizers(self):
174174
# give lightning the list of optimizers you want to use.
175175
# lightning will call automatically
176176
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
177177
return [optimizer]
178-
178+
179179
@property
180180
def tng_dataloader(self):
181181
return pytorch_dataloader('train')
@@ -187,7 +187,7 @@ class My_Model(RootModule):
187187
@property
188188
def test_dataloader(self):
189189
return pytorch_dataloader('test')
190-
190+
191191
# ---------------
192192
# MODIFY YOUR COMMAND LINE ARGS
193193
@staticmethod
@@ -206,7 +206,7 @@ class My_Model(RootModule):
206206
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
207207
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
208208
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
209-
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
209+
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
210210

211211
#### Model training
212212
| Name | Description | Input | Return |
@@ -222,7 +222,7 @@ class My_Model(RootModule):
222222
|---|---|---|---|
223223
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
224224
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
225-
225+
226226
## Optional model hooks.
227227
Add these to the model whenever you want to configure training behavior.
228228

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

imgs/.DS_Store

-6 KB
Binary file not shown.

pytorch_lightning/root_module/__init__.py

Whitespace-only changes.

pytorch_lightning/utils/__init__.py

Whitespace-only changes.

setup.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,58 @@
22

33
from setuptools import setup, find_packages
44

5-
setup(name='pytorch-lightning',
6-
version='0.0.2',
7-
description='Rapid research framework',
8-
author='',
9-
author_email='',
10-
url='https://github.com/williamFalcon/pytorch-lightning',
11-
install_requires=['test-tube', 'torch', 'tqdm'],
12-
packages=find_packages()
13-
)
5+
# https://packaging.python.org/guides/single-sourcing-package-version/
6+
version = {}
7+
with open(os.path.join("src", "pytorch-lightning", "__init__.py")) as fp:
8+
exec(fp.read(), version)
9+
10+
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
11+
setup(
12+
name="pytorch-lightning",
13+
version=version["__version__"],
14+
description="The Keras for ML researchers using PyTorch",
15+
author="William Falcon",
16+
author_email="[email protected]",
17+
url="https://github.com/williamFalcon/pytorch-lightning",
18+
download_url="https://github.com/williamFalcon/pytorch-lightning",
19+
license="MIT",
20+
keywords=["deep learning", "pytorch", "AI"],
21+
python_requires=">=3.5",
22+
install_requires=[
23+
"torch",
24+
"tqdm",
25+
"test-tube",
26+
],
27+
extras_require={
28+
"dev": [
29+
"black ; python_version>='3.6'",
30+
"coverage",
31+
"isort",
32+
"pytest",
33+
"pytest-cov<2.6.0",
34+
"pycodestyle",
35+
"sphinx",
36+
"nbsphinx",
37+
"ipython>=5.0",
38+
"jupyter-client",
39+
]
40+
},
41+
packages=find_packages("src"),
42+
package_dir={"": "src"},
43+
entry_points={"console_scripts": ["pytorch-lightning = pytorch-lightning.cli:main"]},
44+
classifiers=[
45+
"Development Status :: 4 - Beta",
46+
"Intended Audience :: Education",
47+
"Intended Audience :: Science/Research",
48+
"License :: OSI Approved :: MIT License",
49+
"Operating System :: OS Independent",
50+
"Programming Language :: Python",
51+
"Programming Language :: Python :: 3",
52+
"Programming Language :: Python :: 3.5",
53+
"Programming Language :: Python :: 3.6",
54+
"Programming Language :: Python :: 3.7",
55+
],
56+
long_description=open("README.md", encoding="utf-8").read(),
57+
include_package_data=True,
58+
zip_safe=False,
59+
)

0 commit comments

Comments
 (0)