@@ -353,6 +353,7 @@ def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
353353 self .low_latency_kernel = low_latency_kernel
354354 self .tile_size = tile_size
355355 self .epilogue_tile_m = epilogue_tile_m
356+ self .tuning_config = self .get_tuning_config ()
356357
357358 instance_key = (output_dtype , use_deep_seek_fp8 , low_latency_kernel ,
358359 tile_size , epilogue_tile_m )
@@ -407,7 +408,6 @@ def get_valid_tactics(
407408
408409 return tactics
409410
410-
411411 def get_default_valid_tactic (
412412 self ,
413413 inputs : List [torch .Tensor ],
@@ -425,10 +425,7 @@ def get_default_valid_tactic(
425425
426426 return default_tactic
427427
428-
429- def get_dynamic_tensor_specs (
430- self
431- ) -> Tuple [DynamicTensorSpec , ...]:
428+ def get_dynamic_tensor_specs (self ) -> Tuple [DynamicTensorSpec , ...]:
432429 """Get the dynamic tensor specs for use with the AutoTuner."""
433430
434431 # These indices correspond to the 0th input tensor and it's first dimension
@@ -441,10 +438,9 @@ def get_dynamic_tensor_specs(
441438 m_values = (8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 )
442439 round_rule = lambda x : last_positive_power_of_2 (x )
443440
444- spec = DynamicTensorSpec (
445- MAT1_IDX , TUNED_DIM , m_values , round_rule )
441+ specs = (DynamicTensorSpec (MAT1_IDX , TUNED_DIM , m_values , round_rule ), )
446442
447- return ( spec , )
443+ return specs
448444
449445 def get_constraint_specs (self ) -> Tuple [ConstraintSpec , ...]:
450446 """Get the constraint specs for the dynamic tensors for use with the AutoTuner.
@@ -469,11 +465,22 @@ def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
469465 SFS_A_IDX = 2
470466 CONSTRAINED_DIM = 1
471467
472- constraint_dq_sfs_a = (ConstraintSpec (
473- SFS_A_IDX , CONSTRAINED_DIM , _constrain_dq_sfs_a_dim1 ),)
468+ constraint_dq_sfs_a = (ConstraintSpec (SFS_A_IDX , CONSTRAINED_DIM ,
469+ _constrain_dq_sfs_a_dim1 ), )
474470
475471 return constraint_dq_sfs_a
476472
473+ def get_tuning_config (self ) -> TuningConfig :
474+ """Get the tuning configuration for the AutoTuner."""
475+
476+ dynamic_tensor_specs = self .get_dynamic_tensor_specs ()
477+ constraint_specs = self .get_constraint_specs ()
478+
479+ tuning_config = TuningConfig (dynamic_tensor_specs = dynamic_tensor_specs ,
480+ constraint_specs = constraint_specs )
481+
482+ return tuning_config
483+
477484
478485@torch .library .custom_op ("trtllm::fp8_batched_gemm_trtllmgen" , mutates_args = ())
479486def fp8_batched_gemm_trtllmgen (
@@ -497,18 +504,12 @@ def fp8_batched_gemm_trtllmgen(
497504
498505 tuner = AutoTuner .get ()
499506
500- dynamic_tensor_specs = kernel_runner .get_dynamic_tensor_specs ()
501- constraint_specs = kernel_runner .get_constraint_specs ()
502-
503- tuning_config = TuningConfig (dynamic_tensor_specs = dynamic_tensor_specs ,
504- constraint_specs = constraint_specs )
505-
506507 inputs = [mat1 , mat2 , dq_sfs_a , dq_sfs_b , scale_c ]
507508
508509 _ , best_tactic = tuner .choose_one (
509510 "trtllm::fp8_batched_gemm_trtllmgen::batched_gemm" ,
510511 [kernel_runner ],
511- tuning_config ,
512+ kernel_runner . tuning_config ,
512513 inputs ,
513514 )
514515
0 commit comments