From 19161ed2b67d82c1d7363bbe21a0ac90a54d7652 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 19 Dec 2024 13:42:47 -0800 Subject: [PATCH] fix: Fix null inputs --- py/torch_tensorrt/dynamo/_compiler.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index a89d7bbd2c..d355cefe77 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -263,8 +263,10 @@ def cross_compile_for_windows( "When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True" ) # Aliasing inputs to arg_inputs for better understanding - if not arg_inputs and not inputs: - raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") + if not arg_inputs and not kwarg_inputs and not inputs: + raise AssertionError( + "'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None." + ) elif arg_inputs and inputs: raise AssertionError( @@ -582,8 +584,10 @@ def compile( "When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True" ) # Aliasing inputs to arg_inputs for better understanding - if not arg_inputs and not inputs: - raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") + if not arg_inputs and not kwarg_inputs and not inputs: + raise AssertionError( + "'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None." + ) elif arg_inputs and inputs: raise AssertionError( @@ -1069,8 +1073,10 @@ def convert_exported_program_to_serialized_trt_engine( "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" ) - if arg_inputs is None and inputs is None: - raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") + if not arg_inputs and not kwarg_inputs and not inputs: + raise AssertionError( + "'arg_inputs', 'kwarg_inputs' and 'inputs' should not all be None." + ) elif arg_inputs is not None and inputs is not None: raise AssertionError(