@@ -533,19 +533,24 @@ def __init__(self, mod, mod_extra_config, *args, **kwargs):
533533 self .fetch_from_cache = mod .fetch_from_cache
534534 self .forward = self .forward_measure
535535
536- def forward (self , input , cache , block_indices , block_offset ):
536+ def forward (self , input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset ):
537537 qinput = self .quant_input (input )
538- output_cache = self .forward_orig (qinput , cache , block_indices , block_offset )
538+ output_cache = self .forward_orig (qinput , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset )
539539 return self .quant_output (output_cache )
540540
541- def forward_measure (self , input , cache , block_indices , block_offset ):
541+ def forward_measure (self , input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset ):
542542 measure_input ((input ), self ._mod_extra_config .inputs )
543- output_cache = self .forward_orig (input , cache , block_indices , block_offset )
543+ output_cache = self .forward_orig (input , cache , num_kv_cache_passes , num_slots_available , block_indices , block_offset )
544544 measure_output ((output_cache ), self ._mod_extra_config .outputs )
545545 return output_cache
546546
547- def fetch_from_cache (self , cache , blocks ):
547+ def fetch_from_cache (self , cache , blocks , permutations = None ):
548548 quant_cache = self .quant_input (cache )
549+ if permutations :
550+ output_cache = self .orig_fetch_from_cache (quant_cache , blocks , permutations )
551+ for i in range (len (output_cache )):
552+ output_cache [i ]= self .quant_output (output_cache [i ])
553+ return output_cache
549554 output_cache = self .orig_fetch_from_cache (quant_cache , blocks )
550555 return self .quant_output (output_cache )
551556
0 commit comments