diff --git a/test/sagemaker_tests/pytorch/training/integration/sagemaker/test_smdataparallel.py b/test/sagemaker_tests/pytorch/training/integration/sagemaker/test_smdataparallel.py index 4182b4cca8d1..42e9f6e7dc6a 100644 --- a/test/sagemaker_tests/pytorch/training/integration/sagemaker/test_smdataparallel.py +++ b/test/sagemaker_tests/pytorch/training/integration/sagemaker/test_smdataparallel.py @@ -140,14 +140,14 @@ def test_smdataparallel_mnist(n_virginia_sagemaker_session, framework_version, n @pytest.mark.integration("smdataparallel_smmodelparallel") @pytest.mark.model("mnist") @pytest.mark.parametrize('instance_types', ["ml.p3.16xlarge"]) -def test_smmodelparallel_smdataparallel_mnist(instance_types, n_virginia_ecr_image, py_version, n_virginia_sagemaker_session, tmpdir): +def test_smmodelparallel_smdataparallel_mnist(instance_types, ecr_image, py_version, sagemaker_session, tmpdir): """ Tests SM Distributed DataParallel and ModelParallel single-node via script mode This test has been added for SM DataParallelism and ModelParallelism tests for re:invent. TODO: Consider reworking these tests after re:Invent releases are done """ - can_run_modelparallel = can_run_smmodelparallel(n_virginia_ecr_image) - can_run_dataparallel = can_run_smdataparallel(n_virginia_ecr_image) + can_run_modelparallel = can_run_smmodelparallel(ecr_image) + can_run_dataparallel = can_run_smdataparallel(ecr_image) if can_run_dataparallel and can_run_modelparallel: entry_point = 'smdataparallel_smmodelparallel_mnist_script_mode.sh' elif can_run_dataparallel: @@ -160,12 +160,12 @@ def test_smmodelparallel_smdataparallel_mnist(instance_types, n_virginia_ecr_ima with timeout(minutes=DEFAULT_TIMEOUT): pytorch = PyTorch(entry_point=entry_point, role='SageMakerRole', - image_uri=n_virginia_ecr_image, + image_uri=ecr_image, source_dir=mnist_path, instance_count=1, instance_type=instance_types, - sagemaker_session=n_virginia_sagemaker_session) + sagemaker_session=sagemaker_session) - pytorch = _disable_sm_profiler(n_virginia_sagemaker_session.boto_region_name, pytorch) + pytorch = _disable_sm_profiler(sagemaker_session.boto_region_name, pytorch) pytorch.fit() diff --git a/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_mnist.py b/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_mnist.py index fe371c9eeea7..c4bf7aed779a 100755 --- a/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_mnist.py +++ b/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_mnist.py @@ -41,9 +41,9 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version): sagemaker_session=sagemaker_session, image_uri=ecr_image, framework_version=framework_version) - + estimator = _disable_sm_profiler(sagemaker_session.boto_region_name, estimator) - + inputs = estimator.sagemaker_session.upload_data( path=os.path.join(resource_path, 'mnist', 'data'), key_prefix='scriptmode/mnist') @@ -191,15 +191,15 @@ def test_smdebug(sagemaker_session, ecr_image, instance_type, framework_version) @pytest.mark.model("mnist") @pytest.mark.skip_cpu @pytest.mark.skip_py2_containers -def test_smdataparallel_smmodelparallel_mnist(n_virginia_sagemaker_session, instance_type, n_virginia_ecr_image, tmpdir, framework_version): +def test_smdataparallel_smmodelparallel_mnist(sagemaker_session, instance_type, ecr_image, tmpdir, framework_version): """ Tests SM Distributed DataParallel and ModelParallel single-node via script mode This test has been added for SM DataParallelism and ModelParallelism tests for re:invent. TODO: Consider reworking these tests after re:Invent releases are done """ instance_type = "ml.p3.16xlarge" - _, image_framework_version = get_framework_and_version_from_tag(n_virginia_ecr_image) - image_cuda_version = get_cuda_version_from_tag(n_virginia_ecr_image) + _, image_framework_version = get_framework_and_version_from_tag(ecr_image) + image_cuda_version = get_cuda_version_from_tag(ecr_image) if Version(image_framework_version) < Version("2.3.1") or image_cuda_version != "cu110": pytest.skip("SMD Model and Data Parallelism are only supported on CUDA 11, and on TensorFlow 2.3.1 or higher") smmodelparallel_path = os.path.join(RESOURCE_PATH, 'smmodelparallel') @@ -209,12 +209,12 @@ def test_smdataparallel_smmodelparallel_mnist(n_virginia_sagemaker_session, inst instance_count=1, instance_type=instance_type, source_dir=smmodelparallel_path, - sagemaker_session=n_virginia_sagemaker_session, - image_uri=n_virginia_ecr_image, + sagemaker_session=sagemaker_session, + image_uri=ecr_image, framework_version=framework_version, py_version='py3') - - estimator = _disable_sm_profiler(n_virginia_sagemaker_session.boto_region_name, estimator) + + estimator = _disable_sm_profiler(sagemaker_session.boto_region_name, estimator) estimator.fit()