From 958bc1a3a7893bb2d1a21a3e55c15b159b854db2 Mon Sep 17 00:00:00 2001 From: Lauren Yu <6631887+laurenyu@users.noreply.github.com> Date: Tue, 16 Jun 2020 16:43:46 -0700 Subject: [PATCH] infra: add single-instance, multi-process Horovod test for local GPU --- test/integration/local/test_horovod.py | 39 +++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/test/integration/local/test_horovod.py b/test/integration/local/test_horovod.py index 506fb825..0572f98d 100644 --- a/test/integration/local/test_horovod.py +++ b/test/integration/local/test_horovod.py @@ -22,25 +22,38 @@ RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources') +@pytest.mark.skip_cpu +@pytest.mark.skip_generic +def test_distributed_training_horovod_gpu( + sagemaker_local_session, image_uri, tmpdir, framework_version +): + _test_distributed_training_horovod( + 1, 2, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local_gpu' + ) + + @pytest.mark.skip_gpu @pytest.mark.skip_generic -@pytest.mark.parametrize('instances, processes', [ - [1, 2], - (2, 1), - (2, 2), - (5, 2)]) -def test_distributed_training_horovod_basic(instances, - processes, - sagemaker_local_session, - image_uri, - tmpdir, - framework_version): +@pytest.mark.parametrize( + 'instances, processes', [(1, 2), (2, 1), (2, 2), (5, 2)] +) +def test_distributed_training_horovod_cpu( + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version +): + _test_distributed_training_horovod( + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local' + ) + + +def _test_distributed_training_horovod( + instances, processes, session, image_uri, tmpdir, framework_version, instance_type +): output_path = 'file://%s' % tmpdir estimator = TensorFlow( entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'), role='SageMakerRole', - train_instance_type='local', - sagemaker_session=sagemaker_local_session, + train_instance_type=instance_type, + sagemaker_session=session, train_instance_count=instances, image_name=image_uri, output_path=output_path,