11import os
22import sys
3- import torch
43import time
54
65import habana_frameworks .torch .core as htcore
7-
8- from torch .utils .data import DataLoader
9- from torchvision import transforms , datasets
6+ import torch
107import torch .nn as nn
118import torch .nn .functional as F
9+ from torch .utils .data import DataLoader
10+ from torchvision import datasets , transforms
11+
1212
1313class Net (nn .Module ):
1414 def __init__ (self ):
1515 super (Net , self ).__init__ ()
16- self .fc1 = nn .Linear (784 , 256 )
17- self .fc2 = nn .Linear (256 , 64 )
18- self .fc3 = nn .Linear (64 , 10 )
16+ self .fc1 = nn .Linear (784 , 256 )
17+ self .fc2 = nn .Linear (256 , 64 )
18+ self .fc3 = nn .Linear (64 , 10 )
19+
1920 def forward (self , x ):
20- out = x .view (- 1 ,28 * 28 )
21+ out = x .view (- 1 , 28 * 28 )
2122 out = F .relu (self .fc1 (out ))
2223 out = F .relu (self .fc2 (out ))
2324 out = self .fc3 (out )
2425 out = F .log_softmax (out , dim = 1 )
2526 return out
2627
28+
2729model = Net ()
2830model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth"
2931model_path = "/tmp/.neural_compressor/mnist-epoch_20.pth"
@@ -36,14 +38,12 @@ def forward(self, x):
3638model = model .to ("hpu" )
3739
3840
39- transform = transforms .Compose ([
40- transforms .ToTensor (),
41- transforms .Normalize ((0.1307 ,), (0.3081 ,))])
41+ transform = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))])
4242
43- data_path = ' ./data'
44- test_kwargs = {' batch_size' : 32 }
43+ data_path = " ./data"
44+ test_kwargs = {" batch_size" : 32 }
4545dataset1 = datasets .MNIST (data_path , train = False , download = True , transform = transform )
46- test_loader = torch .utils .data .DataLoader (dataset1 ,** test_kwargs )
46+ test_loader = torch .utils .data .DataLoader (dataset1 , ** test_kwargs )
4747
4848correct = 0
4949for batch_idx , (data , label ) in enumerate (test_loader ):
@@ -56,4 +56,4 @@ def forward(self, x):
5656
5757 correct += output .max (1 )[1 ].eq (label ).sum ()
5858
59- print (' Accuracy: {:.2f}%' .format (100. * correct / (len (test_loader ) * 32 )))
59+ print (" Accuracy: {:.2f}%" .format (100.0 * correct / (len (test_loader ) * 32 )))
0 commit comments