Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions py/torch_tensorrt/ts/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
def TensorRTCompileSpec(
inputs: Optional[List[torch.Tensor | Input]] = None,
input_signature: Optional[Any] = None,
device: torch.device | Device = Device._current_device(),
device: Optional[torch.device | Device] = None,
disable_tf32: bool = False,
sparse_weights: bool = False,
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
Expand Down Expand Up @@ -365,7 +365,7 @@ def TensorRTCompileSpec(
compile_spec = {
"inputs": inputs if inputs is not None else [],
# "input_signature": input_signature,
"device": device,
"device": Device._current_device() if device is None else device,
"disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
"sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
"enabled_precisions": (
Expand Down
Loading