diff --git a/docs/run_demo.md b/docs/run_demo.md index 91b0871b3f4c..8bcb1e16a480 100644 --- a/docs/run_demo.md +++ b/docs/run_demo.md @@ -19,41 +19,54 @@ where `HOST` is the IP address of your system. Note that we use the [Slurm](https://slurm.schedmd.com/documentation.html) job scheduling system here. ```bash -HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./example/train_vit_2d.py ./configs/vit/vit_2d.py +HOST=xxx.xxx.xxx.xxx srun ./scripts/slurm_dist_train.sh ./examples/run_trainer.py ./configs/vit/vit_2d.py ``` `./configs/vit/vit_2d.py` is a config file, which is introduced in the [Config file](config.md) section below. These config files are used by ColossalAI to define all kinds of training arguments, such as the model, dataset and training method (optimizer, lr_scheduler, epoch, etc.). Config files are highly customizable and can be modified so as to train different models. -`./example/run_trainer.py` contains a standard training script and is presented below, it reads the config file and +`./examples/run_trainer.py` contains a standard training script and is presented below, it reads the config file and realizes the training process. ```python import colossalai +from colossalai.core import global_context as gpc from colossalai.engine import Engine +from colossalai.logging import get_global_dist_logger from colossalai.trainer import Trainer -from colossalai.core import global_context as gpc -model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() -engine = Engine( - model=model, - criterion=criterion, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - schedule=schedule -) - -trainer = Trainer(engine=engine, - hooks_cfg=gpc.config.hooks, - verbose=True) -trainer.fit( - train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - max_epochs=gpc.config.num_epochs, - display_progress=True, - test_interval=5 -) + +def run_trainer(): + model, train_dataloader, test_dataloader, criterion, optimizer, schedule, lr_scheduler = colossalai.initialize() + logger = get_global_dist_logger() + schedule.data_sync = False + engine = Engine( + model=model, + criterion=criterion, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + schedule=schedule + ) + logger.info("engine is built", ranks=[0]) + + trainer = Trainer(engine=engine, + hooks_cfg=gpc.config.hooks, + verbose=True) + logger.info("trainer is built", ranks=[0]) + + logger.info("start training", ranks=[0]) + trainer.fit( + train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + max_epochs=gpc.config.num_epochs, + display_progress=True, + test_interval=2 + ) + + +if __name__ == '__main__': + run_trainer() ``` Alternatively, the `model` variable can be substituted with a self-defined model or a pre-defined model in our Model