Skip to content

Commit 9815556

Browse files
authored
Update mnist.py dataset
1 parent 08cec8f commit 9815556

File tree

1 file changed

+4
-4
lines changed
  • training/distributed_training/tensorflow/multi_worker_mirrored_strategy

1 file changed

+4
-4
lines changed

training/distributed_training/tensorflow/multi_worker_mirrored_strategy/mnist.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ def model(x_train, y_train, x_test, y_test):
3939

4040
def _load_training_data(base_dir):
4141
"""Load MNIST training data"""
42-
x_train = np.load(os.path.join(base_dir, "train_data.npy"))
43-
y_train = np.load(os.path.join(base_dir, "train_labels.npy"))
42+
x_train = np.load(os.path.join(base_dir, "input_train.npy"))
43+
y_train = np.load(os.path.join(base_dir, "input_train_labels.npy"))
4444
return x_train, y_train
4545

4646

4747
def _load_testing_data(base_dir):
4848
"""Load MNIST testing data"""
49-
x_test = np.load(os.path.join(base_dir, "eval_data.npy"))
50-
y_test = np.load(os.path.join(base_dir, "eval_labels.npy"))
49+
x_test = np.load(os.path.join(base_dir, "input_test.npy"))
50+
y_test = np.load(os.path.join(base_dir, "input_test_labels.npy"))
5151
return x_test, y_test
5252

5353

0 commit comments

Comments
 (0)