Skip to content
3 changes: 3 additions & 0 deletions RecommenderSystems/dlrm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def str_list(x):
parser.add_argument(
"--data_dir", type=str, default="/dataset/wdl_ofrecord/ofrecord"
)
parser.add_argument("--train_sub_folders", type=str_list,
default=','.join([f'day_{i}' for i in range(23)]))
parser.add_argument("--val_sub_folders", type=str_list, default="day_23")
parser.add_argument('--data_part_name_suffix_length', type=int, default=-1)
parser.add_argument('--eval_batchs', type=int, default=20)
parser.add_argument('--eval_batch_size', type=int, default=512)
Expand Down
47 changes: 29 additions & 18 deletions RecommenderSystems/dlrm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,13 @@


class DLRMValGraph(flow.nn.Graph):
def __init__(self, wdl_module, dataloader, use_fp16=False):
def __init__(self, wdl_module, use_fp16=False):
super(DLRMValGraph, self).__init__()
self.module = wdl_module
self.dataloader = dataloader
if use_fp16:
self.config.enable_amp(True)

def build(self):
(
labels,
dense_fields,
sparse_fields,
) = self.dataloader()
def build(self, labels, dense_fields, sparse_fields):
labels = labels.to("cuda").to(dtype=flow.float32)
dense_fields = dense_fields.to("cuda")
sparse_fields = sparse_fields.to("cuda")
Expand All @@ -24,28 +18,45 @@ def build(self):


class DLRMTrainGraph(flow.nn.Graph):
def __init__(self, wdl_module, dataloader, bce_loss, optimizer, lr_scheduler=None, grad_scaler=None, use_fp16=False):
def __init__(self, wdl_module, bce_loss, optimizer, lr_scheduler=None, grad_scaler=None, use_fp16=False):
super(DLRMTrainGraph, self).__init__()
self.module = wdl_module
self.dataloader = dataloader
self.bce_loss = bce_loss
self.add_optimizer(optimizer, lr_sch=lr_scheduler)
self.config.allow_fuse_model_update_ops(True)
self.config.allow_fuse_add_to_output(True)
self.config.allow_fuse_cast_scale(True)
if use_fp16:
self.config.enable_amp(True)
self.set_grad_scaler(grad_scaler)

def build(self):
(
labels,
dense_fields,
sparse_fields,
) = self.dataloader()
def build(self, labels, dense_fields, sparse_fields):
labels = labels.to("cuda").to(dtype=flow.float32)
dense_fields = dense_fields.to("cuda")
sparse_fields = sparse_fields.to("cuda")
dense_fields = dense_fields.to("cuda").to(dtype=flow.float32)
sparse_fields = sparse_fields.to("cuda").to(dtype=flow.int64)

logits = self.module(dense_fields, sparse_fields)
loss = self.bce_loss(logits, labels)
reduce_loss = flow.mean(loss)
reduce_loss.backward()
return reduce_loss


class DLRMValGraphWithDataloader(DLRMValGraph):
def __init__(self, wdl_module, dataloader, use_fp16=False):
super(DLRMValGraphWithDataloader, self).__init__(wdl_module, use_fp16)
self.dataloader = dataloader

def build(self):
labels, dense_fields, sparse_fields = self.dataloader()
return super(DLRMValGraphWithDataloader, self).build(labels, dense_fields, sparse_fields)


class DLRMTrainGraphWithDataloader(DLRMTrainGraph):
def __init__(self, wdl_module, dataloader, bce_loss, optimizer, lr_scheduler=None, grad_scaler=None, use_fp16=False):
super(DLRMTrainGraphWithDataloader, self).__init__(wdl_module, bce_loss, optimizer, lr_scheduler, grad_scaler, use_fp16)
self.dataloader = dataloader

def build(self):
labels, dense_fields, sparse_fields = self.dataloader()
return super(DLRMTrainGraphWithDataloader, self).build(labels, dense_fields, sparse_fields)
99 changes: 51 additions & 48 deletions RecommenderSystems/dlrm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from sklearn.metrics import roc_auc_score
from config import get_args
from models.data import make_data_loader
from utils.petastorm_dataloader import make_petastorm_dataloader
from models.dlrm import make_dlrm_module
from lr_scheduler import make_lr_scheduler
from oneflow.nn.parallel import DistributedDataParallel as DDP
from graph import DLRMValGraph, DLRMTrainGraph
from graph import DLRMTrainGraphWithDataloader, DLRMValGraph, DLRMTrainGraph, DLRMValGraphWithDataloader
import warnings
import utils.logger as log
from utils.auc_calculater import calculate_auc_from_dir
Expand Down Expand Up @@ -45,8 +46,12 @@ def __init__(self):
self.eval_interval = args.eval_interval
self.eval_batchs = args.eval_batchs
self.init_logger()
self.train_dataloader = make_data_loader(args, "train", self.is_global, self.dataset_format)
self.val_dataloader = make_data_loader(args, "val", self.is_global, self.dataset_format)
if self.dataset_format == 'petastorm':
self.train_dataloader = make_petastorm_dataloader(args, "train")
self.val_dataloader = make_petastorm_dataloader(args, "val")
else:
self.train_dataloader = make_data_loader(args, "train", self.is_global, self.dataset_format)
self.val_dataloader = make_data_loader(args, "val", self.is_global, self.dataset_format)
self.dlrm_module = make_dlrm_module(args)
if self.is_global:
self.dlrm_module.to_global(flow.env.all_device_placement("cuda"), flow.sbp.broadcast)
Expand All @@ -71,13 +76,20 @@ def __init__(self):

self.loss = flow.nn.BCELoss(reduction="none").to("cuda")
if self.execution_mode == "graph":
self.eval_graph = DLRMValGraph(
self.dlrm_module, self.val_dataloader, args.use_fp16
)
self.train_graph = DLRMTrainGraph(
self.dlrm_module, self.train_dataloader, self.loss, self.opt,
self.lr_scheduler, self.grad_scaler, args.use_fp16
)
if self.dataset_format == 'petastorm':
self.eval_graph = DLRMValGraph(self.dlrm_module, args.use_fp16)
self.train_graph = DLRMTrainGraph(
self.dlrm_module, self.loss, self.opt,
self.lr_scheduler, self.grad_scaler, args.use_fp16
)
else:
self.eval_graph = DLRMValGraphWithDataloader(
self.dlrm_module, self.val_dataloader, args.use_fp16
)
self.train_graph = DLRMTrainGraphWithDataloader(
self.dlrm_module, self.train_dataloader, self.loss, self.opt,
self.lr_scheduler, self.grad_scaler, args.use_fp16
)

def init_model(self):
args = self.args
Expand Down Expand Up @@ -178,10 +190,7 @@ def eval(self, save_model=False):
labels = []
preds = []
for _ in range(self.eval_batchs):
if self.execution_mode == "graph":
pred, label = self.eval_graph()
else:
pred, label = self.inference()
pred, label = self.inference()
label_ = label.numpy().astype(np.float32)
labels.append(label_)
preds.append(pred.numpy())
Expand All @@ -203,49 +212,43 @@ def eval(self, save_model=False):
self.save(sub_save_dir)
self.dlrm_module.train()

def inference(self):
(
labels,
dense_fields,
sparse_fields,
) = self.val_dataloader()
labels = labels.to("cuda")
dense_fields = dense_fields.to("cuda")
sparse_fields = sparse_fields.to("cuda")
with flow.no_grad():
predicts = self.dlrm_module(
dense_fields, sparse_fields
)
return predicts, labels

def forward(self):
(
labels,
dense_fields,
sparse_fields,
) = self.train_dataloader()
def load_data(self, dataloader):
labels, dense_fields, sparse_fields = dataloader()
labels = labels.to("cuda")
dense_fields = dense_fields.to("cuda")
sparse_fields = sparse_fields.to("cuda")
predicts = self.dlrm_module(dense_fields, sparse_fields)
loss = self.loss(predicts, labels)
reduce_loss = flow.mean(loss)
return reduce_loss
return labels, dense_fields, sparse_fields

def train_eager(self):
loss = self.forward()
loss.backward()
self.opt.step()
self.opt.zero_grad()
return loss
def inference(self):
if self.execution_mode == "graph":
if self.dataset_format == "petastorm":
labels, dense_fields, sparse_fields = self.load_data(self.val_dataloader)
return self.eval_graph(labels, dense_fields, sparse_fields)
else:
return self.eval_graph()
else:
labels, dense_fields, sparse_fields = self.load_data(self.val_dataloader)
with flow.no_grad():
predicts = self.dlrm_module(dense_fields, sparse_fields)
return predicts, labels

def train_one_step(self):
self.dlrm_module.train()
if self.execution_mode == "graph":
train_loss = self.train_graph()
if self.dataset_format == "petastorm":
labels, dense_fields, sparse_fields = self.load_data(self.train_dataloader)
return self.train_graph(labels, dense_fields, sparse_fields)
else:
return self.train_graph()
else:
train_loss = self.train_eager()
return train_loss
labels, dense_fields, sparse_fields = self.load_data(self.train_dataloader)
predicts = self.dlrm_module(dense_fields, sparse_fields)
loss = self.loss(predicts, labels)
loss = flow.mean(loss)
loss.backward()
self.opt.step()
self.opt.zero_grad()
return loss


def tol(tensor, pure_local=True):
Expand Down
10 changes: 7 additions & 3 deletions RecommenderSystems/dlrm/train_one_embedding_graph.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
rm core.*
DEVICE_NUM_PER_NODE=4
DEVICE_NUM_PER_NODE=1
MASTER_ADDR=127.0.0.1
NUM_NODES=1
NODE_RANK=0
Expand All @@ -11,18 +11,19 @@ ulimit -SHn 131072
eval_batch_size=32744
eval_batchs=$(( 3274330 / eval_batch_size ))
#eval_batchs=$(( 90243072 / eval_batch_size ))
export GRADIENT_SHUFFLE_USE_FP16=1

export LD_PRELOAD=/lib/x86_64-linux-gnu/libtcmalloc.so.4:
#export LD_PRELOAD=/lib/x86_64-linux-gnu/libtcmalloc.so.4:
export BLOCK_BASED_PATH="rocks"
echo "ll BLOCK_BASED_PATH"
ls -l $BLOCK_BASED_PATH
rm -rf rocks/0-1/*
rm -rf rocks/0-4/*
rm -rf rocks/1-4/*
rm -rf rocks/2-4/*
rm -rf rocks/3-4/*

#/usr/local/cuda-11.4/bin/nsys profile --stat=true \
#numactl --interleave=all \
python3 -m oneflow.distributed.launch \
--nproc_per_node $DEVICE_NUM_PER_NODE \
--nnodes $NUM_NODES \
Expand All @@ -47,7 +48,10 @@ python3 -m oneflow.distributed.launch \
--cache_memory_budget_mb '16384,163840' \
--value_memory_kind 'device,host' \
--persistent_path $BLOCK_BASED_PATH \
--use_fp16 \
--loss_scale_policy 'dynamic' \
--column_size_array '227605432,39060,17295,7424,20265,3,7122,1543,63,130229467,3067956,405282,10,2209,11938,155,4,976,14,292775614,40790948,187188510,590152,12973,108,36' \
--test_name 'train_one_embedding_graph_'$DEVICE_NUM_PER_NODE'gpu' | tee 'train_one_embedding_graph_'$DEVICE_NUM_PER_NODE'gpu'.log
#--dataset_format "petastorm" \
#--eval_save_dir '/NVME0/guoran/auc' \
#--eval_after_training \
124 changes: 124 additions & 0 deletions RecommenderSystems/dlrm/utils/petastorm_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import os
import oneflow as flow
import glob
import numpy as np
from petastorm.reader import make_batch_reader


def make_petastorm_dataloader(args, mode):
assert mode in ("train", "val")
return PetastormDataLoader(
args.data_dir,
args.train_sub_folders if mode=='train' else args.val_sub_folders,
num_dense_fields=args.num_dense_fields,
num_sparse_fields=args.num_sparse_fields,
batch_size=args.batch_size_per_proc if mode=='train' else args.eval_batch_size_per_proc,
mode=mode,
is_global=args.is_global,
)


class PetastormDataLoader():
def __init__(
self,
data_dir, subfolders,
num_dense_fields: int = 13,
num_sparse_fields: int = 26,
batch_size: int = 16,
mode: str = "train",
is_global = False,
):
assert mode in ("train", "val")

self.is_global = is_global
self.placement = flow.env.all_device_placement("cpu") if is_global else None
self.sbp = flow.sbp.split(0) if is_global else None

files = []
for folder in subfolders:
files += ['file://' + name for name in glob.glob(f'{data_dir}/{folder}/*.parquet')]
files.sort()

self.reader = make_batch_reader(files, workers_count=2,
shuffle_row_groups=(mode=='train'),
num_epochs=None if mode == 'train' else 1,
shard_seed=1234,
shard_count=flow.env.get_world_size(),
cur_shard=flow.env.get_rank(),
)
self.batch_size = batch_size
# self.total_batch_size = total_batch_size
fields = ['label']
fields += [f"I{i+1}" for i in range(num_dense_fields)]
self.I_end = len(fields)
fields += [f"C{i+1}" for i in range(num_sparse_fields)]
self.C_end = len(fields)
self.fields = fields
self.batch_generator = self.get_batches()

def __call__(self):
np_label, np_denses, np_sparses = next(self.batch_generator)
np_dense = np.stack(np_denses, axis=-1)
np_sparse = np.stack(np_sparses, axis=-1)
labels = flow.tensor(np_label.reshape(-1, 1), dtype=flow.float)
dense_fields = flow.tensor(np_dense, dtype=flow.float)
sparse_fields = flow.tensor(np_sparse, dtype=flow.int32)
if self.is_global:
labels = labels.to_global(placement=self.placement, sbp=self.sbp)
dense_fields = dense_fields.to_global(placement=self.placement, sbp=self.sbp)
sparse_fields = sparse_fields.to_global(placement=self.placement, sbp=self.sbp)
return labels, dense_fields, sparse_fields

def get_batches(self, batch_size=None):
if batch_size is None:
batch_size = self.batch_size
tail = None
for rg in self.reader:
rgdict = rg._asdict()
rglist = [rgdict[field] for field in self.fields]
pos = 0
if tail is not None:
pos = self.batch_size - len(tail[0])
tail = list([np.concatenate((tail[i], rglist[i][0:(batch_size - len(tail[i]))])) for i in range(self.C_end)])
if len(tail[0]) == batch_size:
label = tail[0]
dense = tail[1:self.I_end] #np.stack(tail[1:14], axis=-1)
sparse = tail[self.I_end:self.C_end] #np.stack(tail[14:40], axis=-1)
tail = None
yield label, dense, sparse
else:
pos = 0
continue
while (pos + batch_size) <= len(rglist[0]):
label = rglist[0][pos:pos+batch_size]
dense = [rglist[j][pos:pos+batch_size] for j in range(1, self.I_end)]
sparse = [rglist[j][pos:pos+batch_size] for j in range(self.I_end, self.C_end)]
pos += batch_size
yield label, dense, sparse
if pos != len(rglist[0]):
tail = [rglist[i][pos:] for i in range(self.C_end)]

# def __exit__(self):
# self.reader.stop()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.data_dir = '/minio/sdd/dataset/criteo1t/add_slot_size_snappy_true'
args.train_sub_folders = [f'day_{i}' for i in range(23)]
args.val_sub_folders = ['day_23']
args.num_dense_fields = 13
args.num_sparse_fields = 26
args.batch_size_per_proc = 16
args.eval_batch_size_per_proc = 32
args.is_global = True

# subfolders =
m = make_petastorm_dataloader(args, mode='train')
# m = PetastormDataLoader(data_dir, subfolders)
for i in range(10):
labels, dense_fields, sparse_fields = m()
print(i, labels.shape, dense_fields.shape, sparse_fields.shape)
print(i, labels.is_global)