-
Notifications
You must be signed in to change notification settings - Fork 348
Convert model inference test from pytest to unittest #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert model inference test from pytest to unittest #2644
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2644
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cba6d5d with merge base 7b0671c ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/test_ao_models.py
Outdated
for device in devices: | ||
for batch_size in batch_sizes: | ||
for is_training in training_modes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can change these to common_utils.parametrize
like:
Lines 186 to 187 in 0935f66
@common_utils.parametrize("device", COMMON_DEVICES) | |
@common_utils.parametrize("dtype", COMMON_DTYPES) |
Line 425 in 0935f66
common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love these decorators, thanks!
test/test_ao_models.py
Outdated
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
_BATCH_SIZES = [1, 4] | ||
_TRAINING_MODES = [True, False] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please feel free to inline these
test/test_ao_models.py
Outdated
# Define test parameters | ||
COMMON_DEVICES = common_utils.parametrize("device", _AVAILABLE_DEVICES) | ||
COMMON_DTYPES = common_utils.parametrize("dtype", [torch.float32, torch.bfloat16]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need to define these here, can just put these in the function itself
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks!
test/test_ao_models.py
Outdated
from torchao._models.llama.model import Transformer | ||
|
||
_AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
from torchao.testing import common_utils |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the import is not correct, please run the tests locally first to make sure the test runs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for my carelessness. I fixed it and tested it locally.
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 7 jobs have failed, first few of them are: Run TorchAO Experimental Tests, Code Analysis with Ruff, PR Label Check, Run Regression Tests, Run 1xH100 Tests Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge --ignore-current |
|
@pytorchbot merge -f "ci appears to be stuck" |
You are not authorized to force merges to this repository. Please use the regular |
we will need to use the merge button to merge in torchao |
Like this because there is no "merged" label, thanks. Also, could you check #2660 ? It replaces |
@jerryzh168 gentle ping for the reminder :) |
* convert ao inference test from pytest to unittest * refactor: `common_utils` for common parameters * incline common params * fix uncorrect library import
Summary:
Converts TorchAO model (
Transformer
) inference test from Pytest into Unittest. Annotations are added for better readability.cc @osbm