Skip to content

Commit c3f22ce

Browse files
authored
Update mnist-distributed.py dataset
1 parent 9815556 commit c3f22ce

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

training/distributed_training/tensorflow/multi_worker_mirrored_strategy/mnist-distributed.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,15 @@ def model(x_train, y_train, x_test, y_test, strategy):
4444

4545
def _load_training_data(base_dir):
4646
"""Load MNIST training data"""
47-
x_train = np.load(os.path.join(base_dir, "train_data.npy"))
48-
y_train = np.load(os.path.join(base_dir, "train_labels.npy"))
47+
x_train = np.load(os.path.join(base_dir, "input_train.npy"))
48+
y_train = np.load(os.path.join(base_dir, "input_train_labels.npy"))
4949
return x_train, y_train
5050

5151

5252
def _load_testing_data(base_dir):
5353
"""Load MNIST testing data"""
54-
x_test = np.load(os.path.join(base_dir, "eval_data.npy"))
55-
y_test = np.load(os.path.join(base_dir, "eval_labels.npy"))
54+
x_test = np.load(os.path.join(base_dir, "input_test.npy"))
55+
y_test = np.load(os.path.join(base_dir, "input_test_labels.npy"))
5656
return x_test, y_test
5757

5858

0 commit comments

Comments
 (0)