|
| 1 | +from typing import List, Optional, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import tensorrt as trt |
| 5 | +from torch.fx.node import Target |
| 6 | +from torch_tensorrt.dynamo.conversion import impl |
| 7 | +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext |
| 8 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 9 | + SourceIR, |
| 10 | + cast_trt_tensor, |
| 11 | + get_trt_tensor, |
| 12 | +) |
| 13 | +from torch_tensorrt.fx.types import TRTTensor |
| 14 | + |
| 15 | + |
| 16 | +def full( |
| 17 | + ctx: ConversionContext, |
| 18 | + target: Union[Target, str], |
| 19 | + source_ir: Optional[SourceIR], |
| 20 | + name: str, |
| 21 | + shape: Union[List[int], TRTTensor], |
| 22 | + fill_value: Union[int, float, bool], |
| 23 | +) -> TRTTensor: |
| 24 | + # in static shape scenario, shape is a list of int |
| 25 | + if isinstance(shape, List): |
| 26 | + return np.full(shape, fill_value) |
| 27 | + |
| 28 | + # in dynamic shape scenario, shape is a shap tensor |
| 29 | + # use IFillLayer to fill the shape tensor with LINSPACE value |
| 30 | + layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype) |
| 31 | + layer.set_input(0, shape) |
| 32 | + layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0)) |
| 33 | + delta = get_trt_tensor(ctx, 1, name + "_delta") |
| 34 | + input = [] |
| 35 | + for _ in range(shape.shape[0]): |
| 36 | + input.append(delta) |
| 37 | + delta = impl.cat.cat(ctx, target, source_ir, name + "_cat", input, dim=0) |
| 38 | + layer.set_input(2, delta) |
| 39 | + output = layer.get_output(0) |
| 40 | + |
| 41 | + # fill the output tensor with the actual fill_value |
| 42 | + output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0) |
| 43 | + if isinstance(fill_value, (int, float)): |
| 44 | + if isinstance(fill_value, float): |
| 45 | + output = cast_trt_tensor( |
| 46 | + ctx, output, trt.float32, name + "_casted", target, source_ir |
| 47 | + ) |
| 48 | + output = impl.elementwise.add( |
| 49 | + ctx, target, source_ir, name + "_add", output, fill_value |
| 50 | + ) |
| 51 | + |
| 52 | + if isinstance(fill_value, bool): |
| 53 | + output = cast_trt_tensor( |
| 54 | + ctx, output, trt.bool, name + "_casted", target, source_ir |
| 55 | + ) |
| 56 | + output = impl.elementwise.logical_or( |
| 57 | + ctx, target, source_ir, name + "_add", output, fill_value |
| 58 | + ) |
| 59 | + |
| 60 | + return output |
0 commit comments