Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions docs/run_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,54 @@ where `HOST` is the IP address of your system. Note that we use
the [Slurm](https://slurm.schedmd.com/documentation.html) job scheduling system here.

```bash
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./example/train_vit_2d.py ./configs/vit/vit_2d.py
HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py
```

`./configs/vit/vit_2d.py` is a config file, which is introduced in the [Config file](config.md) section below. These
config files are used by ColossalAI to define all kinds of training arguments, such as the model, dataset and training
method (optimizer, lr_scheduler, epoch, etc.). Config files are highly customizable and can be modified so as to train
different models.
`./example/run_trainer.py` contains a standard training script and is presented below, it reads the config file and
`./examples/run_trainer.py` contains a standard training script and is presented below, it reads the config file and
realizes the training process.

```python
import colossalai
from colossalai.core import global_context as gpc
from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger
from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc

model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)

trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
display_progress=True,
test_interval=5
)

def run_trainer():
model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize()
logger = get_global_dist_logger()
schedule.data_sync = False
engine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
schedule=schedule
)
logger.info("engine is built", ranks=[0])

trainer = Trainer(engine=engine,
hooks_cfg=gpc.config.hooks,
verbose=True)
logger.info("trainer is built", ranks=[0])

logger.info("start training", ranks=[0])
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
max_epochs=gpc.config.num_epochs,
display_progress=True,
test_interval=2
)


if __name__ == '__main__':
run_trainer()
```

Alternatively, the `model` variable can be substituted with a self-defined model or a pre-defined model in our Model
Expand Down