-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Torch] Various updates for PyTorch frontend #7348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
siju-samuel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
t-vi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
Since torch.sort returns both sorted values and indices while the Relay one doesn't, torch.sort conversion is not efficient, especially for multidimensional input (currently does both sort and argsort!). Suggestions for a better implementation are welcome.
I think gather is the direct equivalent of how you use take on 1d. Maybe it's worth fixing this in this PR.
|
@t-vi Thanks, I think I tried |
|
Thanks @siju-samuel @t-vi |
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
* add conversion for detr * remove explicit broadcast_to before batched matmul * use take with wrap mode * add test for transformer and negative indices * add sort and argsort * add logical_and * support masked_select * add gpu targets to masked_select test * improve sort conversion
This PR adds various updates to the PyTorch frontend, to fully support the recent transformer based object detection model from facebook, DETR https://github.com/facebookresearch/detr. After this PR, DETR runs on TVM and gets the correct results. TVM with auto scheduled GPU conv2d and batched matmul is 1.4x faster than PyTorch in my environment.
Also added various missing ops reported by hummingbird projects. @interesaaat
logical_and,masked_select,sortandargsort.masked_selectrequires VM to run.cumsumandmasked_fillop, to enable importing DETR https://github.com/facebookresearch/detr.cumsumis also requested by hummingbird.mode="wrap"in Relaytakeop. Without this, the result doesn't match with DETR.broadcast_tobeforebatch_matmul. Without it, memory usage blows up during constant evaluation of DETR. This is the same issue as [Topi] Allow batch_matmul to broadcast along batch dimension. #6616, please see the explanation there.torch.nn.Transformerto the tests.Since
torch.sortreturns both sorted values and indices while the Relay one doesn't,torch.sortconversion is not efficient,especially for multidimensional input (currently does both sort and argsort!). Suggestions for a better implementation are welcome.UPDATE: fixedplease review @siju-samuel @jwfromm @kevinthesun @t-vi