|  | 
|  | 1 | +# Standard Library | 
|  | 2 | +# Third Party | 
|  | 3 | +import pytest | 
|  | 4 | +import tensorflow.compat.v2 as tf | 
|  | 5 | +from tests.zero_code_change.tf_utils import get_estimator, get_input_fns | 
|  | 6 | + | 
|  | 7 | +# First Party | 
|  | 8 | +import smdebug.tensorflow as smd | 
|  | 9 | +from smdebug.core.collection import CollectionKeys | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +@pytest.mark.parametrize("saveall", [True, False]) | 
|  | 13 | +def test_estimator(out_dir, tf_eager_mode, saveall): | 
|  | 14 | +    """ Works as intended. """ | 
|  | 15 | +    if tf_eager_mode is False: | 
|  | 16 | +        tf.compat.v1.disable_eager_execution() | 
|  | 17 | +        tf.compat.v1.reset_default_graph() | 
|  | 18 | +    tf.keras.backend.clear_session() | 
|  | 19 | +    mnist_classifier = get_estimator() | 
|  | 20 | +    train_input_fn, eval_input_fn = get_input_fns() | 
|  | 21 | + | 
|  | 22 | +    # Train and evaluate | 
|  | 23 | +    train_steps, eval_steps = 8, 2 | 
|  | 24 | +    hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) | 
|  | 25 | +    hook.set_mode(mode=smd.modes.TRAIN) | 
|  | 26 | +    mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook]) | 
|  | 27 | +    hook.set_mode(mode=smd.modes.EVAL) | 
|  | 28 | +    mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook]) | 
|  | 29 | + | 
|  | 30 | +    # Check that hook created and tensors saved | 
|  | 31 | +    trial = smd.create_trial(path=out_dir) | 
|  | 32 | +    tnames = trial.tensor_names() | 
|  | 33 | +    assert len(trial.steps()) > 0 | 
|  | 34 | +    if saveall: | 
|  | 35 | +        # Number of tensors in each collection | 
|  | 36 | +        # vanilla TF 2.2: all = 300, loss = 1, weights = 4, gradients = 0, biases = 18, optimizer variables = 0, metrics = 0, others = 277 | 
|  | 37 | +        # AWS-TF 2.2 : all = 300, loss = 1, weights = 4, gradients = 8, biases = 18, optimizer variables = 0, metrics = 0, others = 269 | 
|  | 38 | +        # AWS-TF 2.1 : all = 309, loss = 1, weights = 4, gradients = 8, biases = 18, optimizer variables = 0, metrics = 0, others = 278 | 
|  | 39 | +        assert len(tnames) >= 300 | 
|  | 40 | +        assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1 | 
|  | 41 | +        assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 4 | 
|  | 42 | +        assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 18 | 
|  | 43 | +        assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) >= 0 | 
|  | 44 | +        assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) >= 0 | 
|  | 45 | +    else: | 
|  | 46 | +        assert len(tnames) == 1 | 
|  | 47 | +        assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1 | 
|  | 48 | + | 
|  | 49 | + | 
|  | 50 | +@pytest.mark.parametrize("saveall", [True, False]) | 
|  | 51 | +def test_linear_classifier(out_dir, tf_eager_mode, saveall): | 
|  | 52 | +    """ Works as intended. """ | 
|  | 53 | +    if tf_eager_mode is False: | 
|  | 54 | +        tf.compat.v1.disable_eager_execution() | 
|  | 55 | +        tf.compat.v1.reset_default_graph() | 
|  | 56 | +    tf.keras.backend.clear_session() | 
|  | 57 | +    train_input_fn, eval_input_fn = get_input_fns() | 
|  | 58 | +    x_feature = tf.feature_column.numeric_column("x", shape=(28, 28)) | 
|  | 59 | +    estimator = tf.estimator.LinearClassifier( | 
|  | 60 | +        feature_columns=[x_feature], model_dir="/tmp/mnist_linear_classifier", n_classes=10 | 
|  | 61 | +    ) | 
|  | 62 | +    hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) | 
|  | 63 | +    estimator.train(input_fn=train_input_fn, steps=10, hooks=[hook]) | 
|  | 64 | + | 
|  | 65 | +    # Check that hook created and tensors saved | 
|  | 66 | +    trial = smd.create_trial(path=out_dir) | 
|  | 67 | +    tnames = trial.tensor_names() | 
|  | 68 | +    assert len(trial.steps()) > 0 | 
|  | 69 | +    if saveall: | 
|  | 70 | +        # Number of tensors in each collection | 
|  | 71 | +        # vanilla TF 2.2: all = 214, loss = 2, weights = 1, gradients = 0, biases = 12, optimizer variables = 0, metrics = 0, others = 199 | 
|  | 72 | +        # AWS-TF 2.2: all = 219, loss = 2, weights = 1, gradients = 2, biases = 12, optimizer variables = 5, metrics = 0, others = 197 | 
|  | 73 | +        # AWS-TF 2.1: all = 226, loss = 2, weights = 1, gradients = 2, biases = 12, optimizer variables = 5, metrics = 0, others = 204 | 
|  | 74 | +        assert len(tnames) >= 214 | 
|  | 75 | +        assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 2 | 
|  | 76 | +        assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 1 | 
|  | 77 | +        assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 12 | 
|  | 78 | +        assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) >= 0 | 
|  | 79 | +        assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) >= 0 | 
|  | 80 | +    else: | 
|  | 81 | +        assert len(tnames) == 2 | 
|  | 82 | +        assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 2 | 
0 commit comments