Skip to content

Commit 1f2079e

Browse files
authored
fix response content_type handling (#69)
1 parent bb2075d commit 1f2079e

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/sagemaker_pytorch_container/serving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def default_output_fn(prediction, accept):
8080
if type(prediction) == torch.Tensor:
8181
prediction = prediction.detach().cpu().numpy()
8282

83-
return worker.Response(encoders.encode(prediction, accept), accept)
83+
return worker.Response(response=encoders.encode(prediction, accept), mimetype=accept)
8484

8585

8686
def _user_module_transformer(user_module):

test/unit/test_serving.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def test_default_output_fn_json(tensor):
141141
output = default_output_fn(tensor, content_types.JSON)
142142

143143
assert json.dumps(tensor.cpu().numpy().tolist()) in output.get_data(as_text=True)
144-
assert content_types.JSON in output.content_type
144+
assert content_types.JSON == output.mimetype
145145

146146

147147
def test_default_output_fn_npy(tensor):
@@ -151,23 +151,23 @@ def test_default_output_fn_npy(tensor):
151151
np.save(stream, tensor.cpu().numpy())
152152

153153
assert stream.getvalue() in output.get_data(as_text=False)
154-
assert content_types.NPY in output.content_type
154+
assert content_types.NPY == output.mimetype
155155

156156

157157
def test_default_output_fn_csv_long():
158158
tensor = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
159159
output = default_output_fn(tensor, content_types.CSV)
160160

161161
assert '1,2,3\n4,5,6\n' in output.get_data(as_text=True)
162-
assert content_types.CSV in output.content_type
162+
assert content_types.CSV == output.mimetype
163163

164164

165165
def test_default_output_fn_csv_float():
166166
tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
167167
output = default_output_fn(tensor, content_types.CSV)
168168

169169
assert '1.0,2.0,3.0\n4.0,5.0,6.0\n' in output.get_data(as_text=True)
170-
assert content_types.CSV in output.content_type
170+
assert content_types.CSV == output.mimetype
171171

172172

173173
def test_default_output_fn_bad_accept():
@@ -182,7 +182,7 @@ def test_default_output_fn_gpu():
182182
output = default_output_fn(tensor_gpu, content_types.CSV)
183183

184184
assert '1,2,3\n4,5,6\n' in output.get_data(as_text=True)
185-
assert content_types.CSV in output.content_type
185+
assert content_types.CSV == output.mimetype
186186

187187

188188
@patch('sagemaker_containers.beta.framework.modules.import_module')

0 commit comments

Comments
 (0)