Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 16 additions & 277 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3110,232 +3110,21 @@ def aten_ops_pad(
)


for op in (
torch.ops.aten.upsample_nearest1d,
torch.ops.aten.upsample_nearest2d,
torch.ops.aten.upsample_nearest3d,
torch.ops.aten.upsample_linear1d,
torch.ops.aten.upsample_bilinear2d,
torch.ops.aten.upsample_trilinear3d,
torch.ops.aten.upsample_bicubic2d,
):
for key in (
torch._C.DispatchKey.Autograd,
torch._C.DispatchKey.CompositeImplicitAutograd,
):
if key in op.default.py_kernels:
del op.default.py_kernels[key]
if key in op.vec.py_kernels:
del op.vec.py_kernels[key]


def upsample_compute_output_size(
input_size: torch.Size,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> Optional[Sequence[int]]:
spatial_dimensions = len(input_size) - 2

if output_size is None and scale_factors is None:
raise AssertionError(
"Must specify exactly one of output_size and scale_factors"
)

if output_size is not None:
torch._check(
scale_factors is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
torch._check(len(output_size) == spatial_dimensions)

if scale_factors is not None:
torch._check(
output_size is None,
lambda: "Must specify exactly one of output_size and scale_factors",
)
torch._check(len(scale_factors) == spatial_dimensions)
output_size = []
for i, s in enumerate(scale_factors):
output_size.append(int(input_size[i + 2] * s))

return output_size


@torch.ops.aten.upsample_nearest1d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest1d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest1d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest1d.default(input, osize)


@torch.ops.aten.upsample_nearest2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest2d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest2d.default(input, osize)


@torch.ops.aten.upsample_nearest3d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_nearest3d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_nearest3d.default(input, osize, *scale_factors)
return torch.ops.aten.upsample_nearest3d.default(input, osize)


@torch.ops.aten.upsample_linear1d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_linear1d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_linear1d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_linear1d.default(input, osize, align_corners)


@torch.ops.aten.upsample_bilinear2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_bilinear2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_bilinear2d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_bilinear2d.default(input, osize, align_corners)


@torch.ops.aten.upsample_trilinear3d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_trilinear3d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_trilinear3d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_trilinear3d.default(input, osize, align_corners)


@torch.ops.aten.upsample_bicubic2d.vec.py_impl(
torch._C.DispatchKey.CompositeImplicitAutograd
)
def upsample_bicubic2d_vec(
input: torch.Tensor,
output_size: Optional[Sequence[int]],
align_corners: bool,
scale_factors: Optional[Sequence[float]],
) -> torch.Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
if scale_factors is not None:
return torch.ops.aten.upsample_bicubic2d.default(
input, osize, align_corners, *scale_factors
)
return torch.ops.aten.upsample_bicubic2d.default(input, osize, align_corners)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest1d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_nearest1d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 3 else [args[2]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest2d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_nearest2d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[2], args[3]],
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_nearest3d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_nearest3d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_nearest3d(
def aten_ops_upsample_nearest(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3348,78 +3137,28 @@ def aten_ops_upsample_nearest3d(
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[2], args[3], args[4]],
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 2),
mode="nearest",
align_corners=False,
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_linear1d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
torch.ops.aten.upsample_linear1d.vec, supports_dynamic_shapes=True
)
def aten_ops_upsample_linear1d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 4 else [args[3]],
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_bilinear2d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_bilinear2d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_bilinear2d(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.upsample.upsample(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[3], args[4]],
mode="bilinear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_trilinear3d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_trilinear3d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
}
)
def aten_ops_upsample_trilinear3d(
def aten_ops_upsample_linear(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
Expand All @@ -3432,15 +3171,15 @@ def aten_ops_upsample_trilinear3d(
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 6 else [args[3], args[4], args[5]],
mode="trilinear",
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 3),
mode="linear",
align_corners=args[2],
)


@dynamo_tensorrt_converter(
torch.ops.aten.upsample_bicubic2d.default, supports_dynamic_shapes=True
torch.ops.aten.upsample_bicubic2d.vec, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
Expand All @@ -3460,8 +3199,8 @@ def aten_ops_upsample_bicubic2d(
SourceIR.ATEN,
name,
args[0],
size=args[1],
scale_factor=None if len(args) < 5 else [args[3], args[4]],
size=args_bounds_check(args, 1),
scale_factor=args_bounds_check(args, 3),
mode="bicubic",
align_corners=args[2],
)
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def upsample(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
size: Sequence[int],
size: Optional[Sequence[int]],
scale_factor: Optional[Sequence[float]],
mode: str,
align_corners: bool,
) -> TRTTensor:
layer = ctx.net.add_resize(input)

if scale_factor is not None and all(s is not None for s in scale_factor):
if scale_factor is not None:
layer.scales = [1.0, 1.0] + list(scale_factor)
else:
shape = list(input.shape)[:2] + list(size)
Expand Down
7 changes: 7 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,13 @@
}
torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = {
aten._softmax.default,
aten.upsample_nearest1d.vec,
aten.upsample_nearest2d.vec,
aten.upsample_nearest3d.vec,
aten.upsample_linear1d.vec,
aten.upsample_bilinear2d.vec,
aten.upsample_trilinear3d.vec,
aten.upsample_bicubic2d.vec,
}


Expand Down
Loading
Loading