Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 78 additions & 90 deletions scripts/swin_dataloader_compare_speed_with_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,6 @@
import numpy as np
import argparse

import oneflow as flow
from oneflow.utils.data import DataLoader

from flowvision import datasets, transforms
from flowvision.data import create_transform


ONEREC_URL = "https://oneflow-static.oss-cn-beijing.aliyuncs.com/ci-files/dataset/nanodataset.zip"
MD5 = "7f5cde8b5a6c411107517ac9b00f29db"

Expand Down Expand Up @@ -64,99 +57,84 @@ def ensure_dataset():
shutil.unpack_archive(absolute_file_path)
return str(pathlib.Path.cwd() / "nanodataset")


swin_dataloader_loop_count = 200


def print_rank_0(*args, **kwargs):
rank = int(os.getenv("RANK", "0"))
if rank == 0:
print(*args, **kwargs)


class SubsetRandomSampler(flow.utils.data.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""

def __init__(self, indices):
self.epoch = 0
self.indices = indices

def __iter__(self):
return (self.indices[i] for i in flow.randperm(len(self.indices)).tolist())

def __len__(self):
return len(self.indices)

def set_epoch(self, epoch):
self.epoch = epoch


def build_transform():
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=224,
is_training=True,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5-inc1",
re_prob=0.25,
re_mode="pixel",
re_count=1,
interpolation="bicubic",
)
return transform


# swin-transformer imagenet dataloader
def build_dataset(imagenet_path):
transform = build_transform()
prefix = "train"
root = os.path.join(imagenet_path, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset


def build_loader(imagenet_path, batch_size, num_wokers):
dataset_train = build_dataset(imagenet_path=imagenet_path)

indices = np.arange(
flow.env.get_rank(), len(dataset_train), flow.env.get_world_size()
)
sampler_train = SubsetRandomSampler(indices)

data_loader_train = DataLoader(
dataset_train,
sampler=sampler_train,
batch_size=batch_size,
num_workers=num_wokers,
drop_last=True,
)

return dataset_train, data_loader_train


def run(mode, imagenet_path, batch_size, num_wokers):
if mode == "torch":
import torch as flow
from torch.utils.data import DataLoader

from torchvision import datasets, transforms
from timm.data import create_transform

dataset_train, data_loader_train = build_loader(
args.imagenet_path, args.batch_size, args.num_workers
)
data_loader_train_iter = iter(data_loader_train)

# warm up
for idx in range(5):
samples, targets = data_loader_train_iter.__next__()

start_time = time.time()
for idx in range(swin_dataloader_loop_count):
samples, targets = data_loader_train_iter.__next__()
total_time = time.time() - start_time
return total_time
else:
import oneflow as flow
from oneflow.utils.data import DataLoader

from flowvision import datasets, transforms
from flowvision.data import create_transform

def build_transform():
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=224,
is_training=True,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5-inc1",
re_prob=0.25,
re_mode="pixel",
re_count=1,
interpolation="bicubic",
)
return transform

# swin-transformer imagenet dataloader
def build_dataset(imagenet_path):
transform = build_transform()
prefix = "train"
root = os.path.join(imagenet_path, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
return dataset

def build_loader(imagenet_path, batch_size, num_wokers):
dataset_train = build_dataset(imagenet_path=imagenet_path)

indices = np.arange(0, len(dataset_train), 1)

data_loader_train = DataLoader(
dataset_train,
shuffle=True,
batch_size=batch_size,
num_workers=num_wokers,
drop_last=True,
)

return dataset_train, data_loader_train

def get_time():
dataset_train, data_loader_train = build_loader(
args.imagenet_path, args.batch_size, args.num_workers
)
data_loader_train_iter = iter(data_loader_train)

# warm up
for idx in range(5):
samples, targets = data_loader_train_iter.__next__()

start_time = time.time()
for idx in range(swin_dataloader_loop_count):
samples, targets = data_loader_train_iter.__next__()
total_time = time.time() - start_time
return total_time

return get_time()


if __name__ == "__main__":
Expand All @@ -173,11 +151,21 @@ def run(mode, imagenet_path, batch_size, num_wokers):
pytorch_data_loader_total_time = run(
"torch", args.imagenet_path, args.batch_size, args.num_workers
)
oneflow_data_loader_time = oneflow_data_loader_total_time / swin_dataloader_loop_count
pytorch_data_loader_time = pytorch_data_loader_total_time / swin_dataloader_loop_count
oneflow_data_loader_time = (
oneflow_data_loader_total_time / swin_dataloader_loop_count
)
pytorch_data_loader_time = (
pytorch_data_loader_total_time / swin_dataloader_loop_count
)

relative_speed = oneflow_data_loader_time / pytorch_data_loader_time

print_rank_0(f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
print_rank_0(f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})")
print_rank_0(f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)")
print_rank_0(
f"OneFlow swin dataloader time: {oneflow_data_loader_time:.3f}s (= {oneflow_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
)
print_rank_0(
f"PyTorch swin dataloader time: {pytorch_data_loader_time:.3f}s (= {pytorch_data_loader_total_time:.3f}s / {swin_dataloader_loop_count}, num_workers={args.num_workers})"
)
print_rank_0(
f"Relative speed: {pytorch_data_loader_time / oneflow_data_loader_time:.3f} (= {pytorch_data_loader_time:.3f}s / {oneflow_data_loader_time:.3f}s)"
)