|
1 | 1 | import logging
|
| 2 | +import operator |
2 | 3 | from typing import Dict, Sequence, Tuple, Union
|
3 | 4 | import torch
|
4 | 5 | import tensorrt as trt
|
|
9 | 10 | from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
|
10 | 11 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
11 | 12 | from torch_tensorrt.dynamo.conversion import impl
|
12 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor |
13 |
| -from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor |
| 13 | +from torch_tensorrt.dynamo.conversion.converter_utils import ( |
| 14 | + cast_trt_tensor, |
| 15 | + cast_int_int_div_trt_tensor, |
| 16 | +) |
14 | 17 |
|
15 | 18 | _LOGGER: logging.Logger = logging.getLogger(__name__)
|
16 | 19 |
|
@@ -70,13 +73,13 @@ def aten_ops_div(
|
70 | 73 | kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
|
71 | 74 | ):
|
72 | 75 | kwargs_new["input"] = cast_trt_tensor(
|
73 |
| - network, kwargs_new["input"], trt.float32, name |
| 76 | + network, kwargs_new["input"], trt.float32, name, target |
74 | 77 | )
|
75 | 78 | elif isinstance(args[1], TRTTensor) and (
|
76 | 79 | kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
|
77 | 80 | ):
|
78 | 81 | kwargs_new["other"] = cast_trt_tensor(
|
79 |
| - network, kwargs_new["other"], trt.float32, name |
| 82 | + network, kwargs_new["other"], trt.float32, name, target |
80 | 83 | )
|
81 | 84 | rounding_mode = kwargs.get("rounding_mode")
|
82 | 85 | if rounding_mode is None:
|
@@ -377,3 +380,77 @@ def aten_ops_permute(
|
377 | 380 | args[0],
|
378 | 381 | args[1],
|
379 | 382 | )
|
| 383 | + |
| 384 | + |
| 385 | +def to_copy_dtype_validator(to_copy_node: Node): |
| 386 | + allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} |
| 387 | + |
| 388 | + # Validate input node has convertible kwargs |
| 389 | + if "dtype" in to_copy_node.kwargs: |
| 390 | + if to_copy_node.kwargs["dtype"] in allowed_casts: |
| 391 | + return True |
| 392 | + else: |
| 393 | + _LOGGER.debug( |
| 394 | + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" |
| 395 | + ) |
| 396 | + return False |
| 397 | + else: |
| 398 | + _LOGGER.debug( |
| 399 | + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" |
| 400 | + ) |
| 401 | + return False |
| 402 | + |
| 403 | + |
| 404 | +@dynamo_tensorrt_converter( |
| 405 | + torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator |
| 406 | +) |
| 407 | +def aten_ops_to_copy_dtype( |
| 408 | + network: TRTNetwork, |
| 409 | + target: Target, |
| 410 | + args: Tuple[Argument, ...], |
| 411 | + kwargs: Dict[str, Argument], |
| 412 | + name: str, |
| 413 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 414 | + return impl.cast.to_copy( |
| 415 | + network, |
| 416 | + target, |
| 417 | + SourceIR.ATEN, |
| 418 | + name, |
| 419 | + args[0], |
| 420 | + kwargs["dtype"], |
| 421 | + ) |
| 422 | + |
| 423 | + |
| 424 | +@dynamo_tensorrt_converter(operator.getitem) |
| 425 | +def operator_getitem( |
| 426 | + network: TRTNetwork, |
| 427 | + target: Target, |
| 428 | + args: Tuple[Argument, ...], |
| 429 | + kwargs: Dict[str, Argument], |
| 430 | + name: str, |
| 431 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 432 | + return impl.evaluators.getitem( |
| 433 | + network, |
| 434 | + target, |
| 435 | + SourceIR.ATEN, |
| 436 | + name, |
| 437 | + args[0], |
| 438 | + args[1], |
| 439 | + ) |
| 440 | + |
| 441 | + |
| 442 | +@dynamo_tensorrt_converter(torch.ops.aten.clone.default) |
| 443 | +def aten_ops_clone( |
| 444 | + network: TRTNetwork, |
| 445 | + target: Target, |
| 446 | + args: Tuple[Argument, ...], |
| 447 | + kwargs: Dict[str, Argument], |
| 448 | + name: str, |
| 449 | +) -> Union[TRTTensor, Sequence[TRTTensor]]: |
| 450 | + return impl.evaluators.clone( |
| 451 | + network, |
| 452 | + target, |
| 453 | + SourceIR.ATEN, |
| 454 | + name, |
| 455 | + args[0], |
| 456 | + ) |
0 commit comments