Skip to content

✨[Feature] Proper bfloat16 support in elementwise converter #3458

@HolyWu

Description

@HolyWu
from __future__ import annotations

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

dtype = torch.bfloat16
device = torch.device("cuda", 0)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * 0.5


with torch.inference_mode():
    model = MyModule().eval().to(device, dtype)
    inputs = (torch.randn(1, 3, 224, 224, dtype=dtype, device=device),)
    exported_program = torch.export.export(model, inputs)

    trt_model = torch_tensorrt.dynamo.compile(
        exported_program,
        inputs,
        device=device,
        enabled_precisions={dtype},
        debug=True,
        min_block_size=1,
    )
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_nodes:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_num_users_is_0_nodes:Removed ops that [num_users=0] nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return (mul,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.mul.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.mul.Tensor + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 3, 224, 224)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
    return mul
WARNING:py.warnings:/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/utils.py:423: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).
  return torch.tensor(tensor).dtype

DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.BF16]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.bfloat16))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /mul (kind: aten.mul.Tensor, args: ('x <Node>', '0.5 <float>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.mul.Tensor: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.mul.Tensor
Traceback (most recent call last):
  File "/home/holywu/test.py", line 27, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 693, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 897, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 90, in convert_module
    interpreter_result = interpret_module_to_result(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 69, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 725, in run
    self._construct_trt_network_def()
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 393, in _construct_trt_network_def
    super().run()
  File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 784, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 891, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 1854, in aten_ops_mul
    return impl.elementwise.mul(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py", line 473, in mul
    return convert_binary_elementwise(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py", line 128, in convert_binary_elementwise
    rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/holywu/.local/lib/python3.12/site-packages/torch_tensorrt/_enums.py", line 426, in to
    raise TypeError("Unsupported numpy dtype")
TypeError: Unsupported numpy dtype

While executing %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x, 0.5), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]"):
         # File: /home/holywu/test.py:19 in forward, code: return x * 0.5
        mul: "bf16[1, 3, 224, 224][150528, 50176, 224, 1]" = torch.ops.aten.mul.Tensor(x, 0.5);  x = None
        return mul


Original traceback:
File "/home/holywu/test.py", line 19, in forward
    return x * 0.5

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions