Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
TORCH_CHECK(sm == 100, "Only SM100 is supported by FP8 block scale MOE");
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float, "routing_logits must be float.");
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
"routing_logits and hidden_states must have the same number of tokens.");
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits dim1 must match num_experts.");
TORCH_CHECK(
routing_bias.scalar_type() == at::ScalarType::BFloat16 || routing_bias.scalar_type() == at::ScalarType::Float,
"routing_bias must be bfloat16 or float.");
Expand Down Expand Up @@ -149,8 +151,9 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8.");
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float, "hidden_states_scale must be float.");
TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D.");
TORCH_CHECK(
hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128, "hidden_states_scale has incorrect shape.");
TORCH_CHECK(hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128,
"hidden_states_scale dim0 must match hidden_states dim1 / 128.");
TORCH_CHECK(hidden_states_scale.sizes()[1] == args.num_tokens, "hidden_states_scale dim1 must match num_tokens.");
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn, "gemm1_weights must be fp8.");
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
Expand Down
36 changes: 32 additions & 4 deletions tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from tensorrt_llm._torch.utils import last_positive_power_of_2
from tensorrt_llm._torch.utils import (get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)

from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
Expand Down Expand Up @@ -123,8 +124,11 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]:
HIDDEN_STATES_IDX = 2
TUNED_DIM = 0

m_values = (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096)
round_rule = lambda x: last_positive_power_of_2(x)
MAX_PROFILE_BUCKET = 4096

m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET)
round_rule = lambda x: min(last_positive_power_of_2(x),
MAX_PROFILE_BUCKET)

specs = (DynamicTensorSpec(HIDDEN_STATES_IDX, TUNED_DIM, m_values,
round_rule), )
Expand All @@ -133,7 +137,31 @@ def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]:

@classmethod
def get_constraint_specs(cls) -> Tuple[ConstraintSpec, ...]:
return ()

def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int:
num_tokens = shapes[2][0]

return num_tokens

HS_SCALE_IDX = 3
CONSTRAINED_HS_SCALE_DIM = 1

constraint_hidden_states_scale = ConstraintSpec(
HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, _constrain_to_num_tokens)

ROUTER_LOGITS_IDX = 0
CONSTRAINED_RL_DIM = 0

constraint_routing_logits = ConstraintSpec(ROUTER_LOGITS_IDX,
CONSTRAINED_RL_DIM,
_constrain_to_num_tokens)

constraint_specs_tuple = (
constraint_hidden_states_scale,
constraint_routing_logits,
)

return constraint_specs_tuple

@classmethod
@lru_cache(maxsize=None)
Expand Down