Skip to content

Conversation

@zeshengzong
Copy link
Contributor

@zeshengzong zeshengzong commented Feb 27, 2025

Fixes #147801

Changes

  • Change hardswish gradient compute condition as torch.nn.functional.hardswish
  • Enable cuda for test test_hardswish_grad_corner
  • Add test case for value=-3

Test Result

pytest test/test_nn.py -k test_hardswish
pytest test/test_unary_ufuncs.py -k test_hardswish
pytest test/inductor/test_torchinductor.py -k test_hardswish

image
image
image

cc @ezyang @albanD @gqchen @pearu @nikitaved @soulitzer @Varal7 @xmfan @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148049

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 98ee89a with merge base 2bcc3ac (image):

NEW FAILURE - The following job has failed:

  • linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test (gh)
    RuntimeError: cuDNN version incompatibility: PyTorch was compiled against (9, 8, 0) but found runtime version (9, 7, 1). PyTorch already comes bundled with cuDNN. One option to resolving this error is to ensure PyTorch can find the bundled cuDNN. one possibility is that there is a conflicting cuDNN in LD_LIBRARY_PATH.

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: cpu CPU specific problem (e.g., perf, algorithm) release notes: nn release notes category labels Feb 27, 2025
@zeshengzong zeshengzong marked this pull request as ready for review February 27, 2025 03:11
@nikitaved nikitaved added the module: autograd Related to torch.autograd, and the autograd engine in general label Feb 27, 2025
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 27, 2025
@zou3519 zou3519 requested a review from soulitzer February 27, 2025 16:17
soulitzer
soulitzer previously approved these changes Feb 27, 2025
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@soulitzer
Copy link
Contributor

Failures looks legit:

FAILED [0.7651s] test_jit_fuser_te.py::TestTEFuserStatic::test_hardswish_fwd_bwd - AssertionError: Tensor-likes are not close!

Mismatched elements: 1 / 20 (5.0%)
Greatest absolute difference: 0.06496521830558777 at index (7,) (up to 1e-05 allowed)
Greatest relative difference: 1.0 at index (7,) (up to 1.3e-06 allowed)

To execute this test, run the following from the base repo dir:
    PYTORCH_TEST_WITH_DYNAMO=1 python test/test_jit_fuser_te.py TestTEFuserStatic.test_hardswish_fwd_bwd

@zeshengzong
Copy link
Contributor Author

@soulitzer please check changes when available, thanks!

def backward(grad_output):
m = (self > 3.).type_as(result)
m = torch.where((self >= -3.) & (self <= 3.), self / 3. + .5, m)
m = torch.where((self > -3.) & (self < 3.), self / 3. + .5, m)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to change line 939 to self >= 3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed, thanks!

soulitzer
soulitzer previously approved these changes Mar 4, 2025
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 4, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@soulitzer
Copy link
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-focal-py3.13-clang10 / test (dynamo_wrapped, 1, 3, lf.linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-13)

Details for Dev Infra team Raised by workflow job

inputs.requires_grad = True
self.assertTrue(gradcheck(F.hardswish, (inputs,)))

@onlyCPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we're failing on mps on some dtypes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor test case make it works on cuda and cpu, is there should have a @onlyCUDAAndCPU annotation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see one, but I wonder if onlyNativeDeviceTypes works

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to @onlyNativeDeviceTypes

@zeshengzong zeshengzong force-pushed the fix/aten/hardswish_backward branch from 89c2cb0 to 98ee89a Compare March 11, 2025 11:26
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 11, 2025
0.0,
torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
torch.where(self < 3, grad_output * ((self / 3) + 0.5), grad_output),
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @soulitzer, I change here, does the failing test can run locally? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm there's some amount of setup which I cannot recall on the top of my head, but maybe not too hard to test a small example instead (I've also added some labels which would hopefully test on CI)

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2025

To add the ciflow label ciflow/inductor please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2025

To add the ciflow label ciflow/inductor-perf-compare please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 11, 2025

To add the ciflow label ciflow/inductor-periodic please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@zeshengzong
Copy link
Contributor Author

Hi @soulitzer, shall we try to merge again, thanks!

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Mar 14, 2025

This PR needs to be approved by an authorized maintainer before merge.

@soulitzer
Copy link
Contributor

Hi @soulitzer, shall we try to merge again, thanks!

thanks for the quick fix

@soulitzer
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test

Details for Dev Infra team Raised by workflow job

@soulitzer
Copy link
Contributor

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge), linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: autograd Related to torch.autograd, and the autograd engine in general module: cpu CPU specific problem (e.g., perf, algorithm) open source release notes: nn release notes category Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incorrect Gradients at Boundary Points for torch.nn.functional.hardswish

6 participants