-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
My pytorch model will have different results when run multiple times with the same input after converting to tvm model. The cuda target fmt is ptx. If the target fmt chage back to cubin, then there is no problem.
Expected behavior
The result of multiple run using the same input should stay the same, the print of the sample code should be:
max abs diff is: 0
Actual behavior
max abs diff is: 7.818208
Environment
gpu: rtx 2070
nvcc: Cuda compilation tools, release 11.1, V11.1.74
Nvidia Driver Version: 470.86
system: Linux shukun-desktop 5.13.0-27-generic #29~20.04.1-Ubuntu SMP Fri Jan 14 00:32:30 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
TVM commit: 0c836b7
Steps to reproduce
-
change the target_fmt from cubin to ptx in python/tvm/contrib/nvcc.py
@tvm._ffi.register_func def tvm_callback_cuda_compile(code): """use nvcc to generate fatbin code for better optimization""" ptx = compile_cuda(code, target_format="fatbin") return ptx -
run this code
import math
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
class BatchActivateConvLayer(nn.Module):
def init(
self, channel_in, growth_rate, bottleneck_size_basic_factor, drop_ratio=0.8
):
super(BatchActivateConvLayer, self).__init__()
self.drop_ratio = drop_ratio
self.growth_rate = growth_rate
self.bottleneck_channel_out = bottleneck_size_basic_factor * growth_rate
self.mode_bn = torch.nn.BatchNorm3d(channel_in)
self.mode_conv = nn.Conv3d(
channel_in, self.bottleneck_channel_out, kernel_size=1, stride=1, bias=False
)
self.bn = torch.nn.BatchNorm3d(self.bottleneck_channel_out)
self.conv = nn.Conv3d(
self.bottleneck_channel_out,
growth_rate,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.drop_out = nn.Dropout3d(p=self.drop_ratio)
def forward(self, x):
current = x
current = self.mode_bn(current)
current = self.mode_conv(current)
current = self.bn(current)
current = self.conv(current)
if self.drop_ratio > 0:
current = self.drop_out(current)
return current
class DenseBlock(nn.Module):
def init(
self,
current_block_layers_number,
channel_in,
growth_rate,
bottleneck_size_basic_factor,
drop_ratio=0.8,
):
super(DenseBlock, self).__init__()
self.channel_in = channel_in
self.growth_rate = growth_rate
self.bottleneck_size_basic_factor = bottleneck_size_basic_factor
self.current_channel_in = self.channel_in
self.current_blcok_drop_ratio = drop_ratio
self.current_block_layer_number = current_block_layers_number
for i in range(self.current_block_layer_number):
current_block_layers = BatchActivateConvLayer(
self.current_channel_in,
self.growth_rate,
self.bottleneck_size_basic_factor,
self.current_blcok_drop_ratio,
)
setattr(self, "block_layer_" + str(i), current_block_layers)
self.current_channel_in += self.growth_rate
def get_current_block_channel_out(self):
return self.current_channel_in
def forward(self, x):
current = x
for i in range(self.current_block_layer_number):
current_clone = current.clone()
tmp = getattr(self, "block_layer_" + str(i))(current_clone)
current = torch.cat((current, tmp), 1)
return current
class DenseNet(nn.Module):
def init(
self,
growth_rate=24,
block_config=(2, 2),
compression=0.5,
num_init_features=24,
bottleneck_size_basic_factor=2,
drop_rate=0,
num_classes=2,
small_inputs=True,
rnn_units=512,
):
super(DenseNet, self).init()
self.features = nn.Conv3d(
1, num_init_features, kernel_size=3, stride=1, padding=1, bias=False
)
self.init_feature_channel_number = num_init_features
self.growth_rate = growth_rate
self.compression = compression
self.number_class = num_classes
self.block_config = block_config
self.rnn_units = rnn_units
self.drop_ratio = drop_rate
num_features = num_init_features
self.dense_trainsition_out_put_list = []
for i, num_layers in enumerate(self.block_config):
block = DenseBlock(
num_layers,
num_features,
self.growth_rate,
bottleneck_size_basic_factor,
drop_rate,
)
setattr(self, "block_" + str(i), block)
num_features = num_features + num_layers * growth_rate
self.dense_trainsition_out_put_list.append(num_features)
for name, param in self.named_parameters():
if "conv" in name and "weight" in name:
n = param.size(0) * param.size(2) * param.size(3) * param.size(4)
param.data.normal_().mul_(math.sqrt(2.0 / n))
elif "norm" in name and "weight" in name:
param.data.fill_(1)
elif "norm" in name and "bias" in name:
param.data.fill_(0)
def forward(self, x):
features = self.features(x[:, :1])
for i in range(len(self.block_config)):
features = getattr(self, "block_" + str(i))(features)
return features
def run_tvm_module(module, inpt):
module.set_input(0, inpt)
module.run()
tvm.cuda().sync()
res = module.get_output(0).numpy()
return res
if name == "main":
model = DenseNet()
model.eval()
model_jit = torch.jit.trace(model, example_inputs=torch.randn((4,2,64,64,64)))
print("finish gen trace model")
relay_model, params = relay.frontend.from_pytorch(
model_jit, [('input_0', (4,2,64,64,64))], default_dtype='float32')
target = tvm.target.cuda()
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(relay_model, target=target, params=params)
lib.export_library('./dense.so')
del lib
print("finish compile tvm model")
inpt = np.random.random((4,2,64,64,64))
lib = tvm.runtime.load_module('./dense.so')
module = graph_executor.GraphModule(lib["default"](tvm.cuda()))
res1 = run_tvm_module(module, inpt)
res2 = run_tvm_module(module, inpt)
diff = res1 - res2
print("max abs diff is:", np.max(np.abs(diff)))