Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RecommenderSystems/dlrm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def str_list(x):
parser.add_argument('--eval_batch_size', type=int, default=512)
parser.add_argument("--eval_batch_size_per_proc", type=int, default=None)
parser.add_argument('--eval_interval', type=int, default=1000)
parser.add_argument("--eval_save_dir", type=str, default='', help="eval AUC offline if available")
parser.add_argument("--batch_size", type=int, default=16384)
parser.add_argument("--batch_size_per_proc", type=int, default=None)
parser.add_argument("--learning_rate", type=float, default=1e-3)
Expand Down
26 changes: 20 additions & 6 deletions RecommenderSystems/dlrm/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import pickle

sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
Expand All @@ -15,6 +16,7 @@
from graph import DLRMValGraph, DLRMTrainGraph
import warnings
import utils.logger as log
from utils.auc_calculater import calculate_auc_from_dir


class Trainer(object):
Expand Down Expand Up @@ -153,7 +155,6 @@ def __call__(self):
self.train()

def train(self):
losses = []
self.dlrm_module.train()
for _ in range(self.max_iter):
self.cur_iter += 1
Expand All @@ -168,21 +169,34 @@ def train(self):
if self.eval_after_training:
self.eval(True)

if self.args.eval_save_dir != '' and self.eval_after_training:
calculate_auc_from_dir(self.args.eval_save_dir)

def eval(self, save_model=False):
if self.eval_batchs <= 0:
return
self.dlrm_module.eval()
labels = np.array([[0]])
preds = np.array([[0]])
labels = []
preds = []
for _ in range(self.eval_batchs):
if self.execution_mode == "graph":
pred, label = self.eval_graph()
else:
pred, label = self.inference()
label_ = label.numpy().astype(np.float32)
labels = np.concatenate((labels, label_), axis=0)
preds = np.concatenate((preds, pred.numpy()), axis=0)
auc = roc_auc_score(labels[1:], preds[1:])
labels.append(label_)
preds.append(pred.numpy())
if self.args.eval_save_dir != '':
pf = os.path.join(self.args.eval_save_dir, f'eval_results_iter_{self.cur_iter}.pkl')
with open(pf, 'wb') as f:
obj = {'labels': labels, 'preds': preds, 'iter': self.cur_iter}
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
# auc = roc_auc_score(label_, pred.numpy())
auc = 'nc'
else:
labels = np.concatenate(labels, axis=0)
preds = np.concatenate(preds, axis=0)
auc = roc_auc_score(labels, preds)
self.meter_eval(auc)
if save_model:
sub_save_dir = f"iter_{self.cur_iter}_val_auc_{auc}"
Expand Down
32 changes: 32 additions & 0 deletions RecommenderSystems/dlrm/utils/auc_calculater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import sys
import time
import pickle
import numpy as np
from sklearn.metrics import roc_auc_score


def calculate_auc_from_file(pkl):
results = pickle.load(open(pkl, 'rb'))
labels = results['labels']
preds = results['preds']
iter = results['iter']
labels = np.concatenate(labels, axis=0)
preds = np.concatenate(preds, axis=0)
start = time.time()
auc = roc_auc_score(labels, preds)
duration = time.time() - start
print(f'Iter {iter} AUC: {auc:0.4f}, Num of Evaluation: {labels.shape[0]}, time:{duration:0.3f}')


def calculate_auc_from_dir(directory, startswith='eval_results_iter'):
print('calculate AUC from folder:', directory)
for file in os.listdir(directory):
filename = os.fsdecode(file)
if filename.startswith(startswith) and filename.endswith(".pkl"):
calculate_auc_from_file(os.path.join(directory, filename))


if __name__ == "__main__":
assert len(sys.argv) == 2, 'please input directory'
calculate_auc_from_dir(sys.argv[1])