diff --git a/examples/summarize.py b/examples/summarize.py index afd598e8418..afe037d9435 100644 --- a/examples/summarize.py +++ b/examples/summarize.py @@ -314,7 +314,7 @@ def eval_trt_llm(datapoint, theta=pretrain_config.rotary_base, ) - if batch_size == 0: + if batch_size == 0 or len(batch_input_ids) == 0: return [], [], [], {} input_lengths = [x.size(0) for x in batch_input_ids]