Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion pytorch/launch_nv.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,27 @@
'--no-bn-wd',
'--num-tasks', 4,
'--ami-name', DEFAULT_PYTORCH_SOURCE,
'--env-name', 'pytorch_c10d'
]



# Current best settings 4x p3 - 34.5 minutes
lr = 0.50 * 4 # 4 = num tasks
scale_224 = 224/256
scale_288 = 128/256
c10d = [
'--phases', [
{'ep':0, 'sz':128, 'bs':256, 'trndir':'-sz/160',
'lr':lr*2}
],
'--num-tasks', 4,
'--ami-name', DEFAULT_PYTORCH_SOURCE,
'--env-name', 'pytorch_c10d',
'--c10d',
# '--dist-url', 'file:///home/ubuntu/data/file.sync', # single instances are faster with file sync
# '--dist-url', 'tcp://localhost:6006', # single instances are faster with file sync
# '--dist-url', 'env://',
]

# Current benchmark for 8x p3's - with Aspect Ratio Validation - Works right now for under 30 min (25:45, memory-eight.06, 25:03 sun-eight, 24:31 release-eight.02)
Expand Down Expand Up @@ -356,7 +376,11 @@ def start_training(job, params):
default_params = [
'~/data/imagenet',
'--fp16',
'--logdir', job.logdir
'--logdir', job.logdir,
'--dist-url', f'tcp://{world_0_ip}:6006', # single instances are faster with file sync
# '--dist-url', 'file:///home/ubuntu/data/file.sync', # single instances are faster with file sync
# '--dist-url', 'tcp://localhost:6006', # single instances are faster with file sync
# '--dist-url', 'env://',
]
if world_size > 1: default_params.append('--distributed')
training_args = default_params + params
Expand Down
12 changes: 12 additions & 0 deletions pytorch/training/dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import os
from torch.nn.parallel import distributed_c10d

class DDP(DistributedDataParallel):
# Distributed wrapper. Supports asynchronous evaluation and model saving
Expand All @@ -15,7 +16,18 @@ def load_state_dict(self, *args, **kwargs):
def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)

class DDPC10d(distributed_c10d._DistributedDataParallelC10d):
# Distributed wrapper. Supports asynchronous evaluation and model saving
def forward(self, *args, **kwargs):
# DDP has a sync point on forward. No need to do this for eval. This allows us to have different batch sizes
if self.training: return super().forward(*args, **kwargs)
else: return self.module(*args, **kwargs)

def load_state_dict(self, *args, **kwargs):
self.module.load_state_dict(*args, **kwargs)

def state_dict(self, *args, **kwargs):
return self.module.state_dict(*args, **kwargs)


def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
Expand Down
41 changes: 34 additions & 7 deletions pytorch/training/train_imagenet_nv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.autograd import Variable
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
Expand Down Expand Up @@ -49,6 +48,7 @@ def get_parser():
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode. Default True')
parser.add_argument('--c10d', action='store_true', help='Run model c10d mode. Default True')
parser.add_argument('--loss-scale', type=float, default=1024,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True')
Expand All @@ -75,6 +75,14 @@ def get_parser():
cudnn.benchmark = True
args = get_parser().parse_args()

if args.c10d:
assert(args.distributed)
import torch.distributed.c10d as dist
# from torch.distributed import c10d
from torch.nn.parallel import distributed_c10d
elif args.distributed:
import torch.distributed as dist

# Only want master rank logging to tensorboard
is_master = (not args.distributed) or (dist_utils.env_rank()==0)
is_rank0 = args.local_rank == 0
Expand All @@ -91,15 +99,28 @@ def main():
if args.distributed:
log.console('Distributed initializing process group')
torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size())
dist_url = args.dist_url

if args.c10d and (('file:///' in dist_url) or ('tcp://' in dist_url)):
dist_url = args.dist_url+f'?rank={dist_utils.env_rank()}&world_size={dist_utils.env_world_size()}'
dist.init_process_group(backend=args.dist_backend, init_method=dist_url, world_size=dist_utils.env_world_size())
assert(dist_utils.env_world_size() == dist.get_world_size())
log.console("Distributed: success (%d/%d)"%(args.local_rank, dist.get_world_size()))

log.console('After distributed - test tensor creation works')
tt = torch.tensor([1]).float().cuda()

log.console("Loading model")
model = resnet.resnet50(bn0=args.init_bn0).cuda()
if args.fp16: model = network_to_half(model)
if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

if args.c10d:
# model = dist_utils.DDPC10d(model, device_ids=[args.local_rank], output_device=args.local_rank)
model = distributed_c10d._DistributedDataParallelC10d(model, device_ids=[args.local_rank], output_device=args.local_rank)
log.console('Sanity check to make sure tensor creation works')
tt = torch.tensor([1]).float().cuda() # dead lock when trying to create tensor
log.console(f'Woot able to reduce tensor: {sum_tensor(tt)}')
elif args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)
best_top5 = 93 # only save models over 93%. Otherwise it stops to save every time

global model_params, master_params
Expand Down Expand Up @@ -132,7 +153,7 @@ def main():

if args.distributed:
log.console('Syncing machines before training')
dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda())
sum_tensor(torch.tensor([1.0]).float().cuda())

log.event("~~epoch\thours\ttop1\ttop5\n")
for epoch in range(args.start_epoch, scheduler.tot_epochs):
Expand All @@ -151,7 +172,6 @@ def main():
phase = dm.get_phase(epoch)
if phase: save_checkpoint(epoch, model, best_top5, optimizer, filename=f'sz{phase["bs"]}_checkpoint.path.tar')


def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
net_meter = NetworkMeter()
timer = TimeMeter()
Expand Down Expand Up @@ -192,7 +212,7 @@ def train(trn_loader, model, criterion, optimizer, scheduler, epoch):
reduced_loss, batch_total = to_python_float(loss.data), to_python_float(input.size(0))
if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch
metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda()
batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
batch_total, reduced_loss, corr1, corr5 = sum_tensor(metrics).cpu().numpy()
reduced_loss = reduced_loss/dist_utils.env_world_size()
top1acc = to_python_float(corr1)*(100.0/batch_total)
top5acc = to_python_float(corr5)*(100.0/batch_total)
Expand Down Expand Up @@ -279,7 +299,7 @@ def distributed_predict(input, target, model, criterion):
corr1, corr5 = correct(output.data, target, topk=(1, 5))

metrics = torch.tensor([batch_size, valid_batches, loss, corr1, corr5]).float().cuda()
batch_total, valid_batches, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy()
batch_total, valid_batches, reduced_loss, corr1, corr5 = sum_tensor(metrics).cpu().numpy()
reduced_loss = reduced_loss/valid_batches

top1 = corr1*(100.0/batch_total)
Expand Down Expand Up @@ -389,6 +409,13 @@ def update_lr(self, epoch, batch_num, batch_tot):
tb.log("sizes/lr", lr)
tb.log("sizes/momentum", args.momentum)


def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size()
def sum_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
return rt

# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if isinstance(t, (float, int)): return t
Expand Down