diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 476afa928e2..799bf2f9f03 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -42,7 +42,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE"); TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float."); TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D."); - TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape."); + TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0], + "routing_logits and hidden_states must have the same number of tokens."); + TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); TORCH_CHECK( routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float, "routing_bias must be bfloat16 or float."); @@ -149,8 +151,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float."); TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D."); - TORCH_CHECK( - hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128, "hidden_states_scale has incorrect shape."); + TORCH_CHECK(hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128, + "hidden_states_scale dim0 must match hidden_states dim1 / 128."); + TORCH_CHECK(hidden_states_scale.sizes()[1] == args.num_tokens, "hidden_states_scale dim1 must match num_tokens."); TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8."); TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D."); TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even."); diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index b37c4e017f7..d7f484959f9 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -4,7 +4,8 @@ import torch -from tensorrt_llm._torch.utils import last_positive_power_of_2 +from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2) from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) @@ -123,8 +124,11 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - m_values = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096) - round_rule = lambda x: last_positive_power_of_2(x) + MAX_PROFILE_BUCKET = 4096 + + m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) + round_rule = lambda x: min(last_positive_power_of_2(x), + MAX_PROFILE_BUCKET) specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values, round_rule), ) @@ -133,7 +137,31 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: @classmethod def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]: - return () + + def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: + num_tokens = shapes[2][0] + + return num_tokens + + HS_SCALE_IDX = 3 + CONSTRAINED_HS_SCALE_DIM = 1 + + constraint_hidden_states_scale = ConstraintSpec( + HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, _constrain_to_num_tokens) + + ROUTER_LOGITS_IDX = 0 + CONSTRAINED_RL_DIM = 0 + + constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX, + CONSTRAINED_RL_DIM, + _constrain_to_num_tokens) + + constraint_specs_tuple = ( + constraint_hidden_states_scale, + constraint_routing_logits, + ) + + return constraint_specs_tuple @classmethod @lru_cache(maxsize=None)