Skip to content

Commit 82cf197

Browse files
Fix prelu bug in pytorch frontend (#8192)
* Fix prelu bug in pytorch frontend * Fix lint error * fix lint error * Fix lint error * Try to fix lint error * Fix lint error Co-authored-by: huangyuheng <[email protected]>
1 parent c9db3d0 commit 82cf197

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,9 +754,13 @@ def relu(self, inputs, input_types):
754754
return _op.nn.relu(data)
755755

756756
def prelu(self, inputs, input_types):
757+
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU
757758
data = inputs[0]
758-
alpha = inputs[1]
759-
return _op.nn.prelu(data, alpha)
759+
dim = self.get_dims(data)
760+
ndims = len(dim)
761+
axis = 0 if ndims == 1 else 1
762+
alpha = _op.broadcast_to(inputs[1], (dim[axis]))
763+
return _op.nn.prelu(data, alpha, axis)
760764

761765
def leaky_relu(self, inputs, input_types):
762766
data = inputs[0]

tests/python/frontend/pytorch/test_forward.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,10 @@ def test_forward_prelu():
643643
input_shape = [1, 3, 10, 10]
644644
input_data = torch.rand(input_shape).float()
645645
verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data)
646+
# Test when input channel > 1 and num parameters = 1
647+
verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=input_data)
648+
# Test when input dims < 2
649+
verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=torch.randn(2))
646650

647651

648652
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)