1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13+ import boto3
1314import os
1415import logging
1516import platform
1819import sys
1920import tempfile
2021
21-
22+ from sagemaker import Session
2223from test .utils import local_mode
2324
2425logger = logging .getLogger (__name__ )
3637def pytest_addoption (parser ):
3738 parser .addoption ('--build-image' , '-D' , action = "store_true" )
3839 parser .addoption ('--build-base-image' , '-B' , action = "store_true" )
40+ parser .addoption ('--aws-id' )
41+ parser .addoption ('--instance-type' )
3942 parser .addoption ('--install-container-support' , '-C' , action = "store_true" )
4043 parser .addoption ('--docker-base-name' , default = 'pytorch' )
4144 parser .addoption ('--region' , default = 'us-west-2' )
@@ -46,40 +49,40 @@ def pytest_addoption(parser):
4649 parser .addoption ('--tag' , default = None )
4750
4851
49- @pytest .fixture (scope = 'session' )
50- def docker_base_name (request ):
52+ @pytest .fixture (scope = 'session' , name = 'docker_base_name' )
53+ def fixture_docker_base_name (request ):
5154 return request .config .getoption ('--docker-base-name' )
5255
5356
54- @pytest .fixture (scope = 'session' )
55- def region (request ):
57+ @pytest .fixture (scope = 'session' , name = 'region' )
58+ def fixture_region (request ):
5659 return request .config .getoption ('--region' )
5760
5861
59- @pytest .fixture (scope = 'session' )
60- def framework_version (request ):
62+ @pytest .fixture (scope = 'session' , name = 'framework_version' )
63+ def fixture_framework_version (request ):
6164 return request .config .getoption ('--framework-version' )
6265
6366
64- @pytest .fixture (scope = 'session' )
65- def py_version (request ):
67+ @pytest .fixture (scope = 'session' , name = 'py_version' )
68+ def fixture_py_version (request ):
6669 return 'py{}' .format (int (request .config .getoption ('--py-version' )))
6770
6871
69- @pytest .fixture (scope = 'session' )
70- def processor (request ):
72+ @pytest .fixture (scope = 'session' , name = 'processor' )
73+ def fixture_processor (request ):
7174 return request .config .getoption ('--processor' )
7275
7376
74- @pytest .fixture (scope = 'session' )
75- def tag (request , framework_version , processor , py_version ):
77+ @pytest .fixture (scope = 'session' , name = 'tag' )
78+ def fixture_tag (request , framework_version , processor , py_version ):
7679 provided_tag = request .config .getoption ('--tag' )
7780 default_tag = '{}-{}-{}' .format (framework_version , processor , py_version )
7881 return provided_tag if provided_tag else default_tag
7982
8083
81- @pytest .fixture (scope = 'session' )
82- def docker_image (docker_base_name , tag ):
84+ @pytest .fixture (scope = 'session' , name = 'docker_image' )
85+ def fixture_docker_image (docker_base_name , tag ):
8386 return '{}:{}' .format (docker_base_name , tag )
8487
8588
@@ -96,20 +99,20 @@ def opt_ml():
9699 shutil .rmtree (tmp , True )
97100
98101
99- @pytest .fixture (scope = 'session' )
100- def use_gpu (processor ):
102+ @pytest .fixture (scope = 'session' , name = 'use_gpu' )
103+ def fixture_use_gpu (processor ):
101104 return processor == 'gpu'
102105
103106
104- @pytest .fixture (scope = 'session' , autouse = True )
105- def install_container_support (request ):
107+ @pytest .fixture (scope = 'session' , name = 'install_container_support' , autouse = True )
108+ def fixture_install_container_support (request ):
106109 install = request .config .getoption ('--install-container-support' )
107110 if install :
108111 local_mode .install_container_support ()
109112
110113
111- @pytest .fixture (scope = 'session' , autouse = True )
112- def build_base_image (request , framework_version , py_version , processor , tag , docker_base_name ):
114+ @pytest .fixture (scope = 'session' , name = 'build_base_image' , autouse = True )
115+ def fixture_build_base_image (request , framework_version , py_version , processor , tag , docker_base_name ):
113116 build_base_image = request .config .getoption ('--build-base-image' )
114117 if build_base_image :
115118 return local_mode .build_base_image (framework_name = docker_base_name ,
@@ -122,8 +125,8 @@ def build_base_image(request, framework_version, py_version, processor, tag, doc
122125 return tag
123126
124127
125- @pytest .fixture (scope = 'session' , autouse = True )
126- def build_image (request , framework_version , py_version , processor , tag , docker_base_name ):
128+ @pytest .fixture (scope = 'session' , name = 'build_image' , autouse = True )
129+ def fixture_build_image (request , framework_version , py_version , processor , tag , docker_base_name ):
127130 build_image = request .config .getoption ('--build-image' )
128131 if build_image :
129132 return local_mode .build_image (framework_name = docker_base_name ,
@@ -134,3 +137,38 @@ def build_image(request, framework_version, py_version, processor, tag, docker_b
134137 cwd = os .path .join (dir_path , '..' ))
135138
136139 return tag
140+
141+
142+ @pytest .fixture (scope = 'session' , name = 'sagemaker_session' )
143+ def fixture_sagemaker_session (region ):
144+ return Session (boto_session = boto3 .Session (region_name = region ))
145+
146+
147+ @pytest .fixture (name = 'aws_id' , scope = 'session' )
148+ def fixture_aws_id (request ):
149+ return request .config .getoption ('--aws-id' )
150+
151+
152+ @pytest .fixture (name = 'instance_type' , scope = 'session' )
153+ def fixture_instance_type (request ):
154+ return request .config .getoption ('--instance-type' )
155+
156+
157+ @pytest .fixture (name = 'docker_registry' , scope = 'session' )
158+ def fixture_docker_registry (aws_id , region ):
159+ return '{}.dkr.ecr.{}.amazonaws.com' .format (aws_id , region )
160+
161+
162+ @pytest .fixture (name = 'ecr_image' , scope = 'session' )
163+ def fixture_ecr_image (docker_registry , docker_base_name , tag ):
164+ return '{}/{}:{}' .format (docker_registry , docker_base_name , tag )
165+
166+
167+ @pytest .fixture (scope = 'session' , name = 'dist_cpu_backend' , params = ['tcp' , 'gloo' ])
168+ def fixture_dist_cpu_backend (request ):
169+ return request .param
170+
171+
172+ @pytest .fixture (scope = 'session' , name = 'dist_gpu_backend' , params = ['gloo' ])
173+ def fixture_dist_gpu_backend (request ):
174+ return request .param
0 commit comments