@@ -92,7 +92,7 @@ def attn_custom_op_inplace(
9292 mrope_position_deltas ,
9393 attention_window_size ,
9494 attention_mask_data ,
95- False ,
95+ enable_attn_nvfp4_output = False ,
9696 output = output )
9797
9898
@@ -372,6 +372,58 @@ def _attn_impl(
372372 return attn_output [0 ], attn_output [1 ]
373373 return attn_output , None
374374
375+ def forward_impl (
376+ self ,
377+ q : torch .Tensor ,
378+ k : Optional [torch .Tensor ],
379+ v : Optional [torch .Tensor ],
380+ attn_metadata : AttentionMetadata ,
381+ attention_mask : AttentionMask ,
382+ attention_window_size : Optional [int ],
383+ attention_mask_data : Optional [torch .Tensor ],
384+ mrope_config : Optional [dict ],
385+ ):
386+ mrope_rotary_cos_sin = None
387+ mrope_position_deltas = None
388+ if mrope_config is not None :
389+ if "mrope_rotary_cos_sin" in mrope_config :
390+ mrope_rotary_cos_sin = mrope_config ["mrope_rotary_cos_sin" ]
391+ if "mrope_position_deltas" in mrope_config :
392+ mrope_position_deltas = mrope_config ["mrope_position_deltas" ]
393+
394+ # Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
395+ # Only enable custom inplace op when torch compiling.
396+ use_custom_inplace_op = (self .register_to_config
397+ and (self .attn_backend == "TRTLLM"
398+ or self .attn_backend == "FLASHINFER" )
399+ and is_torch_compiling ())
400+
401+ if use_custom_inplace_op :
402+ output = self .create_output (q )
403+ attn_custom_op_inplace (
404+ q ,
405+ k ,
406+ v ,
407+ attention_mask ,
408+ mrope_rotary_cos_sin ,
409+ mrope_position_deltas ,
410+ attention_window_size ,
411+ attention_mask_data ,
412+ self .layer_idx_str ,
413+ output ,
414+ )
415+ else :
416+ output , output_sf = self ._attn_impl (q , k , v , attn_metadata ,
417+ attention_mask ,
418+ mrope_rotary_cos_sin ,
419+ mrope_position_deltas ,
420+ attention_window_size ,
421+ attention_mask_data )
422+ if output_sf is not None :
423+ output = Fp4QuantizedTensor (output , output_sf )
424+
425+ return output
426+
375427 def forward (
376428 self ,
377429 position_ids : Optional [torch .IntTensor ],
@@ -414,54 +466,18 @@ def forward(
414466 if qkv_lora is not None :
415467 qkv = qkv + qkv_lora
416468
417- mrope_rotary_cos_sin = None
418- mrope_position_deltas = None
419- if mrope_config is not None :
420- if "mrope_rotary_cos_sin" in mrope_config :
421- mrope_rotary_cos_sin = mrope_config ["mrope_rotary_cos_sin" ]
422- if "mrope_position_deltas" in mrope_config :
423- mrope_position_deltas = mrope_config ["mrope_position_deltas" ]
424-
425- output = None
426-
427469 q , k , v = qkv , None , None
428470 q , k , v = self .apply_rope (q , k , v , position_ids )
429471 q , k , v = self .convert_qkv (q , k , v )
430472
431- # Currently only TRTLLM and FLASHINFER are torch compile compatible backends.
432- # Only enable custom inplace op when torch compiling.
433- use_custom_inplace_op = (self .register_to_config
434- and (self .attn_backend == "TRTLLM"
435- or self .attn_backend == "FLASHINFER" )
436- and is_torch_compiling ())
437- if use_custom_inplace_op :
438- output = self .create_output (q )
439- attn_custom_op_inplace (
440- q ,
441- k ,
442- v ,
443- attention_mask ,
444- mrope_rotary_cos_sin ,
445- mrope_position_deltas ,
446- attention_window_size ,
447- attention_mask_data ,
448- self .layer_idx_str ,
449- output = output ,
450- )
451- else :
452- output , output_sf = self ._attn_impl (
453- q ,
454- k ,
455- v ,
456- attn_metadata ,
457- attention_mask ,
458- mrope_rotary_cos_sin ,
459- mrope_position_deltas ,
460- attention_window_size ,
461- attention_mask_data ,
462- )
463- if output_sf is not None :
464- output = Fp4QuantizedTensor (output , output_sf )
473+ output = self .forward_impl (q ,
474+ k ,
475+ v ,
476+ attn_metadata ,
477+ attention_mask ,
478+ attention_window_size ,
479+ attention_mask_data ,
480+ mrope_config = mrope_config )
465481
466482 attn_output = self .o_proj (output ,
467483 all_reduce_params = all_reduce_params ,
0 commit comments