|
23 | 23 | RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', '..', 'resources') |
24 | 24 |
|
25 | 25 |
|
| 26 | +@pytest.mark.skip_cpu |
| 27 | +@pytest.mark.skip_generic |
| 28 | +def test_distributed_training_horovod_gpu( |
| 29 | + sagemaker_local_session, image_uri, tmpdir, framework_version |
| 30 | +): |
| 31 | + _test_distributed_training_horovod( |
| 32 | + 1, 2, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local_gpu' |
| 33 | + ) |
| 34 | + |
| 35 | + |
26 | 36 | @pytest.mark.skip_gpu |
27 | 37 | @pytest.mark.skip_generic |
28 | | -@pytest.mark.parametrize('instances, processes', [ |
29 | | - [1, 2], |
30 | | - (2, 1), |
31 | | - (2, 2), |
32 | | - (5, 2)]) |
33 | | -def test_distributed_training_horovod_basic(instances, |
34 | | - processes, |
35 | | - sagemaker_local_session, |
36 | | - image_uri, |
37 | | - tmpdir, |
38 | | - framework_version): |
| 38 | +@pytest.mark.parametrize( |
| 39 | + 'instances, processes', [(1, 2), (2, 1), (2, 2), (5, 2)] |
| 40 | +) |
| 41 | +def test_distributed_training_horovod_cpu( |
| 42 | + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version |
| 43 | +): |
| 44 | + _test_distributed_training_horovod( |
| 45 | + instances, processes, sagemaker_local_session, image_uri, tmpdir, framework_version, 'local' |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def _test_distributed_training_horovod( |
| 50 | + instances, processes, session, image_uri, tmpdir, framework_version, instance_type |
| 51 | +): |
39 | 52 | output_path = 'file://%s' % tmpdir |
40 | 53 | estimator = TensorFlow( |
41 | 54 | entry_point=os.path.join(RESOURCE_PATH, 'hvdbasic', 'train_hvd_basic.py'), |
42 | 55 | role='SageMakerRole', |
43 | | - train_instance_type='local', |
44 | | - sagemaker_session=sagemaker_local_session, |
| 56 | + train_instance_type=instance_type, |
| 57 | + sagemaker_session=session, |
45 | 58 | train_instance_count=instances, |
46 | 59 | image_name=image_uri, |
47 | 60 | output_path=output_path, |
|
0 commit comments