Skip to content

Commit 9f24eb3

Browse files
AmintorDuskoatqy
authored andcommitted
Fix/frameworks pytorch deploy (aws#3168)
* Fix model deployment Fix reference to model_fn inference function and present consistent results. * Fix model_fn inference function * fix grammar issues * Trigger CI * Fix format * update to appease the grammar check * fix some existing errors
1 parent 1802b60 commit 9f24eb3

File tree

2 files changed

+66
-128
lines changed

2 files changed

+66
-128
lines changed

frameworks/pytorch/code/inference.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import sys
4+
import os
45

56
import torch
67
import torch.nn as nn
@@ -32,10 +33,12 @@ def forward(self, x):
3233

3334
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3435

35-
36-
def model_fn(model_dir):
37-
model = Net().to(device)
38-
model.eval()
36+
# defining model and loading weights to it.
37+
def model_fn(model_dir):
38+
model = Net()
39+
with open(os.path.join(model_dir, "model.pth"), "rb") as f:
40+
model.load_state_dict(torch.load(f))
41+
model.to(device).eval()
3942
return model
4043

4144

0 commit comments

Comments
 (0)