Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Feb 13, 2022

Closes #10223

According to https://github.com/pytorch/pytorch/blob/7b8f73dd32a8a893dfb794433ce501e76c53bc89/torch/nn/modules/conv.py#L127-L129,

to correctly decide the output channels attribute from the weight shape, we need to multiply weight_shape[1] by groups.

cc @comaniac @junrushao1994

@junrushao
Copy link
Member

right. I noticed the same issue very recently. Thanks for fixing this!!

@junrushao junrushao merged commit de73b99 into apache:main Feb 15, 2022
ylc pushed a commit to ylc/tvm that referenced this pull request Feb 16, 2022
* [Torch] Fix conv2d transpose with group

* lint

* wrong issue number

* do not run test on cuda
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* [Torch] Fix conv2d transpose with group

* lint

* wrong issue number

* do not run test on cuda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Conv2DTranspose with groups not working correctly

2 participants