-
Notifications
You must be signed in to change notification settings - Fork 369
feat: expose IResizeLayer in dynamo.conversion.impl #2488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,4 +28,5 @@ | |
topk, | ||
unary, | ||
unsqueeze, | ||
upsample, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Optional, Sequence | ||
|
||
import tensorrt as trt | ||
from torch.fx.node import Target | ||
from torch_tensorrt.dynamo._SourceIR import SourceIR | ||
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext | ||
from torch_tensorrt.fx.converters.converter_utils import set_layer_name | ||
from torch_tensorrt.fx.types import TRTTensor | ||
|
||
|
||
def upsample( | ||
ctx: ConversionContext, | ||
target: Target, | ||
source_ir: Optional[SourceIR], | ||
name: str, | ||
input: TRTTensor, | ||
out_shape: Optional[Sequence[int]], | ||
scale_factors: Optional[Sequence[float]], | ||
resize_mode: str, | ||
align_corners: bool, | ||
) -> TRTTensor: | ||
resize_layer = ctx.net.add_resize(input) | ||
# output size calculation | ||
# Pytorch assumes that one of out_shape/scale_factor is None | ||
# Pytorch assumes that dimensions match for out_shape/scale factor | ||
if out_shape is not None: | ||
resize_layer.shape = list(input.shape)[:2] + list(out_shape) | ||
elif scale_factors is not None: | ||
resize_layer.scales = [1.0, 1.0] + list(scale_factors) | ||
else: | ||
raise RuntimeError( | ||
f"At least one of out_shape and scale_factors should be specified." | ||
) | ||
|
||
# interpolate mode | ||
if resize_mode == "nearest" or None: | ||
resize_layer.resize_mode = trt.ResizeMode.NEAREST | ||
elif resize_mode == "bilinear": | ||
resize_layer.resize_mode = trt.ResizeMode.LINEAR | ||
if align_corners is None or not align_corners: | ||
raise RuntimeError( | ||
f"Interpolation works differently is align_corners is False for {resize_mode} mode in PyTorch and TensorRT." | ||
) | ||
else: | ||
raise RuntimeError( | ||
f"Interpolation mode is {resize_mode} which is not supported by TensorRT." | ||
) | ||
|
||
if resize_mode == "nearest": | ||
resize_layer.coordinate_transformation = ( | ||
trt.ResizeCoordinateTransformation.ASYMMETRIC | ||
) | ||
elif resize_mode == "bilinear": | ||
# align corners | ||
if align_corners is not None and align_corners: | ||
resize_layer.coordinate_transformation = ( | ||
trt.ResizeCoordinateTransformation.ALIGN_CORNERS | ||
) | ||
else: | ||
resize_layer.coordinate_transformation = ( | ||
trt.ResizeCoordinateTransformation.ASYMMETRIC | ||
) | ||
|
||
set_layer_name(resize_layer, target, name, source_ir) | ||
|
||
out = resize_layer.get_output(0) | ||
return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
from .harness import DispatchTestCase | ||
|
||
|
||
class TestUpsampleConverter(DispatchTestCase): | ||
# test case for nearest upsample, using output_size, scale_factors is disabled here | ||
@parameterized.expand( | ||
[ | ||
("upsample_nearest2d.vec_outshape_0", (2, 2), (4, 4)), | ||
("upsample_nearest2d.vec_outshape_1", (2, 2), (5, 5)), | ||
] | ||
) | ||
def test_upsample_nearest_output_shape(self, _, input_shape, output_shape): | ||
class Upsample(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.upsample_nearest2d.vec(input, output_shape, None) | ||
|
||
input = [torch.randn([1, 1] + list(input_shape))] | ||
self.run_test(Upsample(), input) | ||
|
||
# test case for nearest upsample, using scale_factors, output_size is disabled here | ||
@parameterized.expand( | ||
[ | ||
("upsample_nearest2d.vec_scale_0", (2, 2), (2, 2)), | ||
("upsample_nearest2d.vec_scale_1", (2, 2), (1.5, 1.5)), | ||
] | ||
) | ||
def test_upsample_nearest_scale_factor(self, _, input_shape, scale_factor): | ||
class Upsample(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.upsample_nearest2d.vec(input, None, scale_factor) | ||
|
||
input = [torch.randn([1, 1] + list(input_shape))] | ||
self.run_test(Upsample(), input) | ||
|
||
# test case for bilinear upsample, using output_size, scale_factors is disabled here | ||
@parameterized.expand( | ||
[ | ||
("upsample_bilinear2d.vec_outshape_0", (2, 2), (4, 4), True), | ||
("upsample_bilinear2d.vec_outshape_1", (2, 2), (5, 5), True), | ||
] | ||
) | ||
def test_upsample_bilinear_output_shape( | ||
self, _, input_shape, output_shape, align_corners | ||
): | ||
class Upsample(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.upsample_bilinear2d.vec( | ||
input, | ||
output_shape, | ||
align_corners, | ||
None, | ||
) | ||
|
||
input = [torch.randn([1, 1] + list(input_shape))] | ||
self.run_test(Upsample(), input) | ||
|
||
# test case for bilinear upsample, using scale_factors, output_shape is disabled here | ||
@parameterized.expand( | ||
[ | ||
("upsample_bilinear2d.vec_scale_0", (2, 2), (2, 2), True), | ||
("upsample_bilinear2d.vec_scale_1", (2, 2), (1.5, 1.5), True), | ||
] | ||
) | ||
def test_upsample_bilinear_scale_factors( | ||
self, _, input_shape, scale_factors, align_corners | ||
): | ||
class Upsample(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.ops.aten.upsample_bilinear2d.vec( | ||
input, | ||
None, | ||
align_corners, | ||
scale_factors, | ||
) | ||
|
||
input = [torch.randn([1, 1] + list(input_shape))] | ||
self.run_test(Upsample(), input) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.