File tree Expand file tree Collapse file tree 1 file changed +10
-0
lines changed
neural_compressor/torch/algorithms/fp8_quant/_quant_common Expand file tree Collapse file tree 1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -725,6 +725,9 @@ def forward(
725725 is_causal = False ,
726726 scale = None ,
727727 softmax_mode = "None" ,
728+ recompute = None ,
729+ valid_seq_len = None ,
730+ seq_padding_type = "None" ,
728731 ):
729732 qinput = self .quant_q (q ).detach ()
730733 kinput = self .quant_k (k ).detach ()
@@ -746,6 +749,8 @@ def forward(
746749 q_scale_o = self .scale_output ,
747750 d_scale_s = self .descale_amax ,
748751 is_amax_s = False ,
752+ valid_seq_len = valid_seq_len ,
753+ seq_padding_type = seq_padding_type
749754 )
750755 output = results [0 ]
751756 d_out = self .dequant_output (output )
@@ -761,6 +766,9 @@ def forward_measure(
761766 is_causal = False ,
762767 scale = None ,
763768 softmax_mode = "fast" ,
769+ recompute = None ,
770+ valid_seq_len = None ,
771+ seq_padding_type = "None" ,
764772 ):
765773 dq = q .detach ()
766774 dk = k .detach ()
@@ -777,6 +785,8 @@ def forward_measure(
777785 # fp8_fused_sdpa in bf16 can use either FastSoftmax or regular
778786 softmax_mode = "fast" ,
779787 is_amax_s = True ,
788+ valid_seq_len = valid_seq_len ,
789+ seq_padding_type = seq_padding_type
780790 )
781791 output = results [0 ]
782792 amax = results [1 ]
You can’t perform that action at this time.
0 commit comments