Skip to content

Commit a2dab4c

Browse files
committed
Fix pre-commit + only generate tuning config once
Signed-off-by: Dom Brown <[email protected]>
1 parent 4820e24 commit a2dab4c

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
353353
self.low_latency_kernel = low_latency_kernel
354354
self.tile_size = tile_size
355355
self.epilogue_tile_m = epilogue_tile_m
356+
self.tuning_config = self.get_tuning_config()
356357

357358
instance_key = (output_dtype, use_deep_seek_fp8, low_latency_kernel,
358359
tile_size, epilogue_tile_m)
@@ -407,7 +408,6 @@ def get_valid_tactics(
407408

408409
return tactics
409410

410-
411411
def get_default_valid_tactic(
412412
self,
413413
inputs: List[torch.Tensor],
@@ -425,10 +425,7 @@ def get_default_valid_tactic(
425425

426426
return default_tactic
427427

428-
429-
def get_dynamic_tensor_specs(
430-
self
431-
) -> Tuple[DynamicTensorSpec, ...]:
428+
def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:
432429
"""Get the dynamic tensor specs for use with the AutoTuner."""
433430

434431
# These indices correspond to the 0th input tensor and it's first dimension
@@ -441,10 +438,9 @@ def get_dynamic_tensor_specs(
441438
m_values = (8, 16, 32, 64, 128, 256, 512, 1024, 2048)
442439
round_rule = lambda x: last_positive_power_of_2(x)
443440

444-
spec = DynamicTensorSpec(
445-
MAT1_IDX, TUNED_DIM, m_values, round_rule)
441+
specs = (DynamicTensorSpec(MAT1_IDX, TUNED_DIM, m_values, round_rule), )
446442

447-
return (spec, )
443+
return specs
448444

449445
def get_constraint_specs(self) -> Tuple[ConstraintSpec, ...]:
450446
"""Get the constraint specs for the dynamic tensors for use with the AutoTuner.
@@ -469,11 +465,22 @@ def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
469465
SFS_A_IDX = 2
470466
CONSTRAINED_DIM = 1
471467

472-
constraint_dq_sfs_a = (ConstraintSpec(
473-
SFS_A_IDX, CONSTRAINED_DIM, _constrain_dq_sfs_a_dim1),)
468+
constraint_dq_sfs_a = (ConstraintSpec(SFS_A_IDX, CONSTRAINED_DIM,
469+
_constrain_dq_sfs_a_dim1), )
474470

475471
return constraint_dq_sfs_a
476472

473+
def get_tuning_config(self) -> TuningConfig:
474+
"""Get the tuning configuration for the AutoTuner."""
475+
476+
dynamic_tensor_specs = self.get_dynamic_tensor_specs()
477+
constraint_specs = self.get_constraint_specs()
478+
479+
tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
480+
constraint_specs=constraint_specs)
481+
482+
return tuning_config
483+
477484

478485
@torch.library.custom_op("trtllm::fp8_batched_gemm_trtllmgen", mutates_args=())
479486
def fp8_batched_gemm_trtllmgen(
@@ -497,18 +504,12 @@ def fp8_batched_gemm_trtllmgen(
497504

498505
tuner = AutoTuner.get()
499506

500-
dynamic_tensor_specs = kernel_runner.get_dynamic_tensor_specs()
501-
constraint_specs = kernel_runner.get_constraint_specs()
502-
503-
tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
504-
constraint_specs=constraint_specs)
505-
506507
inputs = [mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c]
507508

508509
_, best_tactic = tuner.choose_one(
509510
"trtllm::fp8_batched_gemm_trtllmgen::batched_gemm",
510511
[kernel_runner],
511-
tuning_config,
512+
kernel_runner.tuning_config,
512513
inputs,
513514
)
514515

tests/unittest/_torch/thop/test_tllmg_bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def test_thop(self, test_case: BatchedGemmTestCase) -> None:
346346
atol=1e-2,
347347
rtol=1e-2)
348348

349-
def test_autotunable_thop(self, test_case: BatchedGemmTestCase) -> None:
349+
def test_autotuned_thop(self, test_case: BatchedGemmTestCase) -> None:
350350
torch.random.manual_seed(42)
351351

352352
b = test_case.b

0 commit comments

Comments
 (0)