diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 090bcd42b2f..f981c2d768d 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -25,8 +25,8 @@ class DynamicTensorSpec: """ input_idx: int dim_idx: int - gen_tuning_buckets: Union[Tuple[int], Callable] - map_to_tuning_buckets: Callable + gen_tuning_buckets: Union[Tuple[int], Callable] = () + map_to_tuning_buckets: Callable = lambda x: x @dataclass(slots=True, unsafe_hash=True) @@ -43,7 +43,7 @@ class ConstraintSpec: infer_shape: Callable -@dataclass(kw_only=True, unsafe_hash=True) +@dataclass(kw_only=True) class TuningConfig: """Configuration for autotuning. @@ -82,8 +82,35 @@ class TuningConfig: ... ) ... ) """ + name: Union[str, Tuple[str, ...]] = "" dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = () constraint_specs: Tuple[ConstraintSpec, ...] = () + configs: Dict[str, Any] = field(default_factory=dict) + tune_max_num_tokens: int = None + + +def tuning_config( + name: Union[str, Tuple[str, ...]] = "", + dynamic_tensor_specs: Tuple[DynamicTensorSpec, ...] = (), + constraint_specs: Tuple[ConstraintSpec, ...] = (), + configs: Dict[str, Any] = {}, + tune_max_num_tokens: int = None, +): + + def decorator(func): + tuner = AutoTuner.get() + tuning_config = TuningConfig( + name=name, + dynamic_tensor_specs=dynamic_tensor_specs, + constraint_specs=constraint_specs, + configs=configs, + tune_max_num_tokens=tune_max_num_tokens, + ) + tuner.register_tuning_config(tuning_config) + + return func + + return decorator @dataclass(unsafe_hash=True) @@ -139,7 +166,7 @@ class TunableRunner(ABC): @abstractmethod def get_valid_tactics(self, inputs: List[torch.Tensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: """One tactic corresponding to one cuda kernel normally, but how to interpret the meaning of tactic is pure internal details of the runner. @@ -167,7 +194,8 @@ def forward( inputs: List[torch.Tensor], *, # all others are keyword args only tactic: int = -1, - do_preparation: bool = False) -> Any: + do_preparation: bool = False, + **kwargs) -> Any: """Forward pass for tunable runners. Args: @@ -277,6 +305,7 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000): self.warmup = warmup self.stream_delay_micro_secs = stream_delay_micro_secs self.profiling_cache = {} + self.registered_tuning_configs = {} self.is_tuning_mode = False # Add statistics tracking @@ -296,7 +325,7 @@ def search_cache( runners: List[TunableRunner], input_shapes: Tuple[torch.Size], tuning_config: TuningConfig, - ) -> Tuple[bool, int, int, OptimizationProfile]: + ) -> Tuple[bool, int, int, Dict[str, Any], OptimizationProfile]: """Search for cached profiling results matching the current configuration. Args: @@ -306,7 +335,7 @@ def search_cache( Returns: A tuple containing: - [is_cache_hit, runner_id, tactic, stored_profile] + [is_cache_hit, runner_id, tactic, best_config, stored_profile] """ for r in runners: if (cache_key := AutoTuner._get_cache_key( @@ -314,10 +343,14 @@ def search_cache( tuning_config)) in self.profiling_cache: return True, *self.profiling_cache[cache_key] - return False, 0, -1, None + return False, 0, -1, {}, None - def choose_one(self, custom_op: str, runners: List[TunableRunner], - tuning_config: TuningConfig, inputs: List[torch.Tensor], + def choose_one(self, + custom_op: str, + runners: List[TunableRunner], + inputs: List[torch.Tensor], + tuning_config: TuningConfig = None, + tune_max_num_tokens: int = None, **kwargs) -> Tuple[TunableRunner, int]: """Choose the best runner and tactic combination through performance profiling. @@ -343,11 +376,16 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], input_shapes = tuple(self._get_input_sizes(inputs)) + tuning_config = self.get_tuning_config( + name=custom_op) if tuning_config is None else tuning_config + if tune_max_num_tokens is not None: + tuning_config.tune_max_num_tokens = tune_max_num_tokens + # Early return if it's not tuning, use cache found one or fallback one if not self.is_tuning_mode: - is_cache_hit, runner_id, tactic, stored_profile = self.search_cache( + is_cache_hit, best_runner_id, best_tactic, best_configs, stored_profile = self.search_cache( custom_op, runners, input_shapes, tuning_config) - runner = runners[runner_id] + best_runner = runners[best_runner_id] # TODO: check the stored runner and tactic can implement this shape here # Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf. @@ -360,81 +398,111 @@ def choose_one(self, custom_op: str, runners: List[TunableRunner], logger.debug( f"[AutoTunner]: Generated key{AutoTuner._get_cache_key(custom_op, runners[0], input_shapes, tuning_config)}" ) - return runner, tactic + + if tuning_config.configs: + return (best_runner, best_tactic, best_configs) + else: + return (best_runner, best_tactic) assert len(runners) > 0, "At least one runner is required" assert all([isinstance(r, TunableRunner) for r in runners]), \ "All Given runners must be subclass of TunableRunner" profiles = self._optimization_profiles(tuning_config, inputs) + configs = self._generate_all_configs(tuning_config) + # Record the total configs to try self.stats.tuned_op_total_configs[custom_op] = len(profiles) for p in profiles: + # This can depend on the configs, which is only looped over in profile_runners tensors = self._prepare_input_tensors(p, inputs) - is_cache_hit, runner_id, tactic, _ = self.search_cache( - custom_op, runners, p.get_opt_shapes(), tuning_config) + is_cache_hit, *_ = self.search_cache(custom_op, runners, + p.get_opt_shapes(), + tuning_config) if not is_cache_hit: - min_time = float('inf') # Initialize runner and tactic as None in case of no valid tactic or runners are found - runner_id, tactic = None, None - for r_id, r in enumerate(runners): - # TODO: use FakeTensor here. - valid_tactics = r.get_valid_tactics(tensors, p) - runner_arg_names = { - p.name - for p in inspect.signature( - r.forward).parameters.values() - } - if "do_preparation" in runner_arg_names and len( - valid_tactics) > 0: - r(tensors, tactic=-1, do_preparation=True, **kwargs) - for tac in valid_tactics: - try: - time_measured = self._profile_single_kernel( - r, tensors, tac, **kwargs) - except Exception as e: - shapes = self._get_input_sizes(tensors) - - logger.error( - f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}" - ) - - # Record the failed profiling combinations - if custom_op not in self.stats.failed_profiling_count: - self.stats.failed_profiling_count[ - custom_op] = set() - self.stats.failed_profiling_count[custom_op].add( - AutoTuner._get_cache_key( - custom_op, r, p.get_opt_shapes(), - tuning_config)) - - # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics - # or some runtime error occurs during profiling. - time_measured = float('inf') - if time_measured < min_time: - min_time = time_measured - runner_id, tactic = r_id, tac - if runner_id is not None: + best_runner_id, best_tactic, best_config = self._profile_runners( + custom_op, runners, tensors, p, tuning_config, configs, + **kwargs) + if best_runner_id is not None: # At least one valid (runner, tactic) pair is found cache_key = AutoTuner._get_cache_key( - custom_op, runners[runner_id], p.get_opt_shapes(), + custom_op, runners[best_runner_id], p.get_opt_shapes(), tuning_config) # inspect call stack - self.profiling_cache[cache_key] = (runner_id, tactic, p) + self.profiling_cache[cache_key] = (best_runner_id, + best_tactic, best_config, + p) self.stats.tuned_op_successful_configs[ custom_op] = self.stats.tuned_op_successful_configs.get( custom_op, 0) + 1 logger.debug( - f"[Autotuner]: profiling chosen runner: {runners[runner_id]} {tactic} for {cache_key}" + f"[Autotuner]: profiling chosen runner: {runners[best_runner_id]} {best_tactic}{f' {best_config}' if best_config else ''} for {cache_key}" ) # Get the best runner and tactic from cache # If no valid tactic is found, the fallback runner and tactic will be used - _, runner_id, tactic, _ = self.search_cache(custom_op, runners, - input_shapes, tuning_config) - - return runners[runner_id], tactic + _, runner_id, tactic, config, _ = self.search_cache( + custom_op, runners, input_shapes, tuning_config) + + if tuning_config.configs: + return (runners[runner_id], tactic, config) + else: + return (runners[runner_id], tactic) + + def _profile_runners(self, custom_op: str, runners: List[TunableRunner], + input_tensors: List[torch.Tensor], + profile: OptimizationProfile, + tuning_config: TuningConfig, + configs: List[Dict[str, Any]], **kwargs) -> float: + min_time = float('inf') + best_runner_id, best_tactic, best_config = None, None, None + for runner_id, runner in enumerate(runners): + # TODO: use FakeTensor here. + runner_arg_names = { + p.name + for p in inspect.signature(runner.forward).parameters.values() + } + for config in configs: + valid_tactics = runner.get_valid_tactics( + input_tensors, profile, **config) + if "do_preparation" in runner_arg_names and len( + valid_tactics) > 0: + runner(input_tensors, + tactic=-1, + do_preparation=True, + **config, + **kwargs) + + for tac in valid_tactics: + try: + time_measured = self._profile_single_kernel( + runner, input_tensors, tac, config, **kwargs) + except Exception as e: + # Handle None tensors for optional inputs + shapes = self._get_input_sizes(input_tensors) + + logger.error( + f"[Autotuner]: Failed when profiling {runner} {tac}, shapes={shapes}{f', {config}' if config else ''}. Error occurred: {e}" + ) + + # Record the failed profiling combinations + if custom_op not in self.stats.failed_profiling_count: + self.stats.failed_profiling_count[custom_op] = set() + self.stats.failed_profiling_count[custom_op].add( + AutoTuner._get_cache_key(custom_op, runner, + profile.get_opt_shapes(), + tuning_config)) + + # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics + # or some runtime error occurs during profiling. + time_measured = float('inf') + if time_measured < min_time: + min_time = time_measured + best_runner_id, best_tactic, best_config = runner_id, tac, config + + return best_runner_id, best_tactic, best_config def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: @@ -448,7 +516,7 @@ def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: def _profile_single_kernel(self, runner: TunableRunner, inputs: List[torch.Tensor], tactic: int, - **kwargs) -> float: + config: Dict[str, Any], **kwargs) -> float: """Profile a single kernel implementation for performance measurement. Args: @@ -467,7 +535,7 @@ def _profile_single_kernel(self, runner: TunableRunner, stream = torch.cuda.current_stream() # warm up, no timing for _ in range(self.warmup): - runner(inputs, tactic=tactic, **kwargs) + runner(inputs, tactic=tactic, **config, **kwargs) stream.synchronize() # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling. @@ -479,7 +547,7 @@ def _profile_single_kernel(self, runner: TunableRunner, start.record(stream=stream) for _ in range(self.repeat): - runner(inputs, tactic=tactic, **kwargs) + runner(inputs, tactic=tactic, **config, **kwargs) end.record(stream=stream) stream.synchronize() @@ -487,7 +555,7 @@ def _profile_single_kernel(self, runner: TunableRunner, shapes = self._get_input_sizes(inputs) logger.debug( - f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}" + f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes} {f', {config}' if config else ''}. avg_time: {avg_time:.6f}ms" ) return avg_time @@ -525,10 +593,23 @@ def _optimization_profiles( assert inspect.isfunction(spec.gen_tuning_buckets) or isinstance(spec.gen_tuning_buckets, (list, tuple)), \ "The given dynamic dimension must provide a opt value generation function or a list of opt values" if inspect.isfunction(spec.gen_tuning_buckets): - opt_shapes = spec.gen_tuning_buckets( - base_profile.shapes[spec.input_idx][spec.dim_idx].val) + if tuning_config.tune_max_num_tokens is None: + # Use the current input size as the opt value + opt_shapes = spec.gen_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val) + else: + # Use the tune_max_num_tokens as the opt value + opt_shapes = spec.gen_tuning_buckets( + tuning_config.tune_max_num_tokens) else: + # Default values is an empty tuple, means that user does not want to tune this dimension. opt_shapes = spec.gen_tuning_buckets + # Add the current input value as one of the opt values + opt_shapes = set(opt_shapes) + opt_shapes.add( + spec.map_to_tuning_buckets( + base_profile.shapes[spec.input_idx][spec.dim_idx].val)) + opt_shapes = sorted(list(opt_shapes)) opt_shapes_max = tuple(opt_shapes[1:]) + (float('inf'), ) opt_shapes_max = { v1: v2 @@ -554,6 +635,8 @@ def _optimization_profiles( for spec in tuning_config.constraint_specs: min_value = opt_value = max_value = spec.infer_shape( p.get_opt_shapes()) + if p.shapes[spec.input_idx] == [StaticDim(0)]: + continue p.shapes[spec.input_idx][spec.dim_idx] = DynamicDim( min_value, opt_value, max_value) generated_profiles.append(p) @@ -562,8 +645,12 @@ def _optimization_profiles( @classmethod @lru_cache(maxsize=None) - def _find_nearest_profile(cls, shapes: Tuple[torch.Size], - tuning_config: TuningConfig) -> Tuple: + def _find_nearest_profile(cls, + shapes: Tuple[torch.Size], + dynamic_tensor_specs: Tuple[DynamicTensorSpec, + ...], + constraint_specs: Tuple[ConstraintSpec, ...], + tune_max_num_tokens: int = None) -> Tuple: """Find the nearest optimization profile for given inputs User can define their own nearest profile generation method to reduce the host overhead. @@ -578,13 +665,20 @@ def _find_nearest_profile(cls, shapes: Tuple[torch.Size], """ base_profile = list(list(shape) for shape in shapes) - for spec in tuning_config.dynamic_tensor_specs: + for spec in dynamic_tensor_specs: base_profile[spec.input_idx][ spec.dim_idx] = spec.map_to_tuning_buckets( base_profile[spec.input_idx][spec.dim_idx]) + if tune_max_num_tokens is not None: + base_profile[spec.input_idx][spec.dim_idx] = min( + base_profile[spec.input_idx][spec.dim_idx], + tune_max_num_tokens) + # associated dimensions dependent on other free dynamic dimensions, so assign -1 in the profile - for spec in tuning_config.constraint_specs: + for spec in constraint_specs: + if base_profile[spec.input_idx] == [0]: + continue base_profile[spec.input_idx][spec.dim_idx] = -1 return tuple(tuple(shape) for shape in base_profile) @@ -598,7 +692,10 @@ def _get_cache_key( tuning_config: TuningConfig, ) -> Tuple: return (custom_op, runner.__class__.__name__, hash(runner), - cls._find_nearest_profile(input_shapes, tuning_config)) + cls._find_nearest_profile(input_shapes, + tuning_config.dynamic_tensor_specs, + tuning_config.constraint_specs, + tuning_config.tune_max_num_tokens)) def _create_tensor_like(self, origin_tensor: torch.Tensor, dims: List[Dim]) -> torch.Tensor: @@ -642,6 +739,17 @@ def _prepare_input_tensors( tensors.append(tensor) return tensors + def _generate_all_configs( + self, tuning_config: TuningConfig) -> List[Dict[str, Any]]: + # If there is no config, return a list with an empty dict + if not tuning_config.configs: + return [{}] + logger.debug( + f"[Autotuner]: {tuning_config.name} all configs: {tuning_config.configs}" + ) + prod = itertools.product(*tuning_config.configs.values()) + return list(dict(zip(tuning_config.configs.keys(), p)) for p in prod) + def clear_cache(self) -> None: """Clear the profiling cache.""" self.profiling_cache.clear() @@ -649,3 +757,12 @@ def clear_cache(self) -> None: def reset_statistics(self) -> None: """Reset all statistics counters.""" self.stats = AutoTunerStatistics() + + def register_tuning_config(self, tuning_config: TuningConfig): + name = tuning_config.name + names = [name] if isinstance(name, str) else name + for name in names: + self.registered_tuning_configs[name] = tuning_config + + def get_tuning_config(self, name: str) -> TuningConfig: + return self.registered_tuning_configs.get(name, TuningConfig()) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index ffeb90c2fd3..558bc213bcc 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -6,6 +6,7 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils from tensorrt_llm._utils import get_sm_version +from .. import autotuner from ..attention_backend.interface import AttentionInputType from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) @@ -23,9 +24,6 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None: class MoERunner(TunableRunner): # avoid overhead of creating a new runner in forward pass runner_dict = dict() - tuning_config = TuningConfig(dynamic_tensor_specs=( - DynamicTensorSpec(0, 0, get_last_power_of_2_num_tokens_buckets(8192), - lambda x: min(last_positive_power_of_2(x), 8192)), )) def __init__( self, @@ -87,6 +85,7 @@ def forward( do_preparation: bool = False, ): x, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases = inputs + # determine if we should use min latency mode according to the profiled seq len self.fused_moe_runner.run_gemm_profile( x, fc1_expert_weights, @@ -107,17 +106,15 @@ def forward( do_preparation, ) - @classmethod - @lru_cache(maxsize=None) - def refine_tuning_config(cls, tune_max_num_tokens: int): - cls.tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets( - tune_max_num_tokens), lambda x: min( - last_positive_power_of_2(x), tune_max_num_tokens)), )) - @torch.library.custom_op("trtllm::fused_moe", mutates_args=()) +@autotuner.tuning_config( + name=("trtllm::fused_moe::gemm1", "trtllm::fused_moe::gemm2"), + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + tune_max_num_tokens=8192, +) def fused_moe( input: torch.Tensor, token_selected_experts: torch.Tensor, @@ -144,7 +141,6 @@ def fused_moe( ) -> List[torch.Tensor]: tuner = AutoTuner.get() - MoERunner.refine_tuning_config(tune_max_num_tokens) # allocate workspace for profiling moe_runner = MoERunner( @@ -168,23 +164,23 @@ def fused_moe( _, gemm_tactic_1 = tuner.choose_one( "trtllm::fused_moe::gemm1", [moe_runner], - MoERunner.tuning_config, [ input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases ], gemm_idx=1, + tune_max_num_tokens=tune_max_num_tokens, ) _, gemm_tactic_2 = tuner.choose_one( "trtllm::fused_moe::gemm2", [moe_runner], - MoERunner.tuning_config, [ input, fc1_expert_weights, fc1_expert_biases, fc2_expert_weights, fc2_expert_biases ], gemm_idx=2, + tune_max_num_tokens=tune_max_num_tokens, ) run_moe = moe_runner.fused_moe_runner.run_moe_min_latency if min_latency_mode else moe_runner.fused_moe_runner.run_moe @@ -257,14 +253,6 @@ def _( class FP8RowwiseGemmRunner(TunableRunner): runner_dict = dict() - tuning_config = TuningConfig( - dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2), ), - constraint_specs=( - ConstraintSpec(2, 0, lambda shapes: shapes[0][0]), - ConstraintSpec(3, 0, lambda shapes: shapes[1][0]), - )) def __init__( self, @@ -305,6 +293,16 @@ def forward( @torch.library.custom_op("trtllm::fp8_rowwise_gemm", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::fp8_rowwise_gemm::gemm", + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=( + ConstraintSpec(2, 0, lambda shapes: shapes[0][0]), + ConstraintSpec(3, 0, lambda shapes: shapes[1][0]), + ), +) def fp8_rowwise_gemm( act: torch.Tensor, weight: torch.Tensor, @@ -322,7 +320,6 @@ def fp8_rowwise_gemm( _, best_tactic = tuner.choose_one( "trtllm::fp8_rowwise_gemm::gemm", [fp8_rowwise_gemm_runner], - FP8RowwiseGemmRunner.tuning_config, [act, weight, act_scale, weight_scale], ) @@ -344,11 +341,6 @@ def _( class FP4GemmRunner(TunableRunner): runner_dict = dict() - tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( - 0, 0, get_last_power_of_2_num_tokens_buckets, - last_positive_power_of_2), ), - constraint_specs=(ConstraintSpec( - 2, 0, fp4_scale_infer_shape), )) def __init__( self, @@ -391,6 +383,13 @@ def forward( @torch.library.custom_op("trtllm::nvfp4_gemm", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::nvfp4_gemm::gemm", + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), +) def nvfp4_gemm( act_fp4: torch.Tensor, weight: torch.Tensor, @@ -410,13 +409,13 @@ def nvfp4_gemm( _, best_tactic = tuner.choose_one( "trtllm::fp4_gemm::gemm", [nvfp4_gemm_runner], - FP4GemmRunner.tuning_config, [act_fp4, weight, act_sf, weight_scale, alpha], ) return nvfp4_gemm_runner( inputs=[act_fp4, weight, act_sf, weight_scale, alpha], - tactic=best_tactic) + tactic=best_tactic, + ) @nvfp4_gemm.register_fake @@ -589,8 +588,8 @@ def fp8_batched_gemm_trtllmgen( _, best_tactic = tuner.choose_one( "trtllm::fp8_batched_gemm_trtllmgen::batched_gemm", [kernel_runner], - FP8BatchedGemmRunner.tuning_config, inputs, + tuning_config=FP8BatchedGemmRunner.tuning_config, ) return kernel_runner( @@ -633,6 +632,13 @@ def _( @torch.library.custom_op("trtllm::w4a8_mxfp4_fp8_gemm", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::w4a8_mxfp4_fp8_gemm::gemm", + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), +) def w4a8_mxfp4_fp8_gemm( act_fp8: torch.Tensor, weight: torch.Tensor, @@ -652,7 +658,6 @@ def w4a8_mxfp4_fp8_gemm( _, best_tactic = tuner.choose_one( "trtllm::w4a8_mxfp4_fp8_gemm::gemm", [w4a8_mxfp4_fp8_gemm_runner], - FP4GemmRunner.tuning_config, [act_fp8, weight, act_sf, weight_scale, alpha], ) @@ -719,6 +724,12 @@ def forward(self, @torch.library.custom_op("trtllm::w4a16_gemm", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::w4a16_gemm::gemm", + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets(8192), + last_positive_power_of_2), ), +) def w4a16_gemm(input: torch.Tensor, weight: torch.Tensor, scales: torch.Tensor, @@ -731,11 +742,6 @@ def w4a16_gemm(input: torch.Tensor, tuner = AutoTuner.get() - tuning_config = TuningConfig(dynamic_tensor_specs=( - # For tensor index 0 (input A), tune dimension 0 (M dimension) - DynamicTensorSpec(0, 0, (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, - 16, 8, 4, 2, 1), last_positive_power_of_2), )) - # NOTE: qunant_mode equals 0 it means we use scale only (FINEGRAINED_SCALE_ONLY), zeros is not used, else we use scale and zero point quant_mode = 1 if has_zero_point else 0 if quant_mode == 0: @@ -745,7 +751,7 @@ def w4a16_gemm(input: torch.Tensor, kwargs = {"group_size": group_size, "zeros": zeros, "bias": bias} _, best_tactic = tuner.choose_one("trtllm::w4a16_gemm::gemm", - [w4a16_gemm_runner], tuning_config, + [w4a16_gemm_runner], [input, weight, scales], **kwargs) return w4a16_gemm_runner(inputs=[input, weight, scales], diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 067680025dd..b0325d876ac 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import lru_cache from typing import List, Optional, Tuple import torch @@ -7,8 +6,9 @@ from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2) +from .. import autotuner from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, - OptimizationProfile, TunableRunner, TuningConfig) + OptimizationProfile, TunableRunner) @dataclass(frozen=True) @@ -50,9 +50,6 @@ def __init__(self, num_experts: int, top_k: int, n_group: Optional[int], self.routing_method_type = routing_method_type self.do_finalize = do_finalize - FP4BlockScaleMoERunner.tuning_config = FP4BlockScaleMoERunner.get_tuning_config( - ) - instance_key = ( self.top_k, self.intermediate_size, @@ -132,19 +129,16 @@ def get_valid_tactics( def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 - - m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) - round_rule = lambda x: min(last_positive_power_of_2(x), - MAX_PROFILE_BUCKET) + m_values = get_last_power_of_2_num_tokens_buckets + round_rule = last_positive_power_of_2 specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, round_rule), ) return specs - @classmethod - def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: + @staticmethod + def get_constraint_specs() -> Tuple[ConstraintSpec, ...]: def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: HIDDEN_STATES_IDX = 2 @@ -189,20 +183,14 @@ def _constrain_fp4_linear_layout(shapes: Tuple[torch.Size]) -> int: return constraint_specs_tuple - @classmethod - @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: - - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() - constraint_specs = cls.get_constraint_specs() - - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) - - return tuning_config - @torch.library.custom_op("trtllm::fp4_block_scale_moe_runner", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::fp4_block_scale_moe_runner", + dynamic_tensor_specs=FP4BlockScaleMoERunner.get_dynamic_tensor_specs(), + constraint_specs=FP4BlockScaleMoERunner.get_constraint_specs(), + tune_max_num_tokens=4096, +) def fp4_block_scale_moe_runner(routing_logits: torch.Tensor, routing_bias: Optional[torch.Tensor], hidden_states: torch.Tensor, @@ -247,7 +235,6 @@ def fp4_block_scale_moe_runner(routing_logits: torch.Tensor, _, best_tactic = tuner.choose_one( "trtllm::fp4_block_scale_moe_runner", [kernel_runner], - kernel_runner.tuning_config, inputs, ) @@ -289,9 +276,6 @@ def __init__(self, num_experts: int, top_k: int, n_group: int, self.tile_tokens_dim = tile_tokens_dim self.routing_method_type = routing_method_type - FP8BlockScaleMoERunner.tuning_config = FP8BlockScaleMoERunner.get_tuning_config( - ) - instance_key = ( self.top_k, self.intermediate_size, @@ -367,8 +351,8 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - m_values = get_last_power_of_2_num_tokens_buckets(2048) - round_rule = lambda x: min(last_positive_power_of_2(x), 2048) + m_values = get_last_power_of_2_num_tokens_buckets + round_rule = last_positive_power_of_2 specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, round_rule), ) @@ -379,20 +363,14 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: return () - @classmethod - @lru_cache(maxsize=None) - def get_tuning_config(cls) -> TuningConfig: - - dynamic_tensor_specs = cls.get_dynamic_tensor_specs() - constraint_specs = cls.get_constraint_specs() - - tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs, - constraint_specs=constraint_specs) - - return tuning_config - @torch.library.custom_op("trtllm::fp8_block_scale_moe_runner", mutates_args=()) +@autotuner.tuning_config( + name="trtllm::fp8_block_scale_moe_runner", + dynamic_tensor_specs=FP8BlockScaleMoERunner.get_dynamic_tensor_specs(), + constraint_specs=FP8BlockScaleMoERunner.get_constraint_specs(), + tune_max_num_tokens=2048, +) def fp8_block_scale_moe_runner(routing_logits: torch.Tensor, routing_bias: torch.Tensor, hidden_states: torch.Tensor, @@ -431,8 +409,8 @@ def fp8_block_scale_moe_runner(routing_logits: torch.Tensor, _, best_tactic = tuner.choose_one( "trtllm::fp8_block_scale_moe_runner", [kernel_runner], - kernel_runner.tuning_config, inputs, + tuning_config=kernel_runner.tuning_config, ) return kernel_runner(inputs, tactic=best_tactic) diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index f687e9d9f55..403e4081258 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -226,7 +226,6 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]: while m >= 1: num_token_buckets.append(m) m //= 2 - return tuple(num_token_buckets) diff --git a/tests/unittest/_torch/test_autotuner.py b/tests/unittest/_torch/test_autotuner.py index 21eb0a96260..091d6491d29 100644 --- a/tests/unittest/_torch/test_autotuner.py +++ b/tests/unittest/_torch/test_autotuner.py @@ -19,14 +19,16 @@ def test_multi_dynamic_dims(): x = torch.rand([5, 1024]) w = torch.rand([7, 19]) dynamic_tensor_specs = ( - DynamicTensorSpec(0, 0, [1, 3, 5], lambda x: x // 2), - DynamicTensorSpec(0, 1, [16, 24, 1024], lambda x: x // 2), + DynamicTensorSpec(0, 0, [1, 3, 5]), + DynamicTensorSpec(0, 1, [16, 24, 1024]), DynamicTensorSpec(1, 1, [3, 7, 9], lambda x: x // 2), ) profiles = tuner._optimization_profiles( tuning_config=TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs), inputs=[x, w]) + # choice(0, 0) * choice(0, 1) * choice(1, 1) + # 3 * 3 * 3 = 27, because 19 is mapped to 9 and already inside the bucket assert len(profiles) == 27 sample_0 = OptimizationProfile(shapes=[[ DynamicDim(min=1, opt=1, max=3), @@ -90,7 +92,7 @@ def check_gemm_tactic_valid(tactic: int, m: int) -> bool: class GemmRunner(TunableRunner): def get_valid_tactics(self, inputs: List[FakeTensor], - profile: OptimizationProfile) -> List[int]: + profile: OptimizationProfile, **kwargs) -> List[int]: # The simulated delay is not deterministic, so we need to return specific tactics here return [-1, 0, 1] @@ -98,25 +100,28 @@ def forward(self, /, inputs: List[torch.Tensor], *, - tactic: int = -1) -> torch.Tensor: + tactic: int = -1, + **kwargs) -> torch.Tensor: assert tactic in [-1, 0, 1] return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs) @torch.library.custom_op("autotuner_test::get_best_gemm_tactic", mutates_args=()) -def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - runners = [GemmRunner()] - tunner = AutoTuner.get() - tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( +@autotuner.tuning_config( + name="test_autotuner_get_best_gemm_tactic", + dynamic_tensor_specs=(DynamicTensorSpec( input_idx=0, dim_idx=0, gen_tuning_buckets=get_power_of_2_num_tokens_buckets, - map_to_tuning_buckets=next_positive_power_of_2), ), ) + map_to_tuning_buckets=next_positive_power_of_2), ), +) +def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + runners = [GemmRunner()] + tunner = AutoTuner.get() runner, tactic = tunner.choose_one( "autotuner_test::get_best_gemm_tactic", runners, - tuning_config, [x, w], ) return torch.tensor(tactic) @@ -173,12 +178,12 @@ def forward(self, map_to_tuning_buckets=next_positive_power_of_2), ), ) with autotune(): runner, tactic = tunner.choose_one("test_autotuner_try_block", runners, - tuning_config, [x, w]) + [x, w], tuning_config) m = M // 2 while m >= 1: _, tactic = tunner.choose_one("test_autotuner_try_block", runners, - tuning_config, [torch.randn(m, 64), w]) + [torch.randn(m, 64), w], tuning_config) assert tactic in [ -1, 0 ], f"Expect only tactic -1, 0 being chosen, but got tactic {tactic}." @@ -187,6 +192,14 @@ def forward(self, @torch.library.custom_op("autotuner_test::recursive_get_best_gemm_tactic", mutates_args=()) +@autotuner.tuning_config( + name="test_autotuner_recursive_get_best_gemm_tactic", + dynamic_tensor_specs=(DynamicTensorSpec( + input_idx=0, + dim_idx=0, + gen_tuning_buckets=get_power_of_2_num_tokens_buckets, + map_to_tuning_buckets=next_positive_power_of_2), ), +) def recursive_get_best_gemm_tactic(x: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor) -> torch.Tensor: # Only the first custom_op is tuned, the second one uses the tuned result in cache @@ -235,6 +248,14 @@ def forward(self, return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs) +@autotuner.tuning_config( + name="test_multiple_runners", + dynamic_tensor_specs=(DynamicTensorSpec( + input_idx=0, + dim_idx=0, + gen_tuning_buckets=get_power_of_2_num_tokens_buckets, + map_to_tuning_buckets=next_positive_power_of_2), ), +) def test_multiple_runners_different_attributes(): """Test that runners with different attributes get different cache entries""" x, w = torch.randn(16, 64), torch.randn(64, 128) @@ -244,39 +265,32 @@ def test_multiple_runners_different_attributes(): runner_1 = GemmRunnerWithAttributes(block_size=256, num_warps=8) runners = [runner_0, runner_1] - tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec( - input_idx=0, - dim_idx=0, - gen_tuning_buckets=get_power_of_2_num_tokens_buckets, - map_to_tuning_buckets=next_positive_power_of_2), ), ) - # Do tuning with autotune(): tuner = AutoTuner.get() runner_a, tactic_a = tuner.choose_one("test_multiple_runners", runners, - tuning_config, [x, w]) + [x, w]) # Verify different cache keys are generated shapes = (x.shape, w.shape) - cache_key_0 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_0, - tuning_config=tuning_config) - cache_key_1 = tuner._get_cache_key(custom_op="test_multiple_runners", - input_shapes=shapes, - runner=runner_1, - tuning_config=tuning_config) + cache_key_0 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_0, + tuning_config=tuner.get_tuning_config("test_multiple_runners")) + cache_key_1 = tuner._get_cache_key( + custom_op="test_multiple_runners", + input_shapes=shapes, + runner=runner_1, + tuning_config=tuner.get_tuning_config("test_multiple_runners")) assert cache_key_0 != cache_key_1, "Runners with different attributes should have different cache keys" -def test_multiple_dynamic_shapes_cache(): - """Test that different dynamic shape combinations are properly cached""" - w = torch.randn(64, 128) - runners = [GemmRunner()] - - # Define dynamic ranges for both dimensions - tuning_config = TuningConfig(dynamic_tensor_specs=( +# Define dynamic ranges for both dimensions +@autotuner.tuning_config( + name="test_multiple_dynamic_shapes", + dynamic_tensor_specs=( DynamicTensorSpec(input_idx=0, dim_idx=0, gen_tuning_buckets=(3, 4, 5), @@ -285,14 +299,19 @@ def test_multiple_dynamic_shapes_cache(): dim_idx=1, gen_tuning_buckets=(64, 128, 256, 512), map_to_tuning_buckets=lambda x: x), - ), ) + ), +) +def test_multiple_dynamic_shapes_cache(): + """Test that different dynamic shape combinations are properly cached""" + w = torch.randn(64, 128) + runners = [GemmRunner()] # Do tuning with a sample input x = torch.randn(3, 64) with autotune(): tuner = AutoTuner.get() runner, tactic = tuner.choose_one("test_multiple_dynamic_shapes", - runners, tuning_config, [x, w]) + runners, [x, w]) # Verify cache size - should have 12 entries (3x4 combinations) cache_entries = [ @@ -301,3 +320,39 @@ def test_multiple_dynamic_shapes_cache(): ] assert len(cache_entries) == 12, \ f"Expected 12 cache entries for 3x4 shape combinations, got {len(cache_entries)}" + + +class GemmRunnerWithTacticConfigs(TunableRunner): + + def get_valid_tactics(self, inputs: List[FakeTensor], + profile: OptimizationProfile, + block_size: int) -> List[int]: + # The simulated delay is not deterministic, so we need to return specific tactics here + return [-1, 0, 1] + + def forward(self, + /, + inputs: List[torch.Tensor], + *, + tactic: int = -1, + block_size: int = 128) -> torch.Tensor: + assert tactic in [-1, 0, 1] + return [gemm_0, gemm_1, gemm_fallback][tactic](*inputs) + + +@autotuner.tuning_config( + name="test_autotuner_tactic_configs", + configs={ + "block_size": [128, 256], + }, +) +def test_autotuner_tactic_configs(): + runner_0 = GemmRunnerWithTacticConfigs() + runners = [runner_0] + x, w = torch.randn(64, 64), torch.randn(64, 128) + with autotune(): + tuner = AutoTuner.get() + runner, tactic, configs = tuner.choose_one( + "test_autotuner_tactic_configs", runners, [x, w]) + + runner_0.forward(inputs=[x, w], tactic=tactic, **configs) diff --git a/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py b/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py index 09fb1a85420..f8b15a559ea 100644 --- a/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py +++ b/tests/unittest/_torch/thop/test_w4a8_mxfp4_mxfp8_gemm.py @@ -20,6 +20,7 @@ from utils.util import skip_pre_blackwell_unittest, unittest_name_func import tensorrt_llm +from tensorrt_llm._torch.autotuner import autotune class TestFunctional(unittest.TestCase): @@ -73,9 +74,10 @@ def random_noise(pos, mat_ref, mat): a_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(a_block_sf) b_block_sf = torch.ops.trtllm.nvfp4_block_scale_interleave(b_block_sf) - c = (torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(mat_a, mat_b, a_block_sf, - b_block_sf, a_sf, - torch.bfloat16)) + with autotune(): + c = (torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(mat_a, mat_b, a_block_sf, + b_block_sf, a_sf, + torch.bfloat16)) c_ref = (mat_a_ref @ mat_b_ref.T * a_sf).to(torch.bfloat16) assert torch.allclose(c_ref, c, atol=1e-2, rtol=1e-2)